301 lines
10 KiB
Python
301 lines
10 KiB
Python
"""Thread memory summary generation and application helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import hashlib
|
|
from typing import Any
|
|
|
|
from deerflow.agents.memory.thread_prompt import create_empty_thread_memory
|
|
from deerflow.agents.memory.thread_storage import get_thread_memory_storage
|
|
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
|
|
from deerflow.agents.memory.updater import _extract_text
|
|
from deerflow.config.thread_memory_config import get_thread_memory_config
|
|
from deerflow.models import create_chat_model
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SUMMARY_RENDER_PROMPT = """You are an assistant that renders thread memory into natural language.
|
|
|
|
Thread memory JSON:
|
|
<memory_json>
|
|
{memory_json}
|
|
</memory_json>
|
|
|
|
Task:
|
|
- Output a concise, human-friendly editable profile summary.
|
|
- Keep the original language of the memory content where possible.
|
|
- Cover user profile, history, and key facts.
|
|
- Return plain text only (no markdown code fences).
|
|
"""
|
|
|
|
SUMMARY_PARSE_PROMPT = """You convert user-edited natural-language memory into a structured patch JSON.
|
|
|
|
Current thread memory JSON:
|
|
<current_memory_json>
|
|
{current_memory_json}
|
|
</current_memory_json>
|
|
|
|
Edited summary text:
|
|
<edited_summary>
|
|
{edited_summary}
|
|
</edited_summary>
|
|
|
|
Return JSON only with this schema (all fields optional):
|
|
{{
|
|
"user": {{
|
|
"workContext": {{"summary": string}},
|
|
"personalContext": {{"summary": string}},
|
|
"topOfMind": {{"summary": string}}
|
|
}},
|
|
"history": {{
|
|
"recentMonths": {{"summary": string}},
|
|
"earlierContext": {{"summary": string}},
|
|
"longTermBackground": {{"summary": string}}
|
|
}},
|
|
"facts": [
|
|
{{
|
|
"content": string,
|
|
"category": "preference"|"knowledge"|"context"|"behavior"|"goal"|"correction",
|
|
"confidence": number
|
|
}}
|
|
]
|
|
}}
|
|
"""
|
|
|
|
|
|
class ThreadMemoryConflictError(RuntimeError):
|
|
"""Raised when compare-and-swap save fails due to version mismatch."""
|
|
|
|
|
|
def _get_summary_model():
|
|
config = get_thread_memory_config()
|
|
return create_chat_model(name=config.model_name, thinking_enabled=False, stream_usage=False)
|
|
|
|
|
|
def _strip_code_fence(text: str) -> str:
|
|
cleaned = text.strip()
|
|
if not cleaned.startswith("```"):
|
|
return cleaned
|
|
lines = cleaned.split("\n")
|
|
return "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]).strip()
|
|
|
|
|
|
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
|
cleaned = _strip_code_fence(text)
|
|
try:
|
|
parsed = json.loads(cleaned)
|
|
return parsed if isinstance(parsed, dict) else None
|
|
except json.JSONDecodeError:
|
|
repaired = _escape_inner_quotes_in_json_strings(cleaned)
|
|
if repaired != cleaned:
|
|
try:
|
|
parsed = json.loads(repaired)
|
|
if isinstance(parsed, dict):
|
|
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=full_text")
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
|
|
if not match:
|
|
return None
|
|
try:
|
|
parsed = json.loads(match.group(0))
|
|
return parsed if isinstance(parsed, dict) else None
|
|
except json.JSONDecodeError:
|
|
candidate = match.group(0)
|
|
repaired = _escape_inner_quotes_in_json_strings(candidate)
|
|
if repaired != candidate:
|
|
try:
|
|
parsed = json.loads(repaired)
|
|
if isinstance(parsed, dict):
|
|
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=regex_object")
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
return None
|
|
return None
|
|
|
|
|
|
def _escape_inner_quotes_in_json_strings(text: str) -> str:
|
|
"""Heuristically repair unescaped inner double quotes inside JSON strings.
|
|
|
|
If a quote appears while inside a string but the next non-space character is
|
|
not a valid string terminator (comma, object/array close, or key colon), it is
|
|
treated as content and escaped.
|
|
"""
|
|
out: list[str] = []
|
|
in_string = False
|
|
escape = False
|
|
n = len(text)
|
|
i = 0
|
|
while i < n:
|
|
ch = text[i]
|
|
if not in_string:
|
|
out.append(ch)
|
|
if ch == '"':
|
|
in_string = True
|
|
i += 1
|
|
continue
|
|
|
|
if escape:
|
|
out.append(ch)
|
|
escape = False
|
|
i += 1
|
|
continue
|
|
|
|
if ch == "\\":
|
|
out.append(ch)
|
|
escape = True
|
|
i += 1
|
|
continue
|
|
|
|
if ch == '"':
|
|
j = i + 1
|
|
while j < n and text[j].isspace():
|
|
j += 1
|
|
next_char = text[j] if j < n else ""
|
|
# Valid JSON string terminators in context:
|
|
# - key string: :
|
|
# - value string: , } ]
|
|
if next_char in {":", ",", "}", "]", ""}:
|
|
out.append(ch)
|
|
in_string = False
|
|
else:
|
|
out.append('\\"')
|
|
i += 1
|
|
continue
|
|
|
|
out.append(ch)
|
|
i += 1
|
|
|
|
return "".join(out)
|
|
|
|
|
|
def _merge_summary_patch(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
|
|
merged = {"ownerId": base.get("ownerId"), **create_empty_thread_memory()}
|
|
merged["user"] = dict(base.get("user", {})) if isinstance(base.get("user"), dict) else merged["user"]
|
|
merged["history"] = dict(base.get("history", {})) if isinstance(base.get("history"), dict) else merged["history"]
|
|
merged["facts"] = list(base.get("facts", [])) if isinstance(base.get("facts"), list) else []
|
|
|
|
for section_name in ("user", "history"):
|
|
section_patch = patch.get(section_name, {})
|
|
if not isinstance(section_patch, dict):
|
|
continue
|
|
for key, value in section_patch.items():
|
|
if key not in merged[section_name] or not isinstance(value, dict):
|
|
continue
|
|
summary = value.get("summary")
|
|
if isinstance(summary, str):
|
|
merged[section_name][key]["summary"] = summary
|
|
|
|
facts_patch = patch.get("facts")
|
|
if isinstance(facts_patch, list):
|
|
merged["facts"] = facts_patch
|
|
|
|
return merged
|
|
|
|
|
|
def render_thread_memory_summary(thread_id: str) -> dict[str, Any]:
|
|
storage = get_thread_memory_storage()
|
|
current = storage.load(thread_id)
|
|
memory = {"ownerId": None, **create_empty_thread_memory()} if current is None else current
|
|
memory_payload = {
|
|
"user": memory.get("user", {}),
|
|
"history": memory.get("history", {}),
|
|
"facts": memory.get("facts", []),
|
|
}
|
|
prompt = SUMMARY_RENDER_PROMPT.format(memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2))
|
|
response = _get_summary_model().invoke(prompt)
|
|
text = _strip_code_fence(_extract_text(response.content))
|
|
return {
|
|
"threadId": thread_id,
|
|
"memoryVersion": int(memory.get("memoryVersion", 0)),
|
|
"summary": text,
|
|
}
|
|
|
|
|
|
def apply_thread_memory_summary(thread_id: str, edited_summary: str, expected_version: int) -> dict[str, Any]:
|
|
storage = get_thread_memory_storage()
|
|
current = storage.load(thread_id)
|
|
base = {"ownerId": None, **create_empty_thread_memory()} if current is None else current
|
|
memory_payload = {
|
|
"user": base.get("user", {}),
|
|
"history": base.get("history", {}),
|
|
"facts": base.get("facts", []),
|
|
}
|
|
prompt = SUMMARY_PARSE_PROMPT.format(
|
|
current_memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2),
|
|
edited_summary=edited_summary,
|
|
)
|
|
response = _get_summary_model().invoke(prompt)
|
|
raw = _extract_text(response.content)
|
|
raw_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|
logger.warning(
|
|
"THREAD_SUMMARY_DEBUG parse_raw_meta thread=%s raw_length=%d raw_sha256=%s",
|
|
thread_id,
|
|
len(raw),
|
|
raw_hash,
|
|
)
|
|
patch = _extract_json_object(raw)
|
|
if patch is None:
|
|
cleaned = _strip_code_fence(raw)
|
|
decode_error = None
|
|
try:
|
|
json.loads(cleaned)
|
|
except json.JSONDecodeError as exc:
|
|
decode_error = exc
|
|
if decode_error is not None:
|
|
logger.warning(
|
|
"THREAD_SUMMARY_DEBUG parse_error thread=%s msg=%s line=%d col=%d pos=%d snippet=%r",
|
|
thread_id,
|
|
decode_error.msg,
|
|
decode_error.lineno,
|
|
decode_error.colno,
|
|
decode_error.pos,
|
|
cleaned[max(0, decode_error.pos - 80): decode_error.pos + 80],
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"THREAD_SUMMARY_DEBUG parse_error thread=%s msg=no_json_object_extracted raw_head=%r",
|
|
thread_id,
|
|
cleaned[:200],
|
|
)
|
|
logger.warning("THREAD_SUMMARY_DEBUG parse_fallback thread=%s", thread_id)
|
|
patch = {
|
|
"user": {
|
|
"topOfMind": {
|
|
"summary": edited_summary.strip(),
|
|
}
|
|
}
|
|
}
|
|
else:
|
|
logger.warning(
|
|
"THREAD_SUMMARY_DEBUG parse_success thread=%s patch=%s",
|
|
thread_id,
|
|
json.dumps(patch, ensure_ascii=False)[:2000],
|
|
)
|
|
merged = _merge_summary_patch(base, patch if isinstance(patch, dict) else {})
|
|
cleaned = ThreadMemoryUpdater()._scrub_sensitive(merged, thread_id)
|
|
cleaned["ownerId"] = base.get("ownerId")
|
|
logger.warning(
|
|
"THREAD_SUMMARY_DEBUG apply_cleaned thread=%s cleaned=%s",
|
|
thread_id,
|
|
json.dumps(
|
|
{
|
|
"user": cleaned.get("user", {}),
|
|
"history": cleaned.get("history", {}),
|
|
"facts_count": len(cleaned.get("facts", []) if isinstance(cleaned.get("facts"), list) else []),
|
|
},
|
|
ensure_ascii=False,
|
|
)[:2000],
|
|
)
|
|
|
|
if not storage.save(thread_id, cleaned, expected_version=expected_version):
|
|
raise ThreadMemoryConflictError(f"Thread memory version conflict for {thread_id}")
|
|
|
|
latest = storage.load(thread_id)
|
|
return latest if latest is not None else {"threadId": thread_id, "memoryVersion": expected_version, **cleaned}
|