* feat(memory): add structured reflection and correction detection * fix(memory): align sourceError schema and prompt guidance --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
3e461d9d08
commit
0cdecf7b30
|
|
@ -49,6 +49,7 @@ class Fact(BaseModel):
|
||||||
confidence: float = Field(default=0.5, description="Confidence score (0-1)")
|
confidence: float = Field(default=0.5, description="Confidence score (0-1)")
|
||||||
createdAt: str = Field(default="", description="Creation timestamp")
|
createdAt: str = Field(default="", description="Creation timestamp")
|
||||||
source: str = Field(default="unknown", description="Source thread ID")
|
source: str = Field(default="unknown", description="Source thread ID")
|
||||||
|
sourceError: str | None = Field(default=None, description="Optional description of the prior mistake or wrong approach")
|
||||||
|
|
||||||
|
|
||||||
class MemoryResponse(BaseModel):
|
class MemoryResponse(BaseModel):
|
||||||
|
|
@ -108,6 +109,7 @@ class MemoryStatusResponse(BaseModel):
|
||||||
@router.get(
|
@router.get(
|
||||||
"/memory",
|
"/memory",
|
||||||
response_model=MemoryResponse,
|
response_model=MemoryResponse,
|
||||||
|
response_model_exclude_none=True,
|
||||||
summary="Get Memory Data",
|
summary="Get Memory Data",
|
||||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||||
)
|
)
|
||||||
|
|
@ -152,6 +154,7 @@ async def get_memory() -> MemoryResponse:
|
||||||
@router.post(
|
@router.post(
|
||||||
"/memory/reload",
|
"/memory/reload",
|
||||||
response_model=MemoryResponse,
|
response_model=MemoryResponse,
|
||||||
|
response_model_exclude_none=True,
|
||||||
summary="Reload Memory Data",
|
summary="Reload Memory Data",
|
||||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||||
)
|
)
|
||||||
|
|
@ -171,6 +174,7 @@ async def reload_memory() -> MemoryResponse:
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/memory",
|
"/memory",
|
||||||
response_model=MemoryResponse,
|
response_model=MemoryResponse,
|
||||||
|
response_model_exclude_none=True,
|
||||||
summary="Clear All Memory Data",
|
summary="Clear All Memory Data",
|
||||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||||
)
|
)
|
||||||
|
|
@ -187,6 +191,7 @@ async def clear_memory() -> MemoryResponse:
|
||||||
@router.post(
|
@router.post(
|
||||||
"/memory/facts",
|
"/memory/facts",
|
||||||
response_model=MemoryResponse,
|
response_model=MemoryResponse,
|
||||||
|
response_model_exclude_none=True,
|
||||||
summary="Create Memory Fact",
|
summary="Create Memory Fact",
|
||||||
description="Create a single saved memory fact manually.",
|
description="Create a single saved memory fact manually.",
|
||||||
)
|
)
|
||||||
|
|
@ -209,6 +214,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/memory/facts/{fact_id}",
|
"/memory/facts/{fact_id}",
|
||||||
response_model=MemoryResponse,
|
response_model=MemoryResponse,
|
||||||
|
response_model_exclude_none=True,
|
||||||
summary="Delete Memory Fact",
|
summary="Delete Memory Fact",
|
||||||
description="Delete a single saved memory fact by its fact id.",
|
description="Delete a single saved memory fact by its fact id.",
|
||||||
)
|
)
|
||||||
|
|
@ -227,6 +233,7 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/memory/facts/{fact_id}",
|
"/memory/facts/{fact_id}",
|
||||||
response_model=MemoryResponse,
|
response_model=MemoryResponse,
|
||||||
|
response_model_exclude_none=True,
|
||||||
summary="Patch Memory Fact",
|
summary="Patch Memory Fact",
|
||||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||||
)
|
)
|
||||||
|
|
@ -252,6 +259,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||||
@router.get(
|
@router.get(
|
||||||
"/memory/export",
|
"/memory/export",
|
||||||
response_model=MemoryResponse,
|
response_model=MemoryResponse,
|
||||||
|
response_model_exclude_none=True,
|
||||||
summary="Export Memory Data",
|
summary="Export Memory Data",
|
||||||
description="Export the current global memory data as JSON for backup or transfer.",
|
description="Export the current global memory data as JSON for backup or transfer.",
|
||||||
)
|
)
|
||||||
|
|
@ -264,6 +272,7 @@ async def export_memory() -> MemoryResponse:
|
||||||
@router.post(
|
@router.post(
|
||||||
"/memory/import",
|
"/memory/import",
|
||||||
response_model=MemoryResponse,
|
response_model=MemoryResponse,
|
||||||
|
response_model_exclude_none=True,
|
||||||
summary="Import Memory Data",
|
summary="Import Memory Data",
|
||||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||||
)
|
)
|
||||||
|
|
@ -317,6 +326,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||||
@router.get(
|
@router.get(
|
||||||
"/memory/status",
|
"/memory/status",
|
||||||
response_model=MemoryStatusResponse,
|
response_model=MemoryStatusResponse,
|
||||||
|
response_model_exclude_none=True,
|
||||||
summary="Get Memory Status",
|
summary="Get Memory Status",
|
||||||
description="Retrieve both memory configuration and current data in a single request.",
|
description="Retrieve both memory configuration and current data in a single request.",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,17 @@ Instructions:
|
||||||
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
|
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
|
||||||
3. Update the memory sections as needed following the detailed length guidelines below
|
3. Update the memory sections as needed following the detailed length guidelines below
|
||||||
|
|
||||||
|
Before extracting facts, perform a structured reflection on the conversation:
|
||||||
|
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
|
||||||
|
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
|
||||||
|
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
|
||||||
|
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
|
||||||
|
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
|
||||||
|
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
|
||||||
|
If yes, record them as facts with the most appropriate category and confidence.
|
||||||
|
|
||||||
|
{correction_hint}
|
||||||
|
|
||||||
Memory Section Guidelines:
|
Memory Section Guidelines:
|
||||||
|
|
||||||
**User Context** (Current state - concise summaries):
|
**User Context** (Current state - concise summaries):
|
||||||
|
|
@ -62,6 +73,7 @@ Memory Section Guidelines:
|
||||||
* context: Background facts (job title, projects, locations, languages)
|
* context: Background facts (job title, projects, locations, languages)
|
||||||
* behavior: Working patterns, communication habits, problem-solving approaches
|
* behavior: Working patterns, communication habits, problem-solving approaches
|
||||||
* goal: Stated objectives, learning targets, project ambitions
|
* goal: Stated objectives, learning targets, project ambitions
|
||||||
|
* correction: Explicit agent mistakes or user corrections, including the correct approach
|
||||||
- Confidence levels:
|
- Confidence levels:
|
||||||
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
|
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
|
||||||
* 0.7-0.8: Strongly implied from actions/discussions
|
* 0.7-0.8: Strongly implied from actions/discussions
|
||||||
|
|
@ -94,7 +106,7 @@ Output Format (JSON):
|
||||||
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
|
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
|
||||||
}},
|
}},
|
||||||
"newFacts": [
|
"newFacts": [
|
||||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||||
],
|
],
|
||||||
"factsToRemove": ["fact_id_1", "fact_id_2"]
|
"factsToRemove": ["fact_id_1", "fact_id_2"]
|
||||||
}}
|
}}
|
||||||
|
|
@ -104,6 +116,8 @@ Important Rules:
|
||||||
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
|
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
|
||||||
- Include specific metrics, version numbers, and proper nouns in facts
|
- Include specific metrics, version numbers, and proper nouns in facts
|
||||||
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
|
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
|
||||||
|
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
|
||||||
|
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
|
||||||
- Remove facts that are contradicted by new information
|
- Remove facts that are contradicted by new information
|
||||||
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
|
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
|
||||||
Keep 3-5 concurrent focus themes that are still active and relevant
|
Keep 3-5 concurrent focus themes that are still active and relevant
|
||||||
|
|
@ -126,7 +140,7 @@ Message:
|
||||||
Extract facts in this JSON format:
|
Extract facts in this JSON format:
|
||||||
{{
|
{{
|
||||||
"facts": [
|
"facts": [
|
||||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
|
@ -136,6 +150,7 @@ Categories:
|
||||||
- context: Background context (location, job, projects)
|
- context: Background context (location, job, projects)
|
||||||
- behavior: Behavioral patterns
|
- behavior: Behavioral patterns
|
||||||
- goal: User's goals or objectives
|
- goal: User's goals or objectives
|
||||||
|
- correction: Explicit corrections or mistakes to avoid repeating
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
- Only extract clear, specific facts
|
- Only extract clear, specific facts
|
||||||
|
|
@ -262,7 +277,11 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||||
continue
|
continue
|
||||||
category = str(fact.get("category", "context")).strip() or "context"
|
category = str(fact.get("category", "context")).strip() or "context"
|
||||||
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
|
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
|
||||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
source_error = fact.get("sourceError")
|
||||||
|
if category == "correction" and isinstance(source_error, str) and source_error.strip():
|
||||||
|
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
|
||||||
|
else:
|
||||||
|
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||||
|
|
||||||
# Each additional line is preceded by a newline (except the first).
|
# Each additional line is preceded by a newline (except the first).
|
||||||
line_text = ("\n" + line) if fact_lines else line
|
line_text = ("\n" + line) if fact_lines else line
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class ConversationContext:
|
||||||
messages: list[Any]
|
messages: list[Any]
|
||||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
agent_name: str | None = None
|
agent_name: str | None = None
|
||||||
|
correction_detected: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MemoryUpdateQueue:
|
class MemoryUpdateQueue:
|
||||||
|
|
@ -37,25 +38,38 @@ class MemoryUpdateQueue:
|
||||||
self._timer: threading.Timer | None = None
|
self._timer: threading.Timer | None = None
|
||||||
self._processing = False
|
self._processing = False
|
||||||
|
|
||||||
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
|
def add(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
messages: list[Any],
|
||||||
|
agent_name: str | None = None,
|
||||||
|
correction_detected: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Add a conversation to the update queue.
|
"""Add a conversation to the update queue.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: The thread ID.
|
thread_id: The thread ID.
|
||||||
messages: The conversation messages.
|
messages: The conversation messages.
|
||||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||||
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
"""
|
"""
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
if not config.enabled:
|
if not config.enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
context = ConversationContext(
|
|
||||||
thread_id=thread_id,
|
|
||||||
messages=messages,
|
|
||||||
agent_name=agent_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
existing_context = next(
|
||||||
|
(context for context in self._queue if context.thread_id == thread_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||||
|
context = ConversationContext(
|
||||||
|
thread_id=thread_id,
|
||||||
|
messages=messages,
|
||||||
|
agent_name=agent_name,
|
||||||
|
correction_detected=merged_correction_detected,
|
||||||
|
)
|
||||||
|
|
||||||
# Check if this thread already has a pending update
|
# Check if this thread already has a pending update
|
||||||
# If so, replace it with the newer one
|
# If so, replace it with the newer one
|
||||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||||
|
|
@ -115,6 +129,7 @@ class MemoryUpdateQueue:
|
||||||
messages=context.messages,
|
messages=context.messages,
|
||||||
thread_id=context.thread_id,
|
thread_id=context.thread_id,
|
||||||
agent_name=context.agent_name,
|
agent_name=context.agent_name,
|
||||||
|
correction_detected=context.correction_detected,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||||
|
|
|
||||||
|
|
@ -266,13 +266,20 @@ class MemoryUpdater:
|
||||||
model_name = self._model_name or config.model_name
|
model_name = self._model_name or config.model_name
|
||||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||||
|
|
||||||
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
def update_memory(
|
||||||
|
self,
|
||||||
|
messages: list[Any],
|
||||||
|
thread_id: str | None = None,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
correction_detected: bool = False,
|
||||||
|
) -> bool:
|
||||||
"""Update memory based on conversation messages.
|
"""Update memory based on conversation messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of conversation messages.
|
messages: List of conversation messages.
|
||||||
thread_id: Optional thread ID for tracking source.
|
thread_id: Optional thread ID for tracking source.
|
||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if update was successful, False otherwise.
|
True if update was successful, False otherwise.
|
||||||
|
|
@ -295,9 +302,19 @@ class MemoryUpdater:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Build prompt
|
# Build prompt
|
||||||
|
correction_hint = ""
|
||||||
|
if correction_detected:
|
||||||
|
correction_hint = (
|
||||||
|
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||||
|
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||||
|
"and record the correct approach as a fact with category "
|
||||||
|
'"correction" and confidence >= 0.95 when appropriate.'
|
||||||
|
)
|
||||||
|
|
||||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||||
current_memory=json.dumps(current_memory, indent=2),
|
current_memory=json.dumps(current_memory, indent=2),
|
||||||
conversation=conversation_text,
|
conversation=conversation_text,
|
||||||
|
correction_hint=correction_hint,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call LLM
|
# Call LLM
|
||||||
|
|
@ -383,6 +400,8 @@ class MemoryUpdater:
|
||||||
confidence = fact.get("confidence", 0.5)
|
confidence = fact.get("confidence", 0.5)
|
||||||
if confidence >= config.fact_confidence_threshold:
|
if confidence >= config.fact_confidence_threshold:
|
||||||
raw_content = fact.get("content", "")
|
raw_content = fact.get("content", "")
|
||||||
|
if not isinstance(raw_content, str):
|
||||||
|
continue
|
||||||
normalized_content = raw_content.strip()
|
normalized_content = raw_content.strip()
|
||||||
fact_key = _fact_content_key(normalized_content)
|
fact_key = _fact_content_key(normalized_content)
|
||||||
if fact_key is not None and fact_key in existing_fact_keys:
|
if fact_key is not None and fact_key in existing_fact_keys:
|
||||||
|
|
@ -396,6 +415,11 @@ class MemoryUpdater:
|
||||||
"createdAt": now,
|
"createdAt": now,
|
||||||
"source": thread_id or "unknown",
|
"source": thread_id or "unknown",
|
||||||
}
|
}
|
||||||
|
source_error = fact.get("sourceError")
|
||||||
|
if isinstance(source_error, str):
|
||||||
|
normalized_source_error = source_error.strip()
|
||||||
|
if normalized_source_error:
|
||||||
|
fact_entry["sourceError"] = normalized_source_error
|
||||||
current_memory["facts"].append(fact_entry)
|
current_memory["facts"].append(fact_entry)
|
||||||
if fact_key is not None:
|
if fact_key is not None:
|
||||||
existing_fact_keys.add(fact_key)
|
existing_fact_keys.add(fact_key)
|
||||||
|
|
@ -412,16 +436,22 @@ class MemoryUpdater:
|
||||||
return current_memory
|
return current_memory
|
||||||
|
|
||||||
|
|
||||||
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
def update_memory_from_conversation(
|
||||||
|
messages: list[Any],
|
||||||
|
thread_id: str | None = None,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
correction_detected: bool = False,
|
||||||
|
) -> bool:
|
||||||
"""Convenience function to update memory from a conversation.
|
"""Convenience function to update memory from a conversation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of conversation messages.
|
messages: List of conversation messages.
|
||||||
thread_id: Optional thread ID.
|
thread_id: Optional thread ID.
|
||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if successful, False otherwise.
|
True if successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
return updater.update_memory(messages, thread_id, agent_name)
|
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,21 @@ from deerflow.config.memory_config import get_memory_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||||
|
_CORRECTION_PATTERNS = (
|
||||||
|
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||||
|
re.compile(r"不对"),
|
||||||
|
re.compile(r"你理解错了"),
|
||||||
|
re.compile(r"你理解有误"),
|
||||||
|
re.compile(r"重试"),
|
||||||
|
re.compile(r"重新来"),
|
||||||
|
re.compile(r"换一种"),
|
||||||
|
re.compile(r"改用"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddlewareState(AgentState):
|
class MemoryMiddlewareState(AgentState):
|
||||||
"""Compatible with the `ThreadState` schema."""
|
"""Compatible with the `ThreadState` schema."""
|
||||||
|
|
@ -21,6 +36,22 @@ class MemoryMiddlewareState(AgentState):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_message_text(message: Any) -> str:
|
||||||
|
"""Extract plain text from message content for filtering and signal detection."""
|
||||||
|
content = getattr(message, "content", "")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts: list[str] = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, str):
|
||||||
|
text_parts.append(part)
|
||||||
|
elif isinstance(part, dict):
|
||||||
|
text_val = part.get("text")
|
||||||
|
if isinstance(text_val, str):
|
||||||
|
text_parts.append(text_val)
|
||||||
|
return " ".join(text_parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||||
"""Filter messages to keep only user inputs and final assistant responses.
|
"""Filter messages to keep only user inputs and final assistant responses.
|
||||||
|
|
||||||
|
|
@ -44,18 +75,13 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||||
Returns:
|
Returns:
|
||||||
Filtered list containing only user inputs and final assistant responses.
|
Filtered list containing only user inputs and final assistant responses.
|
||||||
"""
|
"""
|
||||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
|
||||||
|
|
||||||
filtered = []
|
filtered = []
|
||||||
skip_next_ai = False
|
skip_next_ai = False
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
msg_type = getattr(msg, "type", None)
|
msg_type = getattr(msg, "type", None)
|
||||||
|
|
||||||
if msg_type == "human":
|
if msg_type == "human":
|
||||||
content = getattr(msg, "content", "")
|
content_str = _extract_message_text(msg)
|
||||||
if isinstance(content, list):
|
|
||||||
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
|
|
||||||
content_str = str(content)
|
|
||||||
if "<uploaded_files>" in content_str:
|
if "<uploaded_files>" in content_str:
|
||||||
# Strip the ephemeral upload block; keep the user's real question.
|
# Strip the ephemeral upload block; keep the user's real question.
|
||||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||||
|
|
@ -87,6 +113,25 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||||
return filtered
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def detect_correction(messages: list[Any]) -> bool:
|
||||||
|
"""Detect explicit user corrections in recent conversation turns.
|
||||||
|
|
||||||
|
The queue keeps only one pending context per thread, so callers pass the
|
||||||
|
latest filtered message list. Checking only recent user turns keeps signal
|
||||||
|
detection conservative while avoiding stale corrections from long histories.
|
||||||
|
"""
|
||||||
|
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||||
|
|
||||||
|
for msg in recent_user_msgs:
|
||||||
|
content = _extract_message_text(msg).strip()
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||||
"""Middleware that queues conversation for memory update after agent execution.
|
"""Middleware that queues conversation for memory update after agent execution.
|
||||||
|
|
||||||
|
|
@ -150,7 +195,13 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Queue the filtered conversation for memory update
|
# Queue the filtered conversation for memory update
|
||||||
|
correction_detected = detect_correction(filtered_messages)
|
||||||
queue = get_memory_queue()
|
queue = get_memory_queue()
|
||||||
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
|
queue.add(
|
||||||
|
thread_id=thread_id,
|
||||||
|
messages=filtered_messages,
|
||||||
|
agent_name=self._agent_name,
|
||||||
|
correction_detected=correction_detected,
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -119,3 +119,38 @@ def test_format_memory_skips_non_string_content_facts() -> None:
|
||||||
# The formatted line for a list content would be "- [knowledge | 0.85] ['list']".
|
# The formatted line for a list content would be "- [knowledge | 0.85] ['list']".
|
||||||
assert "| 0.85]" not in result
|
assert "| 0.85]" not in result
|
||||||
assert "Valid fact" in result
|
assert "Valid fact" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_memory_renders_correction_source_error() -> None:
|
||||||
|
memory_data = {
|
||||||
|
"facts": [
|
||||||
|
{
|
||||||
|
"content": "Use make dev for local development.",
|
||||||
|
"category": "correction",
|
||||||
|
"confidence": 0.95,
|
||||||
|
"sourceError": "The agent previously suggested npm start.",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||||
|
|
||||||
|
assert "Use make dev for local development." in result
|
||||||
|
assert "avoid: The agent previously suggested npm start." in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_memory_renders_correction_without_source_error_normally() -> None:
|
||||||
|
memory_data = {
|
||||||
|
"facts": [
|
||||||
|
{
|
||||||
|
"content": "Use make dev for local development.",
|
||||||
|
"category": "correction",
|
||||||
|
"confidence": 0.95,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||||
|
|
||||||
|
assert "Use make dev for local development." in result
|
||||||
|
assert "avoid:" not in result
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||||
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_config(**overrides: object) -> MemoryConfig:
|
||||||
|
config = MemoryConfig()
|
||||||
|
for key, value in overrides.items():
|
||||||
|
setattr(config, key, value)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch.object(queue, "_reset_timer"),
|
||||||
|
):
|
||||||
|
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
|
||||||
|
queue.add(thread_id="thread-1", messages=["second"], correction_detected=False)
|
||||||
|
|
||||||
|
assert len(queue._queue) == 1
|
||||||
|
assert queue._queue[0].messages == ["second"]
|
||||||
|
assert queue._queue[0].correction_detected is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
queue._queue = [
|
||||||
|
ConversationContext(
|
||||||
|
thread_id="thread-1",
|
||||||
|
messages=["conversation"],
|
||||||
|
agent_name="lead_agent",
|
||||||
|
correction_detected=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
mock_updater = MagicMock()
|
||||||
|
mock_updater.update_memory.return_value = True
|
||||||
|
|
||||||
|
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||||
|
queue._process_queue()
|
||||||
|
|
||||||
|
mock_updater.update_memory.assert_called_once_with(
|
||||||
|
messages=["conversation"],
|
||||||
|
thread_id="thread-1",
|
||||||
|
agent_name="lead_agent",
|
||||||
|
correction_detected=True,
|
||||||
|
)
|
||||||
|
|
@ -72,6 +72,56 @@ def test_import_memory_route_returns_imported_memory() -> None:
|
||||||
assert response.json()["facts"] == imported_memory["facts"]
|
assert response.json()["facts"] == imported_memory["facts"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_memory_route_preserves_source_error() -> None:
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(memory.router)
|
||||||
|
exported_memory = _sample_memory(
|
||||||
|
facts=[
|
||||||
|
{
|
||||||
|
"id": "fact_correction",
|
||||||
|
"content": "Use make dev for local development.",
|
||||||
|
"category": "correction",
|
||||||
|
"confidence": 0.95,
|
||||||
|
"createdAt": "2026-03-20T00:00:00Z",
|
||||||
|
"source": "thread-1",
|
||||||
|
"sourceError": "The agent previously suggested npm start.",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.gateway.routers.memory.get_memory_data", return_value=exported_memory):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/memory/export")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start."
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_memory_route_preserves_source_error() -> None:
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(memory.router)
|
||||||
|
imported_memory = _sample_memory(
|
||||||
|
facts=[
|
||||||
|
{
|
||||||
|
"id": "fact_correction",
|
||||||
|
"content": "Use make dev for local development.",
|
||||||
|
"category": "correction",
|
||||||
|
"confidence": 0.95,
|
||||||
|
"createdAt": "2026-03-20T00:00:00Z",
|
||||||
|
"source": "thread-1",
|
||||||
|
"sourceError": "The agent previously suggested npm start.",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.gateway.routers.memory.import_memory_data", return_value=imported_memory):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/api/memory/import", json=imported_memory)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start."
|
||||||
|
|
||||||
|
|
||||||
def test_clear_memory_route_returns_cleared_memory() -> None:
|
def test_clear_memory_route_returns_cleared_memory() -> None:
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(memory.router)
|
app.include_router(memory.router)
|
||||||
|
|
|
||||||
|
|
@ -146,6 +146,53 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||||
assert result["facts"][1]["source"] == "thread-9"
|
assert result["facts"][1]["source"] == "thread-9"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_updates_preserves_source_error() -> None:
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
current_memory = _make_memory()
|
||||||
|
update_data = {
|
||||||
|
"newFacts": [
|
||||||
|
{
|
||||||
|
"content": "Use make dev for local development.",
|
||||||
|
"category": "correction",
|
||||||
|
"confidence": 0.95,
|
||||||
|
"sourceError": "The agent previously suggested npm start.",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"deerflow.agents.memory.updater.get_memory_config",
|
||||||
|
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||||
|
):
|
||||||
|
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||||
|
|
||||||
|
assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start."
|
||||||
|
assert result["facts"][0]["category"] == "correction"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_updates_ignores_empty_source_error() -> None:
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
current_memory = _make_memory()
|
||||||
|
update_data = {
|
||||||
|
"newFacts": [
|
||||||
|
{
|
||||||
|
"content": "Use make dev for local development.",
|
||||||
|
"category": "correction",
|
||||||
|
"confidence": 0.95,
|
||||||
|
"sourceError": " ",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"deerflow.agents.memory.updater.get_memory_config",
|
||||||
|
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||||
|
):
|
||||||
|
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||||
|
|
||||||
|
assert "sourceError" not in result["facts"][0]
|
||||||
|
|
||||||
|
|
||||||
def test_clear_memory_data_resets_all_sections() -> None:
|
def test_clear_memory_data_resets_all_sections() -> None:
|
||||||
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
|
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
|
||||||
result = clear_memory_data()
|
result = clear_memory_data()
|
||||||
|
|
@ -522,3 +569,53 @@ class TestUpdateMemoryStructuredResponse:
|
||||||
result = updater.update_memory([msg, ai_msg])
|
result = updater.update_memory([msg, ai_msg])
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
def test_correction_hint_injected_when_detected(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "No, that's wrong."
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Understood"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
|
||||||
|
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
prompt = model.invoke.call_args[0][0]
|
||||||
|
assert "Explicit correction signals were detected" in prompt
|
||||||
|
|
||||||
|
def test_correction_hint_empty_when_not_detected(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Let's talk about memory."
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Sure"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
|
||||||
|
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
prompt = model.invoke.call_args[0][0]
|
||||||
|
assert "Explicit correction signals were detected" not in prompt
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ persisting in long-term memory:
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
|
|
||||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory
|
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
|
|
@ -134,6 +134,64 @@ class TestFilterMessagesForMemory:
|
||||||
assert "<uploaded_files>" not in all_content
|
assert "<uploaded_files>" not in all_content
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# detect_correction
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectCorrection:
|
||||||
|
def test_detects_english_correction_signal(self):
|
||||||
|
msgs = [
|
||||||
|
_human("Please help me run the project."),
|
||||||
|
_ai("Use npm start."),
|
||||||
|
_human("That's wrong, use make dev instead."),
|
||||||
|
_ai("Understood."),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_correction(msgs) is True
|
||||||
|
|
||||||
|
def test_detects_chinese_correction_signal(self):
|
||||||
|
msgs = [
|
||||||
|
_human("帮我启动项目"),
|
||||||
|
_ai("用 npm start"),
|
||||||
|
_human("不对,改用 make dev"),
|
||||||
|
_ai("明白了"),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_correction(msgs) is True
|
||||||
|
|
||||||
|
def test_returns_false_without_signal(self):
|
||||||
|
msgs = [
|
||||||
|
_human("Please explain the build setup."),
|
||||||
|
_ai("Here is the build setup."),
|
||||||
|
_human("Thanks, that makes sense."),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_correction(msgs) is False
|
||||||
|
|
||||||
|
def test_only_checks_recent_messages(self):
|
||||||
|
msgs = [
|
||||||
|
_human("That is wrong, use make dev instead."),
|
||||||
|
_ai("Noted."),
|
||||||
|
_human("Let's discuss tests."),
|
||||||
|
_ai("Sure."),
|
||||||
|
_human("What about linting?"),
|
||||||
|
_ai("Use ruff."),
|
||||||
|
_human("And formatting?"),
|
||||||
|
_ai("Use make format."),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_correction(msgs) is False
|
||||||
|
|
||||||
|
def test_handles_list_content(self):
|
||||||
|
msgs = [
|
||||||
|
HumanMessage(content=["That is wrong,", {"type": "text", "text": "use make dev instead."}]),
|
||||||
|
_ai("Updated."),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_correction(msgs) is True
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# _strip_upload_mentions_from_memory
|
# _strip_upload_mentions_from_memory
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue