fix(thread-memory): 修复语言识别与队列健壮性

This commit is contained in:
肖应宇 2026-05-08 11:45:56 +08:00
parent cba81112fd
commit 1c14be0c33
6 changed files with 152 additions and 32 deletions

View File

@ -6,6 +6,8 @@ import json
import re import re
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage
from deerflow.agents.memory.prompt import format_conversation_for_update, format_memory_for_injection from deerflow.agents.memory.prompt import format_conversation_for_update, format_memory_for_injection
THREAD_MEMORY_UPDATE_PROMPT = """You are a user profile memory system. THREAD_MEMORY_UPDATE_PROMPT = """You are a user profile memory system.
@ -68,13 +70,43 @@ def create_empty_thread_memory() -> dict[str, Any]:
} }
def _extract_human_text(content: Any) -> str:
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
chunks: list[str] = []
for item in content:
if isinstance(item, str):
stripped = item.strip()
if stripped:
chunks.append(stripped)
elif isinstance(item, dict):
text_val = item.get("text")
if isinstance(text_val, str):
stripped = text_val.strip()
if stripped:
chunks.append(stripped)
return "\n".join(chunks).strip()
return ""
def _infer_preferred_memory_language(messages: list[Any]) -> str: def _infer_preferred_memory_language(messages: list[Any]) -> str:
conversation = format_conversation_for_update(messages) user_texts: list[str] = []
if not conversation.strip(): for msg in messages:
if isinstance(msg, HumanMessage):
extracted = _extract_human_text(getattr(msg, "content", None))
if extracted:
user_texts.append(extracted)
if not user_texts:
return "same as the user's latest message" return "same as the user's latest message"
# Prioritize the latest user message; fallback to a short recent window.
recent_window = user_texts[-3:]
language_sample = "\n".join(recent_window)
# If user explicitly provides locale hints, prefer them. # If user explicitly provides locale hints, prefer them.
locale_match = re.search(r"\b([a-z]{2}-[A-Z]{2})\b", conversation) locale_match = re.search(r"\b([a-z]{2}-[A-Z]{2})\b", language_sample)
if locale_match: if locale_match:
return locale_match.group(1) return locale_match.group(1)
@ -90,7 +122,7 @@ def _infer_preferred_memory_language(messages: list[Any]) -> str:
"he-IL": r"[\u0590-\u05FF]", "he-IL": r"[\u0590-\u05FF]",
"el-GR": r"[\u0370-\u03FF]", "el-GR": r"[\u0370-\u03FF]",
} }
counts = {lang: len(re.findall(pattern, conversation)) for lang, pattern in script_patterns.items()} counts = {lang: len(re.findall(pattern, language_sample)) for lang, pattern in script_patterns.items()}
best_lang, best_count = max(counts.items(), key=lambda item: item[1]) best_lang, best_count = max(counts.items(), key=lambda item: item[1])
if best_count > 0: if best_count > 0:
return best_lang return best_lang

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import threading import threading
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from deerflow.config.thread_memory_config import get_thread_memory_config from deerflow.config.thread_memory_config import get_thread_memory_config
@ -14,54 +14,54 @@ from deerflow.config.thread_memory_config import get_thread_memory_config
class ThreadConversationContext: class ThreadConversationContext:
thread_id: str thread_id: str
messages: list[Any] messages: list[Any]
timestamp: datetime = field(default_factory=datetime.utcnow) timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
class ThreadMemoryUpdateQueue: class ThreadMemoryUpdateQueue:
def __init__(self): def __init__(self):
self._queue: list[ThreadConversationContext] = [] self._queue_by_thread: dict[str, ThreadConversationContext] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
self._timer: threading.Timer | None = None self._timers: dict[str, threading.Timer] = {}
self._processing = False self._processing_threads: set[str] = set()
def add(self, thread_id: str, messages: list[Any]) -> None: def add(self, thread_id: str, messages: list[Any]) -> None:
config = get_thread_memory_config() config = get_thread_memory_config()
if not config.enabled: if not config.enabled:
return return
with self._lock: with self._lock:
self._queue = [c for c in self._queue if c.thread_id != thread_id] self._queue_by_thread[thread_id] = ThreadConversationContext(thread_id=thread_id, messages=messages)
self._queue.append(ThreadConversationContext(thread_id=thread_id, messages=messages)) self._reset_timer(thread_id)
self._reset_timer()
def _reset_timer(self) -> None: def _reset_timer(self, thread_id: str) -> None:
config = get_thread_memory_config() config = get_thread_memory_config()
if self._timer is not None: timer = self._timers.get(thread_id)
self._timer.cancel() if timer is not None:
self._timer = threading.Timer(config.debounce_seconds, self._process_queue) timer.cancel()
self._timer.daemon = True timer = threading.Timer(config.debounce_seconds, self._process_thread, args=(thread_id,))
self._timer.start() timer.daemon = True
self._timers[thread_id] = timer
timer.start()
def _process_queue(self) -> None: def _process_thread(self, thread_id: str) -> None:
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
with self._lock: with self._lock:
if self._processing: if thread_id in self._processing_threads:
self._reset_timer() self._reset_timer(thread_id)
return return
if not self._queue: context = self._queue_by_thread.pop(thread_id, None)
if context is None:
self._timers.pop(thread_id, None)
return return
self._processing = True self._processing_threads.add(thread_id)
contexts = self._queue.copy() self._timers.pop(thread_id, None)
self._queue.clear()
self._timer = None
try: try:
updater = ThreadMemoryUpdater() updater = ThreadMemoryUpdater()
for context in contexts:
updater.update_memory(context.messages, context.thread_id) updater.update_memory(context.messages, context.thread_id)
finally: finally:
with self._lock: with self._lock:
self._processing = False self._processing_threads.discard(thread_id)
_thread_queue: ThreadMemoryUpdateQueue | None = None _thread_queue: ThreadMemoryUpdateQueue | None = None

View File

@ -40,6 +40,13 @@ class ThreadMemoryUpdater:
) )
def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]: def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]:
def safe_confidence(val: Any, default: float = 0.5) -> float:
try:
parsed = float(val)
except (TypeError, ValueError):
return default
return max(0.0, min(1.0, parsed))
def safe_text(val: Any) -> str | None: def safe_text(val: Any) -> str | None:
if not isinstance(val, str): if not isinstance(val, str):
return None return None
@ -89,13 +96,13 @@ class ThreadMemoryUpdater:
if key in seen: if key in seen:
continue continue
seen.add(key) seen.add(key)
confidence = float(fact.get("confidence", 0.5)) confidence = safe_confidence(fact.get("confidence", 0.5))
cleaned["facts"].append( cleaned["facts"].append(
{ {
"id": f"fact_{uuid.uuid4().hex[:8]}", "id": f"fact_{uuid.uuid4().hex[:8]}",
"content": content, "content": content,
"category": str(fact.get("category", "context")).strip() or "context", "category": str(fact.get("category", "context")).strip() or "context",
"confidence": max(0.0, min(1.0, confidence)), "confidence": confidence,
"createdAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"), "createdAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
"source": thread_id, "source": thread_id,
} }

View File

@ -1,4 +1,4 @@
from langchain_core.messages import HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, format_thread_memory_for_injection from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, format_thread_memory_for_injection
@ -51,3 +51,31 @@ def test_build_thread_memory_prompt_prefers_japanese_for_japanese_conversation()
[HumanMessage(content="私は日本語で会話したいです。")], [HumanMessage(content="私は日本語で会話したいです。")],
) )
assert "Preferred memory language: ja-JP" in prompt assert "Preferred memory language: ja-JP" in prompt
def test_build_thread_memory_prompt_uses_user_messages_only_for_language_inference():
prompt = build_thread_memory_prompt(
{"user": {}, "history": {}, "facts": []},
[
HumanMessage(content="请用中文记录记忆"),
AIMessage(content="Sure, I will answer in English with many many words."),
AIMessage(content="More English content that should not change language inference."),
],
)
assert "Preferred memory language: zh-Hans" in prompt
def test_build_thread_memory_prompt_handles_structured_human_content():
prompt = build_thread_memory_prompt(
{"user": {}, "history": {}, "facts": []},
[
HumanMessage(
content=[
{"type": "text", "text": "我希望记忆使用中文。"},
{"type": "text", "text": "请继续。"},
]
),
AIMessage(content="I can also reply in English."),
],
)
assert "Preferred memory language: zh-Hans" in prompt

View File

@ -0,0 +1,33 @@
from unittest.mock import patch
from deerflow.agents.memory.thread_queue import ThreadMemoryUpdateQueue
def test_thread_queue_keeps_latest_message_per_thread():
queue = ThreadMemoryUpdateQueue()
with patch.object(queue, "_reset_timer"):
queue.add("thread-a", ["msg-1"])
queue.add("thread-b", ["msg-2"])
queue.add("thread-a", ["msg-3"])
assert set(queue._queue_by_thread.keys()) == {"thread-a", "thread-b"}
assert queue._queue_by_thread["thread-a"].messages == ["msg-3"]
def test_thread_queue_processes_single_thread_without_affecting_others():
queue = ThreadMemoryUpdateQueue()
with patch.object(queue, "_reset_timer"):
queue.add("thread-a", ["a-msg"])
queue.add("thread-b", ["b-msg"])
updater_calls: list[tuple[list[str], str]] = []
class _FakeUpdater:
def update_memory(self, messages, thread_id):
updater_calls.append((messages, thread_id))
with patch("deerflow.agents.memory.thread_updater.ThreadMemoryUpdater", _FakeUpdater):
queue._process_thread("thread-a")
assert updater_calls == [(["a-msg"], "thread-a")]
assert "thread-b" in queue._queue_by_thread

View File

@ -0,0 +1,20 @@
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
def test_scrub_sensitive_tolerates_non_numeric_confidence():
updater = ThreadMemoryUpdater()
cleaned = updater._scrub_sensitive(
{
"user": {},
"history": {},
"facts": [
{"content": "Uses React", "category": "knowledge", "confidence": "high"},
{"content": "Uses TypeScript", "category": "knowledge", "confidence": None},
],
},
"thread-test",
)
assert len(cleaned["facts"]) == 2
assert cleaned["facts"][0]["confidence"] == 0.5
assert cleaned["facts"][1]["confidence"] == 0.5