deerflow2/backend/tests/test_thread_memory_summary.py

104 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from unittest.mock import patch
import pytest
from deerflow.agents.memory.thread_summary import (
ThreadMemoryConflictError,
_extract_json_object,
apply_thread_memory_summary,
render_thread_memory_summary,
)
def test_render_thread_memory_summary_returns_text():
fake_storage = type(
"S",
(),
{"load": lambda self, tid: {"threadId": tid, "user": {}, "history": {}, "facts": [], "memoryVersion": 2}},
)()
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "用户总结"})()})()
with (
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=fake_storage),
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
):
result = render_thread_memory_summary("t1")
assert result["threadId"] == "t1"
assert result["memoryVersion"] == 2
assert result["summary"] == "用户总结"
def test_apply_thread_memory_summary_raises_conflict_on_cas_failure():
class _Storage:
def load(self, _tid):
return {"threadId": "t1", "ownerId": None, "user": {}, "history": {}, "facts": [], "memoryVersion": 1}
def save(self, _tid, _data, expected_version=None):
return False
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "{}"})()})()
fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})()
with (
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=_Storage()),
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater),
):
with pytest.raises(ThreadMemoryConflictError):
apply_thread_memory_summary("t1", "更新内容", 1)
def test_apply_thread_memory_summary_falls_back_when_model_output_is_not_json():
class _Storage:
def __init__(self):
self.saved = None
def load(self, _tid):
if self.saved is not None:
return {"threadId": "t1", "memoryVersion": 2, **self.saved}
return {
"threadId": "t1",
"ownerId": None,
"user": {"topOfMind": {"summary": ""}},
"history": {},
"facts": [],
"memoryVersion": 1,
}
def save(self, _tid, data, expected_version=None):
self.saved = data
return True
storage = _Storage()
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "这是自然语言不是JSON"})()})()
fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})()
with (
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=storage),
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater),
):
result = apply_thread_memory_summary("t1", "我最近在做线程记忆功能", 1)
assert storage.saved is not None
assert storage.saved["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能"
assert result["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能"
def test_extract_json_object_repairs_inner_unescaped_quotes():
raw = """
{
"user": {
"topOfMind": {
"summary": "反感“作为 AI"这种句式,认为回答不用寒暄直接说重点。"
}
},
"history": {},
"facts": []
}
""".strip()
parsed = _extract_json_object(raw)
assert parsed is not None
assert parsed["user"]["topOfMind"]["summary"].startswith("反感“作为 AI")