149 lines
6.1 KiB
Python
149 lines
6.1 KiB
Python
"""Per-thread memory updater."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import uuid
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
from deerflow.agents.memory.updater import _extract_text
|
|
from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, create_empty_thread_memory
|
|
from deerflow.agents.memory.thread_storage import get_thread_memory_storage
|
|
from deerflow.config.thread_memory_config import get_thread_memory_config
|
|
from deerflow.models import create_chat_model
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_SENSITIVE_PATTERNS = (
|
|
re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"),
|
|
re.compile(r"\b(?:\+?\d[\d -]{7,}\d)\b"),
|
|
re.compile(r"\b(?:api[_-]?key|token|password|passwd|secret)\b", re.IGNORECASE),
|
|
re.compile(r"\b\d{15,19}\b"), # bank-card like
|
|
)
|
|
|
|
|
|
class ThreadMemoryUpdater:
|
|
def __init__(self, model_name: str | None = None):
|
|
self._model_name = model_name
|
|
|
|
def _get_model(self):
|
|
config = get_thread_memory_config()
|
|
# Non-stream invoke path: some OpenAI-compatible gateways reject
|
|
# stream_options when stream=false, so force stream_usage off here.
|
|
return create_chat_model(
|
|
name=self._model_name or config.model_name,
|
|
thinking_enabled=False,
|
|
stream_usage=False,
|
|
)
|
|
|
|
def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]:
|
|
def safe_confidence(val: Any, default: float = 0.5) -> float:
|
|
try:
|
|
parsed = float(val)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
return max(0.0, min(1.0, parsed))
|
|
|
|
def safe_text(val: Any) -> str | None:
|
|
if not isinstance(val, str):
|
|
return None
|
|
text = val.strip()
|
|
if not text:
|
|
return None
|
|
if any(p.search(text) for p in _SENSITIVE_PATTERNS):
|
|
logger.info("thread_memory sensitive value dropped for thread=%s", thread_id)
|
|
return None
|
|
return text
|
|
|
|
user = data.get("user", {})
|
|
history = data.get("history", {})
|
|
facts = data.get("facts", [])
|
|
cleaned = create_empty_thread_memory()
|
|
|
|
def copy_summary_section(target_parent: dict[str, Any], target_key: str, source_parent: Any):
|
|
if not isinstance(source_parent, dict):
|
|
return
|
|
source_section = source_parent.get(target_key)
|
|
if not isinstance(source_section, dict):
|
|
return
|
|
summary = safe_text(source_section.get("summary"))
|
|
updated_at = safe_text(source_section.get("updatedAt"))
|
|
if summary:
|
|
target_parent[target_key]["summary"] = summary
|
|
if updated_at:
|
|
target_parent[target_key]["updatedAt"] = updated_at
|
|
elif summary:
|
|
target_parent[target_key]["updatedAt"] = datetime.now(UTC).isoformat().replace("+00:00", "Z")
|
|
|
|
copy_summary_section(cleaned["user"], "workContext", user)
|
|
copy_summary_section(cleaned["user"], "personalContext", user)
|
|
copy_summary_section(cleaned["user"], "topOfMind", user)
|
|
copy_summary_section(cleaned["history"], "recentMonths", history)
|
|
copy_summary_section(cleaned["history"], "earlierContext", history)
|
|
copy_summary_section(cleaned["history"], "longTermBackground", history)
|
|
|
|
seen: set[str] = set()
|
|
for fact in facts if isinstance(facts, list) else []:
|
|
if not isinstance(fact, dict):
|
|
continue
|
|
content = safe_text(fact.get("content"))
|
|
if not content:
|
|
continue
|
|
key = content.casefold()
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
confidence = safe_confidence(fact.get("confidence", 0.5))
|
|
cleaned["facts"].append(
|
|
{
|
|
"id": f"fact_{uuid.uuid4().hex[:8]}",
|
|
"content": content,
|
|
"category": str(fact.get("category", "context")).strip() or "context",
|
|
"confidence": confidence,
|
|
"createdAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
|
"source": thread_id,
|
|
}
|
|
)
|
|
return cleaned
|
|
|
|
def update_memory(self, messages: list[Any], thread_id: str) -> bool:
|
|
config = get_thread_memory_config()
|
|
if not config.enabled or not messages or not thread_id:
|
|
return False
|
|
|
|
storage = get_thread_memory_storage()
|
|
current = storage.load(thread_id)
|
|
base_memory = create_empty_thread_memory() if current is None else {
|
|
"user": current.get("user", {}),
|
|
"history": current.get("history", {}),
|
|
"facts": current.get("facts", []),
|
|
}
|
|
prompt = build_thread_memory_prompt(base_memory, messages)
|
|
if not prompt.strip():
|
|
return False
|
|
|
|
try:
|
|
response = self._get_model().invoke(prompt)
|
|
response_text = _extract_text(response.content).strip()
|
|
if response_text.startswith("```"):
|
|
lines = response_text.split("\n")
|
|
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
|
parsed = json.loads(response_text)
|
|
cleaned = self._scrub_sensitive(parsed, thread_id)
|
|
|
|
expected_version = 0 if current is None else int(current.get("memoryVersion", 0))
|
|
if storage.save(thread_id, cleaned, expected_version=expected_version):
|
|
return True
|
|
|
|
# conflict retry once
|
|
latest = storage.load(thread_id)
|
|
latest_version = 0 if latest is None else int(latest.get("memoryVersion", 0))
|
|
logger.info("thread_memory conflict detected, retrying once: thread=%s version=%s", thread_id, latest_version)
|
|
return storage.save(thread_id, cleaned, expected_version=latest_version)
|
|
except Exception:
|
|
logger.exception("Thread memory update failed for thread=%s", thread_id)
|
|
return False
|