deerflow2/backend/tests/test_thread_memory_middleware.py
2026-05-08 10:19:09 +08:00

33 lines
1.3 KiB
Python

from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.thread_memory_config import ThreadMemoryConfig
def test_thread_memory_queue_runs_even_if_global_memory_disabled():
middleware = MemoryMiddleware()
state = {"messages": [HumanMessage(content="My name is Alice"), AIMessage(content="Nice to meet you")]}
runtime = SimpleNamespace(context={"thread_id": "thread-test"})
mock_global_queue = MagicMock()
mock_thread_queue = MagicMock()
with (
patch("deerflow.agents.middlewares.memory_middleware.get_memory_config", return_value=MemoryConfig(enabled=False)),
patch(
"deerflow.agents.middlewares.memory_middleware.get_thread_memory_config",
return_value=ThreadMemoryConfig(enabled=True),
),
patch("deerflow.agents.middlewares.memory_middleware.get_memory_queue", return_value=mock_global_queue),
patch("deerflow.agents.middlewares.memory_middleware.get_thread_memory_queue", return_value=mock_thread_queue),
):
middleware.after_agent(state, runtime)
mock_global_queue.add.assert_not_called()
mock_thread_queue.add.assert_called_once()