"""Middleware for logging LLM token usage and optionally reporting to an external billing API.""" import asyncio import logging from typing import override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain_core.messages import HumanMessage from langgraph.runtime import Runtime from deerflow.config.app_config import get_app_config logger = logging.getLogger(__name__) class TokenUsageMiddleware(AgentMiddleware): """Logs token usage from model response usage_metadata and optionally reports to an external billing API.""" @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: cfg = get_app_config().token_usage logger.info( "[TokenUsageMiddleware] after_model triggered: enabled=%s report_enabled=%s", getattr(cfg, "enabled", False), getattr(cfg, "report_enabled", False), ) if getattr(cfg, "enabled", False): self._log_usage(state) return None @override async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: cfg = get_app_config().token_usage logger.info( "[TokenUsageMiddleware] aafter_model triggered: enabled=%s report_enabled=%s", getattr(cfg, "enabled", False), getattr(cfg, "report_enabled", False), ) if getattr(cfg, "enabled", False): self._log_usage(state) if getattr(cfg, "report_enabled", False): _schedule_usage_report(state, runtime) return None def _log_usage(self, state: AgentState) -> None: messages = state.get("messages", []) logger.info("[TokenUsageMiddleware] _log_usage messages_count=%s", len(messages)) if not messages: logger.info("[TokenUsageMiddleware] _log_usage skip: no messages") return None usage = _extract_usage_from_messages(messages) if usage: logger.info( "LLM token usage: input=%s output=%s total=%s", usage.get("input_tokens", "?"), usage.get("output_tokens", "?"), usage.get("total_tokens", "?"), ) return None def _schedule_usage_report(state: AgentState, runtime: Runtime) -> None: """Fire-and-forget: schedule a billing report request after each model call.""" cfg = get_app_config().token_usage logger.info("[TokenUsageMiddleware] _schedule_usage_report entered") messages = state.get("messages", []) if not messages: logger.info("[TokenUsageMiddleware] skip report: no messages") return last = messages[-1] usage = _extract_usage_from_messages(messages) if not usage: logger.info("[TokenUsageMiddleware] skip report: no token usage found") _log_message_diagnostics(last) return thread_id = _extract_thread_id(runtime) model_key = _extract_model_key(runtime, last) model_name = model_key if model_key: model_cfg = get_app_config().get_model_config(model_key) if model_cfg and model_cfg.display_name: model_name = model_cfg.display_name transaction_no = _extract_transaction_no(last) if not transaction_no: transaction_no = 0 question = _extract_latest_question(messages) if isinstance(question, str) and len(question) > 27: question = question[:27] + "。。。" payload: dict = { "sessionId": thread_id, "inputToken": usage.get("input_tokens"), "outputToken": usage.get("output_tokens"), "totalTokens": usage.get("total_tokens"), "modelName": model_name, "question": question, "transactionNo": transaction_no, "remark": "", } logger.info("Token billing payload: %s", payload) if not cfg.report_url: logger.info("[TokenUsageMiddleware] skip report: report_url is empty") return asyncio.create_task( _post_usage_report(cfg.report_url, cfg.report_headers, payload), name="token_usage_report", ) async def _post_usage_report(url: str, headers: dict[str, str], payload: dict) -> None: """POST billing payload to the configured external endpoint.""" try: import httpx async with httpx.AsyncClient(timeout=10.0) as client: response = await client.post(url, json=payload, headers=headers) if response.status_code >= 400: logger.warning( "Token billing report HTTP %s from %s", response.status_code, url, ) except Exception as exc: logger.warning("Failed to report token billing to %s: %s", url, exc) def _extract_latest_question(messages: list) -> str: """Extract latest user question text from message history.""" for msg in reversed(messages): if isinstance(msg, HumanMessage): content = getattr(msg, "content", "") if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for part in content: if isinstance(part, str): parts.append(part) elif isinstance(part, dict): text = part.get("text") if isinstance(text, str): parts.append(text) return "\n".join(p for p in parts if p) return str(content) return "" def _extract_transaction_no(last_message: object) -> str | None: """Get provider response id from AI message for billing transactionNo.""" msg_id = getattr(last_message, "id", None) if isinstance(msg_id, str) and msg_id: return msg_id response_metadata = getattr(last_message, "response_metadata", None) if isinstance(response_metadata, dict): meta_id = response_metadata.get("id") or response_metadata.get("response_id") if isinstance(meta_id, str) and meta_id: return meta_id additional_kwargs = getattr(last_message, "additional_kwargs", None) if isinstance(additional_kwargs, dict): kw_id = additional_kwargs.get("id") if isinstance(kw_id, str) and kw_id: return kw_id return None def _extract_thread_id(runtime: Runtime) -> str | None: """Extract thread_id safely across LangGraph runtime variants.""" context = getattr(runtime, "context", None) if isinstance(context, dict): thread_id = context.get("thread_id") if isinstance(thread_id, str) and thread_id: return thread_id config = getattr(runtime, "config", None) if isinstance(config, dict): thread_id = config.get("configurable", {}).get("thread_id") if isinstance(thread_id, str) and thread_id: return thread_id return None def _extract_model_key(runtime: Runtime, last_message: object) -> str | None: """Extract model key/name safely from runtime config or message metadata.""" config = getattr(runtime, "config", None) if isinstance(config, dict): configurable = config.get("configurable", {}) model_key = configurable.get("model") or configurable.get("model_name") if isinstance(model_key, str) and model_key: return model_key response_metadata = getattr(last_message, "response_metadata", None) if isinstance(response_metadata, dict): model_key = response_metadata.get("model_name") or response_metadata.get("model") if isinstance(model_key, str) and model_key: return model_key return None def _extract_usage(last_message: object) -> dict[str, int] | None: """Extract token usage from common provider/LangChain metadata shapes.""" # Primary LangChain shape. usage_metadata = getattr(last_message, "usage_metadata", None) usage = _normalize_usage_dict(usage_metadata) if usage: logger.info("[TokenUsageMiddleware] usage source=usage_metadata") return usage # Common provider shape on AIMessage.response_metadata.usage response_metadata = getattr(last_message, "response_metadata", None) if isinstance(response_metadata, dict): usage = _normalize_usage_dict(response_metadata.get("usage")) if usage: logger.info("[TokenUsageMiddleware] usage source=response_metadata.usage") return usage usage = _normalize_usage_dict(response_metadata.get("token_usage")) if usage: logger.info("[TokenUsageMiddleware] usage source=response_metadata.token_usage") return usage # Some providers attach usage-like payloads to additional_kwargs. additional_kwargs = getattr(last_message, "additional_kwargs", None) if isinstance(additional_kwargs, dict): usage = _normalize_usage_dict(additional_kwargs.get("usage")) if usage: logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.usage") return usage usage = _normalize_usage_dict(additional_kwargs.get("token_usage")) if usage: logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.token_usage") return usage return None def _extract_usage_from_messages(messages: list) -> dict[str, int] | None: """Find token usage from the most recent message that contains it.""" for msg in reversed(messages): usage = _extract_usage(msg) if usage: return usage return None def _log_message_diagnostics(last_message: object) -> None: """Log lightweight diagnostics for message metadata shape when usage is missing.""" response_metadata = getattr(last_message, "response_metadata", None) additional_kwargs = getattr(last_message, "additional_kwargs", None) logger.info( "[TokenUsageMiddleware] diagnostics: message_type=%s has_usage_metadata=%s response_metadata_keys=%s additional_kwargs_keys=%s", type(last_message).__name__, getattr(last_message, "usage_metadata", None) is not None, sorted(response_metadata.keys()) if isinstance(response_metadata, dict) else [], sorted(additional_kwargs.keys()) if isinstance(additional_kwargs, dict) else [], ) def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None: """Normalize token usage keys to input_tokens/output_tokens/total_tokens.""" if not isinstance(raw_usage, dict): return None input_tokens = raw_usage.get("input_tokens") if input_tokens is None: input_tokens = raw_usage.get("prompt_tokens") output_tokens = raw_usage.get("output_tokens") if output_tokens is None: output_tokens = raw_usage.get("completion_tokens") total_tokens = raw_usage.get("total_tokens") if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int): total_tokens = input_tokens + output_tokens if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)): return None return { "input_tokens": int(input_tokens or 0), "output_tokens": int(output_tokens or 0), "total_tokens": int(total_tokens or 0), }