fix(thread-memory): 修复语言识别与队列健壮性
This commit is contained in:
parent
cba81112fd
commit
1c14be0c33
@ -6,6 +6,8 @@ import json
|
|||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
from deerflow.agents.memory.prompt import format_conversation_for_update, format_memory_for_injection
|
from deerflow.agents.memory.prompt import format_conversation_for_update, format_memory_for_injection
|
||||||
|
|
||||||
THREAD_MEMORY_UPDATE_PROMPT = """You are a user profile memory system.
|
THREAD_MEMORY_UPDATE_PROMPT = """You are a user profile memory system.
|
||||||
@ -68,13 +70,43 @@ def create_empty_thread_memory() -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_human_text(content: Any) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content.strip()
|
||||||
|
if isinstance(content, list):
|
||||||
|
chunks: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
stripped = item.strip()
|
||||||
|
if stripped:
|
||||||
|
chunks.append(stripped)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text_val = item.get("text")
|
||||||
|
if isinstance(text_val, str):
|
||||||
|
stripped = text_val.strip()
|
||||||
|
if stripped:
|
||||||
|
chunks.append(stripped)
|
||||||
|
return "\n".join(chunks).strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _infer_preferred_memory_language(messages: list[Any]) -> str:
|
def _infer_preferred_memory_language(messages: list[Any]) -> str:
|
||||||
conversation = format_conversation_for_update(messages)
|
user_texts: list[str] = []
|
||||||
if not conversation.strip():
|
for msg in messages:
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
extracted = _extract_human_text(getattr(msg, "content", None))
|
||||||
|
if extracted:
|
||||||
|
user_texts.append(extracted)
|
||||||
|
|
||||||
|
if not user_texts:
|
||||||
return "same as the user's latest message"
|
return "same as the user's latest message"
|
||||||
|
|
||||||
|
# Prioritize the latest user message; fallback to a short recent window.
|
||||||
|
recent_window = user_texts[-3:]
|
||||||
|
language_sample = "\n".join(recent_window)
|
||||||
|
|
||||||
# If user explicitly provides locale hints, prefer them.
|
# If user explicitly provides locale hints, prefer them.
|
||||||
locale_match = re.search(r"\b([a-z]{2}-[A-Z]{2})\b", conversation)
|
locale_match = re.search(r"\b([a-z]{2}-[A-Z]{2})\b", language_sample)
|
||||||
if locale_match:
|
if locale_match:
|
||||||
return locale_match.group(1)
|
return locale_match.group(1)
|
||||||
|
|
||||||
@ -90,7 +122,7 @@ def _infer_preferred_memory_language(messages: list[Any]) -> str:
|
|||||||
"he-IL": r"[\u0590-\u05FF]",
|
"he-IL": r"[\u0590-\u05FF]",
|
||||||
"el-GR": r"[\u0370-\u03FF]",
|
"el-GR": r"[\u0370-\u03FF]",
|
||||||
}
|
}
|
||||||
counts = {lang: len(re.findall(pattern, conversation)) for lang, pattern in script_patterns.items()}
|
counts = {lang: len(re.findall(pattern, language_sample)) for lang, pattern in script_patterns.items()}
|
||||||
best_lang, best_count = max(counts.items(), key=lambda item: item[1])
|
best_lang, best_count = max(counts.items(), key=lambda item: item[1])
|
||||||
if best_count > 0:
|
if best_count > 0:
|
||||||
return best_lang
|
return best_lang
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from deerflow.config.thread_memory_config import get_thread_memory_config
|
from deerflow.config.thread_memory_config import get_thread_memory_config
|
||||||
@ -14,54 +14,54 @@ from deerflow.config.thread_memory_config import get_thread_memory_config
|
|||||||
class ThreadConversationContext:
|
class ThreadConversationContext:
|
||||||
thread_id: str
|
thread_id: str
|
||||||
messages: list[Any]
|
messages: list[Any]
|
||||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
|
||||||
class ThreadMemoryUpdateQueue:
|
class ThreadMemoryUpdateQueue:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._queue: list[ThreadConversationContext] = []
|
self._queue_by_thread: dict[str, ThreadConversationContext] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._timer: threading.Timer | None = None
|
self._timers: dict[str, threading.Timer] = {}
|
||||||
self._processing = False
|
self._processing_threads: set[str] = set()
|
||||||
|
|
||||||
def add(self, thread_id: str, messages: list[Any]) -> None:
|
def add(self, thread_id: str, messages: list[Any]) -> None:
|
||||||
config = get_thread_memory_config()
|
config = get_thread_memory_config()
|
||||||
if not config.enabled:
|
if not config.enabled:
|
||||||
return
|
return
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
self._queue_by_thread[thread_id] = ThreadConversationContext(thread_id=thread_id, messages=messages)
|
||||||
self._queue.append(ThreadConversationContext(thread_id=thread_id, messages=messages))
|
self._reset_timer(thread_id)
|
||||||
self._reset_timer()
|
|
||||||
|
|
||||||
def _reset_timer(self) -> None:
|
def _reset_timer(self, thread_id: str) -> None:
|
||||||
config = get_thread_memory_config()
|
config = get_thread_memory_config()
|
||||||
if self._timer is not None:
|
timer = self._timers.get(thread_id)
|
||||||
self._timer.cancel()
|
if timer is not None:
|
||||||
self._timer = threading.Timer(config.debounce_seconds, self._process_queue)
|
timer.cancel()
|
||||||
self._timer.daemon = True
|
timer = threading.Timer(config.debounce_seconds, self._process_thread, args=(thread_id,))
|
||||||
self._timer.start()
|
timer.daemon = True
|
||||||
|
self._timers[thread_id] = timer
|
||||||
|
timer.start()
|
||||||
|
|
||||||
def _process_queue(self) -> None:
|
def _process_thread(self, thread_id: str) -> None:
|
||||||
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
|
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._processing:
|
if thread_id in self._processing_threads:
|
||||||
self._reset_timer()
|
self._reset_timer(thread_id)
|
||||||
return
|
return
|
||||||
if not self._queue:
|
context = self._queue_by_thread.pop(thread_id, None)
|
||||||
|
if context is None:
|
||||||
|
self._timers.pop(thread_id, None)
|
||||||
return
|
return
|
||||||
self._processing = True
|
self._processing_threads.add(thread_id)
|
||||||
contexts = self._queue.copy()
|
self._timers.pop(thread_id, None)
|
||||||
self._queue.clear()
|
|
||||||
self._timer = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updater = ThreadMemoryUpdater()
|
updater = ThreadMemoryUpdater()
|
||||||
for context in contexts:
|
updater.update_memory(context.messages, context.thread_id)
|
||||||
updater.update_memory(context.messages, context.thread_id)
|
|
||||||
finally:
|
finally:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._processing = False
|
self._processing_threads.discard(thread_id)
|
||||||
|
|
||||||
|
|
||||||
_thread_queue: ThreadMemoryUpdateQueue | None = None
|
_thread_queue: ThreadMemoryUpdateQueue | None = None
|
||||||
|
|||||||
@ -40,6 +40,13 @@ class ThreadMemoryUpdater:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]:
|
def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]:
|
||||||
|
def safe_confidence(val: Any, default: float = 0.5) -> float:
|
||||||
|
try:
|
||||||
|
parsed = float(val)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
return max(0.0, min(1.0, parsed))
|
||||||
|
|
||||||
def safe_text(val: Any) -> str | None:
|
def safe_text(val: Any) -> str | None:
|
||||||
if not isinstance(val, str):
|
if not isinstance(val, str):
|
||||||
return None
|
return None
|
||||||
@ -89,13 +96,13 @@ class ThreadMemoryUpdater:
|
|||||||
if key in seen:
|
if key in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(key)
|
seen.add(key)
|
||||||
confidence = float(fact.get("confidence", 0.5))
|
confidence = safe_confidence(fact.get("confidence", 0.5))
|
||||||
cleaned["facts"].append(
|
cleaned["facts"].append(
|
||||||
{
|
{
|
||||||
"id": f"fact_{uuid.uuid4().hex[:8]}",
|
"id": f"fact_{uuid.uuid4().hex[:8]}",
|
||||||
"content": content,
|
"content": content,
|
||||||
"category": str(fact.get("category", "context")).strip() or "context",
|
"category": str(fact.get("category", "context")).strip() or "context",
|
||||||
"confidence": max(0.0, min(1.0, confidence)),
|
"confidence": confidence,
|
||||||
"createdAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
"createdAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
||||||
"source": thread_id,
|
"source": thread_id,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, format_thread_memory_for_injection
|
from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, format_thread_memory_for_injection
|
||||||
|
|
||||||
@ -51,3 +51,31 @@ def test_build_thread_memory_prompt_prefers_japanese_for_japanese_conversation()
|
|||||||
[HumanMessage(content="私は日本語で会話したいです。")],
|
[HumanMessage(content="私は日本語で会話したいです。")],
|
||||||
)
|
)
|
||||||
assert "Preferred memory language: ja-JP" in prompt
|
assert "Preferred memory language: ja-JP" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_thread_memory_prompt_uses_user_messages_only_for_language_inference():
|
||||||
|
prompt = build_thread_memory_prompt(
|
||||||
|
{"user": {}, "history": {}, "facts": []},
|
||||||
|
[
|
||||||
|
HumanMessage(content="请用中文记录记忆"),
|
||||||
|
AIMessage(content="Sure, I will answer in English with many many words."),
|
||||||
|
AIMessage(content="More English content that should not change language inference."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert "Preferred memory language: zh-Hans" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_thread_memory_prompt_handles_structured_human_content():
|
||||||
|
prompt = build_thread_memory_prompt(
|
||||||
|
{"user": {}, "history": {}, "facts": []},
|
||||||
|
[
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "我希望记忆使用中文。"},
|
||||||
|
{"type": "text", "text": "请继续。"},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
AIMessage(content="I can also reply in English."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert "Preferred memory language: zh-Hans" in prompt
|
||||||
|
|||||||
33
backend/tests/test_thread_memory_queue.py
Normal file
33
backend/tests/test_thread_memory_queue.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from deerflow.agents.memory.thread_queue import ThreadMemoryUpdateQueue
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_queue_keeps_latest_message_per_thread():
|
||||||
|
queue = ThreadMemoryUpdateQueue()
|
||||||
|
with patch.object(queue, "_reset_timer"):
|
||||||
|
queue.add("thread-a", ["msg-1"])
|
||||||
|
queue.add("thread-b", ["msg-2"])
|
||||||
|
queue.add("thread-a", ["msg-3"])
|
||||||
|
|
||||||
|
assert set(queue._queue_by_thread.keys()) == {"thread-a", "thread-b"}
|
||||||
|
assert queue._queue_by_thread["thread-a"].messages == ["msg-3"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_queue_processes_single_thread_without_affecting_others():
|
||||||
|
queue = ThreadMemoryUpdateQueue()
|
||||||
|
with patch.object(queue, "_reset_timer"):
|
||||||
|
queue.add("thread-a", ["a-msg"])
|
||||||
|
queue.add("thread-b", ["b-msg"])
|
||||||
|
|
||||||
|
updater_calls: list[tuple[list[str], str]] = []
|
||||||
|
|
||||||
|
class _FakeUpdater:
|
||||||
|
def update_memory(self, messages, thread_id):
|
||||||
|
updater_calls.append((messages, thread_id))
|
||||||
|
|
||||||
|
with patch("deerflow.agents.memory.thread_updater.ThreadMemoryUpdater", _FakeUpdater):
|
||||||
|
queue._process_thread("thread-a")
|
||||||
|
|
||||||
|
assert updater_calls == [(["a-msg"], "thread-a")]
|
||||||
|
assert "thread-b" in queue._queue_by_thread
|
||||||
20
backend/tests/test_thread_memory_updater.py
Normal file
20
backend/tests/test_thread_memory_updater.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
|
||||||
|
|
||||||
|
|
||||||
|
def test_scrub_sensitive_tolerates_non_numeric_confidence():
|
||||||
|
updater = ThreadMemoryUpdater()
|
||||||
|
cleaned = updater._scrub_sensitive(
|
||||||
|
{
|
||||||
|
"user": {},
|
||||||
|
"history": {},
|
||||||
|
"facts": [
|
||||||
|
{"content": "Uses React", "category": "knowledge", "confidence": "high"},
|
||||||
|
{"content": "Uses TypeScript", "category": "knowledge", "confidence": None},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"thread-test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(cleaned["facts"]) == 2
|
||||||
|
assert cleaned["facts"][0]["confidence"] == 0.5
|
||||||
|
assert cleaned["facts"][1]["confidence"] == 0.5
|
||||||
Loading…
Reference in New Issue
Block a user