118 lines
3.7 KiB
Python
118 lines
3.7 KiB
Python
import logging
|
|
from pathlib import Path
|
|
from typing import NotRequired, override
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langgraph.runtime import Runtime
|
|
|
|
from deerflow.agents.thread_state import (
|
|
ARTIFACTS_REPLACE_SENTINEL,
|
|
ThreadDataState,
|
|
)
|
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs/"
|
|
_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH = _OUTPUTS_VIRTUAL_PREFIX.lstrip("/")
|
|
|
|
|
|
class ArtifactReconcileState(AgentState):
|
|
"""Compatible with the `ThreadState` schema."""
|
|
|
|
artifacts: NotRequired[list[str] | None]
|
|
thread_data: NotRequired[ThreadDataState | None]
|
|
|
|
|
|
class ArtifactReconcileMiddleware(AgentMiddleware[ArtifactReconcileState]):
|
|
"""Keep artifact state aligned with files currently in outputs."""
|
|
|
|
state_schema = ArtifactReconcileState
|
|
|
|
def _to_outputs_file(self, virtual_path: str, outputs_dir: Path) -> Path | None:
|
|
stripped = virtual_path.lstrip("/")
|
|
if not stripped.startswith(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH):
|
|
# Keep non-outputs paths untouched; this middleware is for outputs drift.
|
|
return None
|
|
|
|
relative = stripped[len(_OUTPUTS_VIRTUAL_PREFIX_NO_LEADING_SLASH) :]
|
|
if not relative:
|
|
return None
|
|
|
|
candidate = (outputs_dir / relative).resolve()
|
|
try:
|
|
candidate.relative_to(outputs_dir)
|
|
except ValueError:
|
|
return None
|
|
return candidate
|
|
|
|
def _to_virtual_artifact(self, actual_path: Path, outputs_dir: Path) -> str | None:
|
|
try:
|
|
relative = actual_path.resolve().relative_to(outputs_dir)
|
|
except ValueError:
|
|
return None
|
|
return f"{_OUTPUTS_VIRTUAL_PREFIX}{relative.as_posix()}"
|
|
|
|
def _discover_outputs(self, outputs_dir: Path) -> list[str]:
|
|
if not outputs_dir.is_dir():
|
|
return []
|
|
|
|
discovered: list[str] = []
|
|
for path in sorted(outputs_dir.rglob("*")):
|
|
if not path.is_file():
|
|
continue
|
|
virtual_path = self._to_virtual_artifact(path, outputs_dir)
|
|
if virtual_path:
|
|
discovered.append(virtual_path)
|
|
return discovered
|
|
|
|
@override
|
|
def before_model(
|
|
self,
|
|
state: ArtifactReconcileState,
|
|
runtime: Runtime, # noqa: ARG002
|
|
) -> dict | None:
|
|
artifacts = state.get("artifacts") or []
|
|
thread_data = state.get("thread_data") or {}
|
|
outputs_path = thread_data.get("outputs_path")
|
|
if not outputs_path:
|
|
return None
|
|
|
|
outputs_dir = Path(outputs_path).resolve()
|
|
kept: list[str] = []
|
|
changed = False
|
|
|
|
for artifact in artifacts:
|
|
if not isinstance(artifact, str):
|
|
changed = True
|
|
continue
|
|
if artifact == ARTIFACTS_REPLACE_SENTINEL:
|
|
changed = True
|
|
continue
|
|
|
|
actual_path = self._to_outputs_file(artifact, outputs_dir)
|
|
if actual_path is None:
|
|
kept.append(artifact)
|
|
continue
|
|
|
|
if actual_path.exists() and actual_path.is_file():
|
|
kept.append(artifact)
|
|
else:
|
|
changed = True
|
|
logger.info(
|
|
"Reconciled stale artifact from state: virtual=%s outputs_dir=%s",
|
|
artifact,
|
|
outputs_dir,
|
|
)
|
|
|
|
discovered = self._discover_outputs(outputs_dir)
|
|
merged = list(dict.fromkeys([*kept, *discovered]))
|
|
if merged != kept:
|
|
changed = True
|
|
|
|
if not changed:
|
|
return None
|
|
|
|
return {"artifacts": [ARTIFACTS_REPLACE_SENTINEL, *merged]}
|