refactor(memory): 提取 JSON 工具函数到共享模块

将 thread_summary.py 中的 _strip_code_fence、_extract_json_object、
_escape_inner_quotes_in_json_strings 三个函数提取到新建的
json_utils.py 共享模块,thread_updater.py 同步使用统一接口。
This commit is contained in:
肖应宇 2026-06-11 17:47:07 +08:00
parent c17ba298fb
commit 252cd71fe0
3 changed files with 104 additions and 105 deletions

View File

@ -0,0 +1,95 @@
"""JSON extraction helpers for LLM-generated memory payloads."""
from __future__ import annotations
import json
import re
from typing import Any
def strip_code_fence(text: str) -> str:
cleaned = text.strip()
if not cleaned.startswith("```"):
return cleaned
lines = cleaned.split("\n")
return "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]).strip()
def escape_inner_quotes_in_json_strings(text: str) -> str:
"""Heuristically repair unescaped inner double quotes inside JSON strings."""
out: list[str] = []
in_string = False
escape = False
n = len(text)
i = 0
while i < n:
ch = text[i]
if not in_string:
out.append(ch)
if ch == '"':
in_string = True
i += 1
continue
if escape:
out.append(ch)
escape = False
i += 1
continue
if ch == "\\":
out.append(ch)
escape = True
i += 1
continue
if ch == '"':
j = i + 1
while j < n and text[j].isspace():
j += 1
next_char = text[j] if j < n else ""
if next_char in {":", ",", "}", "]", ""}:
out.append(ch)
in_string = False
else:
out.append('\\"')
i += 1
continue
out.append(ch)
i += 1
return "".join(out)
def extract_json_object(text: str) -> dict[str, Any] | None:
cleaned = strip_code_fence(text)
try:
parsed = json.loads(cleaned)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
repaired = escape_inner_quotes_in_json_strings(cleaned)
if repaired != cleaned:
try:
parsed = json.loads(repaired)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
if not match:
return None
candidate = match.group(0)
try:
parsed = json.loads(candidate)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
repaired = escape_inner_quotes_in_json_strings(candidate)
if repaired != candidate:
try:
parsed = json.loads(repaired)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
return None
return None

View File

@ -4,10 +4,14 @@ from __future__ import annotations
import json import json
import logging import logging
import re
import hashlib import hashlib
from typing import Any from typing import Any
from deerflow.agents.memory.json_utils import (
escape_inner_quotes_in_json_strings as _escape_inner_quotes_in_json_strings,
)
from deerflow.agents.memory.json_utils import extract_json_object as _extract_json_object
from deerflow.agents.memory.json_utils import strip_code_fence as _strip_code_fence
from deerflow.agents.memory.thread_prompt import create_empty_thread_memory from deerflow.agents.memory.thread_prompt import create_empty_thread_memory
from deerflow.agents.memory.thread_storage import get_thread_memory_storage from deerflow.agents.memory.thread_storage import get_thread_memory_storage
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
@ -74,106 +78,6 @@ def _get_summary_model():
config = get_thread_memory_config() config = get_thread_memory_config()
return create_chat_model(name=config.model_name, thinking_enabled=False, stream_usage=False) return create_chat_model(name=config.model_name, thinking_enabled=False, stream_usage=False)
def _strip_code_fence(text: str) -> str:
cleaned = text.strip()
if not cleaned.startswith("```"):
return cleaned
lines = cleaned.split("\n")
return "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]).strip()
def _extract_json_object(text: str) -> dict[str, Any] | None:
cleaned = _strip_code_fence(text)
try:
parsed = json.loads(cleaned)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
repaired = _escape_inner_quotes_in_json_strings(cleaned)
if repaired != cleaned:
try:
parsed = json.loads(repaired)
if isinstance(parsed, dict):
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=full_text")
return parsed
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
candidate = match.group(0)
repaired = _escape_inner_quotes_in_json_strings(candidate)
if repaired != candidate:
try:
parsed = json.loads(repaired)
if isinstance(parsed, dict):
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=regex_object")
return parsed
except json.JSONDecodeError:
return None
return None
def _escape_inner_quotes_in_json_strings(text: str) -> str:
"""Heuristically repair unescaped inner double quotes inside JSON strings.
If a quote appears while inside a string but the next non-space character is
not a valid string terminator (comma, object/array close, or key colon), it is
treated as content and escaped.
"""
out: list[str] = []
in_string = False
escape = False
n = len(text)
i = 0
while i < n:
ch = text[i]
if not in_string:
out.append(ch)
if ch == '"':
in_string = True
i += 1
continue
if escape:
out.append(ch)
escape = False
i += 1
continue
if ch == "\\":
out.append(ch)
escape = True
i += 1
continue
if ch == '"':
j = i + 1
while j < n and text[j].isspace():
j += 1
next_char = text[j] if j < n else ""
# Valid JSON string terminators in context:
# - key string: :
# - value string: , } ]
if next_char in {":", ",", "}", "]", ""}:
out.append(ch)
in_string = False
else:
out.append('\\"')
i += 1
continue
out.append(ch)
i += 1
return "".join(out)
def _merge_summary_patch(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: def _merge_summary_patch(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
merged = {"ownerId": base.get("ownerId"), **create_empty_thread_memory()} merged = {"ownerId": base.get("ownerId"), **create_empty_thread_memory()}
merged["user"] = dict(base.get("user", {})) if isinstance(base.get("user"), dict) else merged["user"] merged["user"] = dict(base.get("user", {})) if isinstance(base.get("user"), dict) else merged["user"]

View File

@ -9,6 +9,7 @@ import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from deerflow.agents.memory.json_utils import extract_json_object
from deerflow.agents.memory.updater import _extract_text from deerflow.agents.memory.updater import _extract_text
from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, create_empty_thread_memory from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, create_empty_thread_memory
from deerflow.agents.memory.thread_storage import get_thread_memory_storage from deerflow.agents.memory.thread_storage import get_thread_memory_storage
@ -128,10 +129,9 @@ class ThreadMemoryUpdater:
try: try:
response = self._get_model().invoke(prompt) response = self._get_model().invoke(prompt)
response_text = _extract_text(response.content).strip() response_text = _extract_text(response.content).strip()
if response_text.startswith("```"): parsed = extract_json_object(response_text)
lines = response_text.split("\n") if not isinstance(parsed, dict):
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:]) raise json.JSONDecodeError("No valid JSON object found", response_text, 0)
parsed = json.loads(response_text)
cleaned = self._scrub_sensitive(parsed, thread_id) cleaned = self._scrub_sensitive(parsed, thread_id)
expected_version = 0 if current is None else int(current.get("memoryVersion", 0)) expected_version = 0 if current is None else int(current.get("memoryVersion", 0))