deerflow2/backend/packages/harness/deerflow/agents/middlewares/message_timestamp_middlewar...

90 lines
3.0 KiB
Python

"""Middleware that stamps conversation messages with backend timestamps."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from typing import override
from zoneinfo import ZoneInfo
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
_TIMESTAMP_KEY = "deerflow_created_at"
try:
_BEIJING_TZ = ZoneInfo("Asia/Shanghai")
except Exception:
# Fallback when zoneinfo database is unavailable.
_BEIJING_TZ = timezone(timedelta(hours=8))
def _beijing_iso_millis(dt: datetime) -> str:
return dt.astimezone(_BEIJING_TZ).isoformat(timespec="milliseconds")
def _extract_existing_timestamp(message: Any) -> str | None:
if isinstance(message, dict):
top = message.get("created_at")
if isinstance(top, str) and top:
return top
additional_kwargs = message.get("additional_kwargs")
if isinstance(additional_kwargs, dict):
value = additional_kwargs.get(_TIMESTAMP_KEY) or additional_kwargs.get("created_at")
if isinstance(value, str) and value:
return value
return None
additional_kwargs = getattr(message, "additional_kwargs", None)
if isinstance(additional_kwargs, dict):
value = additional_kwargs.get(_TIMESTAMP_KEY) or additional_kwargs.get("created_at")
if isinstance(value, str) and value:
return value
return None
def _stamp_message(message: Any, timestamp: str) -> None:
if _extract_existing_timestamp(message):
return
if isinstance(message, dict):
additional_kwargs = message.get("additional_kwargs")
if not isinstance(additional_kwargs, dict):
additional_kwargs = {}
message["additional_kwargs"] = additional_kwargs
additional_kwargs[_TIMESTAMP_KEY] = timestamp
return
additional_kwargs = getattr(message, "additional_kwargs", None)
if not isinstance(additional_kwargs, dict):
additional_kwargs = {}
try:
setattr(message, "additional_kwargs", additional_kwargs)
except Exception:
return
additional_kwargs[_TIMESTAMP_KEY] = timestamp
def _stamp_messages(messages: list[Any]) -> None:
now = datetime.now(_BEIJING_TZ)
for idx, message in enumerate(messages):
_stamp_message(message, _beijing_iso_millis(now + timedelta(milliseconds=idx)))
class MessageTimestampMiddleware(AgentMiddleware):
"""Ensure every persisted conversation message has a backend timestamp."""
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state.get("messages")
if isinstance(messages, list):
_stamp_messages(messages)
return None
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state.get("messages")
if isinstance(messages, list):
_stamp_messages(messages)
return None