"""Per-thread memory updater.""" from __future__ import annotations import json import logging import re import uuid from datetime import UTC, datetime from typing import Any from deerflow.agents.memory.updater import _extract_text from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, create_empty_thread_memory from deerflow.agents.memory.thread_storage import get_thread_memory_storage from deerflow.config.thread_memory_config import get_thread_memory_config from deerflow.models import create_chat_model logger = logging.getLogger(__name__) _SENSITIVE_PATTERNS = ( re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"), re.compile(r"\b(?:\+?\d[\d -]{7,}\d)\b"), re.compile(r"\b(?:api[_-]?key|token|password|passwd|secret)\b", re.IGNORECASE), re.compile(r"\b\d{15,19}\b"), # bank-card like ) class ThreadMemoryUpdater: def __init__(self, model_name: str | None = None): self._model_name = model_name def _get_model(self): config = get_thread_memory_config() # Non-stream invoke path: some OpenAI-compatible gateways reject # stream_options when stream=false, so force stream_usage off here. return create_chat_model( name=self._model_name or config.model_name, thinking_enabled=False, stream_usage=False, ) def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]: def safe_confidence(val: Any, default: float = 0.5) -> float: try: parsed = float(val) except (TypeError, ValueError): return default return max(0.0, min(1.0, parsed)) def safe_text(val: Any) -> str | None: if not isinstance(val, str): return None text = val.strip() if not text: return None if any(p.search(text) for p in _SENSITIVE_PATTERNS): logger.info("thread_memory sensitive value dropped for thread=%s", thread_id) return None return text user = data.get("user", {}) history = data.get("history", {}) facts = data.get("facts", []) cleaned = create_empty_thread_memory() def copy_summary_section(target_parent: dict[str, Any], target_key: str, source_parent: Any): if not isinstance(source_parent, dict): return source_section = source_parent.get(target_key) if not isinstance(source_section, dict): return summary = safe_text(source_section.get("summary")) updated_at = safe_text(source_section.get("updatedAt")) if summary: target_parent[target_key]["summary"] = summary if updated_at: target_parent[target_key]["updatedAt"] = updated_at elif summary: target_parent[target_key]["updatedAt"] = datetime.now(UTC).isoformat().replace("+00:00", "Z") copy_summary_section(cleaned["user"], "workContext", user) copy_summary_section(cleaned["user"], "personalContext", user) copy_summary_section(cleaned["user"], "topOfMind", user) copy_summary_section(cleaned["history"], "recentMonths", history) copy_summary_section(cleaned["history"], "earlierContext", history) copy_summary_section(cleaned["history"], "longTermBackground", history) seen: set[str] = set() for fact in facts if isinstance(facts, list) else []: if not isinstance(fact, dict): continue content = safe_text(fact.get("content")) if not content: continue key = content.casefold() if key in seen: continue seen.add(key) confidence = safe_confidence(fact.get("confidence", 0.5)) cleaned["facts"].append( { "id": f"fact_{uuid.uuid4().hex[:8]}", "content": content, "category": str(fact.get("category", "context")).strip() or "context", "confidence": confidence, "createdAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"), "source": thread_id, } ) return cleaned def update_memory(self, messages: list[Any], thread_id: str) -> bool: config = get_thread_memory_config() if not config.enabled or not messages or not thread_id: return False storage = get_thread_memory_storage() current = storage.load(thread_id) base_memory = create_empty_thread_memory() if current is None else { "user": current.get("user", {}), "history": current.get("history", {}), "facts": current.get("facts", []), } prompt = build_thread_memory_prompt(base_memory, messages) if not prompt.strip(): return False try: response = self._get_model().invoke(prompt) response_text = _extract_text(response.content).strip() if response_text.startswith("```"): lines = response_text.split("\n") response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:]) parsed = json.loads(response_text) cleaned = self._scrub_sensitive(parsed, thread_id) expected_version = 0 if current is None else int(current.get("memoryVersion", 0)) if storage.save(thread_id, cleaned, expected_version=expected_version): return True # conflict retry once latest = storage.load(thread_id) latest_version = 0 if latest is None else int(latest.get("memoryVersion", 0)) logger.info("thread_memory conflict detected, retrying once: thread=%s version=%s", thread_id, latest_version) return storage.save(thread_id, cleaned, expected_version=latest_version) except Exception: logger.exception("Thread memory update failed for thread=%s", thread_id) return False