feat(backend): 增加摘要标题与产物状态对账

This commit is contained in:
肖应宇 2026-04-24 17:04:05 +08:00 committed by MT-Fire
parent 31f4bdb99a
commit 256a2d36ec
6 changed files with 271 additions and 5 deletions

View File

@ -2,10 +2,12 @@ import logging
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
from langchain_core.messages.human import HumanMessage
from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.artifact_reconcile_middleware import ArtifactReconcileMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.message_timestamp_middleware import MessageTimestampMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
@ -23,6 +25,15 @@ from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
SUMMARY_MESSAGE_TITLE = "以下是目前对话的摘要:"
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""Summarization middleware with DeerFlow's user-facing summary heading."""
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
return [HumanMessage(content=f"{SUMMARY_MESSAGE_TITLE}\n\n{summary}")]
def _resolve_model_name(requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
@ -78,7 +89,7 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
if config.summary_prompt is not None:
kwargs["summary_prompt"] = config.summary_prompt
return SummarizationMiddleware(**kwargs)
return DeerFlowSummarizationMiddleware(**kwargs)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
@ -234,6 +245,9 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
if get_app_config().token_usage.enabled:
middlewares.append(TokenUsageMiddleware())
# Reconcile stale artifact entries against real outputs files.
middlewares.append(ArtifactReconcileMiddleware())
# Stamp every conversation message with backend timestamp metadata.
middlewares.append(MessageTimestampMiddleware())

View File

@ -0,0 +1,114 @@
import logging
from pathlib import Path
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import (
ARTIFACTS_REPLACE_SENTINEL,
ThreadDataState,
)
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
logger = logging.getLogger(__name__)
_OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs/"
_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH = _OUTPUTS_VIRTUAL_PREFIX.lstrip("/")
class ArtifactReconcileState(AgentState):
"""Compatible with the `ThreadState` schema."""
artifacts: NotRequired[list[str] | None]
thread_data: NotRequired[ThreadDataState | None]
class ArtifactReconcileMiddleware(AgentMiddleware[ArtifactReconcileState]):
"""Keep artifact state aligned with files currently in outputs."""
state_schema = ArtifactReconcileState
def _to_outputs_file(self, virtual_path: str, outputs_dir: Path) -> Path | None:
stripped = virtual_path.lstrip("/")
if not stripped.startswith(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH):
# Keep non-outputs paths untouched; this middleware is for outputs drift.
return None
relative = stripped[len(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH) :]
if not relative:
return None
candidate = (outputs_dir / relative).resolve()
try:
candidate.relative_to(outputs_dir)
except ValueError:
return None
return candidate
def _to_virtual_artifact(self, actual_path: Path, outputs_dir: Path) -> str | None:
try:
relative = actual_path.resolve().relative_to(outputs_dir)
except ValueError:
return None
return f"{_OUTPUTS_VIRTUAL_PREFIX}{relative.as_posix()}"
def _discover_outputs(self, outputs_dir: Path) -> list[str]:
if not outputs_dir.is_dir():
return []
discovered: list[str] = []
for path in sorted(outputs_dir.rglob("*")):
if not path.is_file():
continue
virtual_path = self._to_virtual_artifact(path, outputs_dir)
if virtual_path:
discovered.append(virtual_path)
return discovered
@override
def before_model(
self,
state: ArtifactReconcileState,
runtime: Runtime, # noqa: ARG002
) -> dict | None:
artifacts = state.get("artifacts") or []
thread_data = state.get("thread_data") or {}
outputs_path = thread_data.get("outputs_path")
if not outputs_path:
return None
outputs_dir = Path(outputs_path).resolve()
kept: list[str] = []
changed = False
for artifact in artifacts:
if not isinstance(artifact, str):
changed = True
continue
actual_path = self._to_outputs_file(artifact, outputs_dir)
if actual_path is None:
kept.append(artifact)
continue
if actual_path.exists() and actual_path.is_file():
kept.append(artifact)
else:
changed = True
logger.info(
"Reconciled stale artifact from state: virtual=%s outputs_dir=%s",
artifact,
outputs_dir,
)
discovered = self._discover_outputs(outputs_dir)
merged = list(dict.fromkeys([*kept, *discovered]))
if merged != kept:
changed = True
if not changed:
return None
return {"artifacts": [ARTIFACTS_REPLACE_SENTINEL, *merged]}

View File

@ -2,6 +2,8 @@ from typing import Annotated, NotRequired, TypedDict
from langchain.agents import AgentState
ARTIFACTS_REPLACE_SENTINEL = "__deerflow_replace_artifacts__"
class SandboxState(TypedDict):
sandbox_id: NotRequired[str | None]
@ -20,6 +22,8 @@ class ViewedImageData(TypedDict):
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
"""Reducer for artifacts list - merges and deduplicates artifacts."""
if new and new[0] == ARTIFACTS_REPLACE_SENTINEL:
return list(dict.fromkeys(new[1:]))
if existing is None:
return new or []
if new is None:

View File

@ -0,0 +1,89 @@
from types import SimpleNamespace
from deerflow.agents.middlewares.artifact_reconcile_middleware import (
ArtifactReconcileMiddleware,
)
from deerflow.agents.thread_state import ARTIFACTS_REPLACE_SENTINEL
def test_before_model_prunes_missing_outputs_artifacts(tmp_path):
outputs_dir = tmp_path / "outputs"
outputs_dir.mkdir()
existing = outputs_dir / "keep.md"
existing.write_text("ok", encoding="utf-8")
middleware = ArtifactReconcileMiddleware()
state = {
"thread_data": {"outputs_path": str(outputs_dir)},
"artifacts": [
"/mnt/user-data/outputs/keep.md",
"/mnt/user-data/outputs/missing.md",
],
}
result = middleware.before_model(state, runtime=SimpleNamespace(context={}))
assert result == {
"artifacts": [ARTIFACTS_REPLACE_SENTINEL, "/mnt/user-data/outputs/keep.md"]
}
def test_before_model_returns_none_when_no_changes(tmp_path):
outputs_dir = tmp_path / "outputs"
outputs_dir.mkdir()
existing = outputs_dir / "keep.md"
existing.write_text("ok", encoding="utf-8")
middleware = ArtifactReconcileMiddleware()
state = {
"thread_data": {"outputs_path": str(outputs_dir)},
"artifacts": ["/mnt/user-data/outputs/keep.md"],
}
result = middleware.before_model(state, runtime=SimpleNamespace(context={}))
assert result is None
def test_before_model_adds_unpresented_outputs_files(tmp_path):
outputs_dir = tmp_path / "outputs"
outputs_dir.mkdir()
existing = outputs_dir / "keep.md"
existing.write_text("ok", encoding="utf-8")
extra = outputs_dir / "extra.md"
extra.write_text("ok", encoding="utf-8")
middleware = ArtifactReconcileMiddleware()
state = {
"thread_data": {"outputs_path": str(outputs_dir)},
"artifacts": ["/mnt/user-data/outputs/keep.md"],
}
result = middleware.before_model(state, runtime=SimpleNamespace(context={}))
assert result == {
"artifacts": [
ARTIFACTS_REPLACE_SENTINEL,
"/mnt/user-data/outputs/keep.md",
"/mnt/user-data/outputs/extra.md",
]
}
def test_before_model_discovers_outputs_when_artifacts_empty(tmp_path):
outputs_dir = tmp_path / "outputs"
outputs_dir.mkdir()
report = outputs_dir / "report.md"
report.write_text("ok", encoding="utf-8")
middleware = ArtifactReconcileMiddleware()
state = {
"thread_data": {"outputs_path": str(outputs_dir)},
"artifacts": [],
}
result = middleware.before_model(state, runtime=SimpleNamespace(context={}))
assert result == {
"artifacts": [ARTIFACTS_REPLACE_SENTINEL, "/mnt/user-data/outputs/report.md"]
}

View File

@ -147,7 +147,8 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
)
captured: dict[str, object] = {}
fake_model = object()
fake_model = MagicMock()
fake_model._llm_type = "test-chat"
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
captured["name"] = name
@ -156,10 +157,20 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
return fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
middleware = lead_agent_module._create_summarization_middleware()
assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False
assert middleware["model"] is fake_model
assert isinstance(middleware, lead_agent_module.DeerFlowSummarizationMiddleware)
assert middleware.model is fake_model
def test_deerflow_summarization_middleware_uses_chinese_summary_title():
middleware = lead_agent_module.DeerFlowSummarizationMiddleware(
model=MagicMock(),
trigger=("messages", 2),
)
messages = middleware._build_new_messages("旧上下文")
assert messages[0].content == "以下是目前对话的摘要:\n\n旧上下文"

View File

@ -0,0 +1,34 @@
from deerflow.agents.thread_state import (
ARTIFACTS_REPLACE_SENTINEL,
merge_artifacts,
)
def test_merge_artifacts_default_merge_dedup():
existing = ["/mnt/user-data/outputs/a.md", "/mnt/user-data/outputs/b.md"]
new = ["/mnt/user-data/outputs/b.md", "/mnt/user-data/outputs/c.md"]
result = merge_artifacts(existing, new)
assert result == [
"/mnt/user-data/outputs/a.md",
"/mnt/user-data/outputs/b.md",
"/mnt/user-data/outputs/c.md",
]
def test_merge_artifacts_supports_replace_sentinel():
existing = ["/mnt/user-data/outputs/a.md", "/mnt/user-data/outputs/b.md"]
new = [
ARTIFACTS_REPLACE_SENTINEL,
"/mnt/user-data/outputs/b.md",
"/mnt/user-data/outputs/c.md",
"/mnt/user-data/outputs/c.md",
]
result = merge_artifacts(existing, new)
assert result == [
"/mnt/user-data/outputs/b.md",
"/mnt/user-data/outputs/c.md",
]