fix(thread-memory): 修复语言识别与队列健壮性

This commit is contained in:
肖应宇 2026-05-08 11:45:56 +08:00
parent cba81112fd
commit 1c14be0c33
6 changed files with 152 additions and 32 deletions

View File

@ -6,6 +6,8 @@ import json
import re
from typing import Any
from langchain_core.messages import HumanMessage
from deerflow.agents.memory.prompt import format_conversation_for_update, format_memory_for_injection
THREAD_MEMORY_UPDATE_PROMPT = """You are a user profile memory system.
@ -68,13 +70,43 @@ def create_empty_thread_memory() -> dict[str, Any]:
}
def _extract_human_text(content: Any) -> str:
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
chunks: list[str] = []
for item in content:
if isinstance(item, str):
stripped = item.strip()
if stripped:
chunks.append(stripped)
elif isinstance(item, dict):
text_val = item.get("text")
if isinstance(text_val, str):
stripped = text_val.strip()
if stripped:
chunks.append(stripped)
return "\n".join(chunks).strip()
return ""
def _infer_preferred_memory_language(messages: list[Any]) -> str:
conversation = format_conversation_for_update(messages)
if not conversation.strip():
user_texts: list[str] = []
for msg in messages:
if isinstance(msg, HumanMessage):
extracted = _extract_human_text(getattr(msg, "content", None))
if extracted:
user_texts.append(extracted)
if not user_texts:
return "same as the user's latest message"
# Prioritize the latest user message; fallback to a short recent window.
recent_window = user_texts[-3:]
language_sample = "\n".join(recent_window)
# If user explicitly provides locale hints, prefer them.
locale_match = re.search(r"\b([a-z]{2}-[A-Z]{2})\b", conversation)
locale_match = re.search(r"\b([a-z]{2}-[A-Z]{2})\b", language_sample)
if locale_match:
return locale_match.group(1)
@ -90,7 +122,7 @@ def _infer_preferred_memory_language(messages: list[Any]) -> str:
"he-IL": r"[\u0590-\u05FF]",
"el-GR": r"[\u0370-\u03FF]",
}
counts = {lang: len(re.findall(pattern, conversation)) for lang, pattern in script_patterns.items()}
counts = {lang: len(re.findall(pattern, language_sample)) for lang, pattern in script_patterns.items()}
best_lang, best_count = max(counts.items(), key=lambda item: item[1])
if best_count > 0:
return best_lang

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import threading
from dataclasses import dataclass, field
from datetime import datetime
from datetime import UTC, datetime
from typing import Any
from deerflow.config.thread_memory_config import get_thread_memory_config
@ -14,54 +14,54 @@ from deerflow.config.thread_memory_config import get_thread_memory_config
class ThreadConversationContext:
thread_id: str
messages: list[Any]
timestamp: datetime = field(default_factory=datetime.utcnow)
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
class ThreadMemoryUpdateQueue:
def __init__(self):
self._queue: list[ThreadConversationContext] = []
self._queue_by_thread: dict[str, ThreadConversationContext] = {}
self._lock = threading.Lock()
self._timer: threading.Timer | None = None
self._processing = False
self._timers: dict[str, threading.Timer] = {}
self._processing_threads: set[str] = set()
def add(self, thread_id: str, messages: list[Any]) -> None:
config = get_thread_memory_config()
if not config.enabled:
return
with self._lock:
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(ThreadConversationContext(thread_id=thread_id, messages=messages))
self._reset_timer()
self._queue_by_thread[thread_id] = ThreadConversationContext(thread_id=thread_id, messages=messages)
self._reset_timer(thread_id)
def _reset_timer(self) -> None:
def _reset_timer(self, thread_id: str) -> None:
config = get_thread_memory_config()
if self._timer is not None:
self._timer.cancel()
self._timer = threading.Timer(config.debounce_seconds, self._process_queue)
self._timer.daemon = True
self._timer.start()
timer = self._timers.get(thread_id)
if timer is not None:
timer.cancel()
timer = threading.Timer(config.debounce_seconds, self._process_thread, args=(thread_id,))
timer.daemon = True
self._timers[thread_id] = timer
timer.start()
def _process_queue(self) -> None:
def _process_thread(self, thread_id: str) -> None:
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
with self._lock:
if self._processing:
self._reset_timer()
if thread_id in self._processing_threads:
self._reset_timer(thread_id)
return
if not self._queue:
context = self._queue_by_thread.pop(thread_id, None)
if context is None:
self._timers.pop(thread_id, None)
return
self._processing = True
contexts = self._queue.copy()
self._queue.clear()
self._timer = None
self._processing_threads.add(thread_id)
self._timers.pop(thread_id, None)
try:
updater = ThreadMemoryUpdater()
for context in contexts:
updater.update_memory(context.messages, context.thread_id)
finally:
with self._lock:
self._processing = False
self._processing_threads.discard(thread_id)
_thread_queue: ThreadMemoryUpdateQueue | None = None

View File

@ -40,6 +40,13 @@ class ThreadMemoryUpdater:
)
def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]:
def safe_confidence(val: Any, default: float = 0.5) -> float:
try:
parsed = float(val)
except (TypeError, ValueError):
return default
return max(0.0, min(1.0, parsed))
def safe_text(val: Any) -> str | None:
if not isinstance(val, str):
return None
@ -89,13 +96,13 @@ class ThreadMemoryUpdater:
if key in seen:
continue
seen.add(key)
confidence = float(fact.get("confidence", 0.5))
confidence = safe_confidence(fact.get("confidence", 0.5))
cleaned["facts"].append(
{
"id": f"fact_{uuid.uuid4().hex[:8]}",
"content": content,
"category": str(fact.get("category", "context")).strip() or "context",
"confidence": max(0.0, min(1.0, confidence)),
"confidence": confidence,
"createdAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
"source": thread_id,
}

View File

@ -1,4 +1,4 @@
from langchain_core.messages import HumanMessage
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, format_thread_memory_for_injection
@ -51,3 +51,31 @@ def test_build_thread_memory_prompt_prefers_japanese_for_japanese_conversation()
[HumanMessage(content="私は日本語で会話したいです。")],
)
assert "Preferred memory language: ja-JP" in prompt
def test_build_thread_memory_prompt_uses_user_messages_only_for_language_inference():
prompt = build_thread_memory_prompt(
{"user": {}, "history": {}, "facts": []},
[
HumanMessage(content="请用中文记录记忆"),
AIMessage(content="Sure, I will answer in English with many many words."),
AIMessage(content="More English content that should not change language inference."),
],
)
assert "Preferred memory language: zh-Hans" in prompt
def test_build_thread_memory_prompt_handles_structured_human_content():
prompt = build_thread_memory_prompt(
{"user": {}, "history": {}, "facts": []},
[
HumanMessage(
content=[
{"type": "text", "text": "我希望记忆使用中文。"},
{"type": "text", "text": "请继续。"},
]
),
AIMessage(content="I can also reply in English."),
],
)
assert "Preferred memory language: zh-Hans" in prompt

View File

@ -0,0 +1,33 @@
from unittest.mock import patch
from deerflow.agents.memory.thread_queue import ThreadMemoryUpdateQueue
def test_thread_queue_keeps_latest_message_per_thread():
queue = ThreadMemoryUpdateQueue()
with patch.object(queue, "_reset_timer"):
queue.add("thread-a", ["msg-1"])
queue.add("thread-b", ["msg-2"])
queue.add("thread-a", ["msg-3"])
assert set(queue._queue_by_thread.keys()) == {"thread-a", "thread-b"}
assert queue._queue_by_thread["thread-a"].messages == ["msg-3"]
def test_thread_queue_processes_single_thread_without_affecting_others():
queue = ThreadMemoryUpdateQueue()
with patch.object(queue, "_reset_timer"):
queue.add("thread-a", ["a-msg"])
queue.add("thread-b", ["b-msg"])
updater_calls: list[tuple[list[str], str]] = []
class _FakeUpdater:
def update_memory(self, messages, thread_id):
updater_calls.append((messages, thread_id))
with patch("deerflow.agents.memory.thread_updater.ThreadMemoryUpdater", _FakeUpdater):
queue._process_thread("thread-a")
assert updater_calls == [(["a-msg"], "thread-a")]
assert "thread-b" in queue._queue_by_thread

View File

@ -0,0 +1,20 @@
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
def test_scrub_sensitive_tolerates_non_numeric_confidence():
updater = ThreadMemoryUpdater()
cleaned = updater._scrub_sensitive(
{
"user": {},
"history": {},
"facts": [
{"content": "Uses React", "category": "knowledge", "confidence": "high"},
{"content": "Uses TypeScript", "category": "knowledge", "confidence": None},
],
},
"thread-test",
)
assert len(cleaned["facts"]) == 2
assert cleaned["facts"][0]["confidence"] == 0.5
assert cleaned["facts"][1]["confidence"] == 0.5