feat: 使用大模型美观输出,等待用户输入之后,大模型输出规范json,再反序列化存入数据库。
This commit is contained in:
parent
1c14be0c33
commit
88732e58c4
@ -21,6 +21,11 @@ from fastapi import APIRouter, HTTPException, Request
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.deps import get_checkpointer, get_store
|
from app.gateway.deps import get_checkpointer, get_store
|
||||||
|
from deerflow.agents.memory.thread_summary import (
|
||||||
|
ThreadMemoryConflictError,
|
||||||
|
apply_thread_memory_summary,
|
||||||
|
render_thread_memory_summary,
|
||||||
|
)
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.agents.memory.thread_storage import delete_thread_memory_data
|
from deerflow.agents.memory.thread_storage import delete_thread_memory_data
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values
|
||||||
@ -122,6 +127,27 @@ class ThreadHistoryRequest(BaseModel):
|
|||||||
before: str | None = Field(default=None, description="Cursor for pagination")
|
before: str | None = Field(default=None, description="Cursor for pagination")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMemorySummaryResponse(BaseModel):
|
||||||
|
threadId: str
|
||||||
|
memoryVersion: int
|
||||||
|
summary: str
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMemorySummaryUpdateRequest(BaseModel):
|
||||||
|
summary: str = Field(..., min_length=1, description="User-edited natural language memory summary")
|
||||||
|
memoryVersion: int = Field(..., ge=0, description="Expected memory version for CAS update")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMemoryRecordResponse(BaseModel):
|
||||||
|
threadId: str
|
||||||
|
ownerId: str | None = None
|
||||||
|
user: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
history: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
facts: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
memoryVersion: int = 0
|
||||||
|
lastUpdated: str = ""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -241,14 +267,21 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
|||||||
await checkpointer.adelete_thread(thread_id)
|
await checkpointer.adelete_thread(thread_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
|
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
|
||||||
try:
|
|
||||||
delete_thread_memory_data(thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not delete thread memory for thread %s (not critical)", thread_id)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}/memory", response_model=ThreadDeleteResponse)
|
||||||
|
async def delete_thread_memory(thread_id: str) -> ThreadDeleteResponse:
|
||||||
|
"""Delete per-thread memory only (explicit trigger)."""
|
||||||
|
try:
|
||||||
|
delete_thread_memory_data(thread_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to delete thread memory for %s", thread_id)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete thread memory.") from exc
|
||||||
|
return ThreadDeleteResponse(success=True, message=f"Deleted thread memory for {thread_id}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ThreadResponse)
|
@router.post("", response_model=ThreadResponse)
|
||||||
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
|
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
|
||||||
"""Create a new thread.
|
"""Create a new thread.
|
||||||
@ -685,3 +718,27 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
|||||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/memory-summary", response_model=ThreadMemorySummaryResponse)
|
||||||
|
async def get_thread_memory_summary(thread_id: str) -> ThreadMemorySummaryResponse:
|
||||||
|
"""Render per-thread memory as human-readable text for user inspection/editing."""
|
||||||
|
try:
|
||||||
|
payload = render_thread_memory_summary(thread_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to render thread memory summary for %s", thread_id)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to render thread memory summary.") from exc
|
||||||
|
return ThreadMemorySummaryResponse(**payload)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/memory-summary", response_model=ThreadMemoryRecordResponse)
|
||||||
|
async def update_thread_memory_summary(thread_id: str, body: ThreadMemorySummaryUpdateRequest) -> ThreadMemoryRecordResponse:
|
||||||
|
"""Apply edited natural-language summary back into structured thread memory."""
|
||||||
|
try:
|
||||||
|
payload = apply_thread_memory_summary(thread_id, body.summary, body.memoryVersion)
|
||||||
|
except ThreadMemoryConflictError as exc:
|
||||||
|
raise HTTPException(status_code=409, detail="Thread memory changed; refresh and retry.") from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to apply thread memory summary for %s", thread_id)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to apply thread memory summary.") from exc
|
||||||
|
return ThreadMemoryRecordResponse(**payload)
|
||||||
|
|||||||
@ -0,0 +1,300 @@
|
|||||||
|
"""Thread memory summary generation and application helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from deerflow.agents.memory.thread_prompt import create_empty_thread_memory
|
||||||
|
from deerflow.agents.memory.thread_storage import get_thread_memory_storage
|
||||||
|
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
|
||||||
|
from deerflow.agents.memory.updater import _extract_text
|
||||||
|
from deerflow.config.thread_memory_config import get_thread_memory_config
|
||||||
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SUMMARY_RENDER_PROMPT = """You are an assistant that renders thread memory into natural language.
|
||||||
|
|
||||||
|
Thread memory JSON:
|
||||||
|
<memory_json>
|
||||||
|
{memory_json}
|
||||||
|
</memory_json>
|
||||||
|
|
||||||
|
Task:
|
||||||
|
- Output a concise, human-friendly editable profile summary.
|
||||||
|
- Keep the original language of the memory content where possible.
|
||||||
|
- Cover user profile, history, and key facts.
|
||||||
|
- Return plain text only (no markdown code fences).
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUMMARY_PARSE_PROMPT = """You convert user-edited natural-language memory into a structured patch JSON.
|
||||||
|
|
||||||
|
Current thread memory JSON:
|
||||||
|
<current_memory_json>
|
||||||
|
{current_memory_json}
|
||||||
|
</current_memory_json>
|
||||||
|
|
||||||
|
Edited summary text:
|
||||||
|
<edited_summary>
|
||||||
|
{edited_summary}
|
||||||
|
</edited_summary>
|
||||||
|
|
||||||
|
Return JSON only with this schema (all fields optional):
|
||||||
|
{{
|
||||||
|
"user": {{
|
||||||
|
"workContext": {{"summary": string}},
|
||||||
|
"personalContext": {{"summary": string}},
|
||||||
|
"topOfMind": {{"summary": string}}
|
||||||
|
}},
|
||||||
|
"history": {{
|
||||||
|
"recentMonths": {{"summary": string}},
|
||||||
|
"earlierContext": {{"summary": string}},
|
||||||
|
"longTermBackground": {{"summary": string}}
|
||||||
|
}},
|
||||||
|
"facts": [
|
||||||
|
{{
|
||||||
|
"content": string,
|
||||||
|
"category": "preference"|"knowledge"|"context"|"behavior"|"goal"|"correction",
|
||||||
|
"confidence": number
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMemoryConflictError(RuntimeError):
|
||||||
|
"""Raised when compare-and-swap save fails due to version mismatch."""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_summary_model():
|
||||||
|
config = get_thread_memory_config()
|
||||||
|
return create_chat_model(name=config.model_name, thinking_enabled=False, stream_usage=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_code_fence(text: str) -> str:
|
||||||
|
cleaned = text.strip()
|
||||||
|
if not cleaned.startswith("```"):
|
||||||
|
return cleaned
|
||||||
|
lines = cleaned.split("\n")
|
||||||
|
return "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||||||
|
cleaned = _strip_code_fence(text)
|
||||||
|
try:
|
||||||
|
parsed = json.loads(cleaned)
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
repaired = _escape_inner_quotes_in_json_strings(cleaned)
|
||||||
|
if repaired != cleaned:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(repaired)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=full_text")
|
||||||
|
return parsed
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(match.group(0))
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
candidate = match.group(0)
|
||||||
|
repaired = _escape_inner_quotes_in_json_strings(candidate)
|
||||||
|
if repaired != candidate:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(repaired)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=regex_object")
|
||||||
|
return parsed
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _escape_inner_quotes_in_json_strings(text: str) -> str:
|
||||||
|
"""Heuristically repair unescaped inner double quotes inside JSON strings.
|
||||||
|
|
||||||
|
If a quote appears while inside a string but the next non-space character is
|
||||||
|
not a valid string terminator (comma, object/array close, or key colon), it is
|
||||||
|
treated as content and escaped.
|
||||||
|
"""
|
||||||
|
out: list[str] = []
|
||||||
|
in_string = False
|
||||||
|
escape = False
|
||||||
|
n = len(text)
|
||||||
|
i = 0
|
||||||
|
while i < n:
|
||||||
|
ch = text[i]
|
||||||
|
if not in_string:
|
||||||
|
out.append(ch)
|
||||||
|
if ch == '"':
|
||||||
|
in_string = True
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if escape:
|
||||||
|
out.append(ch)
|
||||||
|
escape = False
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ch == "\\":
|
||||||
|
out.append(ch)
|
||||||
|
escape = True
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ch == '"':
|
||||||
|
j = i + 1
|
||||||
|
while j < n and text[j].isspace():
|
||||||
|
j += 1
|
||||||
|
next_char = text[j] if j < n else ""
|
||||||
|
# Valid JSON string terminators in context:
|
||||||
|
# - key string: :
|
||||||
|
# - value string: , } ]
|
||||||
|
if next_char in {":", ",", "}", "]", ""}:
|
||||||
|
out.append(ch)
|
||||||
|
in_string = False
|
||||||
|
else:
|
||||||
|
out.append('\\"')
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
out.append(ch)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return "".join(out)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_summary_patch(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
merged = {"ownerId": base.get("ownerId"), **create_empty_thread_memory()}
|
||||||
|
merged["user"] = dict(base.get("user", {})) if isinstance(base.get("user"), dict) else merged["user"]
|
||||||
|
merged["history"] = dict(base.get("history", {})) if isinstance(base.get("history"), dict) else merged["history"]
|
||||||
|
merged["facts"] = list(base.get("facts", [])) if isinstance(base.get("facts"), list) else []
|
||||||
|
|
||||||
|
for section_name in ("user", "history"):
|
||||||
|
section_patch = patch.get(section_name, {})
|
||||||
|
if not isinstance(section_patch, dict):
|
||||||
|
continue
|
||||||
|
for key, value in section_patch.items():
|
||||||
|
if key not in merged[section_name] or not isinstance(value, dict):
|
||||||
|
continue
|
||||||
|
summary = value.get("summary")
|
||||||
|
if isinstance(summary, str):
|
||||||
|
merged[section_name][key]["summary"] = summary
|
||||||
|
|
||||||
|
facts_patch = patch.get("facts")
|
||||||
|
if isinstance(facts_patch, list):
|
||||||
|
merged["facts"] = facts_patch
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def render_thread_memory_summary(thread_id: str) -> dict[str, Any]:
|
||||||
|
storage = get_thread_memory_storage()
|
||||||
|
current = storage.load(thread_id)
|
||||||
|
memory = {"ownerId": None, **create_empty_thread_memory()} if current is None else current
|
||||||
|
memory_payload = {
|
||||||
|
"user": memory.get("user", {}),
|
||||||
|
"history": memory.get("history", {}),
|
||||||
|
"facts": memory.get("facts", []),
|
||||||
|
}
|
||||||
|
prompt = SUMMARY_RENDER_PROMPT.format(memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2))
|
||||||
|
response = _get_summary_model().invoke(prompt)
|
||||||
|
text = _strip_code_fence(_extract_text(response.content))
|
||||||
|
return {
|
||||||
|
"threadId": thread_id,
|
||||||
|
"memoryVersion": int(memory.get("memoryVersion", 0)),
|
||||||
|
"summary": text,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def apply_thread_memory_summary(thread_id: str, edited_summary: str, expected_version: int) -> dict[str, Any]:
|
||||||
|
storage = get_thread_memory_storage()
|
||||||
|
current = storage.load(thread_id)
|
||||||
|
base = {"ownerId": None, **create_empty_thread_memory()} if current is None else current
|
||||||
|
memory_payload = {
|
||||||
|
"user": base.get("user", {}),
|
||||||
|
"history": base.get("history", {}),
|
||||||
|
"facts": base.get("facts", []),
|
||||||
|
}
|
||||||
|
prompt = SUMMARY_PARSE_PROMPT.format(
|
||||||
|
current_memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2),
|
||||||
|
edited_summary=edited_summary,
|
||||||
|
)
|
||||||
|
response = _get_summary_model().invoke(prompt)
|
||||||
|
raw = _extract_text(response.content)
|
||||||
|
raw_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||||
|
logger.warning(
|
||||||
|
"THREAD_SUMMARY_DEBUG parse_raw_meta thread=%s raw_length=%d raw_sha256=%s",
|
||||||
|
thread_id,
|
||||||
|
len(raw),
|
||||||
|
raw_hash,
|
||||||
|
)
|
||||||
|
patch = _extract_json_object(raw)
|
||||||
|
if patch is None:
|
||||||
|
cleaned = _strip_code_fence(raw)
|
||||||
|
decode_error = None
|
||||||
|
try:
|
||||||
|
json.loads(cleaned)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
decode_error = exc
|
||||||
|
if decode_error is not None:
|
||||||
|
logger.warning(
|
||||||
|
"THREAD_SUMMARY_DEBUG parse_error thread=%s msg=%s line=%d col=%d pos=%d snippet=%r",
|
||||||
|
thread_id,
|
||||||
|
decode_error.msg,
|
||||||
|
decode_error.lineno,
|
||||||
|
decode_error.colno,
|
||||||
|
decode_error.pos,
|
||||||
|
cleaned[max(0, decode_error.pos - 80): decode_error.pos + 80],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"THREAD_SUMMARY_DEBUG parse_error thread=%s msg=no_json_object_extracted raw_head=%r",
|
||||||
|
thread_id,
|
||||||
|
cleaned[:200],
|
||||||
|
)
|
||||||
|
logger.warning("THREAD_SUMMARY_DEBUG parse_fallback thread=%s", thread_id)
|
||||||
|
patch = {
|
||||||
|
"user": {
|
||||||
|
"topOfMind": {
|
||||||
|
"summary": edited_summary.strip(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"THREAD_SUMMARY_DEBUG parse_success thread=%s patch=%s",
|
||||||
|
thread_id,
|
||||||
|
json.dumps(patch, ensure_ascii=False)[:2000],
|
||||||
|
)
|
||||||
|
merged = _merge_summary_patch(base, patch if isinstance(patch, dict) else {})
|
||||||
|
cleaned = ThreadMemoryUpdater()._scrub_sensitive(merged, thread_id)
|
||||||
|
cleaned["ownerId"] = base.get("ownerId")
|
||||||
|
logger.warning(
|
||||||
|
"THREAD_SUMMARY_DEBUG apply_cleaned thread=%s cleaned=%s",
|
||||||
|
thread_id,
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"user": cleaned.get("user", {}),
|
||||||
|
"history": cleaned.get("history", {}),
|
||||||
|
"facts_count": len(cleaned.get("facts", []) if isinstance(cleaned.get("facts"), list) else []),
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)[:2000],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not storage.save(thread_id, cleaned, expected_version=expected_version):
|
||||||
|
raise ThreadMemoryConflictError(f"Thread memory version conflict for {thread_id}")
|
||||||
|
|
||||||
|
latest = storage.load(thread_id)
|
||||||
|
return latest if latest is not None else {"threadId": thread_id, "memoryVersion": expected_version, **cleaned}
|
||||||
31
backend/tests/test_thread_memory_delete_trigger.py
Normal file
31
backend/tests/test_thread_memory_delete_trigger.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.gateway.routers import threads
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_thread_does_not_delete_thread_memory():
|
||||||
|
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace(checkpointer=None, store=None)))
|
||||||
|
with (
|
||||||
|
patch("app.gateway.routers.threads._delete_thread_data", return_value=threads.ThreadDeleteResponse(success=True, message="ok")),
|
||||||
|
patch("app.gateway.routers.threads.get_store", return_value=None),
|
||||||
|
patch("app.gateway.routers.threads.delete_thread_memory_data") as delete_memory,
|
||||||
|
):
|
||||||
|
response = await threads.delete_thread_data("thread-1", request)
|
||||||
|
|
||||||
|
assert response.success is True
|
||||||
|
delete_memory.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_thread_memory_endpoint_calls_cleanup():
|
||||||
|
with patch("app.gateway.routers.threads.delete_thread_memory_data") as delete_memory:
|
||||||
|
response = await threads.delete_thread_memory("thread-1")
|
||||||
|
|
||||||
|
assert response.success is True
|
||||||
|
assert response.message == "Deleted thread memory for thread-1"
|
||||||
|
delete_memory.assert_called_once_with("thread-1")
|
||||||
|
|
||||||
103
backend/tests/test_thread_memory_summary.py
Normal file
103
backend/tests/test_thread_memory_summary.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from deerflow.agents.memory.thread_summary import (
|
||||||
|
ThreadMemoryConflictError,
|
||||||
|
_extract_json_object,
|
||||||
|
apply_thread_memory_summary,
|
||||||
|
render_thread_memory_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_thread_memory_summary_returns_text():
|
||||||
|
fake_storage = type(
|
||||||
|
"S",
|
||||||
|
(),
|
||||||
|
{"load": lambda self, tid: {"threadId": tid, "user": {}, "history": {}, "facts": [], "memoryVersion": 2}},
|
||||||
|
)()
|
||||||
|
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "用户总结"})()})()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=fake_storage),
|
||||||
|
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
|
||||||
|
):
|
||||||
|
result = render_thread_memory_summary("t1")
|
||||||
|
|
||||||
|
assert result["threadId"] == "t1"
|
||||||
|
assert result["memoryVersion"] == 2
|
||||||
|
assert result["summary"] == "用户总结"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_thread_memory_summary_raises_conflict_on_cas_failure():
|
||||||
|
class _Storage:
|
||||||
|
def load(self, _tid):
|
||||||
|
return {"threadId": "t1", "ownerId": None, "user": {}, "history": {}, "facts": [], "memoryVersion": 1}
|
||||||
|
|
||||||
|
def save(self, _tid, _data, expected_version=None):
|
||||||
|
return False
|
||||||
|
|
||||||
|
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "{}"})()})()
|
||||||
|
fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=_Storage()),
|
||||||
|
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
|
||||||
|
patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater),
|
||||||
|
):
|
||||||
|
with pytest.raises(ThreadMemoryConflictError):
|
||||||
|
apply_thread_memory_summary("t1", "更新内容", 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_thread_memory_summary_falls_back_when_model_output_is_not_json():
|
||||||
|
class _Storage:
|
||||||
|
def __init__(self):
|
||||||
|
self.saved = None
|
||||||
|
|
||||||
|
def load(self, _tid):
|
||||||
|
if self.saved is not None:
|
||||||
|
return {"threadId": "t1", "memoryVersion": 2, **self.saved}
|
||||||
|
return {
|
||||||
|
"threadId": "t1",
|
||||||
|
"ownerId": None,
|
||||||
|
"user": {"topOfMind": {"summary": ""}},
|
||||||
|
"history": {},
|
||||||
|
"facts": [],
|
||||||
|
"memoryVersion": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
def save(self, _tid, data, expected_version=None):
|
||||||
|
self.saved = data
|
||||||
|
return True
|
||||||
|
|
||||||
|
storage = _Storage()
|
||||||
|
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "这是自然语言,不是JSON"})()})()
|
||||||
|
fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=storage),
|
||||||
|
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
|
||||||
|
patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater),
|
||||||
|
):
|
||||||
|
result = apply_thread_memory_summary("t1", "我最近在做线程记忆功能", 1)
|
||||||
|
|
||||||
|
assert storage.saved is not None
|
||||||
|
assert storage.saved["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能"
|
||||||
|
assert result["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_json_object_repairs_inner_unescaped_quotes():
|
||||||
|
raw = """
|
||||||
|
{
|
||||||
|
"user": {
|
||||||
|
"topOfMind": {
|
||||||
|
"summary": "反感“作为 AI"这种句式,认为回答不用寒暄直接说重点。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"history": {},
|
||||||
|
"facts": []
|
||||||
|
}
|
||||||
|
""".strip()
|
||||||
|
parsed = _extract_json_object(raw)
|
||||||
|
assert parsed is not None
|
||||||
|
assert parsed["user"]["topOfMind"]["summary"].startswith("反感“作为 AI")
|
||||||
@ -27,6 +27,7 @@ import { IframeTestPanel } from "@/components/workspace/iframe-test-panel";
|
|||||||
import { InputBox } from "@/components/workspace/input-box";
|
import { InputBox } from "@/components/workspace/input-box";
|
||||||
import { MessageList } from "@/components/workspace/messages";
|
import { MessageList } from "@/components/workspace/messages";
|
||||||
import { ThreadContext } from "@/components/workspace/messages/context";
|
import { ThreadContext } from "@/components/workspace/messages/context";
|
||||||
|
import { ThreadMemoryTestPanel } from "@/components/workspace/thread-memory-test-panel";
|
||||||
import { ThreadTitle } from "@/components/workspace/thread-title";
|
import { ThreadTitle } from "@/components/workspace/thread-title";
|
||||||
import { Tooltip } from "@/components/workspace/tooltip";
|
import { Tooltip } from "@/components/workspace/tooltip";
|
||||||
import { useSpecificChatMode } from "@/components/workspace/use-chat-mode";
|
import { useSpecificChatMode } from "@/components/workspace/use-chat-mode";
|
||||||
@ -705,6 +706,7 @@ export default function ChatPage() {
|
|||||||
|
|
||||||
{/* MARK: 开发测试:iframe 通信功能测试面板 */}
|
{/* MARK: 开发测试:iframe 通信功能测试面板 */}
|
||||||
{/* {process.env.NODE_ENV !== "production" && <IframeTestPanel />} */}
|
{/* {process.env.NODE_ENV !== "production" && <IframeTestPanel />} */}
|
||||||
|
{/* <ThreadMemoryTestPanel threadId={threadId} /> */}
|
||||||
</div>
|
</div>
|
||||||
</ThreadContext.Provider>
|
</ThreadContext.Provider>
|
||||||
);
|
);
|
||||||
|
|||||||
142
frontend/src/components/workspace/thread-memory-test-panel.tsx
Normal file
142
frontend/src/components/workspace/thread-memory-test-panel.tsx
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Textarea } from "@/components/ui/textarea";
|
||||||
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
|
|
||||||
|
type ThreadMemoryTestPanelProps = {
|
||||||
|
threadId?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function ThreadMemoryTestPanel({ threadId }: ThreadMemoryTestPanelProps) {
|
||||||
|
const [memorySummary, setMemorySummary] = useState("");
|
||||||
|
const [memoryVersion, setMemoryVersion] = useState<number | null>(null);
|
||||||
|
const [loadingSummary, setLoadingSummary] = useState(false);
|
||||||
|
const [savingSummary, setSavingSummary] = useState(false);
|
||||||
|
const [deletingMemory, setDeletingMemory] = useState(false);
|
||||||
|
const [open, setOpen] = useState(true);
|
||||||
|
|
||||||
|
if (!threadId || threadId === "new") return null;
|
||||||
|
|
||||||
|
const handleLoadMemorySummary = async () => {
|
||||||
|
setLoadingSummary(true);
|
||||||
|
try {
|
||||||
|
const res = await fetch(
|
||||||
|
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/memory-summary`,
|
||||||
|
);
|
||||||
|
if (!res.ok) throw new Error(`HTTP ${res.status}`);
|
||||||
|
const data = (await res.json()) as { summary: string; memoryVersion: number };
|
||||||
|
setMemorySummary(data.summary ?? "");
|
||||||
|
setMemoryVersion(data.memoryVersion ?? 0);
|
||||||
|
toast.success("已加载会话记忆");
|
||||||
|
} catch {
|
||||||
|
toast.error("加载会话记忆失败");
|
||||||
|
} finally {
|
||||||
|
setLoadingSummary(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSaveMemorySummary = async () => {
|
||||||
|
if (memoryVersion == null) return;
|
||||||
|
setSavingSummary(true);
|
||||||
|
try {
|
||||||
|
const res = await fetch(
|
||||||
|
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/memory-summary`,
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({ summary: memorySummary, memoryVersion }),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (res.status === 409) {
|
||||||
|
toast.error("记忆已更新,请先重新加载再保存");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!res.ok) throw new Error(`HTTP ${res.status}`);
|
||||||
|
const data = (await res.json()) as { memoryVersion?: number };
|
||||||
|
if (typeof data.memoryVersion === "number") setMemoryVersion(data.memoryVersion);
|
||||||
|
toast.success("会话记忆已保存");
|
||||||
|
} catch {
|
||||||
|
toast.error("保存会话记忆失败");
|
||||||
|
} finally {
|
||||||
|
setSavingSummary(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDeleteMemory = async () => {
|
||||||
|
setDeletingMemory(true);
|
||||||
|
try {
|
||||||
|
const res = await fetch(
|
||||||
|
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/memory`,
|
||||||
|
{ method: "DELETE" },
|
||||||
|
);
|
||||||
|
if (!res.ok) throw new Error(`HTTP ${res.status}`);
|
||||||
|
setMemorySummary("");
|
||||||
|
setMemoryVersion(0);
|
||||||
|
toast.success("当前会话记忆已删除");
|
||||||
|
} catch {
|
||||||
|
toast.error("删除会话记忆失败");
|
||||||
|
} finally {
|
||||||
|
setDeletingMemory(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="fixed right-4 bottom-4 z-50 w-[380px] rounded-lg border border-ws-divider bg-ws-surface-elevated p-3 shadow-lg">
|
||||||
|
<div className="mb-2 flex items-center justify-between">
|
||||||
|
<div className="text-sm font-semibold">Thread Memory TestPanel</div>
|
||||||
|
<Button size="sm" variant="ghost" onClick={() => setOpen((v) => !v)}>
|
||||||
|
{open ? "收起" : "展开"}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{open && (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => {
|
||||||
|
void handleLoadMemorySummary();
|
||||||
|
}}
|
||||||
|
disabled={loadingSummary}
|
||||||
|
>
|
||||||
|
{loadingSummary ? "加载中..." : "查看记忆"}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
onClick={() => {
|
||||||
|
void handleSaveMemorySummary();
|
||||||
|
}}
|
||||||
|
disabled={savingSummary || memoryVersion == null}
|
||||||
|
>
|
||||||
|
{savingSummary ? "保存中..." : "保存记忆"}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="destructive"
|
||||||
|
onClick={() => {
|
||||||
|
void handleDeleteMemory();
|
||||||
|
}}
|
||||||
|
disabled={deletingMemory}
|
||||||
|
>
|
||||||
|
{deletingMemory ? "删除中..." : "删除记忆"}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-ws-text-subtle-strong">
|
||||||
|
threadId: {threadId.slice(0, 8)}... | version:{" "}
|
||||||
|
{memoryVersion == null ? "-" : memoryVersion}
|
||||||
|
</div>
|
||||||
|
<Textarea
|
||||||
|
value={memorySummary}
|
||||||
|
onChange={(e) => setMemorySummary(e.target.value)}
|
||||||
|
placeholder="这里显示会话记忆总结,可编辑后保存"
|
||||||
|
className="min-h-32 bg-white/80"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user