deerflow2/backend/packages/harness/deerflow/agents/memory/thread_updater.py
2026-05-18 16:03:53 +08:00

133 lines
5.5 KiB
Python

"""Per-thread memory updater."""
from __future__ import annotations
import json
import logging
import re
import uuid
from datetime import datetime
from typing import Any
from deerflow.agents.memory.updater import _extract_text
from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, create_empty_thread_memory
from deerflow.agents.memory.thread_storage import get_thread_memory_storage
from deerflow.config.thread_memory_config import get_thread_memory_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
_SENSITIVE_PATTERNS = (
re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"),
re.compile(r"\b(?:\+?\d[\d -]{7,}\d)\b"),
re.compile(r"\b(?:api[_-]?key|token|password|passwd|secret)\b", re.IGNORECASE),
re.compile(r"\b\d{15,19}\b"), # bank-card like
)
class ThreadMemoryUpdater:
def __init__(self, model_name: str | None = None):
self._model_name = model_name
def _get_model(self):
config = get_thread_memory_config()
# Non-stream invoke path: some OpenAI-compatible gateways reject
# stream_options when stream=false, so force stream_usage off here.
return create_chat_model(
name=self._model_name or config.model_name,
thinking_enabled=False,
stream_usage=False,
)
def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]:
def safe_text(val: Any) -> str | None:
if not isinstance(val, str):
return None
text = val.strip()
if not text:
return None
if any(p.search(text) for p in _SENSITIVE_PATTERNS):
logger.info("thread_memory sensitive value dropped for thread=%s", thread_id)
return None
return text
profile = data.get("profile", {})
preferences = data.get("preferences", {})
facts = data.get("facts", [])
cleaned = create_empty_thread_memory()
cleaned["profile"]["name"] = safe_text(profile.get("name"))
cleaned["profile"]["role"] = safe_text(profile.get("role"))
cleaned["profile"]["language"] = safe_text(profile.get("language"))
cleaned["profile"]["context"] = safe_text(profile.get("context"))
expertise = profile.get("expertise")
if isinstance(expertise, list):
cleaned["profile"]["expertise"] = [x for x in (safe_text(item) for item in expertise) if x]
cleaned["preferences"]["tone"] = safe_text(preferences.get("tone"))
cleaned["preferences"]["verbosity"] = safe_text(preferences.get("verbosity"))
cleaned["preferences"]["codeStyle"] = safe_text(preferences.get("codeStyle"))
cleaned["preferences"]["other"] = safe_text(preferences.get("other"))
seen: set[str] = set()
for fact in facts if isinstance(facts, list) else []:
if not isinstance(fact, dict):
continue
content = safe_text(fact.get("content"))
if not content:
continue
key = content.casefold()
if key in seen:
continue
seen.add(key)
confidence = float(fact.get("confidence", 0.5))
cleaned["facts"].append(
{
"id": f"fact_{uuid.uuid4().hex[:8]}",
"content": content,
"category": str(fact.get("category", "context")).strip() or "context",
"confidence": max(0.0, min(1.0, confidence)),
"createdAt": datetime.utcnow().isoformat() + "Z",
"source": thread_id,
}
)
return cleaned
def update_memory(self, messages: list[Any], thread_id: str) -> bool:
config = get_thread_memory_config()
if not config.enabled or not messages or not thread_id:
return False
storage = get_thread_memory_storage()
current = storage.load(thread_id)
base_memory = create_empty_thread_memory() if current is None else {
"profile": current.get("profile", {}),
"preferences": current.get("preferences", {}),
"facts": current.get("facts", []),
}
prompt = build_thread_memory_prompt(base_memory, messages)
if not prompt.strip():
return False
try:
response = self._get_model().invoke(prompt)
response_text = _extract_text(response.content).strip()
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
parsed = json.loads(response_text)
cleaned = self._scrub_sensitive(parsed, thread_id)
expected_version = 0 if current is None else int(current.get("memoryVersion", 0))
if storage.save(thread_id, cleaned, expected_version=expected_version):
return True
# conflict retry once
latest = storage.load(thread_id)
latest_version = 0 if latest is None else int(latest.get("memoryVersion", 0))
logger.info("thread_memory conflict detected, retrying once: thread=%s version=%s", thread_id, latest_version)
return storage.save(thread_id, cleaned, expected_version=latest_version)
except Exception:
logger.exception("Thread memory update failed for thread=%s", thread_id)
return False