feat(memory): structured reflection + correction detection in MemoryMiddleware (#1620) (#1668)

* feat(memory): add structured reflection and correction detection

* fix(memory): align sourceError schema and prompt guidance

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
AochenShen99 2026-04-01 16:45:29 +08:00 committed by GitHub
parent 3e461d9d08
commit 0cdecf7b30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 436 additions and 21 deletions

View File

@ -49,6 +49,7 @@ class Fact(BaseModel):
confidence: float = Field(default=0.5, description="Confidence score (0-1)") confidence: float = Field(default=0.5, description="Confidence score (0-1)")
createdAt: str = Field(default="", description="Creation timestamp") createdAt: str = Field(default="", description="Creation timestamp")
source: str = Field(default="unknown", description="Source thread ID") source: str = Field(default="unknown", description="Source thread ID")
sourceError: str | None = Field(default=None, description="Optional description of the prior mistake or wrong approach")
class MemoryResponse(BaseModel): class MemoryResponse(BaseModel):
@ -108,6 +109,7 @@ class MemoryStatusResponse(BaseModel):
@router.get( @router.get(
"/memory", "/memory",
response_model=MemoryResponse, response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Get Memory Data", summary="Get Memory Data",
description="Retrieve the current global memory data including user context, history, and facts.", description="Retrieve the current global memory data including user context, history, and facts.",
) )
@ -152,6 +154,7 @@ async def get_memory() -> MemoryResponse:
@router.post( @router.post(
"/memory/reload", "/memory/reload",
response_model=MemoryResponse, response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Reload Memory Data", summary="Reload Memory Data",
description="Reload memory data from the storage file, refreshing the in-memory cache.", description="Reload memory data from the storage file, refreshing the in-memory cache.",
) )
@ -171,6 +174,7 @@ async def reload_memory() -> MemoryResponse:
@router.delete( @router.delete(
"/memory", "/memory",
response_model=MemoryResponse, response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Clear All Memory Data", summary="Clear All Memory Data",
description="Delete all saved memory data and reset the memory structure to an empty state.", description="Delete all saved memory data and reset the memory structure to an empty state.",
) )
@ -187,6 +191,7 @@ async def clear_memory() -> MemoryResponse:
@router.post( @router.post(
"/memory/facts", "/memory/facts",
response_model=MemoryResponse, response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Create Memory Fact", summary="Create Memory Fact",
description="Create a single saved memory fact manually.", description="Create a single saved memory fact manually.",
) )
@ -209,6 +214,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
@router.delete( @router.delete(
"/memory/facts/{fact_id}", "/memory/facts/{fact_id}",
response_model=MemoryResponse, response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Delete Memory Fact", summary="Delete Memory Fact",
description="Delete a single saved memory fact by its fact id.", description="Delete a single saved memory fact by its fact id.",
) )
@ -227,6 +233,7 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
@router.patch( @router.patch(
"/memory/facts/{fact_id}", "/memory/facts/{fact_id}",
response_model=MemoryResponse, response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Patch Memory Fact", summary="Patch Memory Fact",
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.", description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
) )
@ -252,6 +259,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
@router.get( @router.get(
"/memory/export", "/memory/export",
response_model=MemoryResponse, response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Export Memory Data", summary="Export Memory Data",
description="Export the current global memory data as JSON for backup or transfer.", description="Export the current global memory data as JSON for backup or transfer.",
) )
@ -264,6 +272,7 @@ async def export_memory() -> MemoryResponse:
@router.post( @router.post(
"/memory/import", "/memory/import",
response_model=MemoryResponse, response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Import Memory Data", summary="Import Memory Data",
description="Import and overwrite the current global memory data from a JSON payload.", description="Import and overwrite the current global memory data from a JSON payload.",
) )
@ -317,6 +326,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
@router.get( @router.get(
"/memory/status", "/memory/status",
response_model=MemoryStatusResponse, response_model=MemoryStatusResponse,
response_model_exclude_none=True,
summary="Get Memory Status", summary="Get Memory Status",
description="Retrieve both memory configuration and current data in a single request.", description="Retrieve both memory configuration and current data in a single request.",
) )

View File

@ -29,6 +29,17 @@ Instructions:
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies) 2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
3. Update the memory sections as needed following the detailed length guidelines below 3. Update the memory sections as needed following the detailed length guidelines below
Before extracting facts, perform a structured reflection on the conversation:
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
If yes, record them as facts with the most appropriate category and confidence.
{correction_hint}
Memory Section Guidelines: Memory Section Guidelines:
**User Context** (Current state - concise summaries): **User Context** (Current state - concise summaries):
@ -62,6 +73,7 @@ Memory Section Guidelines:
* context: Background facts (job title, projects, locations, languages) * context: Background facts (job title, projects, locations, languages)
* behavior: Working patterns, communication habits, problem-solving approaches * behavior: Working patterns, communication habits, problem-solving approaches
* goal: Stated objectives, learning targets, project ambitions * goal: Stated objectives, learning targets, project ambitions
* correction: Explicit agent mistakes or user corrections, including the correct approach
- Confidence levels: - Confidence levels:
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y") * 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
* 0.7-0.8: Strongly implied from actions/discussions * 0.7-0.8: Strongly implied from actions/discussions
@ -94,7 +106,7 @@ Output Format (JSON):
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }} "longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
}}, }},
"newFacts": [ "newFacts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }} {{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
], ],
"factsToRemove": ["fact_id_1", "fact_id_2"] "factsToRemove": ["fact_id_1", "fact_id_2"]
}} }}
@ -104,6 +116,8 @@ Important Rules:
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs) - Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
- Include specific metrics, version numbers, and proper nouns in facts - Include specific metrics, version numbers, and proper nouns in facts
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+) - Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
- Remove facts that are contradicted by new information - Remove facts that are contradicted by new information
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones - When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
Keep 3-5 concurrent focus themes that are still active and relevant Keep 3-5 concurrent focus themes that are still active and relevant
@ -126,7 +140,7 @@ Message:
Extract facts in this JSON format: Extract facts in this JSON format:
{{ {{
"facts": [ "facts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }} {{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
] ]
}} }}
@ -136,6 +150,7 @@ Categories:
- context: Background context (location, job, projects) - context: Background context (location, job, projects)
- behavior: Behavioral patterns - behavior: Behavioral patterns
- goal: User's goals or objectives - goal: User's goals or objectives
- correction: Explicit corrections or mistakes to avoid repeating
Rules: Rules:
- Only extract clear, specific facts - Only extract clear, specific facts
@ -262,7 +277,11 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
continue continue
category = str(fact.get("category", "context")).strip() or "context" category = str(fact.get("category", "context")).strip() or "context"
confidence = _coerce_confidence(fact.get("confidence"), default=0.0) confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
line = f"- [{category} | {confidence:.2f}] {content}" source_error = fact.get("sourceError")
if category == "correction" and isinstance(source_error, str) and source_error.strip():
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
else:
line = f"- [{category} | {confidence:.2f}] {content}"
# Each additional line is preceded by a newline (except the first). # Each additional line is preceded by a newline (except the first).
line_text = ("\n" + line) if fact_lines else line line_text = ("\n" + line) if fact_lines else line

View File

@ -20,6 +20,7 @@ class ConversationContext:
messages: list[Any] messages: list[Any]
timestamp: datetime = field(default_factory=datetime.utcnow) timestamp: datetime = field(default_factory=datetime.utcnow)
agent_name: str | None = None agent_name: str | None = None
correction_detected: bool = False
class MemoryUpdateQueue: class MemoryUpdateQueue:
@ -37,25 +38,38 @@ class MemoryUpdateQueue:
self._timer: threading.Timer | None = None self._timer: threading.Timer | None = None
self._processing = False self._processing = False
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None: def add(
self,
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
) -> None:
"""Add a conversation to the update queue. """Add a conversation to the update queue.
Args: Args:
thread_id: The thread ID. thread_id: The thread ID.
messages: The conversation messages. messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory. agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal.
""" """
config = get_memory_config() config = get_memory_config()
if not config.enabled: if not config.enabled:
return return
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
)
with self._lock: with self._lock:
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
)
# Check if this thread already has a pending update # Check if this thread already has a pending update
# If so, replace it with the newer one # If so, replace it with the newer one
self._queue = [c for c in self._queue if c.thread_id != thread_id] self._queue = [c for c in self._queue if c.thread_id != thread_id]
@ -115,6 +129,7 @@ class MemoryUpdateQueue:
messages=context.messages, messages=context.messages,
thread_id=context.thread_id, thread_id=context.thread_id,
agent_name=context.agent_name, agent_name=context.agent_name,
correction_detected=context.correction_detected,
) )
if success: if success:
logger.info("Memory updated successfully for thread %s", context.thread_id) logger.info("Memory updated successfully for thread %s", context.thread_id)

View File

@ -266,13 +266,20 @@ class MemoryUpdater:
model_name = self._model_name or config.model_name model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False) return create_chat_model(name=model_name, thinking_enabled=False)
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool: def update_memory(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
) -> bool:
"""Update memory based on conversation messages. """Update memory based on conversation messages.
Args: Args:
messages: List of conversation messages. messages: List of conversation messages.
thread_id: Optional thread ID for tracking source. thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory. agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
Returns: Returns:
True if update was successful, False otherwise. True if update was successful, False otherwise.
@ -295,9 +302,19 @@ class MemoryUpdater:
return False return False
# Build prompt # Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
prompt = MEMORY_UPDATE_PROMPT.format( prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2), current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text, conversation=conversation_text,
correction_hint=correction_hint,
) )
# Call LLM # Call LLM
@ -383,6 +400,8 @@ class MemoryUpdater:
confidence = fact.get("confidence", 0.5) confidence = fact.get("confidence", 0.5)
if confidence >= config.fact_confidence_threshold: if confidence >= config.fact_confidence_threshold:
raw_content = fact.get("content", "") raw_content = fact.get("content", "")
if not isinstance(raw_content, str):
continue
normalized_content = raw_content.strip() normalized_content = raw_content.strip()
fact_key = _fact_content_key(normalized_content) fact_key = _fact_content_key(normalized_content)
if fact_key is not None and fact_key in existing_fact_keys: if fact_key is not None and fact_key in existing_fact_keys:
@ -396,6 +415,11 @@ class MemoryUpdater:
"createdAt": now, "createdAt": now,
"source": thread_id or "unknown", "source": thread_id or "unknown",
} }
source_error = fact.get("sourceError")
if isinstance(source_error, str):
normalized_source_error = source_error.strip()
if normalized_source_error:
fact_entry["sourceError"] = normalized_source_error
current_memory["facts"].append(fact_entry) current_memory["facts"].append(fact_entry)
if fact_key is not None: if fact_key is not None:
existing_fact_keys.add(fact_key) existing_fact_keys.add(fact_key)
@ -412,16 +436,22 @@ class MemoryUpdater:
return current_memory return current_memory
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool: def update_memory_from_conversation(
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
) -> bool:
"""Convenience function to update memory from a conversation. """Convenience function to update memory from a conversation.
Args: Args:
messages: List of conversation messages. messages: List of conversation messages.
thread_id: Optional thread ID. thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory. agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
Returns: Returns:
True if successful, False otherwise. True if successful, False otherwise.
""" """
updater = MemoryUpdater() updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name) return updater.update_memory(messages, thread_id, agent_name, correction_detected)

View File

@ -14,6 +14,21 @@ from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
class MemoryMiddlewareState(AgentState): class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema.""" """Compatible with the `ThreadState` schema."""
@ -21,6 +36,22 @@ class MemoryMiddlewareState(AgentState):
pass pass
def _extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]: def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses. """Filter messages to keep only user inputs and final assistant responses.
@ -44,18 +75,13 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
Returns: Returns:
Filtered list containing only user inputs and final assistant responses. Filtered list containing only user inputs and final assistant responses.
""" """
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
filtered = [] filtered = []
skip_next_ai = False skip_next_ai = False
for msg in messages: for msg in messages:
msg_type = getattr(msg, "type", None) msg_type = getattr(msg, "type", None)
if msg_type == "human": if msg_type == "human":
content = getattr(msg, "content", "") content_str = _extract_message_text(msg)
if isinstance(content, list):
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
content_str = str(content)
if "<uploaded_files>" in content_str: if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question. # Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip() stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
@ -87,6 +113,25 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
return filtered return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale corrections from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution. """Middleware that queues conversation for memory update after agent execution.
@ -150,7 +195,13 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
return None return None
# Queue the filtered conversation for memory update # Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
queue = get_memory_queue() queue = get_memory_queue()
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name) queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
correction_detected=correction_detected,
)
return None return None

View File

@ -119,3 +119,38 @@ def test_format_memory_skips_non_string_content_facts() -> None:
# The formatted line for a list content would be "- [knowledge | 0.85] ['list']". # The formatted line for a list content would be "- [knowledge | 0.85] ['list']".
assert "| 0.85]" not in result assert "| 0.85]" not in result
assert "Valid fact" in result assert "Valid fact" in result
def test_format_memory_renders_correction_source_error() -> None:
memory_data = {
"facts": [
{
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"sourceError": "The agent previously suggested npm start.",
}
]
}
result = format_memory_for_injection(memory_data, max_tokens=2000)
assert "Use make dev for local development." in result
assert "avoid: The agent previously suggested npm start." in result
def test_format_memory_renders_correction_without_source_error_normally() -> None:
memory_data = {
"facts": [
{
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
}
]
}
result = format_memory_for_injection(memory_data, max_tokens=2000)
assert "Use make dev for local development." in result
assert "avoid:" not in result

View File

@ -0,0 +1,50 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
def _memory_config(**overrides: object) -> MemoryConfig:
config = MemoryConfig()
for key, value in overrides.items():
setattr(config, key, value)
return config
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
queue.add(thread_id="thread-1", messages=["second"], correction_detected=False)
assert len(queue._queue) == 1
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].correction_detected is True
def test_process_queue_forwards_correction_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue._queue = [
ConversationContext(
thread_id="thread-1",
messages=["conversation"],
agent_name="lead_agent",
correction_detected=True,
)
]
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
queue._process_queue()
mock_updater.update_memory.assert_called_once_with(
messages=["conversation"],
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=True,
)

View File

@ -72,6 +72,56 @@ def test_import_memory_route_returns_imported_memory() -> None:
assert response.json()["facts"] == imported_memory["facts"] assert response.json()["facts"] == imported_memory["facts"]
def test_export_memory_route_preserves_source_error() -> None:
app = FastAPI()
app.include_router(memory.router)
exported_memory = _sample_memory(
facts=[
{
"id": "fact_correction",
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"createdAt": "2026-03-20T00:00:00Z",
"source": "thread-1",
"sourceError": "The agent previously suggested npm start.",
}
]
)
with patch("app.gateway.routers.memory.get_memory_data", return_value=exported_memory):
with TestClient(app) as client:
response = client.get("/api/memory/export")
assert response.status_code == 200
assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start."
def test_import_memory_route_preserves_source_error() -> None:
app = FastAPI()
app.include_router(memory.router)
imported_memory = _sample_memory(
facts=[
{
"id": "fact_correction",
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"createdAt": "2026-03-20T00:00:00Z",
"source": "thread-1",
"sourceError": "The agent previously suggested npm start.",
}
]
)
with patch("app.gateway.routers.memory.import_memory_data", return_value=imported_memory):
with TestClient(app) as client:
response = client.post("/api/memory/import", json=imported_memory)
assert response.status_code == 200
assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start."
def test_clear_memory_route_returns_cleared_memory() -> None: def test_clear_memory_route_returns_cleared_memory() -> None:
app = FastAPI() app = FastAPI()
app.include_router(memory.router) app.include_router(memory.router)

View File

@ -146,6 +146,53 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
assert result["facts"][1]["source"] == "thread-9" assert result["facts"][1]["source"] == "thread-9"
def test_apply_updates_preserves_source_error() -> None:
updater = MemoryUpdater()
current_memory = _make_memory()
update_data = {
"newFacts": [
{
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"sourceError": "The agent previously suggested npm start.",
}
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start."
assert result["facts"][0]["category"] == "correction"
def test_apply_updates_ignores_empty_source_error() -> None:
updater = MemoryUpdater()
current_memory = _make_memory()
update_data = {
"newFacts": [
{
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"sourceError": " ",
}
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
assert "sourceError" not in result["facts"][0]
def test_clear_memory_data_resets_all_sections() -> None: def test_clear_memory_data_resets_all_sections() -> None:
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True): with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
result = clear_memory_data() result = clear_memory_data()
@ -522,3 +569,53 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg]) result = updater.update_memory([msg, ai_msg])
assert result is True assert result is True
def test_correction_hint_injected_when_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "No, that's wrong."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Understood"
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], correction_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" in prompt
def test_correction_hint_empty_when_not_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Let's talk about memory."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Sure"
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], correction_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" not in prompt

View File

@ -10,7 +10,7 @@ persisting in long-term memory:
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
@ -134,6 +134,64 @@ class TestFilterMessagesForMemory:
assert "<uploaded_files>" not in all_content assert "<uploaded_files>" not in all_content
# ===========================================================================
# detect_correction
# ===========================================================================
class TestDetectCorrection:
def test_detects_english_correction_signal(self):
msgs = [
_human("Please help me run the project."),
_ai("Use npm start."),
_human("That's wrong, use make dev instead."),
_ai("Understood."),
]
assert detect_correction(msgs) is True
def test_detects_chinese_correction_signal(self):
msgs = [
_human("帮我启动项目"),
_ai("用 npm start"),
_human("不对,改用 make dev"),
_ai("明白了"),
]
assert detect_correction(msgs) is True
def test_returns_false_without_signal(self):
msgs = [
_human("Please explain the build setup."),
_ai("Here is the build setup."),
_human("Thanks, that makes sense."),
]
assert detect_correction(msgs) is False
def test_only_checks_recent_messages(self):
msgs = [
_human("That is wrong, use make dev instead."),
_ai("Noted."),
_human("Let's discuss tests."),
_ai("Sure."),
_human("What about linting?"),
_ai("Use ruff."),
_human("And formatting?"),
_ai("Use make format."),
]
assert detect_correction(msgs) is False
def test_handles_list_content(self):
msgs = [
HumanMessage(content=["That is wrong,", {"type": "text", "text": "use make dev instead."}]),
_ai("Updated."),
]
assert detect_correction(msgs) is True
# =========================================================================== # ===========================================================================
# _strip_upload_mentions_from_memory # _strip_upload_mentions_from_memory
# =========================================================================== # ===========================================================================