feat: 对齐df的注入模式
This commit is contained in:
parent
7b7ba7698e
commit
6829d41895
@ -6,7 +6,7 @@ import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from deerflow.agents.memory.prompt import _coerce_confidence, _count_tokens, format_conversation_for_update
|
||||
from deerflow.agents.memory.prompt import format_conversation_for_update, format_memory_for_injection
|
||||
|
||||
THREAD_MEMORY_UPDATE_PROMPT = """You are a user profile memory system.
|
||||
|
||||
@ -100,65 +100,7 @@ def _infer_preferred_memory_language(messages: list[Any]) -> str:
|
||||
|
||||
|
||||
def format_thread_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||
if not memory_data:
|
||||
return ""
|
||||
|
||||
user = memory_data.get("user") or {}
|
||||
history = memory_data.get("history") or {}
|
||||
facts = memory_data.get("facts") or []
|
||||
|
||||
user_lines: list[str] = []
|
||||
for key, label in (("workContext", "Work Context"), ("personalContext", "Personal Context"), ("topOfMind", "Top Of Mind")):
|
||||
section = user.get(key) if isinstance(user, dict) else None
|
||||
if isinstance(section, dict):
|
||||
summary = section.get("summary")
|
||||
if isinstance(summary, str) and summary.strip():
|
||||
user_lines.append(f"- {label}: {summary.strip()}")
|
||||
|
||||
history_lines: list[str] = []
|
||||
for key, label in (("recentMonths", "Recent Months"), ("earlierContext", "Earlier Context"), ("longTermBackground", "Long-Term Background")):
|
||||
section = history.get(key) if isinstance(history, dict) else None
|
||||
if isinstance(section, dict):
|
||||
summary = section.get("summary")
|
||||
if isinstance(summary, str) and summary.strip():
|
||||
history_lines.append(f"- {label}: {summary.strip()}")
|
||||
|
||||
sections: list[str] = []
|
||||
if user_lines:
|
||||
sections.append("User:\n" + "\n".join(user_lines))
|
||||
if history_lines:
|
||||
sections.append("History:\n" + "\n".join(history_lines))
|
||||
|
||||
# Facts are lowest priority: include by confidence/recency and trim by token budget.
|
||||
ranked_facts = sorted(
|
||||
(
|
||||
f
|
||||
for f in facts
|
||||
if isinstance(f, dict) and isinstance(f.get("content"), str) and f.get("content", "").strip()
|
||||
),
|
||||
key=lambda f: (_coerce_confidence(f.get("confidence"), default=0.0), str(f.get("createdAt", ""))),
|
||||
reverse=True,
|
||||
)
|
||||
base = "\n\n".join(sections)
|
||||
running = _count_tokens(base) if base else 0
|
||||
fact_lines: list[str] = []
|
||||
if ranked_facts:
|
||||
running += _count_tokens("\n\nFacts:\n" if base else "Facts:\n")
|
||||
for fact in ranked_facts:
|
||||
line = (
|
||||
f"- [{str(fact.get('category', 'context')).strip() or 'context'} | "
|
||||
f"{_coerce_confidence(fact.get('confidence'), default=0.0):.2f}] {fact.get('content').strip()}"
|
||||
)
|
||||
candidate = ("\n" + line) if fact_lines else line
|
||||
cost = _count_tokens(candidate)
|
||||
if running + cost > max_tokens:
|
||||
break
|
||||
fact_lines.append(line)
|
||||
running += cost
|
||||
if fact_lines:
|
||||
sections.append("Facts:\n" + "\n".join(fact_lines))
|
||||
|
||||
return "\n\n".join(sections)
|
||||
return format_memory_for_injection(memory_data, max_tokens=max_tokens)
|
||||
|
||||
|
||||
def build_thread_memory_prompt(existing_memory: dict[str, Any], messages: list[Any]) -> str:
|
||||
|
||||
@ -4,7 +4,7 @@ from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, for
|
||||
|
||||
|
||||
def test_thread_memory_injection_keeps_profile_and_preferences_under_small_budget(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.agents.memory.thread_prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text))
|
||||
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text))
|
||||
memory = {
|
||||
"user": {
|
||||
"workContext": {"summary": "Building APIs", "updatedAt": "2026-05-08T00:00:00Z"},
|
||||
@ -23,7 +23,7 @@ def test_thread_memory_injection_keeps_profile_and_preferences_under_small_budge
|
||||
}
|
||||
|
||||
result = format_thread_memory_for_injection(memory, max_tokens=140)
|
||||
assert "User:" in result
|
||||
assert "User Context:" in result
|
||||
assert "History:" in result
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user