deerflow2/backend/packages/harness/deerflow/agents/middlewares/artifact_reconcile_middlewa...

115 lines
3.6 KiB
Python

import logging
from pathlib import Path
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import (
ARTIFACTS_REPLACE_SENTINEL,
ThreadDataState,
)
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
logger = logging.getLogger(__name__)
_OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs/"
_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH = _OUTPUTS_VIRTUAL_PREFIX.lstrip("/")
class ArtifactReconcileState(AgentState):
"""Compatible with the `ThreadState` schema."""
artifacts: NotRequired[list[str] | None]
thread_data: NotRequired[ThreadDataState | None]
class ArtifactReconcileMiddleware(AgentMiddleware[ArtifactReconcileState]):
"""Keep artifact state aligned with files currently in outputs."""
state_schema = ArtifactReconcileState
def _to_outputs_file(self, virtual_path: str, outputs_dir: Path) -> Path | None:
stripped = virtual_path.lstrip("/")
if not stripped.startswith(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH):
# Keep non-outputs paths untouched; this middleware is for outputs drift.
return None
relative = stripped[len(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH) :]
if not relative:
return None
candidate = (outputs_dir / relative).resolve()
try:
candidate.relative_to(outputs_dir)
except ValueError:
return None
return candidate
def _to_virtual_artifact(self, actual_path: Path, outputs_dir: Path) -> str | None:
try:
relative = actual_path.resolve().relative_to(outputs_dir)
except ValueError:
return None
return f"{_OUTPUTS_VIRTUAL_PREFIX}{relative.as_posix()}"
def _discover_outputs(self, outputs_dir: Path) -> list[str]:
if not outputs_dir.is_dir():
return []
discovered: list[str] = []
for path in sorted(outputs_dir.rglob("*")):
if not path.is_file():
continue
virtual_path = self._to_virtual_artifact(path, outputs_dir)
if virtual_path:
discovered.append(virtual_path)
return discovered
@override
def before_model(
self,
state: ArtifactReconcileState,
runtime: Runtime, # noqa: ARG002
) -> dict | None:
artifacts = state.get("artifacts") or []
thread_data = state.get("thread_data") or {}
outputs_path = thread_data.get("outputs_path")
if not outputs_path:
return None
outputs_dir = Path(outputs_path).resolve()
kept: list[str] = []
changed = False
for artifact in artifacts:
if not isinstance(artifact, str):
changed = True
continue
actual_path = self._to_outputs_file(artifact, outputs_dir)
if actual_path is None:
kept.append(artifact)
continue
if actual_path.exists() and actual_path.is_file():
kept.append(artifact)
else:
changed = True
logger.info(
"Reconciled stale artifact from state: virtual=%s outputs_dir=%s",
artifact,
outputs_dir,
)
discovered = self._discover_outputs(outputs_dir)
merged = list(dict.fromkeys([*kept, *discovered]))
if merged != kept:
changed = True
if not changed:
return None
return {"artifacts": [ARTIFACTS_REPLACE_SENTINEL, *merged]}