From b49e838980cae8731329e80a7e7561d01070753e Mon Sep 17 00:00:00 2001 From: MT-Mint <798521692@qq.com> Date: Fri, 8 May 2026 10:19:09 +0800 Subject: [PATCH] =?UTF-8?q?feat:json=E4=BC=9A=E8=AF=9D=E8=AE=B0=E5=BF=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/gateway/routers/threads.py | 5 + .../deerflow/agents/lead_agent/prompt.py | 25 + .../deerflow/agents/memory/thread_prompt.py | 129 +++ .../deerflow/agents/memory/thread_queue.py | 76 ++ .../deerflow/agents/memory/thread_storage.py | 246 ++++++ .../deerflow/agents/memory/thread_updater.py | 132 +++ .../agents/middlewares/memory_middleware.py | 29 +- .../harness/deerflow/config/__init__.py | 3 + .../harness/deerflow/config/app_config.py | 4 + .../deerflow/config/thread_memory_config.py | 50 ++ .../harness/deerflow/models/factory.py | 26 +- .../tests/test_thread_memory_middleware.py | 32 + backend/tests/test_thread_memory_prompt.py | 28 + backend/tests/test_thread_memory_storage.py | 29 + docs/per-thread-memory-design-brainstorm.md | 760 ++++++++++++++++++ docs/thread-memory-manual-test-checklist.md | 213 +++++ 16 files changed, 1767 insertions(+), 20 deletions(-) create mode 100644 backend/packages/harness/deerflow/agents/memory/thread_prompt.py create mode 100644 backend/packages/harness/deerflow/agents/memory/thread_queue.py create mode 100644 backend/packages/harness/deerflow/agents/memory/thread_storage.py create mode 100644 backend/packages/harness/deerflow/agents/memory/thread_updater.py create mode 100644 backend/packages/harness/deerflow/config/thread_memory_config.py create mode 100644 backend/tests/test_thread_memory_middleware.py create mode 100644 backend/tests/test_thread_memory_prompt.py create mode 100644 backend/tests/test_thread_memory_storage.py create mode 100644 docs/per-thread-memory-design-brainstorm.md create mode 100644 docs/thread-memory-manual-test-checklist.md diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index 80860498..dc891a65 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -22,6 +22,7 @@ from pydantic import BaseModel, Field from app.gateway.deps import get_checkpointer, get_store from deerflow.config.paths import Paths, get_paths +from deerflow.agents.memory.thread_storage import delete_thread_memory_data from deerflow.runtime import serialize_channel_values # --------------------------------------------------------------------------- @@ -240,6 +241,10 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe await checkpointer.adelete_thread(thread_id) except Exception: logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id) + try: + delete_thread_memory_data(thread_id) + except Exception: + logger.debug("Could not delete thread memory for thread %s (not critical)", thread_id) return response diff --git a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py index f08ecd9b..8e473c2b 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py @@ -391,9 +391,34 @@ def _get_memory_context(agent_name: str | None = None) -> str: """ try: from deerflow.agents.memory import format_memory_for_injection, get_memory_data + from deerflow.agents.memory.thread_prompt import format_thread_memory_for_injection + from deerflow.agents.memory.thread_storage import get_thread_memory_data from deerflow.config.memory_config import get_memory_config + from deerflow.config.thread_memory_config import get_thread_memory_config + from langgraph.config import get_config config = get_memory_config() + thread_config = get_thread_memory_config() + config_data = get_config() + thread_id = config_data.get("configurable", {}).get("thread_id") + + if thread_config.enabled and thread_config.injection_enabled and thread_id: + thread_memory = get_thread_memory_data(thread_id) + if thread_memory is not None: + thread_content = format_thread_memory_for_injection( + { + "profile": thread_memory.get("profile", {}), + "preferences": thread_memory.get("preferences", {}), + "facts": thread_memory.get("facts", []), + }, + max_tokens=thread_config.max_injection_tokens, + ) + if thread_content.strip(): + return f""" +{thread_content} + +""" + if not config.enabled or not config.injection_enabled: return "" diff --git a/backend/packages/harness/deerflow/agents/memory/thread_prompt.py b/backend/packages/harness/deerflow/agents/memory/thread_prompt.py new file mode 100644 index 00000000..834d1d54 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/memory/thread_prompt.py @@ -0,0 +1,129 @@ +"""Prompt and formatting helpers for per-thread memory.""" + +from __future__ import annotations + +import json +from typing import Any + +from deerflow.agents.memory.prompt import _coerce_confidence, _count_tokens, format_conversation_for_update + +THREAD_MEMORY_UPDATE_PROMPT = """You are a user profile memory system. + +Current per-thread memory: + +{existing_memory} + + +Conversation: + +{conversation} + + +Return JSON only with this schema: +{{ + "profile": {{ + "name": string|null, + "role": string|null, + "expertise": string[], + "language": "zh-CN"|"en-US"|null, + "context": string|null + }}, + "preferences": {{ + "tone": "casual"|"formal"|"technical"|"friendly"|null, + "verbosity": "concise"|"detailed"|null, + "codeStyle": string|null, + "other": string|null + }}, + "facts": [ + {{ + "content": string, + "category": "tech_stack"|"preference"|"personal"|"context"|"goal", + "confidence": number + }} + ] +}} + +Rules: +- Keep only stable and useful user profile facts. +- Do not store sensitive personal data (phone/email/address/password/token/id/bank). +- Deduplicate and keep high-confidence facts. +- Return valid JSON only. +""" + + +def create_empty_thread_memory() -> dict[str, Any]: + return { + "profile": {"name": None, "role": None, "expertise": [], "language": None, "context": None}, + "preferences": {"tone": None, "verbosity": None, "codeStyle": None, "other": None}, + "facts": [], + } + + +def format_thread_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str: + if not memory_data: + return "" + + profile = memory_data.get("profile") or {} + preferences = memory_data.get("preferences") or {} + facts = memory_data.get("facts") or [] + + profile_lines: list[str] = [] + for key, label in (("name", "Name"), ("role", "Role"), ("language", "Language"), ("context", "Context")): + value = profile.get(key) + if isinstance(value, str) and value.strip(): + profile_lines.append(f"- {label}: {value.strip()}") + expertise = profile.get("expertise") + if isinstance(expertise, list): + cleaned = [str(item).strip() for item in expertise if str(item).strip()] + if cleaned: + profile_lines.append(f"- Expertise: {', '.join(cleaned)}") + + pref_lines: list[str] = [] + for key, label in (("tone", "Tone"), ("verbosity", "Verbosity"), ("codeStyle", "Code Style"), ("other", "Other")): + value = preferences.get(key) + if isinstance(value, str) and value.strip(): + pref_lines.append(f"- {label}: {value.strip()}") + + sections: list[str] = [] + if profile_lines: + sections.append("Profile:\n" + "\n".join(profile_lines)) + if pref_lines: + sections.append("Preferences:\n" + "\n".join(pref_lines)) + + # Facts are lowest priority: include by confidence/recency and trim by token budget. + ranked_facts = sorted( + ( + f + for f in facts + if isinstance(f, dict) and isinstance(f.get("content"), str) and f.get("content", "").strip() + ), + key=lambda f: (_coerce_confidence(f.get("confidence"), default=0.0), str(f.get("createdAt", ""))), + reverse=True, + ) + base = "\n\n".join(sections) + running = _count_tokens(base) if base else 0 + fact_lines: list[str] = [] + if ranked_facts: + running += _count_tokens("\n\nFacts:\n" if base else "Facts:\n") + for fact in ranked_facts: + line = ( + f"- [{str(fact.get('category', 'context')).strip() or 'context'} | " + f"{_coerce_confidence(fact.get('confidence'), default=0.0):.2f}] {fact.get('content').strip()}" + ) + candidate = ("\n" + line) if fact_lines else line + cost = _count_tokens(candidate) + if running + cost > max_tokens: + break + fact_lines.append(line) + running += cost + if fact_lines: + sections.append("Facts:\n" + "\n".join(fact_lines)) + + return "\n\n".join(sections) + + +def build_thread_memory_prompt(existing_memory: dict[str, Any], messages: list[Any]) -> str: + return THREAD_MEMORY_UPDATE_PROMPT.format( + existing_memory=json.dumps(existing_memory, ensure_ascii=False, indent=2), + conversation=format_conversation_for_update(messages), + ) diff --git a/backend/packages/harness/deerflow/agents/memory/thread_queue.py b/backend/packages/harness/deerflow/agents/memory/thread_queue.py new file mode 100644 index 00000000..f09b17ca --- /dev/null +++ b/backend/packages/harness/deerflow/agents/memory/thread_queue.py @@ -0,0 +1,76 @@ +"""Debounced queue for per-thread memory updates.""" + +from __future__ import annotations + +import threading +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from deerflow.config.thread_memory_config import get_thread_memory_config + + +@dataclass +class ThreadConversationContext: + thread_id: str + messages: list[Any] + timestamp: datetime = field(default_factory=datetime.utcnow) + + +class ThreadMemoryUpdateQueue: + def __init__(self): + self._queue: list[ThreadConversationContext] = [] + self._lock = threading.Lock() + self._timer: threading.Timer | None = None + self._processing = False + + def add(self, thread_id: str, messages: list[Any]) -> None: + config = get_thread_memory_config() + if not config.enabled: + return + with self._lock: + self._queue = [c for c in self._queue if c.thread_id != thread_id] + self._queue.append(ThreadConversationContext(thread_id=thread_id, messages=messages)) + self._reset_timer() + + def _reset_timer(self) -> None: + config = get_thread_memory_config() + if self._timer is not None: + self._timer.cancel() + self._timer = threading.Timer(config.debounce_seconds, self._process_queue) + self._timer.daemon = True + self._timer.start() + + def _process_queue(self) -> None: + from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater + + with self._lock: + if self._processing: + self._reset_timer() + return + if not self._queue: + return + self._processing = True + contexts = self._queue.copy() + self._queue.clear() + self._timer = None + + try: + updater = ThreadMemoryUpdater() + for context in contexts: + updater.update_memory(context.messages, context.thread_id) + finally: + with self._lock: + self._processing = False + + +_thread_queue: ThreadMemoryUpdateQueue | None = None +_lock = threading.Lock() + + +def get_thread_memory_queue() -> ThreadMemoryUpdateQueue: + global _thread_queue + with _lock: + if _thread_queue is None: + _thread_queue = ThreadMemoryUpdateQueue() + return _thread_queue diff --git a/backend/packages/harness/deerflow/agents/memory/thread_storage.py b/backend/packages/harness/deerflow/agents/memory/thread_storage.py new file mode 100644 index 00000000..d4a670b6 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/memory/thread_storage.py @@ -0,0 +1,246 @@ +"""Storage providers for per-thread memory.""" + +from __future__ import annotations + +import abc +import json +import logging +import sqlite3 +import threading +from datetime import datetime +from pathlib import Path +from typing import Any + +from deerflow.agents.memory.thread_prompt import create_empty_thread_memory +from deerflow.config.paths import get_paths +from deerflow.config.thread_memory_config import get_thread_memory_config + +logger = logging.getLogger(__name__) + + +class ThreadMemoryStorage(abc.ABC): + @abc.abstractmethod + def load(self, thread_id: str) -> dict[str, Any] | None: + pass + + @abc.abstractmethod + def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool: + pass + + @abc.abstractmethod + def delete(self, thread_id: str) -> bool: + pass + + +def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]: + return { + "threadId": row[0], + "ownerId": row[1], + "profile": json.loads(row[2]), + "preferences": json.loads(row[3]), + "facts": json.loads(row[4]), + "memoryVersion": int(row[5]), + "lastUpdated": str(row[6]), + } + + +class SqliteThreadMemoryStorage(ThreadMemoryStorage): + def __init__(self, db_path: str): + path = Path(db_path) + if not path.is_absolute(): + path = get_paths().base_dir / path + path.parent.mkdir(parents=True, exist_ok=True) + self._conn = sqlite3.connect(str(path), check_same_thread=False) + self._lock = threading.Lock() + with self._lock: + self._conn.execute( + """ + CREATE TABLE IF NOT EXISTS thread_memory ( + thread_id TEXT PRIMARY KEY, + owner_id TEXT NULL, + profile TEXT NOT NULL DEFAULT '{}', + preferences TEXT NOT NULL DEFAULT '{}', + facts TEXT NOT NULL DEFAULT '[]', + memory_version INTEGER NOT NULL DEFAULT 0, + last_updated TEXT NOT NULL DEFAULT (datetime('now')) + ) + """ + ) + self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)") + self._conn.commit() + + def load(self, thread_id: str) -> dict[str, Any] | None: + with self._lock: + row = self._conn.execute( + "SELECT thread_id, owner_id, profile, preferences, facts, memory_version, last_updated " + "FROM thread_memory WHERE thread_id = ?", + (thread_id,), + ).fetchone() + return _row_to_memory(row) if row else None + + def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool: + now = datetime.utcnow().isoformat() + "Z" + owner_id = data.get("ownerId") + if expected_version is None: + expected_version = 0 + with self._lock: + cur = self._conn.execute( + """ + INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version, last_updated) + VALUES (?, ?, ?, ?, ?, 0, ?) + ON CONFLICT(thread_id) DO NOTHING + """, + ( + thread_id, + owner_id, + json.dumps(data.get("profile", {}), ensure_ascii=False), + json.dumps(data.get("preferences", {}), ensure_ascii=False), + json.dumps(data.get("facts", []), ensure_ascii=False), + now, + ), + ) + if cur.rowcount == 1: + self._conn.commit() + return True + + cur = self._conn.execute( + """ + UPDATE thread_memory + SET owner_id = ?, profile = ?, preferences = ?, facts = ?, memory_version = memory_version + 1, last_updated = ? + WHERE thread_id = ? AND memory_version = ? + """, + ( + owner_id, + json.dumps(data.get("profile", {}), ensure_ascii=False), + json.dumps(data.get("preferences", {}), ensure_ascii=False), + json.dumps(data.get("facts", []), ensure_ascii=False), + now, + thread_id, + expected_version, + ), + ) + self._conn.commit() + return cur.rowcount == 1 + + def delete(self, thread_id: str) -> bool: + with self._lock: + self._conn.execute("DELETE FROM thread_memory WHERE thread_id = ?", (thread_id,)) + self._conn.commit() + return True + + +class MysqlThreadMemoryStorage(ThreadMemoryStorage): + def __init__(self, host: str, port: int, user: str, password: str, database: str): + import pymysql + + self._conn = pymysql.connect(host=host, port=port, user=user, password=password, database=database, charset="utf8mb4") + with self._conn.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS thread_memory ( + thread_id VARCHAR(64) PRIMARY KEY, + owner_id VARCHAR(64) NULL, + profile JSON NOT NULL, + preferences JSON NOT NULL, + facts JSON NOT NULL, + memory_version INT NOT NULL DEFAULT 0, + last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + INDEX idx_owner_id (owner_id) + ) + """ + ) + self._conn.commit() + + def load(self, thread_id: str) -> dict[str, Any] | None: + with self._conn.cursor() as cur: + cur.execute( + "SELECT thread_id, owner_id, profile, preferences, facts, memory_version, last_updated FROM thread_memory WHERE thread_id = %s", + (thread_id,), + ) + row = cur.fetchone() + return _row_to_memory(row) if row else None + + def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool: + if expected_version is None: + expected_version = 0 + owner_id = data.get("ownerId") + with self._conn.cursor() as cur: + cur.execute( + """ + INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version) + VALUES (%s, %s, %s, %s, %s, 0) + ON DUPLICATE KEY UPDATE thread_id = thread_id + """, + ( + thread_id, + owner_id, + json.dumps(data.get("profile", {}), ensure_ascii=False), + json.dumps(data.get("preferences", {}), ensure_ascii=False), + json.dumps(data.get("facts", []), ensure_ascii=False), + ), + ) + if cur.rowcount == 1: + self._conn.commit() + return True + cur.execute( + """ + UPDATE thread_memory + SET owner_id = %s, profile = %s, preferences = %s, facts = %s, memory_version = memory_version + 1 + WHERE thread_id = %s AND memory_version = %s + """, + ( + owner_id, + json.dumps(data.get("profile", {}), ensure_ascii=False), + json.dumps(data.get("preferences", {}), ensure_ascii=False), + json.dumps(data.get("facts", []), ensure_ascii=False), + thread_id, + expected_version, + ), + ) + self._conn.commit() + return cur.rowcount == 1 + + def delete(self, thread_id: str) -> bool: + with self._conn.cursor() as cur: + cur.execute("DELETE FROM thread_memory WHERE thread_id = %s", (thread_id,)) + self._conn.commit() + return True + + +_thread_storage: ThreadMemoryStorage | None = None +_thread_storage_lock = threading.Lock() + + +def get_thread_memory_storage() -> ThreadMemoryStorage: + global _thread_storage + if _thread_storage is not None: + return _thread_storage + + with _thread_storage_lock: + if _thread_storage is not None: + return _thread_storage + config = get_thread_memory_config() + if config.database.type == "mysql": + mysql = config.database.mysql + _thread_storage = MysqlThreadMemoryStorage( + host=mysql.host, + port=mysql.port, + user=mysql.user, + password=mysql.password, + database=mysql.database, + ) + else: + _thread_storage = SqliteThreadMemoryStorage(config.database.sqlite.path) + return _thread_storage + + +def get_thread_memory_data(thread_id: str) -> dict[str, Any] | None: + return get_thread_memory_storage().load(thread_id) + + +def delete_thread_memory_data(thread_id: str) -> bool: + return get_thread_memory_storage().delete(thread_id) + + +def initial_thread_memory_record() -> dict[str, Any]: + return {"ownerId": None, **create_empty_thread_memory()} diff --git a/backend/packages/harness/deerflow/agents/memory/thread_updater.py b/backend/packages/harness/deerflow/agents/memory/thread_updater.py new file mode 100644 index 00000000..8f8c0090 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/memory/thread_updater.py @@ -0,0 +1,132 @@ +"""Per-thread memory updater.""" + +from __future__ import annotations + +import json +import logging +import re +import uuid +from datetime import datetime +from typing import Any + +from deerflow.agents.memory.updater import _extract_text +from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, create_empty_thread_memory +from deerflow.agents.memory.thread_storage import get_thread_memory_storage +from deerflow.config.thread_memory_config import get_thread_memory_config +from deerflow.models import create_chat_model + +logger = logging.getLogger(__name__) + +_SENSITIVE_PATTERNS = ( + re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"), + re.compile(r"\b(?:\+?\d[\d -]{7,}\d)\b"), + re.compile(r"\b(?:api[_-]?key|token|password|passwd|secret)\b", re.IGNORECASE), + re.compile(r"\b\d{15,19}\b"), # bank-card like +) + + +class ThreadMemoryUpdater: + def __init__(self, model_name: str | None = None): + self._model_name = model_name + + def _get_model(self): + config = get_thread_memory_config() + # Non-stream invoke path: some OpenAI-compatible gateways reject + # stream_options when stream=false, so force stream_usage off here. + return create_chat_model( + name=self._model_name or config.model_name, + thinking_enabled=False, + stream_usage=False, + ) + + def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]: + def safe_text(val: Any) -> str | None: + if not isinstance(val, str): + return None + text = val.strip() + if not text: + return None + if any(p.search(text) for p in _SENSITIVE_PATTERNS): + logger.info("thread_memory sensitive value dropped for thread=%s", thread_id) + return None + return text + + profile = data.get("profile", {}) + preferences = data.get("preferences", {}) + facts = data.get("facts", []) + cleaned = create_empty_thread_memory() + + cleaned["profile"]["name"] = safe_text(profile.get("name")) + cleaned["profile"]["role"] = safe_text(profile.get("role")) + cleaned["profile"]["language"] = safe_text(profile.get("language")) + cleaned["profile"]["context"] = safe_text(profile.get("context")) + expertise = profile.get("expertise") + if isinstance(expertise, list): + cleaned["profile"]["expertise"] = [x for x in (safe_text(item) for item in expertise) if x] + + cleaned["preferences"]["tone"] = safe_text(preferences.get("tone")) + cleaned["preferences"]["verbosity"] = safe_text(preferences.get("verbosity")) + cleaned["preferences"]["codeStyle"] = safe_text(preferences.get("codeStyle")) + cleaned["preferences"]["other"] = safe_text(preferences.get("other")) + + seen: set[str] = set() + for fact in facts if isinstance(facts, list) else []: + if not isinstance(fact, dict): + continue + content = safe_text(fact.get("content")) + if not content: + continue + key = content.casefold() + if key in seen: + continue + seen.add(key) + confidence = float(fact.get("confidence", 0.5)) + cleaned["facts"].append( + { + "id": f"fact_{uuid.uuid4().hex[:8]}", + "content": content, + "category": str(fact.get("category", "context")).strip() or "context", + "confidence": max(0.0, min(1.0, confidence)), + "createdAt": datetime.utcnow().isoformat() + "Z", + "source": thread_id, + } + ) + return cleaned + + def update_memory(self, messages: list[Any], thread_id: str) -> bool: + config = get_thread_memory_config() + if not config.enabled or not messages or not thread_id: + return False + + storage = get_thread_memory_storage() + current = storage.load(thread_id) + base_memory = create_empty_thread_memory() if current is None else { + "profile": current.get("profile", {}), + "preferences": current.get("preferences", {}), + "facts": current.get("facts", []), + } + prompt = build_thread_memory_prompt(base_memory, messages) + if not prompt.strip(): + return False + + try: + response = self._get_model().invoke(prompt) + response_text = _extract_text(response.content).strip() + if response_text.startswith("```"): + lines = response_text.split("\n") + response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:]) + parsed = json.loads(response_text) + cleaned = self._scrub_sensitive(parsed, thread_id) + + expected_version = 0 if current is None else int(current.get("memoryVersion", 0)) + if storage.save(thread_id, cleaned, expected_version=expected_version): + return True + + # conflict retry once + latest = storage.load(thread_id) + latest_version = 0 if latest is None else int(latest.get("memoryVersion", 0)) + logger.info("thread_memory conflict detected, retrying once: thread=%s version=%s", thread_id, latest_version) + return storage.save(thread_id, cleaned, expected_version=latest_version) + except Exception: + logger.exception("Thread memory update failed for thread=%s", thread_id) + return False diff --git a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py index a0fc7c60..033525fd 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py @@ -10,7 +10,9 @@ from langgraph.config import get_config from langgraph.runtime import Runtime from deerflow.agents.memory.queue import get_memory_queue +from deerflow.agents.memory.thread_queue import get_thread_memory_queue from deerflow.config.memory_config import get_memory_config +from deerflow.config.thread_memory_config import get_thread_memory_config logger = logging.getLogger(__name__) @@ -206,8 +208,9 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): Returns: None (no state changes needed from this middleware). """ - config = get_memory_config() - if not config.enabled: + global_config = get_memory_config() + thread_config = get_thread_memory_config() + if not global_config.enabled and not thread_config.enabled: return None # Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata @@ -239,13 +242,19 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): # Queue the filtered conversation for memory update correction_detected = detect_correction(filtered_messages) reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) - queue = get_memory_queue() - queue.add( - thread_id=thread_id, - messages=filtered_messages, - agent_name=self._agent_name, - correction_detected=correction_detected, - reinforcement_detected=reinforcement_detected, - ) + if global_config.enabled: + queue = get_memory_queue() + queue.add( + thread_id=thread_id, + messages=filtered_messages, + agent_name=self._agent_name, + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, + ) + if thread_config.enabled: + get_thread_memory_queue().add( + thread_id=thread_id, + messages=filtered_messages, + ) return None diff --git a/backend/packages/harness/deerflow/config/__init__.py b/backend/packages/harness/deerflow/config/__init__.py index c41be373..b244b060 100644 --- a/backend/packages/harness/deerflow/config/__init__.py +++ b/backend/packages/harness/deerflow/config/__init__.py @@ -2,6 +2,7 @@ from .app_config import get_app_config from .billing_config import BillingConfig from .extensions_config import ExtensionsConfig, get_extensions_config from .memory_config import MemoryConfig, get_memory_config +from .thread_memory_config import ThreadMemoryConfig, get_thread_memory_config from .paths import Paths, get_paths from .skills_config import SkillsConfig from .tracing_config import ( @@ -22,6 +23,8 @@ __all__ = [ "get_extensions_config", "MemoryConfig", "get_memory_config", + "ThreadMemoryConfig", + "get_thread_memory_config", "get_tracing_config", "get_explicitly_enabled_tracing_providers", "get_enabled_tracing_providers", diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index 228c9b14..0973d8ef 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -25,6 +25,7 @@ from deerflow.config.title_config import TitleConfig, load_title_config_from_dic from deerflow.config.token_usage_config import TokenUsageConfig from deerflow.config.tool_config import ToolConfig, ToolGroupConfig from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict +from deerflow.config.thread_memory_config import ThreadMemoryConfig, load_thread_memory_config_from_dict load_dotenv() @@ -55,6 +56,7 @@ class AppConfig(BaseModel): title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration") summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration") memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration") + thread_memory: ThreadMemoryConfig = Field(default_factory=ThreadMemoryConfig, description="Per-thread memory subsystem configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") model_config = ConfigDict(extra="allow", frozen=False) @@ -118,6 +120,8 @@ class AppConfig(BaseModel): # Load memory config if present if "memory" in config_data: load_memory_config_from_dict(config_data["memory"]) + if "thread_memory" in config_data: + load_thread_memory_config_from_dict(config_data["thread_memory"]) # Load subagents config if present if "subagents" in config_data: diff --git a/backend/packages/harness/deerflow/config/thread_memory_config.py b/backend/packages/harness/deerflow/config/thread_memory_config.py new file mode 100644 index 00000000..a335ce35 --- /dev/null +++ b/backend/packages/harness/deerflow/config/thread_memory_config.py @@ -0,0 +1,50 @@ +"""Configuration for per-thread memory mechanism.""" + +from pydantic import BaseModel, Field + + +class ThreadMemorySqliteConfig(BaseModel): + path: str = Field(default="thread_memory.db", description="SQLite database file path") + + +class ThreadMemoryMysqlConfig(BaseModel): + host: str = Field(default="localhost") + port: int = Field(default=3306) + user: str = Field(default="root") + password: str = Field(default="") + database: str = Field(default="deerflow") + + +class ThreadMemoryDatabaseConfig(BaseModel): + type: str = Field(default="sqlite", description="Database type: sqlite or mysql") + sqlite: ThreadMemorySqliteConfig = Field(default_factory=ThreadMemorySqliteConfig) + mysql: ThreadMemoryMysqlConfig = Field(default_factory=ThreadMemoryMysqlConfig) + + +class ThreadMemoryConfig(BaseModel): + enabled: bool = Field(default=True) + debounce_seconds: int = Field(default=30, ge=1, le=300) + model_name: str | None = Field(default=None) + max_facts: int = Field(default=100, ge=10, le=500) + fact_confidence_threshold: float = Field(default=0.7, ge=0.0, le=1.0) + injection_enabled: bool = Field(default=True) + max_injection_tokens: int = Field(default=2000, ge=100, le=8000) + bootstrap_from_global: bool = Field(default=False) + database: ThreadMemoryDatabaseConfig = Field(default_factory=ThreadMemoryDatabaseConfig) + + +_thread_memory_config: ThreadMemoryConfig = ThreadMemoryConfig() + + +def get_thread_memory_config() -> ThreadMemoryConfig: + return _thread_memory_config + + +def set_thread_memory_config(config: ThreadMemoryConfig) -> None: + global _thread_memory_config + _thread_memory_config = config + + +def load_thread_memory_config_from_dict(config_dict: dict) -> None: + global _thread_memory_config + _thread_memory_config = ThreadMemoryConfig(**config_dict) diff --git a/backend/packages/harness/deerflow/models/factory.py b/backend/packages/harness/deerflow/models/factory.py index b17f4577..46510761 100644 --- a/backend/packages/harness/deerflow/models/factory.py +++ b/backend/packages/harness/deerflow/models/factory.py @@ -88,18 +88,24 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, * if not has_stream_usage: model_settings_from_config["stream_usage"] = True + effective_stream_usage = kwargs.get("stream_usage", model_settings_from_config.get("stream_usage")) # Some OpenAI-compatible providers only return usage in streaming mode # when stream_options.include_usage is explicitly enabled. - stream_options_source = "kwargs" if "stream_options" in kwargs else "config" - stream_options = kwargs.get("stream_options") if stream_options_source == "kwargs" else model_settings_from_config.get("stream_options") - if stream_options is None: - model_settings_from_config["stream_options"] = {"include_usage": True} - elif isinstance(stream_options, dict) and "include_usage" not in stream_options: - patched_stream_options = {**stream_options, "include_usage": True} - if stream_options_source == "kwargs": - kwargs["stream_options"] = patched_stream_options - else: - model_settings_from_config["stream_options"] = patched_stream_options + if effective_stream_usage: + stream_options_source = "kwargs" if "stream_options" in kwargs else "config" + stream_options = kwargs.get("stream_options") if stream_options_source == "kwargs" else model_settings_from_config.get("stream_options") + if stream_options is None: + model_settings_from_config["stream_options"] = {"include_usage": True} + elif isinstance(stream_options, dict) and "include_usage" not in stream_options: + patched_stream_options = {**stream_options, "include_usage": True} + if stream_options_source == "kwargs": + kwargs["stream_options"] = patched_stream_options + else: + model_settings_from_config["stream_options"] = patched_stream_options + else: + # Some OpenAI-compatible endpoints reject stream_options when stream is false. + model_settings_from_config.pop("stream_options", None) + kwargs.pop("stream_options", None) except Exception: # Keep model creation robust when langchain_openai isn't available. pass diff --git a/backend/tests/test_thread_memory_middleware.py b/backend/tests/test_thread_memory_middleware.py new file mode 100644 index 00000000..d766b75d --- /dev/null +++ b/backend/tests/test_thread_memory_middleware.py @@ -0,0 +1,32 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from langchain_core.messages import AIMessage, HumanMessage + +from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware +from deerflow.config.memory_config import MemoryConfig +from deerflow.config.thread_memory_config import ThreadMemoryConfig + + +def test_thread_memory_queue_runs_even_if_global_memory_disabled(): + middleware = MemoryMiddleware() + state = {"messages": [HumanMessage(content="My name is Alice"), AIMessage(content="Nice to meet you")]} + runtime = SimpleNamespace(context={"thread_id": "thread-test"}) + + mock_global_queue = MagicMock() + mock_thread_queue = MagicMock() + + with ( + patch("deerflow.agents.middlewares.memory_middleware.get_memory_config", return_value=MemoryConfig(enabled=False)), + patch( + "deerflow.agents.middlewares.memory_middleware.get_thread_memory_config", + return_value=ThreadMemoryConfig(enabled=True), + ), + patch("deerflow.agents.middlewares.memory_middleware.get_memory_queue", return_value=mock_global_queue), + patch("deerflow.agents.middlewares.memory_middleware.get_thread_memory_queue", return_value=mock_thread_queue), + ): + middleware.after_agent(state, runtime) + + mock_global_queue.add.assert_not_called() + mock_thread_queue.add.assert_called_once() + diff --git a/backend/tests/test_thread_memory_prompt.py b/backend/tests/test_thread_memory_prompt.py new file mode 100644 index 00000000..b66eba07 --- /dev/null +++ b/backend/tests/test_thread_memory_prompt.py @@ -0,0 +1,28 @@ +from langchain_core.messages import HumanMessage + +from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, format_thread_memory_for_injection + + +def test_thread_memory_injection_keeps_profile_and_preferences_under_small_budget(monkeypatch): + monkeypatch.setattr("deerflow.agents.memory.thread_prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text)) + memory = { + "profile": {"name": "Alice", "role": "Engineer", "expertise": ["Python", "React"], "language": "en-US", "context": "Building APIs"}, + "preferences": {"tone": "technical", "verbosity": "concise", "codeStyle": "typed-first", "other": "tests first"}, + "facts": [ + {"content": "Fact one that might be trimmed", "category": "context", "confidence": 0.9}, + {"content": "Fact two that might be trimmed", "category": "context", "confidence": 0.8}, + ], + } + + result = format_thread_memory_for_injection(memory, max_tokens=140) + assert "Profile:" in result + assert "Preferences:" in result + + +def test_build_thread_memory_prompt_does_not_raise_format_key_error(): + prompt = build_thread_memory_prompt( + {"profile": {}, "preferences": {}, "facts": []}, + [HumanMessage(content="My name is Alice.")], + ) + assert "Current per-thread memory" in prompt + assert '"profile"' in prompt diff --git a/backend/tests/test_thread_memory_storage.py b/backend/tests/test_thread_memory_storage.py new file mode 100644 index 00000000..de3b7796 --- /dev/null +++ b/backend/tests/test_thread_memory_storage.py @@ -0,0 +1,29 @@ +from deerflow.agents.memory.thread_storage import SqliteThreadMemoryStorage + + +def _payload(): + return { + "ownerId": None, + "profile": {"name": "A", "role": None, "expertise": [], "language": None, "context": None}, + "preferences": {"tone": None, "verbosity": None, "codeStyle": None, "other": None}, + "facts": [], + } + + +def test_sqlite_thread_memory_compare_and_swap(tmp_path): + storage = SqliteThreadMemoryStorage(str(tmp_path / "thread-memory.db")) + thread_id = "thread-1" + + assert storage.save(thread_id, _payload(), expected_version=0) is True + loaded = storage.load(thread_id) + assert loaded is not None + assert loaded["memoryVersion"] == 0 + + # wrong expected version should fail + assert storage.save(thread_id, _payload(), expected_version=9) is False + # correct version should pass and increment + assert storage.save(thread_id, _payload(), expected_version=0) is True + loaded2 = storage.load(thread_id) + assert loaded2 is not None + assert loaded2["memoryVersion"] == 1 + diff --git a/docs/per-thread-memory-design-brainstorm.md b/docs/per-thread-memory-design-brainstorm.md new file mode 100644 index 00000000..71df0c8c --- /dev/null +++ b/docs/per-thread-memory-design-brainstorm.md @@ -0,0 +1,760 @@ +# Per-Thread Memory Brainstorm + +Date: 2026-05-07 + +## Background + +Deerflow 现有的记忆功能是单租户的——不同会话都属于同一个用户,所有对话共享一份全局 `memory.json`。 + +要做一个新的记忆功能:不同对话属于不同用户,每个会话都有一个长期记忆,内容包括用户的使用习惯、个人信息、个人喜好和偏好语气。 + +## 现有记忆系统 + +- **存储**:单一全局 `backend/.deer-flow/memory.json`,所有会话共享 +- **认证**:没有用户认证,没有用户隔离(better-auth 已搭建但未启用) +- **结构**: + - `user`: workContext / personalContext / topOfMind + - `history`: recentMonths / earlierContext / longTermBackground + - `facts[]`: id, content, category, confidence, source +- **读路径**:system prompt 生成时注入 `...` XML 标签 +- **写路径**:MemoryMiddleware 在对话后过滤消息 → MemoryUpdateQueue debounce 30s → MemoryUpdater 调 LLM 提取更新 → 原子写入 +- **配置**:`config.yaml > memory`(enabled, debounce_seconds, max_facts, max_injection_tokens 等) + +--- + +## 决策记录 + +### 存储方式: 数据库 + +~~文件存储 `threads/{thread_id}/profile-memory.json`~~ → **改为数据库表**,通过 `thread_id` 区分用户。 + +### 数据库: SQLite(本地/测试) + MySQL(生产环境) + +### 表结构: 单表 + JSON 列(Option A) + +### 依赖: 最小化,不引入 SQLAlchemy + +SQLite 用标准库 `sqlite3`,MySQL 用 `pymysql`(纯 Python,轻量)。 + +### 与全局记忆关系: 策略 B(fallback) + +Per-thread 有记忆就用 per-thread 的,没有就 fallback 到全局记忆。 + +### 首次对话: 不主动询问用户偏好 + +--- + +## 1. 数据库表设计 + +```sql +-- SQLite +CREATE TABLE IF NOT EXISTS thread_memory ( + thread_id TEXT PRIMARY KEY, + profile TEXT NOT NULL DEFAULT '{}', + preferences TEXT NOT NULL DEFAULT '{}', + facts TEXT NOT NULL DEFAULT '[]', + last_updated TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- MySQL +CREATE TABLE IF NOT EXISTS thread_memory ( + thread_id VARCHAR(64) PRIMARY KEY, + profile JSON NOT NULL, + preferences JSON NOT NULL, + facts JSON NOT NULL, + last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP +); +``` + +**profile** ({}): + +| 字段 | 类型 | 说明 | +|------|------|------| +| `name` | `string \| null` | 用户称呼 | +| `role` | `string \| null` | 职业/角色 | +| `expertise` | `string[]` | 技术栈/专业领域 | +| `language` | `"zh-CN" \| "en-US" \| null` | 使用的语言 | +| `context` | `string \| null` | 其他上下文(自由文本) | + +**preferences** ({}): + +| 字段 | 类型 | 说明 | +|------|------|------| +| `tone` | `"casual" \| "formal" \| "technical" \| "friendly" \| null` | 语气偏好 | +| `verbosity` | `"concise" \| "detailed" \| null` | 回答详细程度 | +| `codeStyle` | `string \| null` | 代码风格偏好 | +| `other` | `string \| null` | 其他偏好(自由文本) | + +**facts** ([]):复用现有全局记忆的 fact 结构 + +```json +{ + "id": "fact_abc123", + "content": "用户在使用 React + TypeScript", + "category": "tech_stack | preference | personal | context | goal", + "confidence": 0.9, + "createdAt": "2026-05-07T...", + "source": "thread_id" +} +``` + +**说明**:三个 JSON 字段在 SQLite 中存为 TEXT(sqlite3 标准库没有原生 JSON 类型),在 MySQL 中存为 JSON。代码层面读写时做 `json.dumps` / `json.loads`,对上层透明。 + +## 2. config.yaml 新增配置段 + +```yaml +thread_memory: + enabled: true + debounce_seconds: 30 + model_name: null # null = 使用默认模型 + max_facts: 100 + fact_confidence_threshold: 0.7 + injection_enabled: true + max_injection_tokens: 2000 + + database: + type: sqlite # sqlite | mysql + sqlite: + path: "thread_memory.db" + mysql: + host: "localhost" + port: 3306 + user: "root" + password: "$MYSQL_PASSWORD" + database: "deerflow" +``` + +大部分字段和现有 `memory` 配置段语义相同,可以在两个配置段之间复用。`database` 段按 type 取子段,工厂函数只读自己需要的部分。 + +## 3. 存储层设计 + +### 3.1 抽象接口 + +```python +# deerflow/agents/memory/thread_storage.py + +import abc +import json +import sqlite3 +from datetime import datetime +from typing import Any + + +class ThreadMemoryStorage(abc.ABC): + + @abc.abstractmethod + def load(self, thread_id: str) -> dict[str, Any] | None: + """加载指定 thread 的记忆,不存在返回 None。""" + ... + + @abc.abstractmethod + def save(self, thread_id: str, data: dict[str, Any]) -> bool: + """保存指定 thread 的记忆(upsert)。""" + ... + + @abc.abstractmethod + def delete(self, thread_id: str) -> bool: + """删除指定 thread 的记忆(thread 被删除时联动)。""" + ... + + +def _create_empty_memory() -> dict[str, Any]: + """Per-thread 记忆的初始空结构。""" + return { + "profile": { + "name": None, + "role": None, + "expertise": [], + "language": None, + "context": None, + }, + "preferences": { + "tone": None, + "verbosity": None, + "codeStyle": None, + "other": None, + }, + "facts": [], + } + + +def _row_to_memory(row: tuple) -> dict[str, Any]: + """将数据库行转为 memory dict。SQLite 的 JSON 列存的是 TEXT,需要 parse。""" + return { + "threadId": row[0], + "profile": json.loads(row[1]), + "preferences": json.loads(row[2]), + "facts": json.loads(row[3]), + "lastUpdated": row[4], + } +``` + +### 3.2 SQLite 实现(本地测试) + +```python +class SqliteThreadMemoryStorage(ThreadMemoryStorage): + + def __init__(self, db_path: str): + self._conn = sqlite3.connect(db_path) + self._conn.execute(""" + CREATE TABLE IF NOT EXISTS thread_memory ( + thread_id TEXT PRIMARY KEY, + profile TEXT NOT NULL DEFAULT '{}', + preferences TEXT NOT NULL DEFAULT '{}', + facts TEXT NOT NULL DEFAULT '[]', + last_updated TEXT NOT NULL DEFAULT (datetime('now')) + ) + """) + self._conn.commit() + + def load(self, thread_id: str) -> dict | None: + row = self._conn.execute( + "SELECT thread_id, profile, preferences, facts, last_updated " + "FROM thread_memory WHERE thread_id = ?", + (thread_id,) + ).fetchone() + return _row_to_memory(row) if row else None + + def save(self, thread_id: str, data: dict) -> bool: + now = datetime.utcnow().isoformat() + "Z" + self._conn.execute(""" + INSERT INTO thread_memory (thread_id, profile, preferences, facts, last_updated) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(thread_id) DO UPDATE SET + profile = excluded.profile, + preferences = excluded.preferences, + facts = excluded.facts, + last_updated = excluded.last_updated + """, ( + thread_id, + json.dumps(data["profile"], ensure_ascii=False), + json.dumps(data["preferences"], ensure_ascii=False), + json.dumps(data["facts"], ensure_ascii=False), + now, + )) + self._conn.commit() + return True + + def delete(self, thread_id: str) -> bool: + self._conn.execute("DELETE FROM thread_memory WHERE thread_id = ?", (thread_id,)) + self._conn.commit() + return True +``` + +### 3.3 MySQL 实现(生产环境) + +```python +class MysqlThreadMemoryStorage(ThreadMemoryStorage): + + def __init__(self, host: str, port: int, user: str, password: str, database: str): + import pymysql + self._conn = pymysql.connect( + host=host, port=port, user=user, password=password, database=database, + charset="utf8mb4", + ) + with self._conn.cursor() as cur: + cur.execute(""" + CREATE TABLE IF NOT EXISTS thread_memory ( + thread_id VARCHAR(64) PRIMARY KEY, + profile JSON NOT NULL, + preferences JSON NOT NULL, + facts JSON NOT NULL, + last_updated TIMESTAMP NOT NULL + DEFAULT CURRENT_TIMESTAMP + ON UPDATE CURRENT_TIMESTAMP + ) + """) + self._conn.commit() + + def load(self, thread_id: str) -> dict | None: + with self._conn.cursor() as cur: + cur.execute( + "SELECT thread_id, profile, preferences, facts, last_updated " + "FROM thread_memory WHERE thread_id = %s", + (thread_id,) + ) + row = cur.fetchone() + return _row_to_memory(row) if row else None + + def save(self, thread_id: str, data: dict) -> bool: + now = datetime.utcnow() + with self._conn.cursor() as cur: + cur.execute(""" + INSERT INTO thread_memory (thread_id, profile, preferences, facts, last_updated) + VALUES (%s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE + profile = VALUES(profile), + preferences = VALUES(preferences), + facts = VALUES(facts), + last_updated = VALUES(last_updated) + """, ( + thread_id, + json.dumps(data["profile"], ensure_ascii=False), + json.dumps(data["preferences"], ensure_ascii=False), + json.dumps(data["facts"], ensure_ascii=False), + now, + )) + self._conn.commit() + return True + + def delete(self, thread_id: str) -> bool: + with self._conn.cursor() as cur: + cur.execute("DELETE FROM thread_memory WHERE thread_id = %s", (thread_id,)) + self._conn.commit() + return True +``` + +### 3.4 工厂函数 + +```python +def get_thread_memory_storage() -> ThreadMemoryStorage: + """从 config 读取 database 配置,构建对应的 storage 实例(单例)。""" + config = get_thread_memory_config() + db = config.database + + if db.type == "sqlite": + return SqliteThreadMemoryStorage(db.sqlite.path) + elif db.type == "mysql": + return MysqlThreadMemoryStorage( + host=db.mysql.host, + port=db.mysql.port, + user=db.mysql.user, + password=db.mysql.password, + database=db.mysql.database, + ) + else: + raise ValueError(f"Unknown thread_memory database type: {db.type}") +``` + +### 3.5 注意事项 + +- **JSON 在 SQLite 中存为 TEXT**:`sqlite3` 标准库没有 JSON 类型,用 TEXT 存储 `json.dumps` 的结果。读写时做序列化/反序列化。MySQL 用原生 JSON 列,`pymysql` 自动处理。 +- **upsert 语法差异**:SQLite 用 `ON CONFLICT ... DO UPDATE SET`,MySQL 用 `ON DUPLICATE KEY UPDATE`,语义等价。 +- **连接管理**:两个实现都在 `__init__` 创建连接并持有。单线程场景没问题。如果将来需要并发,可以加连接池或改为每次操作创建连接。 + +--- + +## 4. upsert 语义:全量替换 vs 合并更新 + +### 两种模式 + +**模式 A — 增量合并**(LLM 出 delta,应用层合并): + +``` +LLM 输入: 现有记忆 + 新对话 +LLM 输出: { profile: { name: "新值", shouldUpdate: true }, newFacts: [...], factsToRemove: [...] } +应用层: 读取现有记忆 → 按 delta 逐字段合并 → 写入 +``` + +现有全局记忆用的就是这个模式。LLM 输出里带 `shouldUpdate` 标记和 `factsToRemove` 列表,应用代码做合并。 + +**模式 B — 全量替换**(LLM 出完整状态,应用层直接覆盖): + +``` +LLM 输入: 现有记忆 + 新对话 +LLM 输出: { profile: { name: "...", role: "...", ... }, preferences: {...}, facts: [...] } +应用层: INSERT ... ON CONFLICT DO UPDATE(整行覆盖) +``` + +### 选择模式 B 的理由 + +1. **profile 和 preferences 本身很小**。每个对象 5-6 个字段,全部输出最多几十个 token,增量节省的 token 可以忽略。 + +2. **去重和淘汰由 LLM 负责,应用层零逻辑**。LLM 看到了完整的现有记忆,在 prompt 中就能决定哪些 facts 要保留、哪些过时了要删、哪些要合并。应用代码只需要 `json.dumps` + upsert。 + +3. **避免字段删除的尴尬**。如果 LLM 想把 `profile.context` 从 `"前端开发者"` 改成 `null`(表示不再确定这个信息),增量模式需要额外表达"显式置 null"还是"不变",全量替换没有歧义。 + +4. **和现有全局记忆的模式不同是合理的**。全局记忆的 `history` 有大量的对话摘要文本,不适合全量替换。Per-thread 记忆的 profile/preferences 是结构化的元数据,全量输出成本低。 + +### 具体流程 + +``` +用户对话结束 + ↓ +MemoryMiddleware.after_agent() 提取 user + final AI 消息 + ↓ +queue.add(thread_id, messages) # debounce 30s + ↓ +ThreadMemoryUpdater.update() + 1. 从 DB 读取现有记忆(不存在就用 _create_empty_memory()) + 2. 构建 prompt: "以下是用户的现有画像和偏好:{existing_memory},以下是新的对话:{conversation},请更新用户画像。" + 3. LLM 返回完整的 profile + preferences + facts + 4. storage.save(thread_id, data) # upsert 整行覆盖 +``` + +**关键点**:LLM 的 prompt 里放了**现有记忆**,LLM 看到之后自己决定: +- 保留哪些 facts +- 更新哪些 profile 字段 +- 新增什么偏好 +- 删除过时的信息(不输出就是删除) + +应用代码不做任何合并判断,只负责把 LLM 输出写入数据库。 + +--- + +## 5. 更新路径 + +### 5.1 MemoryMiddleware 改造(最小改动) + +在现有 `MemoryMiddleware.after_agent()` 中加一段逻辑,当 `thread_id` 存在时,同时向 per-thread 记忆的 queue 推一条: + +```python +# 现有逻辑:全局记忆 +queue = get_memory_queue() +queue.add(thread_id=thread_id, messages=filtered_messages, ...) + +# 新增:per-thread 记忆 +if thread_id: + thread_queue = get_thread_memory_queue() + thread_queue.add(thread_id=thread_id, messages=filtered_messages) +``` + +### 5.2 ThreadMemoryUpdater + +新类,结构类似现有的 `MemoryUpdater`,但使用不同的 prompt 和存储后端: + +```python +class ThreadMemoryUpdater: + def update(self, messages, thread_id): + storage = get_thread_memory_storage() + existing = storage.load(thread_id) or _create_empty_memory() + + prompt = THREAD_MEMORY_UPDATE_PROMPT.format( + existing_memory=json.dumps(existing, ensure_ascii=False), + conversation=format_conversation(messages), + ) + + response = model.invoke(prompt) + new_memory = parse_llm_output(response) # { profile, preferences, facts } + + storage.save(thread_id, new_memory) +``` + +### 5.3 Prompt 设计要点 + +与全局记忆 prompt 的关键区别: + +| | 全局记忆 prompt | Per-thread 记忆 prompt | +|---|---|---| +| **目标** | "对话中发生了什么" | "这个人是谁、喜欢什么" | +| **输出** | user context 摘要 + history 摘要 + facts | profile + preferences + facts | +| **侧重** | 保留对话内容的事实性信息 | 推断用户的身份、偏好、风格 | +| **语气影响** | 无 | 输出 `preferences.tone` 直接影响后续回复风格 | + +--- + +## 6. 读取路径(注入 System Prompt) + +```python +def inject_thread_memory(system_prompt: str, thread_id: str) -> str: + storage = get_thread_memory_storage() + memory = storage.load(thread_id) + + if memory is None: + # fallback 到全局记忆 + return inject_global_memory(system_prompt) + + # 生成 标签注入 system prompt + profile_xml = _format_profile_xml(memory) + return system_prompt + "\n" + profile_xml +``` + +注入内容的 XML 结构示例: + +```xml + + + 张三 + 全栈工程师 + React, TypeScript, Python + zh-CN + 在做一个电商项目 + + + casual + detailed + prefers functional components with hooks + + +``` + +语气偏好(`preferences.tone`)不直接改 system prompt 模板,而是放在 `` XML 里让 LLM 自己理解。方式简单,不用维护 prompt 模板的分支逻辑。如果发现 LLM 不遵循,再考虑动态改写 prompt 模板。 + +--- + +## 7. Thread 删除时的联动 + +Gateway 已有 `DELETE /api/threads/{id}`。在现有 handler 中加一行: + +```python +# app/gateway/routers/threads.py +@router.delete("/api/threads/{thread_id}") +async def delete_thread(thread_id: str): + # ... 现有清理逻辑 ... + + # 新增:删除 per-thread 记忆 + get_thread_memory_storage().delete(thread_id) +``` + +--- + +## 8. 实施步骤 + +1. **新增配置模型** — `thread_memory_config.py`(参考现有 `memory_config.py`) +2. **新增存储层** — `thread_storage.py`(`ThreadMemoryStorage` + `SqliteThreadMemoryStorage` + `MysqlThreadMemoryStorage`) +3. **新增 prompt** — `thread_memory_prompt.py`(用于 LLM 提取用户画像) +4. **新增 updater** — 或扩展现有 `MemoryUpdater`,根据 `thread_id` 参数路由到不同逻辑 +5. **改造 middleware** — `MemoryMiddleware` 中加 per-thread 记忆的 queue 逻辑 +6. **改造注入** — system prompt 生成时注入 `` 标签 +7. **扩展 thread 删除 handler** — 联动删除 DB 记录 +8. **写入测试** — `test_thread_memory_storage.py`, `test_thread_memory_updater.py` + +## 9. 待确认事项 + +- [ ] pymysql 作为新依赖是否 OK? +- [ ] `database` 配置段结构是否合适? +- [ ] upsert 使用全量替换模式(模式 B)是否认同? + +## 10. 第二轮脑暴(风险前置) + +下面这轮不是改大方向,而是把容易在落地时踩坑的点先钉住。 + +### 10.1 隔离键:`thread_id` 是否足够? + +当前设计用 `thread_id` 作为主键隔离用户记忆,简单可行。但有一个隐含前提: +- 一个 thread 永远只对应一个真实用户 + +如果未来支持“同一用户多 thread 共享画像”或“thread 可能转移 owner”,只用 `thread_id` 会限制扩展。 + +可选路径: + +- 路径 A(维持现状,推荐短期):主键 `thread_id`,最快上线。 +- 路径 B(兼容未来):增加 `owner_id`(可空),并加索引 `(owner_id, thread_id)`。 + +建议: +- 第一版继续 `thread_id`,但在表里预留 `owner_id` nullable 字段,避免后续大迁移。 + +### 10.2 并发一致性:同一 thread 的并发写覆盖问题 + +场景:同一 thread 在短时间内触发多次 update,后到达的旧结果可能覆盖先到达的新结果。 + +可选保护: + +- 方案 A:`last_updated` 乐观锁(更新时带 where 条件) +- 方案 B:`memory_version` 整数版本号(推荐) +- 方案 C:严格串行队列(单 thread 单 worker) + +建议: +- 加 `memory_version`(默认 0)。`save` 时做 compare-and-swap 语义: + - 读取 version = n + - 写入时要求 version 仍为 n,成功后 version = n+1 + - 失败则重试一次(重新 load + merge prompt 再写) + +这样不需要分布式锁,也能规避“旧结果回写”。 + +### 10.3 记忆质量控制:防止噪声和幻觉固化 + +LLM 抽取用户画像时,最大风险是把一次性表达当长期偏好。 + +建议加三道门: + +1. 事实类别阈值 +- `preference` 类阈值可略低(如 0.7) +- `personal` 类阈值更高(如 0.85) + +2. 稳定性规则 +- 同类偏好至少被 2 次独立对话支持,才提升为 profile/preference 的强字段 + +3. 冲突降级 +- 新旧事实冲突时,不立刻删旧值 +- 先把旧值降权并标记 `supersededBy`,下一轮再淘汰 + +### 10.4 隐私与合规:先定义“不能记”的边界 + +建议在 prompt 与代码都加 denylist(双保险): + +- 默认不写入:身份证号、手机号、邮箱、住址、银行卡、密码/API Key 等敏感信息 +- 允许写入:技术偏好、工作语境、沟通风格、项目目标 + +实现上: +- 在 `ThreadMemoryUpdater` parse 后做一次 server-side scrub +- 命中敏感模式就丢弃并打审计日志(不落库原文) + +### 10.5 注入预算:避免 memory 挤爆上下文 + +当前有 `max_injection_tokens`,但还缺“裁剪策略”。 + +建议固定优先级: +1. profile(最高) +2. preferences +3. facts(按 confidence + recency 排序后截断) + +当超预算时: +- 永远保留 profile/preference +- 只裁剪 facts + +### 10.6 可观测性:上线后如何判断有效 + +建议最小指标集: + +- `thread_memory_update_total{status=ok|error}` +- `thread_memory_injection_tokens` +- `thread_memory_fact_count` +- `thread_memory_update_latency_ms` +- `thread_memory_conflict_retry_total` + +加两条抽样日志: +- 更新前后摘要 diff(脱敏后) +- 注入片段长度与截断原因 + +### 10.7 迁移与回滚策略(从全局记忆过渡) + +你已选 fallback 策略,这很好。建议再补两个机制: + +- 冷启动导入(可选) + - 首次访问 thread 且无 per-thread 记录时,从全局记忆抽取一份“弱画像”写入 + - 打 `bootstrapped_from_global=true` + +- 一键回滚 + - 配置开关 `thread_memory.injection_enabled=false` 时,立刻只走全局注入 + - 更新链路可继续跑,便于回滚期间保留数据 + +### 10.8 API 语义建议(便于后续运维) + +即使第一版 UI 不暴露,也建议预留内部接口: + +- `GET /internal/thread-memory/{thread_id}`(脱敏视图) +- `DELETE /internal/thread-memory/{thread_id}` +- `POST /internal/thread-memory/{thread_id}/rebuild` + +这样排障时不用直接查库。 + +--- + +## 11. 第三轮决策清单(进入实现前最后拍板) + +- [ ] 表结构是否预留 `owner_id` 与 `memory_version`? +- [ ] 是否采用 `memory_version` 方案处理并发覆盖? +- [ ] 敏感信息 denylist 范围是否按 10.4 执行? +- [ ] 注入裁剪优先级是否固定为 profile > preferences > facts? +- [ ] 是否需要“冷启动导入”全局记忆到 per-thread? +- [ ] 是否要在首版就加内部运维接口? + +如果以上 6 项确定,基本就能把实现风险压到可控范围内。 + +## 12. 默认拍板方案(建议直接采用) + +目标:在不显著增加复杂度的前提下,拿到“可上线 + 可回滚 + 可演进”的第一版。 + +### 12.1 表结构默认值 + +采用:**预留 `owner_id` + 引入 `memory_version`**。 + +SQLite: + +```sql +CREATE TABLE IF NOT EXISTS thread_memory ( + thread_id TEXT PRIMARY KEY, + owner_id TEXT NULL, + profile TEXT NOT NULL DEFAULT '{}', + preferences TEXT NOT NULL DEFAULT '{}', + facts TEXT NOT NULL DEFAULT '[]', + memory_version INTEGER NOT NULL DEFAULT 0, + last_updated TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id); +``` + +MySQL: + +```sql +CREATE TABLE IF NOT EXISTS thread_memory ( + thread_id VARCHAR(64) PRIMARY KEY, + owner_id VARCHAR(64) NULL, + profile JSON NOT NULL, + preferences JSON NOT NULL, + facts JSON NOT NULL, + memory_version INT NOT NULL DEFAULT 0, + last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + INDEX idx_owner_id (owner_id) +); +``` + +### 12.2 并发一致性默认值 + +采用:**`memory_version` 乐观并发控制 + 失败重试 1 次**。 + +保存逻辑: +- `load()` 读出 `memory_version=n` +- `save()` 时执行条件更新(`WHERE thread_id=? AND memory_version=n`) +- 成功则 `memory_version=n+1` +- 如果受影响行数为 0,说明被并发写抢先,重读并重试一次 + +这能防止“旧更新覆盖新更新”,同时实现复杂度可控。 + +### 12.3 隐私策略默认值 + +采用:**默认拒绝敏感信息入库(代码层 hard filter)**。 + +默认 denylist: +- 手机号 +- 邮箱 +- 身份证号/护照号 +- 银行卡号 +- 密码/API Key/Token +- 详细住址 + +规则: +- 命中则从 `profile/preferences/facts` 中删除该片段 +- 仅记录脱敏审计信息(类型 + 时间 + thread_id),不记录原文 + +### 12.4 注入裁剪默认值 + +采用固定优先级:**`profile > preferences > facts`**。 + +当超过 `max_injection_tokens`: +- 必保留:`profile`、`preferences` +- 裁剪:`facts`(按 `confidence DESC, createdAt DESC` 排序后截断) + +这能保证人格与风格信息稳定注入,不被历史 facts 挤掉。 + +### 12.5 冷启动策略默认值 + +采用:**首版不开启自动冷启动导入**(`bootstrap_from_global=false`)。 + +理由: +- 降低“全局脏数据复制到 thread”风险 +- 逻辑更清晰,便于观察 per-thread 记忆真实质量 + +补充: +- 保留 fallback(你当前已定) +- 后续若需要可加后台任务做可控回填 + +### 12.6 内部运维接口默认值 + +采用:**首版只加读接口,写接口延后**。 + +第一版建议: +- `GET /internal/thread-memory/{thread_id}`(脱敏后返回) + +暂不做: +- `DELETE /internal/thread-memory/{thread_id}`(已有 thread delete 联动可覆盖主场景) +- `POST /internal/thread-memory/{thread_id}/rebuild`(二期再加) + +这样可以先满足排障可见性,避免过早扩大运维面。 + +--- + +## 13. 实施前冻结版 Checklist(可直接转开发) + +- [ ] DDL 按 12.1 落地(含 `owner_id`, `memory_version`, index) +- [ ] Storage `save()` 改为 compare-and-swap 语义 +- [ ] Updater 增加一次冲突重试 +- [ ] parse 后执行敏感信息 scrub +- [ ] 注入模块按 `profile > preferences > facts` 裁剪 +- [ ] fallback 保持开启,冷启动导入保持关闭 +- [ ] 增加最小指标与脱敏 diff 日志 +- [ ] 增加内部只读排障接口 + +到这一步,方案已经可以进入实现,不需要再做大改。 diff --git a/docs/thread-memory-manual-test-checklist.md b/docs/thread-memory-manual-test-checklist.md new file mode 100644 index 00000000..c8537224 --- /dev/null +++ b/docs/thread-memory-manual-test-checklist.md @@ -0,0 +1,213 @@ +# Thread Memory 手动测试清单 + +日期:`2026-05-08` +测试人:`__________` + +--- + +## 0. 前置检查 + +- [ ] 已拉取包含以下修复的最新代码并重启后端进程 + - `memory.enabled=false` 时仍允许 `thread_memory` 更新 + - `thread_prompt` 的 JSON 模板转义修复(避免 `KeyError: "profile"`) + - `thread_updater` 使用非流式安全参数(避免 `stream_options` 400) +- [ ] `config.yaml` 中已启用 `thread_memory.enabled: true` +- [ ] 确认使用的是预期配置文件(当前项目根目录 `config.yaml`) + +--- + +## 1. 基础写入与读取 + +前置条件: +- 选择一个新的 `thread_id`(例:`1f571481-e3ae-42b5-a513-945bf8f1cbef`) + +步骤: +1. 在该线程发送 2-3 轮消息,包含姓名、角色、偏好语气等信息 +2. 等待 `debounce_seconds`(默认 30 秒) +3. 查询 `thread_memory` 表 + +期望: +- 出现该 `thread_id` 记录 +- `profile/preferences/facts` 有对应内容 + +结果: +- [1] 通过 +- [ ] 失败(备注:`________________`) + +--- + +## 2. Per-Thread 隔离 + +前置条件: +- 准备两个线程 `thread_A`、`thread_B` + +步骤: +1. 在 A 中输入“前端背景”信息 +2. 在 B 中输入“后端背景”信息 +3. 分别等待写入完成后查看两条记录 + +期望: +- A 仅保存 A 的画像,B 仅保存 B 的画像 +- 两个线程不串数据 + +结果: +- [1] 通过 +- [ ] 失败(备注:`________________`) + +--- + +## 3. 全局记忆 Fallback + +前置条件: +- 全局 memory 有内容 +- 新建一个尚无 per-thread 记录的线程 + +步骤: +1. 先在该新线程发一轮普通消息 +2. 观察回复是否体现全局记忆 +3. 再继续对话触发 per-thread 写入后观察注入变化 + +期望: +- 无 per-thread 时可 fallback 到全局 +- 有 per-thread 后优先使用 per-thread + +结果: +- [ ] 通过 +- [ ] 失败(备注:`未执行(N/A):当前环境 memory.enabled=false,全局记忆关闭,本用例不适用`) + +--- + +## 4. 注入裁剪优先级(Profile > Preferences > Facts) + +前置条件: +- 某线程已有大量 facts + +步骤: +1. 人为积累 facts 到接近/超过注入预算 +2. 保持 profile/preferences 有值 +3. 观察注入后的表现 + +期望: +- 超预算时保留 profile + preferences +- 优先裁剪 facts + +结果: +- [1 ] 通过 +- [ ] 失败(备注:`________________`) + +--- + +## 5. 敏感信息过滤 + +步骤: +1. 在对话中输入邮箱、手机号、token/password 等敏感样例 +2. 等待写入后查库 + +期望: +- 敏感信息不应落入 `profile/preferences/facts` + +结果: +- [1] 通过 +- [ ] 失败(备注:`________________`) + +--- + +## 6. 并发覆盖保护(CAS + version) + +步骤: +1. 同一 `thread_id` 短时间内触发两次更新(尽量并发) +2. 观察最终数据与日志 + +期望: +- 不出现明显“旧数据覆盖新数据” +- 冲突时可见重试行为(日志) + +结果: +- [1] 通过 +- [ ] 失败(备注:`________________`) + +--- + +## 7. Debounce 生效 + +步骤: +1. 在 30 秒内连续发送多条消息 +2. 观察写库频率 + +期望: +- 多条输入被合并处理,不是每条都立即写库 + +结果: +- [1] 通过 +- [ ] 失败(备注:`________________`) + +--- + +## 8. 线程删除联动清理 + +步骤: +1. 对已有 per-thread 记录的线程调用 `DELETE /api/threads/{thread_id}` +2. 查询 `thread_memory` 表 + +期望: +- 对应 `thread_id` 记录被删除 + +结果: +- [ ] 通过 +- [ ] 失败(备注:`未执行:当前产品决策不接受“删线程即删记忆”,需改为用户显式触发清除后再复测`) + +--- + +## 9. SQLite 自动建表与路径 + +步骤: +1. 删除现有 `thread_memory.db`(测试环境) +2. 重启服务并触发一轮写入 +3. 检查 DB 文件和表结构 + +期望: +- 自动创建 DB 文件与 `thread_memory` 表 +- 索引 `idx_thread_memory_owner_id` 存在 + +结果: +- [1] 通过 +- [ ] 失败(备注:`________________`) + +--- + +## 10. 配置开关验证 + +步骤: +1. 关闭 `thread_memory.enabled`,重启并测试写入 +2. 开启 `thread_memory.enabled`,关闭 `thread_memory.injection_enabled`,重启并测试注入 + +期望: +- `enabled=false`:不更新 per-thread +- `injection_enabled=false`:不注入 per-thread(可 fallback) + +结果: +- [1] 通过 +- [ ] 失败(备注:`________________`) + +--- + +## 11. 已知错误回归验证 + +### 11.1 `KeyError: "profile"` 回归 +- [ 1] 未再出现 `thread_prompt.py` 的 `KeyError` 报错 + +### 11.2 `stream_options` 400 回归 +- [ 1] 未再出现 `"'stream_options' only set this when you set stream: true"` 报错 + +备注:`________________` + +--- + +## 测试总结 + +- 总用例数:`11` +- 通过数:`____` +- 失败数:`____` +- 结论: + - [ ] 可上线 + - [ ] 需修复后复测