diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index dc891a65..a4e92974 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -21,6 +21,11 @@ from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field from app.gateway.deps import get_checkpointer, get_store +from deerflow.agents.memory.thread_summary import ( + ThreadMemoryConflictError, + apply_thread_memory_summary, + render_thread_memory_summary, +) from deerflow.config.paths import Paths, get_paths from deerflow.agents.memory.thread_storage import delete_thread_memory_data from deerflow.runtime import serialize_channel_values @@ -122,6 +127,27 @@ class ThreadHistoryRequest(BaseModel): before: str | None = Field(default=None, description="Cursor for pagination") +class ThreadMemorySummaryResponse(BaseModel): + threadId: str + memoryVersion: int + summary: str + + +class ThreadMemorySummaryUpdateRequest(BaseModel): + summary: str = Field(..., min_length=1, description="User-edited natural language memory summary") + memoryVersion: int = Field(..., ge=0, description="Expected memory version for CAS update") + + +class ThreadMemoryRecordResponse(BaseModel): + threadId: str + ownerId: str | None = None + user: dict[str, Any] = Field(default_factory=dict) + history: dict[str, Any] = Field(default_factory=dict) + facts: list[dict[str, Any]] = Field(default_factory=list) + memoryVersion: int = 0 + lastUpdated: str = "" + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -241,14 +267,21 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe await checkpointer.adelete_thread(thread_id) except Exception: logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id) - try: - delete_thread_memory_data(thread_id) - except Exception: - logger.debug("Could not delete thread memory for thread %s (not critical)", thread_id) return response +@router.delete("/{thread_id}/memory", response_model=ThreadDeleteResponse) +async def delete_thread_memory(thread_id: str) -> ThreadDeleteResponse: + """Delete per-thread memory only (explicit trigger).""" + try: + delete_thread_memory_data(thread_id) + except Exception as exc: + logger.exception("Failed to delete thread memory for %s", thread_id) + raise HTTPException(status_code=500, detail="Failed to delete thread memory.") from exc + return ThreadDeleteResponse(success=True, message=f"Deleted thread memory for {thread_id}") + + @router.post("", response_model=ThreadResponse) async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse: """Create a new thread. @@ -685,3 +718,27 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request raise HTTPException(status_code=500, detail="Failed to get thread history") return entries + + +@router.get("/{thread_id}/memory-summary", response_model=ThreadMemorySummaryResponse) +async def get_thread_memory_summary(thread_id: str) -> ThreadMemorySummaryResponse: + """Render per-thread memory as human-readable text for user inspection/editing.""" + try: + payload = render_thread_memory_summary(thread_id) + except Exception as exc: + logger.exception("Failed to render thread memory summary for %s", thread_id) + raise HTTPException(status_code=500, detail="Failed to render thread memory summary.") from exc + return ThreadMemorySummaryResponse(**payload) + + +@router.post("/{thread_id}/memory-summary", response_model=ThreadMemoryRecordResponse) +async def update_thread_memory_summary(thread_id: str, body: ThreadMemorySummaryUpdateRequest) -> ThreadMemoryRecordResponse: + """Apply edited natural-language summary back into structured thread memory.""" + try: + payload = apply_thread_memory_summary(thread_id, body.summary, body.memoryVersion) + except ThreadMemoryConflictError as exc: + raise HTTPException(status_code=409, detail="Thread memory changed; refresh and retry.") from exc + except Exception as exc: + logger.exception("Failed to apply thread memory summary for %s", thread_id) + raise HTTPException(status_code=500, detail="Failed to apply thread memory summary.") from exc + return ThreadMemoryRecordResponse(**payload) diff --git a/backend/packages/harness/deerflow/agents/memory/thread_summary.py b/backend/packages/harness/deerflow/agents/memory/thread_summary.py new file mode 100644 index 00000000..a90ac12d --- /dev/null +++ b/backend/packages/harness/deerflow/agents/memory/thread_summary.py @@ -0,0 +1,300 @@ +"""Thread memory summary generation and application helpers.""" + +from __future__ import annotations + +import json +import logging +import re +import hashlib +from typing import Any + +from deerflow.agents.memory.thread_prompt import create_empty_thread_memory +from deerflow.agents.memory.thread_storage import get_thread_memory_storage +from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater +from deerflow.agents.memory.updater import _extract_text +from deerflow.config.thread_memory_config import get_thread_memory_config +from deerflow.models import create_chat_model + +logger = logging.getLogger(__name__) + +SUMMARY_RENDER_PROMPT = """You are an assistant that renders thread memory into natural language. + +Thread memory JSON: + +{memory_json} + + +Task: +- Output a concise, human-friendly editable profile summary. +- Keep the original language of the memory content where possible. +- Cover user profile, history, and key facts. +- Return plain text only (no markdown code fences). +""" + +SUMMARY_PARSE_PROMPT = """You convert user-edited natural-language memory into a structured patch JSON. + +Current thread memory JSON: + +{current_memory_json} + + +Edited summary text: + +{edited_summary} + + +Return JSON only with this schema (all fields optional): +{{ + "user": {{ + "workContext": {{"summary": string}}, + "personalContext": {{"summary": string}}, + "topOfMind": {{"summary": string}} + }}, + "history": {{ + "recentMonths": {{"summary": string}}, + "earlierContext": {{"summary": string}}, + "longTermBackground": {{"summary": string}} + }}, + "facts": [ + {{ + "content": string, + "category": "preference"|"knowledge"|"context"|"behavior"|"goal"|"correction", + "confidence": number + }} + ] +}} +""" + + +class ThreadMemoryConflictError(RuntimeError): + """Raised when compare-and-swap save fails due to version mismatch.""" + + +def _get_summary_model(): + config = get_thread_memory_config() + return create_chat_model(name=config.model_name, thinking_enabled=False, stream_usage=False) + + +def _strip_code_fence(text: str) -> str: + cleaned = text.strip() + if not cleaned.startswith("```"): + return cleaned + lines = cleaned.split("\n") + return "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]).strip() + + +def _extract_json_object(text: str) -> dict[str, Any] | None: + cleaned = _strip_code_fence(text) + try: + parsed = json.loads(cleaned) + return parsed if isinstance(parsed, dict) else None + except json.JSONDecodeError: + repaired = _escape_inner_quotes_in_json_strings(cleaned) + if repaired != cleaned: + try: + parsed = json.loads(repaired) + if isinstance(parsed, dict): + logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=full_text") + return parsed + except json.JSONDecodeError: + pass + + match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL) + if not match: + return None + try: + parsed = json.loads(match.group(0)) + return parsed if isinstance(parsed, dict) else None + except json.JSONDecodeError: + candidate = match.group(0) + repaired = _escape_inner_quotes_in_json_strings(candidate) + if repaired != candidate: + try: + parsed = json.loads(repaired) + if isinstance(parsed, dict): + logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=regex_object") + return parsed + except json.JSONDecodeError: + return None + return None + + +def _escape_inner_quotes_in_json_strings(text: str) -> str: + """Heuristically repair unescaped inner double quotes inside JSON strings. + + If a quote appears while inside a string but the next non-space character is + not a valid string terminator (comma, object/array close, or key colon), it is + treated as content and escaped. + """ + out: list[str] = [] + in_string = False + escape = False + n = len(text) + i = 0 + while i < n: + ch = text[i] + if not in_string: + out.append(ch) + if ch == '"': + in_string = True + i += 1 + continue + + if escape: + out.append(ch) + escape = False + i += 1 + continue + + if ch == "\\": + out.append(ch) + escape = True + i += 1 + continue + + if ch == '"': + j = i + 1 + while j < n and text[j].isspace(): + j += 1 + next_char = text[j] if j < n else "" + # Valid JSON string terminators in context: + # - key string: : + # - value string: , } ] + if next_char in {":", ",", "}", "]", ""}: + out.append(ch) + in_string = False + else: + out.append('\\"') + i += 1 + continue + + out.append(ch) + i += 1 + + return "".join(out) + + +def _merge_summary_patch(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: + merged = {"ownerId": base.get("ownerId"), **create_empty_thread_memory()} + merged["user"] = dict(base.get("user", {})) if isinstance(base.get("user"), dict) else merged["user"] + merged["history"] = dict(base.get("history", {})) if isinstance(base.get("history"), dict) else merged["history"] + merged["facts"] = list(base.get("facts", [])) if isinstance(base.get("facts"), list) else [] + + for section_name in ("user", "history"): + section_patch = patch.get(section_name, {}) + if not isinstance(section_patch, dict): + continue + for key, value in section_patch.items(): + if key not in merged[section_name] or not isinstance(value, dict): + continue + summary = value.get("summary") + if isinstance(summary, str): + merged[section_name][key]["summary"] = summary + + facts_patch = patch.get("facts") + if isinstance(facts_patch, list): + merged["facts"] = facts_patch + + return merged + + +def render_thread_memory_summary(thread_id: str) -> dict[str, Any]: + storage = get_thread_memory_storage() + current = storage.load(thread_id) + memory = {"ownerId": None, **create_empty_thread_memory()} if current is None else current + memory_payload = { + "user": memory.get("user", {}), + "history": memory.get("history", {}), + "facts": memory.get("facts", []), + } + prompt = SUMMARY_RENDER_PROMPT.format(memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2)) + response = _get_summary_model().invoke(prompt) + text = _strip_code_fence(_extract_text(response.content)) + return { + "threadId": thread_id, + "memoryVersion": int(memory.get("memoryVersion", 0)), + "summary": text, + } + + +def apply_thread_memory_summary(thread_id: str, edited_summary: str, expected_version: int) -> dict[str, Any]: + storage = get_thread_memory_storage() + current = storage.load(thread_id) + base = {"ownerId": None, **create_empty_thread_memory()} if current is None else current + memory_payload = { + "user": base.get("user", {}), + "history": base.get("history", {}), + "facts": base.get("facts", []), + } + prompt = SUMMARY_PARSE_PROMPT.format( + current_memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2), + edited_summary=edited_summary, + ) + response = _get_summary_model().invoke(prompt) + raw = _extract_text(response.content) + raw_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest() + logger.warning( + "THREAD_SUMMARY_DEBUG parse_raw_meta thread=%s raw_length=%d raw_sha256=%s", + thread_id, + len(raw), + raw_hash, + ) + patch = _extract_json_object(raw) + if patch is None: + cleaned = _strip_code_fence(raw) + decode_error = None + try: + json.loads(cleaned) + except json.JSONDecodeError as exc: + decode_error = exc + if decode_error is not None: + logger.warning( + "THREAD_SUMMARY_DEBUG parse_error thread=%s msg=%s line=%d col=%d pos=%d snippet=%r", + thread_id, + decode_error.msg, + decode_error.lineno, + decode_error.colno, + decode_error.pos, + cleaned[max(0, decode_error.pos - 80): decode_error.pos + 80], + ) + else: + logger.warning( + "THREAD_SUMMARY_DEBUG parse_error thread=%s msg=no_json_object_extracted raw_head=%r", + thread_id, + cleaned[:200], + ) + logger.warning("THREAD_SUMMARY_DEBUG parse_fallback thread=%s", thread_id) + patch = { + "user": { + "topOfMind": { + "summary": edited_summary.strip(), + } + } + } + else: + logger.warning( + "THREAD_SUMMARY_DEBUG parse_success thread=%s patch=%s", + thread_id, + json.dumps(patch, ensure_ascii=False)[:2000], + ) + merged = _merge_summary_patch(base, patch if isinstance(patch, dict) else {}) + cleaned = ThreadMemoryUpdater()._scrub_sensitive(merged, thread_id) + cleaned["ownerId"] = base.get("ownerId") + logger.warning( + "THREAD_SUMMARY_DEBUG apply_cleaned thread=%s cleaned=%s", + thread_id, + json.dumps( + { + "user": cleaned.get("user", {}), + "history": cleaned.get("history", {}), + "facts_count": len(cleaned.get("facts", []) if isinstance(cleaned.get("facts"), list) else []), + }, + ensure_ascii=False, + )[:2000], + ) + + if not storage.save(thread_id, cleaned, expected_version=expected_version): + raise ThreadMemoryConflictError(f"Thread memory version conflict for {thread_id}") + + latest = storage.load(thread_id) + return latest if latest is not None else {"threadId": thread_id, "memoryVersion": expected_version, **cleaned} diff --git a/backend/tests/test_thread_memory_delete_trigger.py b/backend/tests/test_thread_memory_delete_trigger.py new file mode 100644 index 00000000..510fa2bd --- /dev/null +++ b/backend/tests/test_thread_memory_delete_trigger.py @@ -0,0 +1,31 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from app.gateway.routers import threads + + +@pytest.mark.anyio +async def test_delete_thread_does_not_delete_thread_memory(): + request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace(checkpointer=None, store=None))) + with ( + patch("app.gateway.routers.threads._delete_thread_data", return_value=threads.ThreadDeleteResponse(success=True, message="ok")), + patch("app.gateway.routers.threads.get_store", return_value=None), + patch("app.gateway.routers.threads.delete_thread_memory_data") as delete_memory, + ): + response = await threads.delete_thread_data("thread-1", request) + + assert response.success is True + delete_memory.assert_not_called() + + +@pytest.mark.anyio +async def test_delete_thread_memory_endpoint_calls_cleanup(): + with patch("app.gateway.routers.threads.delete_thread_memory_data") as delete_memory: + response = await threads.delete_thread_memory("thread-1") + + assert response.success is True + assert response.message == "Deleted thread memory for thread-1" + delete_memory.assert_called_once_with("thread-1") + diff --git a/backend/tests/test_thread_memory_summary.py b/backend/tests/test_thread_memory_summary.py new file mode 100644 index 00000000..fcf62a6d --- /dev/null +++ b/backend/tests/test_thread_memory_summary.py @@ -0,0 +1,103 @@ +from unittest.mock import patch + +import pytest + +from deerflow.agents.memory.thread_summary import ( + ThreadMemoryConflictError, + _extract_json_object, + apply_thread_memory_summary, + render_thread_memory_summary, +) + + +def test_render_thread_memory_summary_returns_text(): + fake_storage = type( + "S", + (), + {"load": lambda self, tid: {"threadId": tid, "user": {}, "history": {}, "facts": [], "memoryVersion": 2}}, + )() + fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "用户总结"})()})() + + with ( + patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=fake_storage), + patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model), + ): + result = render_thread_memory_summary("t1") + + assert result["threadId"] == "t1" + assert result["memoryVersion"] == 2 + assert result["summary"] == "用户总结" + + +def test_apply_thread_memory_summary_raises_conflict_on_cas_failure(): + class _Storage: + def load(self, _tid): + return {"threadId": "t1", "ownerId": None, "user": {}, "history": {}, "facts": [], "memoryVersion": 1} + + def save(self, _tid, _data, expected_version=None): + return False + + fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "{}"})()})() + fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})() + + with ( + patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=_Storage()), + patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model), + patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater), + ): + with pytest.raises(ThreadMemoryConflictError): + apply_thread_memory_summary("t1", "更新内容", 1) + + +def test_apply_thread_memory_summary_falls_back_when_model_output_is_not_json(): + class _Storage: + def __init__(self): + self.saved = None + + def load(self, _tid): + if self.saved is not None: + return {"threadId": "t1", "memoryVersion": 2, **self.saved} + return { + "threadId": "t1", + "ownerId": None, + "user": {"topOfMind": {"summary": ""}}, + "history": {}, + "facts": [], + "memoryVersion": 1, + } + + def save(self, _tid, data, expected_version=None): + self.saved = data + return True + + storage = _Storage() + fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "这是自然语言,不是JSON"})()})() + fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})() + + with ( + patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=storage), + patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model), + patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater), + ): + result = apply_thread_memory_summary("t1", "我最近在做线程记忆功能", 1) + + assert storage.saved is not None + assert storage.saved["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能" + assert result["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能" + + +def test_extract_json_object_repairs_inner_unescaped_quotes(): + raw = """ +{ + "user": { + "topOfMind": { + "summary": "反感“作为 AI"这种句式,认为回答不用寒暄直接说重点。" + } + }, + "history": {}, + "facts": [] +} +""".strip() + parsed = _extract_json_object(raw) + assert parsed is not None + assert parsed["user"]["topOfMind"]["summary"].startswith("反感“作为 AI") diff --git a/frontend/src/app/workspace/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/chats/[thread_id]/page.tsx index d57ead93..a6cbdb3d 100644 --- a/frontend/src/app/workspace/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/chats/[thread_id]/page.tsx @@ -27,6 +27,7 @@ import { IframeTestPanel } from "@/components/workspace/iframe-test-panel"; import { InputBox } from "@/components/workspace/input-box"; import { MessageList } from "@/components/workspace/messages"; import { ThreadContext } from "@/components/workspace/messages/context"; +import { ThreadMemoryTestPanel } from "@/components/workspace/thread-memory-test-panel"; import { ThreadTitle } from "@/components/workspace/thread-title"; import { Tooltip } from "@/components/workspace/tooltip"; import { useSpecificChatMode } from "@/components/workspace/use-chat-mode"; @@ -705,6 +706,7 @@ export default function ChatPage() { {/* MARK: 开发测试:iframe 通信功能测试面板 */} {/* {process.env.NODE_ENV !== "production" && } */} + {/* */} ); diff --git a/frontend/src/components/workspace/thread-memory-test-panel.tsx b/frontend/src/components/workspace/thread-memory-test-panel.tsx new file mode 100644 index 00000000..ea3fd11e --- /dev/null +++ b/frontend/src/components/workspace/thread-memory-test-panel.tsx @@ -0,0 +1,142 @@ +"use client"; + +import { useState } from "react"; +import { toast } from "sonner"; + +import { Button } from "@/components/ui/button"; +import { Textarea } from "@/components/ui/textarea"; +import { getBackendBaseURL } from "@/core/config"; + +type ThreadMemoryTestPanelProps = { + threadId?: string; +}; + +export function ThreadMemoryTestPanel({ threadId }: ThreadMemoryTestPanelProps) { + const [memorySummary, setMemorySummary] = useState(""); + const [memoryVersion, setMemoryVersion] = useState(null); + const [loadingSummary, setLoadingSummary] = useState(false); + const [savingSummary, setSavingSummary] = useState(false); + const [deletingMemory, setDeletingMemory] = useState(false); + const [open, setOpen] = useState(true); + + if (!threadId || threadId === "new") return null; + + const handleLoadMemorySummary = async () => { + setLoadingSummary(true); + try { + const res = await fetch( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/memory-summary`, + ); + if (!res.ok) throw new Error(`HTTP ${res.status}`); + const data = (await res.json()) as { summary: string; memoryVersion: number }; + setMemorySummary(data.summary ?? ""); + setMemoryVersion(data.memoryVersion ?? 0); + toast.success("已加载会话记忆"); + } catch { + toast.error("加载会话记忆失败"); + } finally { + setLoadingSummary(false); + } + }; + + const handleSaveMemorySummary = async () => { + if (memoryVersion == null) return; + setSavingSummary(true); + try { + const res = await fetch( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/memory-summary`, + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ summary: memorySummary, memoryVersion }), + }, + ); + if (res.status === 409) { + toast.error("记忆已更新,请先重新加载再保存"); + return; + } + if (!res.ok) throw new Error(`HTTP ${res.status}`); + const data = (await res.json()) as { memoryVersion?: number }; + if (typeof data.memoryVersion === "number") setMemoryVersion(data.memoryVersion); + toast.success("会话记忆已保存"); + } catch { + toast.error("保存会话记忆失败"); + } finally { + setSavingSummary(false); + } + }; + + const handleDeleteMemory = async () => { + setDeletingMemory(true); + try { + const res = await fetch( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/memory`, + { method: "DELETE" }, + ); + if (!res.ok) throw new Error(`HTTP ${res.status}`); + setMemorySummary(""); + setMemoryVersion(0); + toast.success("当前会话记忆已删除"); + } catch { + toast.error("删除会话记忆失败"); + } finally { + setDeletingMemory(false); + } + }; + + return ( + + + Thread Memory TestPanel + setOpen((v) => !v)}> + {open ? "收起" : "展开"} + + + {open && ( + + + { + void handleLoadMemorySummary(); + }} + disabled={loadingSummary} + > + {loadingSummary ? "加载中..." : "查看记忆"} + + { + void handleSaveMemorySummary(); + }} + disabled={savingSummary || memoryVersion == null} + > + {savingSummary ? "保存中..." : "保存记忆"} + + { + void handleDeleteMemory(); + }} + disabled={deletingMemory} + > + {deletingMemory ? "删除中..." : "删除记忆"} + + + + threadId: {threadId.slice(0, 8)}... | version:{" "} + {memoryVersion == null ? "-" : memoryVersion} + + setMemorySummary(e.target.value)} + placeholder="这里显示会话记忆总结,可编辑后保存" + className="min-h-32 bg-white/80" + /> + + )} + + ); +}