import logging from pathlib import Path from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.runtime import Runtime from deerflow.agents.thread_state import ( ARTIFACTS_REPLACE_SENTINEL, ThreadDataState, ) from deerflow.config.paths import VIRTUAL_PATH_PREFIX logger = logging.getLogger(__name__) _OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs/" _OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH = _OUTPUTS_VIRTUAL_PREFIX.lstrip("/") class ArtifactReconcileState(AgentState): """Compatible with the `ThreadState` schema.""" artifacts: NotRequired[list[str] | None] thread_data: NotRequired[ThreadDataState | None] class ArtifactReconcileMiddleware(AgentMiddleware[ArtifactReconcileState]): """Keep artifact state aligned with files currently in outputs.""" state_schema = ArtifactReconcileState def _to_outputs_file(self, virtual_path: str, outputs_dir: Path) -> Path | None: stripped = virtual_path.lstrip("/") if not stripped.startswith(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH): # Keep non-outputs paths untouched; this middleware is for outputs drift. return None relative = stripped[len(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH) :] if not relative: return None candidate = (outputs_dir / relative).resolve() try: candidate.relative_to(outputs_dir) except ValueError: return None return candidate def _to_virtual_artifact(self, actual_path: Path, outputs_dir: Path) -> str | None: try: relative = actual_path.resolve().relative_to(outputs_dir) except ValueError: return None return f"{_OUTPUTS_VIRTUAL_PREFIX}{relative.as_posix()}" def _discover_outputs(self, outputs_dir: Path) -> list[str]: if not outputs_dir.is_dir(): return [] discovered: list[str] = [] for path in sorted(outputs_dir.rglob("*")): if not path.is_file(): continue virtual_path = self._to_virtual_artifact(path, outputs_dir) if virtual_path: discovered.append(virtual_path) return discovered @override def before_model( self, state: ArtifactReconcileState, runtime: Runtime, # noqa: ARG002 ) -> dict | None: artifacts = state.get("artifacts") or [] thread_data = state.get("thread_data") or {} outputs_path = thread_data.get("outputs_path") if not outputs_path: return None outputs_dir = Path(outputs_path).resolve() kept: list[str] = [] changed = False for artifact in artifacts: if not isinstance(artifact, str): changed = True continue if artifact == ARTIFACTS_REPLACE_SENTINEL: changed = True continue actual_path = self._to_outputs_file(artifact, outputs_dir) if actual_path is None: kept.append(artifact) continue if actual_path.exists() and actual_path.is_file(): kept.append(artifact) else: changed = True logger.info( "Reconciled stale artifact from state: virtual=%s outputs_dir=%s", artifact, outputs_dir, ) discovered = self._discover_outputs(outputs_dir) merged = list(dict.fromkeys([*kept, *discovered])) if merged != kept: changed = True if not changed: return None return {"artifacts": [ARTIFACTS_REPLACE_SENTINEL, *merged]}