104 lines
3.8 KiB
Python
104 lines
3.8 KiB
Python
from unittest.mock import patch
|
||
|
||
import pytest
|
||
|
||
from deerflow.agents.memory.thread_summary import (
|
||
ThreadMemoryConflictError,
|
||
_extract_json_object,
|
||
apply_thread_memory_summary,
|
||
render_thread_memory_summary,
|
||
)
|
||
|
||
|
||
def test_render_thread_memory_summary_returns_text():
|
||
fake_storage = type(
|
||
"S",
|
||
(),
|
||
{"load": lambda self, tid: {"threadId": tid, "user": {}, "history": {}, "facts": [], "memoryVersion": 2}},
|
||
)()
|
||
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "用户总结"})()})()
|
||
|
||
with (
|
||
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=fake_storage),
|
||
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
|
||
):
|
||
result = render_thread_memory_summary("t1")
|
||
|
||
assert result["threadId"] == "t1"
|
||
assert result["memoryVersion"] == 2
|
||
assert result["summary"] == "用户总结"
|
||
|
||
|
||
def test_apply_thread_memory_summary_raises_conflict_on_cas_failure():
|
||
class _Storage:
|
||
def load(self, _tid):
|
||
return {"threadId": "t1", "ownerId": None, "user": {}, "history": {}, "facts": [], "memoryVersion": 1}
|
||
|
||
def save(self, _tid, _data, expected_version=None):
|
||
return False
|
||
|
||
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "{}"})()})()
|
||
fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})()
|
||
|
||
with (
|
||
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=_Storage()),
|
||
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
|
||
patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater),
|
||
):
|
||
with pytest.raises(ThreadMemoryConflictError):
|
||
apply_thread_memory_summary("t1", "更新内容", 1)
|
||
|
||
|
||
def test_apply_thread_memory_summary_falls_back_when_model_output_is_not_json():
|
||
class _Storage:
|
||
def __init__(self):
|
||
self.saved = None
|
||
|
||
def load(self, _tid):
|
||
if self.saved is not None:
|
||
return {"threadId": "t1", "memoryVersion": 2, **self.saved}
|
||
return {
|
||
"threadId": "t1",
|
||
"ownerId": None,
|
||
"user": {"topOfMind": {"summary": ""}},
|
||
"history": {},
|
||
"facts": [],
|
||
"memoryVersion": 1,
|
||
}
|
||
|
||
def save(self, _tid, data, expected_version=None):
|
||
self.saved = data
|
||
return True
|
||
|
||
storage = _Storage()
|
||
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "这是自然语言,不是JSON"})()})()
|
||
fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})()
|
||
|
||
with (
|
||
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=storage),
|
||
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
|
||
patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater),
|
||
):
|
||
result = apply_thread_memory_summary("t1", "我最近在做线程记忆功能", 1)
|
||
|
||
assert storage.saved is not None
|
||
assert storage.saved["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能"
|
||
assert result["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能"
|
||
|
||
|
||
def test_extract_json_object_repairs_inner_unescaped_quotes():
|
||
raw = """
|
||
{
|
||
"user": {
|
||
"topOfMind": {
|
||
"summary": "反感“作为 AI"这种句式,认为回答不用寒暄直接说重点。"
|
||
}
|
||
},
|
||
"history": {},
|
||
"facts": []
|
||
}
|
||
""".strip()
|
||
parsed = _extract_json_object(raw)
|
||
assert parsed is not None
|
||
assert parsed["user"]["topOfMind"]["summary"].startswith("反感“作为 AI")
|