feat(backend): 增加摘要标题与产物状态对账
This commit is contained in:
parent
31f4bdb99a
commit
256a2d36ec
|
|
@ -2,10 +2,12 @@ import logging
|
|||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from deerflow.agents.middlewares.artifact_reconcile_middleware import ArtifactReconcileMiddleware
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.agents.middlewares.message_timestamp_middleware import MessageTimestampMiddleware
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
|
|
@ -23,6 +25,15 @@ from deerflow.models import create_chat_model
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUMMARY_MESSAGE_TITLE = "以下是目前对话的摘要:"
|
||||
|
||||
|
||||
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
"""Summarization middleware with DeerFlow's user-facing summary heading."""
|
||||
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
return [HumanMessage(content=f"{SUMMARY_MESSAGE_TITLE}\n\n{summary}")]
|
||||
|
||||
|
||||
def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
||||
|
|
@ -78,7 +89,7 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
|||
if config.summary_prompt is not None:
|
||||
kwargs["summary_prompt"] = config.summary_prompt
|
||||
|
||||
return SummarizationMiddleware(**kwargs)
|
||||
return DeerFlowSummarizationMiddleware(**kwargs)
|
||||
|
||||
|
||||
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
||||
|
|
@ -234,6 +245,9 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
|||
if get_app_config().token_usage.enabled:
|
||||
middlewares.append(TokenUsageMiddleware())
|
||||
|
||||
# Reconcile stale artifact entries against real outputs files.
|
||||
middlewares.append(ArtifactReconcileMiddleware())
|
||||
|
||||
# Stamp every conversation message with backend timestamp metadata.
|
||||
middlewares.append(MessageTimestampMiddleware())
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import (
|
||||
ARTIFACTS_REPLACE_SENTINEL,
|
||||
ThreadDataState,
|
||||
)
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs/"
|
||||
_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH = _OUTPUTS_VIRTUAL_PREFIX.lstrip("/")
|
||||
|
||||
|
||||
class ArtifactReconcileState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
artifacts: NotRequired[list[str] | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
|
||||
|
||||
class ArtifactReconcileMiddleware(AgentMiddleware[ArtifactReconcileState]):
|
||||
"""Keep artifact state aligned with files currently in outputs."""
|
||||
|
||||
state_schema = ArtifactReconcileState
|
||||
|
||||
def _to_outputs_file(self, virtual_path: str, outputs_dir: Path) -> Path | None:
|
||||
stripped = virtual_path.lstrip("/")
|
||||
if not stripped.startswith(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH):
|
||||
# Keep non-outputs paths untouched; this middleware is for outputs drift.
|
||||
return None
|
||||
|
||||
relative = stripped[len(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH) :]
|
||||
if not relative:
|
||||
return None
|
||||
|
||||
candidate = (outputs_dir / relative).resolve()
|
||||
try:
|
||||
candidate.relative_to(outputs_dir)
|
||||
except ValueError:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
def _to_virtual_artifact(self, actual_path: Path, outputs_dir: Path) -> str | None:
|
||||
try:
|
||||
relative = actual_path.resolve().relative_to(outputs_dir)
|
||||
except ValueError:
|
||||
return None
|
||||
return f"{_OUTPUTS_VIRTUAL_PREFIX}{relative.as_posix()}"
|
||||
|
||||
def _discover_outputs(self, outputs_dir: Path) -> list[str]:
|
||||
if not outputs_dir.is_dir():
|
||||
return []
|
||||
|
||||
discovered: list[str] = []
|
||||
for path in sorted(outputs_dir.rglob("*")):
|
||||
if not path.is_file():
|
||||
continue
|
||||
virtual_path = self._to_virtual_artifact(path, outputs_dir)
|
||||
if virtual_path:
|
||||
discovered.append(virtual_path)
|
||||
return discovered
|
||||
|
||||
@override
|
||||
def before_model(
|
||||
self,
|
||||
state: ArtifactReconcileState,
|
||||
runtime: Runtime, # noqa: ARG002
|
||||
) -> dict | None:
|
||||
artifacts = state.get("artifacts") or []
|
||||
thread_data = state.get("thread_data") or {}
|
||||
outputs_path = thread_data.get("outputs_path")
|
||||
if not outputs_path:
|
||||
return None
|
||||
|
||||
outputs_dir = Path(outputs_path).resolve()
|
||||
kept: list[str] = []
|
||||
changed = False
|
||||
|
||||
for artifact in artifacts:
|
||||
if not isinstance(artifact, str):
|
||||
changed = True
|
||||
continue
|
||||
|
||||
actual_path = self._to_outputs_file(artifact, outputs_dir)
|
||||
if actual_path is None:
|
||||
kept.append(artifact)
|
||||
continue
|
||||
|
||||
if actual_path.exists() and actual_path.is_file():
|
||||
kept.append(artifact)
|
||||
else:
|
||||
changed = True
|
||||
logger.info(
|
||||
"Reconciled stale artifact from state: virtual=%s outputs_dir=%s",
|
||||
artifact,
|
||||
outputs_dir,
|
||||
)
|
||||
|
||||
discovered = self._discover_outputs(outputs_dir)
|
||||
merged = list(dict.fromkeys([*kept, *discovered]))
|
||||
if merged != kept:
|
||||
changed = True
|
||||
|
||||
if not changed:
|
||||
return None
|
||||
|
||||
return {"artifacts": [ARTIFACTS_REPLACE_SENTINEL, *merged]}
|
||||
|
|
@ -2,6 +2,8 @@ from typing import Annotated, NotRequired, TypedDict
|
|||
|
||||
from langchain.agents import AgentState
|
||||
|
||||
ARTIFACTS_REPLACE_SENTINEL = "__deerflow_replace_artifacts__"
|
||||
|
||||
|
||||
class SandboxState(TypedDict):
|
||||
sandbox_id: NotRequired[str | None]
|
||||
|
|
@ -20,6 +22,8 @@ class ViewedImageData(TypedDict):
|
|||
|
||||
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
|
||||
"""Reducer for artifacts list - merges and deduplicates artifacts."""
|
||||
if new and new[0] == ARTIFACTS_REPLACE_SENTINEL:
|
||||
return list(dict.fromkeys(new[1:]))
|
||||
if existing is None:
|
||||
return new or []
|
||||
if new is None:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
from deerflow.agents.middlewares.artifact_reconcile_middleware import (
|
||||
ArtifactReconcileMiddleware,
|
||||
)
|
||||
from deerflow.agents.thread_state import ARTIFACTS_REPLACE_SENTINEL
|
||||
|
||||
|
||||
def test_before_model_prunes_missing_outputs_artifacts(tmp_path):
|
||||
outputs_dir = tmp_path / "outputs"
|
||||
outputs_dir.mkdir()
|
||||
existing = outputs_dir / "keep.md"
|
||||
existing.write_text("ok", encoding="utf-8")
|
||||
|
||||
middleware = ArtifactReconcileMiddleware()
|
||||
state = {
|
||||
"thread_data": {"outputs_path": str(outputs_dir)},
|
||||
"artifacts": [
|
||||
"/mnt/user-data/outputs/keep.md",
|
||||
"/mnt/user-data/outputs/missing.md",
|
||||
],
|
||||
}
|
||||
|
||||
result = middleware.before_model(state, runtime=SimpleNamespace(context={}))
|
||||
|
||||
assert result == {
|
||||
"artifacts": [ARTIFACTS_REPLACE_SENTINEL, "/mnt/user-data/outputs/keep.md"]
|
||||
}
|
||||
|
||||
|
||||
def test_before_model_returns_none_when_no_changes(tmp_path):
|
||||
outputs_dir = tmp_path / "outputs"
|
||||
outputs_dir.mkdir()
|
||||
existing = outputs_dir / "keep.md"
|
||||
existing.write_text("ok", encoding="utf-8")
|
||||
|
||||
middleware = ArtifactReconcileMiddleware()
|
||||
state = {
|
||||
"thread_data": {"outputs_path": str(outputs_dir)},
|
||||
"artifacts": ["/mnt/user-data/outputs/keep.md"],
|
||||
}
|
||||
|
||||
result = middleware.before_model(state, runtime=SimpleNamespace(context={}))
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_before_model_adds_unpresented_outputs_files(tmp_path):
|
||||
outputs_dir = tmp_path / "outputs"
|
||||
outputs_dir.mkdir()
|
||||
existing = outputs_dir / "keep.md"
|
||||
existing.write_text("ok", encoding="utf-8")
|
||||
extra = outputs_dir / "extra.md"
|
||||
extra.write_text("ok", encoding="utf-8")
|
||||
|
||||
middleware = ArtifactReconcileMiddleware()
|
||||
state = {
|
||||
"thread_data": {"outputs_path": str(outputs_dir)},
|
||||
"artifacts": ["/mnt/user-data/outputs/keep.md"],
|
||||
}
|
||||
|
||||
result = middleware.before_model(state, runtime=SimpleNamespace(context={}))
|
||||
|
||||
assert result == {
|
||||
"artifacts": [
|
||||
ARTIFACTS_REPLACE_SENTINEL,
|
||||
"/mnt/user-data/outputs/keep.md",
|
||||
"/mnt/user-data/outputs/extra.md",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_before_model_discovers_outputs_when_artifacts_empty(tmp_path):
|
||||
outputs_dir = tmp_path / "outputs"
|
||||
outputs_dir.mkdir()
|
||||
report = outputs_dir / "report.md"
|
||||
report.write_text("ok", encoding="utf-8")
|
||||
|
||||
middleware = ArtifactReconcileMiddleware()
|
||||
state = {
|
||||
"thread_data": {"outputs_path": str(outputs_dir)},
|
||||
"artifacts": [],
|
||||
}
|
||||
|
||||
result = middleware.before_model(state, runtime=SimpleNamespace(context={}))
|
||||
|
||||
assert result == {
|
||||
"artifacts": [ARTIFACTS_REPLACE_SENTINEL, "/mnt/user-data/outputs/report.md"]
|
||||
}
|
||||
|
|
@ -147,7 +147,8 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
|||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
fake_model = object()
|
||||
fake_model = MagicMock()
|
||||
fake_model._llm_type = "test-chat"
|
||||
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
|
||||
captured["name"] = name
|
||||
|
|
@ -156,10 +157,20 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
|||
return fake_model
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert middleware["model"] is fake_model
|
||||
assert isinstance(middleware, lead_agent_module.DeerFlowSummarizationMiddleware)
|
||||
assert middleware.model is fake_model
|
||||
|
||||
|
||||
def test_deerflow_summarization_middleware_uses_chinese_summary_title():
|
||||
middleware = lead_agent_module.DeerFlowSummarizationMiddleware(
|
||||
model=MagicMock(),
|
||||
trigger=("messages", 2),
|
||||
)
|
||||
|
||||
messages = middleware._build_new_messages("旧上下文")
|
||||
|
||||
assert messages[0].content == "以下是目前对话的摘要:\n\n旧上下文"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
from deerflow.agents.thread_state import (
|
||||
ARTIFACTS_REPLACE_SENTINEL,
|
||||
merge_artifacts,
|
||||
)
|
||||
|
||||
|
||||
def test_merge_artifacts_default_merge_dedup():
|
||||
existing = ["/mnt/user-data/outputs/a.md", "/mnt/user-data/outputs/b.md"]
|
||||
new = ["/mnt/user-data/outputs/b.md", "/mnt/user-data/outputs/c.md"]
|
||||
|
||||
result = merge_artifacts(existing, new)
|
||||
|
||||
assert result == [
|
||||
"/mnt/user-data/outputs/a.md",
|
||||
"/mnt/user-data/outputs/b.md",
|
||||
"/mnt/user-data/outputs/c.md",
|
||||
]
|
||||
|
||||
|
||||
def test_merge_artifacts_supports_replace_sentinel():
|
||||
existing = ["/mnt/user-data/outputs/a.md", "/mnt/user-data/outputs/b.md"]
|
||||
new = [
|
||||
ARTIFACTS_REPLACE_SENTINEL,
|
||||
"/mnt/user-data/outputs/b.md",
|
||||
"/mnt/user-data/outputs/c.md",
|
||||
"/mnt/user-data/outputs/c.md",
|
||||
]
|
||||
|
||||
result = merge_artifacts(existing, new)
|
||||
|
||||
assert result == [
|
||||
"/mnt/user-data/outputs/b.md",
|
||||
"/mnt/user-data/outputs/c.md",
|
||||
]
|
||||
Loading…
Reference in New Issue