deerflow2/backend/packages/harness/deerflow/agents/memory/thread_summary.py

301 lines
10 KiB
Python

"""Thread memory summary generation and application helpers."""
from __future__ import annotations
import json
import logging
import re
import hashlib
from typing import Any
from deerflow.agents.memory.thread_prompt import create_empty_thread_memory
from deerflow.agents.memory.thread_storage import get_thread_memory_storage
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
from deerflow.agents.memory.updater import _extract_text
from deerflow.config.thread_memory_config import get_thread_memory_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
SUMMARY_RENDER_PROMPT = """You are an assistant that renders thread memory into natural language.
Thread memory JSON:
<memory_json>
{memory_json}
</memory_json>
Task:
- Output a concise, human-friendly editable profile summary.
- Keep the original language of the memory content where possible.
- Cover user profile, history, and key facts.
- Return plain text only (no markdown code fences).
"""
SUMMARY_PARSE_PROMPT = """You convert user-edited natural-language memory into a structured patch JSON.
Current thread memory JSON:
<current_memory_json>
{current_memory_json}
</current_memory_json>
Edited summary text:
<edited_summary>
{edited_summary}
</edited_summary>
Return JSON only with this schema (all fields optional):
{{
"user": {{
"workContext": {{"summary": string}},
"personalContext": {{"summary": string}},
"topOfMind": {{"summary": string}}
}},
"history": {{
"recentMonths": {{"summary": string}},
"earlierContext": {{"summary": string}},
"longTermBackground": {{"summary": string}}
}},
"facts": [
{{
"content": string,
"category": "preference"|"knowledge"|"context"|"behavior"|"goal"|"correction",
"confidence": number
}}
]
}}
"""
class ThreadMemoryConflictError(RuntimeError):
"""Raised when compare-and-swap save fails due to version mismatch."""
def _get_summary_model():
config = get_thread_memory_config()
return create_chat_model(name=config.model_name, thinking_enabled=False, stream_usage=False)
def _strip_code_fence(text: str) -> str:
cleaned = text.strip()
if not cleaned.startswith("```"):
return cleaned
lines = cleaned.split("\n")
return "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]).strip()
def _extract_json_object(text: str) -> dict[str, Any] | None:
cleaned = _strip_code_fence(text)
try:
parsed = json.loads(cleaned)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
repaired = _escape_inner_quotes_in_json_strings(cleaned)
if repaired != cleaned:
try:
parsed = json.loads(repaired)
if isinstance(parsed, dict):
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=full_text")
return parsed
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
candidate = match.group(0)
repaired = _escape_inner_quotes_in_json_strings(candidate)
if repaired != candidate:
try:
parsed = json.loads(repaired)
if isinstance(parsed, dict):
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=regex_object")
return parsed
except json.JSONDecodeError:
return None
return None
def _escape_inner_quotes_in_json_strings(text: str) -> str:
"""Heuristically repair unescaped inner double quotes inside JSON strings.
If a quote appears while inside a string but the next non-space character is
not a valid string terminator (comma, object/array close, or key colon), it is
treated as content and escaped.
"""
out: list[str] = []
in_string = False
escape = False
n = len(text)
i = 0
while i < n:
ch = text[i]
if not in_string:
out.append(ch)
if ch == '"':
in_string = True
i += 1
continue
if escape:
out.append(ch)
escape = False
i += 1
continue
if ch == "\\":
out.append(ch)
escape = True
i += 1
continue
if ch == '"':
j = i + 1
while j < n and text[j].isspace():
j += 1
next_char = text[j] if j < n else ""
# Valid JSON string terminators in context:
# - key string: :
# - value string: , } ]
if next_char in {":", ",", "}", "]", ""}:
out.append(ch)
in_string = False
else:
out.append('\\"')
i += 1
continue
out.append(ch)
i += 1
return "".join(out)
def _merge_summary_patch(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
merged = {"ownerId": base.get("ownerId"), **create_empty_thread_memory()}
merged["user"] = dict(base.get("user", {})) if isinstance(base.get("user"), dict) else merged["user"]
merged["history"] = dict(base.get("history", {})) if isinstance(base.get("history"), dict) else merged["history"]
merged["facts"] = list(base.get("facts", [])) if isinstance(base.get("facts"), list) else []
for section_name in ("user", "history"):
section_patch = patch.get(section_name, {})
if not isinstance(section_patch, dict):
continue
for key, value in section_patch.items():
if key not in merged[section_name] or not isinstance(value, dict):
continue
summary = value.get("summary")
if isinstance(summary, str):
merged[section_name][key]["summary"] = summary
facts_patch = patch.get("facts")
if isinstance(facts_patch, list):
merged["facts"] = facts_patch
return merged
def render_thread_memory_summary(thread_id: str) -> dict[str, Any]:
storage = get_thread_memory_storage()
current = storage.load(thread_id)
memory = {"ownerId": None, **create_empty_thread_memory()} if current is None else current
memory_payload = {
"user": memory.get("user", {}),
"history": memory.get("history", {}),
"facts": memory.get("facts", []),
}
prompt = SUMMARY_RENDER_PROMPT.format(memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2))
response = _get_summary_model().invoke(prompt)
text = _strip_code_fence(_extract_text(response.content))
return {
"threadId": thread_id,
"memoryVersion": int(memory.get("memoryVersion", 0)),
"summary": text,
}
def apply_thread_memory_summary(thread_id: str, edited_summary: str, expected_version: int) -> dict[str, Any]:
storage = get_thread_memory_storage()
current = storage.load(thread_id)
base = {"ownerId": None, **create_empty_thread_memory()} if current is None else current
memory_payload = {
"user": base.get("user", {}),
"history": base.get("history", {}),
"facts": base.get("facts", []),
}
prompt = SUMMARY_PARSE_PROMPT.format(
current_memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2),
edited_summary=edited_summary,
)
response = _get_summary_model().invoke(prompt)
raw = _extract_text(response.content)
raw_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest()
logger.warning(
"THREAD_SUMMARY_DEBUG parse_raw_meta thread=%s raw_length=%d raw_sha256=%s",
thread_id,
len(raw),
raw_hash,
)
patch = _extract_json_object(raw)
if patch is None:
cleaned = _strip_code_fence(raw)
decode_error = None
try:
json.loads(cleaned)
except json.JSONDecodeError as exc:
decode_error = exc
if decode_error is not None:
logger.warning(
"THREAD_SUMMARY_DEBUG parse_error thread=%s msg=%s line=%d col=%d pos=%d snippet=%r",
thread_id,
decode_error.msg,
decode_error.lineno,
decode_error.colno,
decode_error.pos,
cleaned[max(0, decode_error.pos - 80): decode_error.pos + 80],
)
else:
logger.warning(
"THREAD_SUMMARY_DEBUG parse_error thread=%s msg=no_json_object_extracted raw_head=%r",
thread_id,
cleaned[:200],
)
logger.warning("THREAD_SUMMARY_DEBUG parse_fallback thread=%s", thread_id)
patch = {
"user": {
"topOfMind": {
"summary": edited_summary.strip(),
}
}
}
else:
logger.warning(
"THREAD_SUMMARY_DEBUG parse_success thread=%s patch=%s",
thread_id,
json.dumps(patch, ensure_ascii=False)[:2000],
)
merged = _merge_summary_patch(base, patch if isinstance(patch, dict) else {})
cleaned = ThreadMemoryUpdater()._scrub_sensitive(merged, thread_id)
cleaned["ownerId"] = base.get("ownerId")
logger.warning(
"THREAD_SUMMARY_DEBUG apply_cleaned thread=%s cleaned=%s",
thread_id,
json.dumps(
{
"user": cleaned.get("user", {}),
"history": cleaned.get("history", {}),
"facts_count": len(cleaned.get("facts", []) if isinstance(cleaned.get("facts"), list) else []),
},
ensure_ascii=False,
)[:2000],
)
if not storage.save(thread_id, cleaned, expected_version=expected_version):
raise ThreadMemoryConflictError(f"Thread memory version conflict for {thread_id}")
latest = storage.load(thread_id)
return latest if latest is not None else {"threadId": thread_id, "memoryVersion": expected_version, **cleaned}