"""Middleware that stamps conversation messages with backend timestamps.""" from __future__ import annotations from datetime import datetime, timedelta, timezone from typing import Any from typing import override from zoneinfo import ZoneInfo from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.runtime import Runtime _TIMESTAMP_KEY = "deerflow_created_at" try: _BEIJING_TZ = ZoneInfo("Asia/Shanghai") except Exception: # Fallback when zoneinfo database is unavailable. _BEIJING_TZ = timezone(timedelta(hours=8)) def _beijing_iso_millis(dt: datetime) -> str: return dt.astimezone(_BEIJING_TZ).isoformat(timespec="milliseconds") def _extract_existing_timestamp(message: Any) -> str | None: if isinstance(message, dict): top = message.get("created_at") if isinstance(top, str) and top: return top additional_kwargs = message.get("additional_kwargs") if isinstance(additional_kwargs, dict): value = additional_kwargs.get(_TIMESTAMP_KEY) or additional_kwargs.get("created_at") if isinstance(value, str) and value: return value return None additional_kwargs = getattr(message, "additional_kwargs", None) if isinstance(additional_kwargs, dict): value = additional_kwargs.get(_TIMESTAMP_KEY) or additional_kwargs.get("created_at") if isinstance(value, str) and value: return value return None def _stamp_message(message: Any, timestamp: str) -> None: if _extract_existing_timestamp(message): return if isinstance(message, dict): additional_kwargs = message.get("additional_kwargs") if not isinstance(additional_kwargs, dict): additional_kwargs = {} message["additional_kwargs"] = additional_kwargs additional_kwargs[_TIMESTAMP_KEY] = timestamp return additional_kwargs = getattr(message, "additional_kwargs", None) if not isinstance(additional_kwargs, dict): additional_kwargs = {} try: setattr(message, "additional_kwargs", additional_kwargs) except Exception: return additional_kwargs[_TIMESTAMP_KEY] = timestamp def _stamp_messages(messages: list[Any]) -> None: now = datetime.now(_BEIJING_TZ) for idx, message in enumerate(messages): _stamp_message(message, _beijing_iso_millis(now + timedelta(milliseconds=idx))) class MessageTimestampMiddleware(AgentMiddleware): """Ensure every persisted conversation message has a backend timestamp.""" @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: messages = state.get("messages") if isinstance(messages, list): _stamp_messages(messages) return None @override async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: messages = state.get("messages") if isinstance(messages, list): _stamp_messages(messages) return None