90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
"""Middleware that stamps conversation messages with backend timestamps."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any
|
|
from typing import override
|
|
from zoneinfo import ZoneInfo
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langgraph.runtime import Runtime
|
|
|
|
_TIMESTAMP_KEY = "deerflow_created_at"
|
|
try:
|
|
_BEIJING_TZ = ZoneInfo("Asia/Shanghai")
|
|
except Exception:
|
|
# Fallback when zoneinfo database is unavailable.
|
|
_BEIJING_TZ = timezone(timedelta(hours=8))
|
|
|
|
|
|
def _beijing_iso_millis(dt: datetime) -> str:
|
|
return dt.astimezone(_BEIJING_TZ).isoformat(timespec="milliseconds")
|
|
|
|
|
|
def _extract_existing_timestamp(message: Any) -> str | None:
|
|
if isinstance(message, dict):
|
|
top = message.get("created_at")
|
|
if isinstance(top, str) and top:
|
|
return top
|
|
additional_kwargs = message.get("additional_kwargs")
|
|
if isinstance(additional_kwargs, dict):
|
|
value = additional_kwargs.get(_TIMESTAMP_KEY) or additional_kwargs.get("created_at")
|
|
if isinstance(value, str) and value:
|
|
return value
|
|
return None
|
|
|
|
additional_kwargs = getattr(message, "additional_kwargs", None)
|
|
if isinstance(additional_kwargs, dict):
|
|
value = additional_kwargs.get(_TIMESTAMP_KEY) or additional_kwargs.get("created_at")
|
|
if isinstance(value, str) and value:
|
|
return value
|
|
return None
|
|
|
|
|
|
def _stamp_message(message: Any, timestamp: str) -> None:
|
|
if _extract_existing_timestamp(message):
|
|
return
|
|
|
|
if isinstance(message, dict):
|
|
additional_kwargs = message.get("additional_kwargs")
|
|
if not isinstance(additional_kwargs, dict):
|
|
additional_kwargs = {}
|
|
message["additional_kwargs"] = additional_kwargs
|
|
additional_kwargs[_TIMESTAMP_KEY] = timestamp
|
|
return
|
|
|
|
additional_kwargs = getattr(message, "additional_kwargs", None)
|
|
if not isinstance(additional_kwargs, dict):
|
|
additional_kwargs = {}
|
|
try:
|
|
setattr(message, "additional_kwargs", additional_kwargs)
|
|
except Exception:
|
|
return
|
|
additional_kwargs[_TIMESTAMP_KEY] = timestamp
|
|
|
|
|
|
def _stamp_messages(messages: list[Any]) -> None:
|
|
now = datetime.now(_BEIJING_TZ)
|
|
for idx, message in enumerate(messages):
|
|
_stamp_message(message, _beijing_iso_millis(now + timedelta(milliseconds=idx)))
|
|
|
|
|
|
class MessageTimestampMiddleware(AgentMiddleware):
|
|
"""Ensure every persisted conversation message has a backend timestamp."""
|
|
|
|
@override
|
|
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
messages = state.get("messages")
|
|
if isinstance(messages, list):
|
|
_stamp_messages(messages)
|
|
return None
|
|
|
|
@override
|
|
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
messages = state.get("messages")
|
|
if isinstance(messages, list):
|
|
_stamp_messages(messages)
|
|
return None
|