deerflow2/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py

297 lines
11 KiB
Python

"""Middleware for logging LLM token usage and optionally reporting to an external billing API."""
import asyncio
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.app_config import get_app_config
logger = logging.getLogger(__name__)
class TokenUsageMiddleware(AgentMiddleware):
"""Logs token usage from model response usage_metadata and optionally reports to an external billing API."""
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
cfg = get_app_config().token_usage
logger.info(
"[TokenUsageMiddleware] after_model triggered: enabled=%s report_enabled=%s",
getattr(cfg, "enabled", False),
getattr(cfg, "report_enabled", False),
)
if getattr(cfg, "enabled", False):
self._log_usage(state)
return None
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
cfg = get_app_config().token_usage
logger.info(
"[TokenUsageMiddleware] aafter_model triggered: enabled=%s report_enabled=%s",
getattr(cfg, "enabled", False),
getattr(cfg, "report_enabled", False),
)
if getattr(cfg, "enabled", False):
self._log_usage(state)
if getattr(cfg, "report_enabled", False):
_schedule_usage_report(state, runtime)
return None
def _log_usage(self, state: AgentState) -> None:
messages = state.get("messages", [])
logger.info("[TokenUsageMiddleware] _log_usage messages_count=%s", len(messages))
if not messages:
logger.info("[TokenUsageMiddleware] _log_usage skip: no messages")
return None
usage = _extract_usage_from_messages(messages)
if usage:
logger.info(
"LLM token usage: input=%s output=%s total=%s",
usage.get("input_tokens", "?"),
usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"),
)
return None
def _schedule_usage_report(state: AgentState, runtime: Runtime) -> None:
"""Fire-and-forget: schedule a billing report request after each model call."""
cfg = get_app_config().token_usage
logger.info("[TokenUsageMiddleware] _schedule_usage_report entered")
messages = state.get("messages", [])
if not messages:
logger.info("[TokenUsageMiddleware] skip report: no messages")
return
last = messages[-1]
usage = _extract_usage_from_messages(messages)
if not usage:
logger.info("[TokenUsageMiddleware] skip report: no token usage found")
_log_message_diagnostics(last)
return
thread_id = _extract_thread_id(runtime)
model_key = _extract_model_key(runtime, last)
model_name = model_key
if model_key:
model_cfg = get_app_config().get_model_config(model_key)
if model_cfg and model_cfg.display_name:
model_name = model_cfg.display_name
transaction_no = _extract_transaction_no(last)
if not transaction_no:
transaction_no = 0
question = _extract_latest_question(messages)
if isinstance(question, str) and len(question) > 27:
question = question[:27] + "。。。"
payload: dict = {
"sessionId": thread_id,
"inputToken": usage.get("input_tokens"),
"outputToken": usage.get("output_tokens"),
"totalTokens": usage.get("total_tokens"),
"modelName": model_name,
"question": question,
"transactionNo": transaction_no,
"remark": "",
}
logger.info("Token billing payload: %s", payload)
if not cfg.report_url:
logger.info("[TokenUsageMiddleware] skip report: report_url is empty")
return
asyncio.create_task(
_post_usage_report(cfg.report_url, cfg.report_headers, payload),
name="token_usage_report",
)
async def _post_usage_report(url: str, headers: dict[str, str], payload: dict) -> None:
"""POST billing payload to the configured external endpoint."""
try:
import httpx
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(url, json=payload, headers=headers)
if response.status_code >= 400:
logger.warning(
"Token billing report HTTP %s from %s",
response.status_code,
url,
)
except Exception as exc:
logger.warning("Failed to report token billing to %s: %s", url, exc)
def _extract_latest_question(messages: list) -> str:
"""Extract latest user question text from message history."""
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
content = getattr(msg, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
elif isinstance(part, dict):
text = part.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(p for p in parts if p)
return str(content)
return ""
def _extract_transaction_no(last_message: object) -> str | None:
"""Get provider response id from AI message for billing transactionNo."""
msg_id = getattr(last_message, "id", None)
if isinstance(msg_id, str) and msg_id:
return msg_id
response_metadata = getattr(last_message, "response_metadata", None)
if isinstance(response_metadata, dict):
meta_id = response_metadata.get("id") or response_metadata.get("response_id")
if isinstance(meta_id, str) and meta_id:
return meta_id
additional_kwargs = getattr(last_message, "additional_kwargs", None)
if isinstance(additional_kwargs, dict):
kw_id = additional_kwargs.get("id")
if isinstance(kw_id, str) and kw_id:
return kw_id
return None
def _extract_thread_id(runtime: Runtime) -> str | None:
"""Extract thread_id safely across LangGraph runtime variants."""
context = getattr(runtime, "context", None)
if isinstance(context, dict):
thread_id = context.get("thread_id")
if isinstance(thread_id, str) and thread_id:
return thread_id
config = getattr(runtime, "config", None)
if isinstance(config, dict):
thread_id = config.get("configurable", {}).get("thread_id")
if isinstance(thread_id, str) and thread_id:
return thread_id
return None
def _extract_model_key(runtime: Runtime, last_message: object) -> str | None:
"""Extract model key/name safely from runtime config or message metadata."""
config = getattr(runtime, "config", None)
if isinstance(config, dict):
configurable = config.get("configurable", {})
model_key = configurable.get("model") or configurable.get("model_name")
if isinstance(model_key, str) and model_key:
return model_key
response_metadata = getattr(last_message, "response_metadata", None)
if isinstance(response_metadata, dict):
model_key = response_metadata.get("model_name") or response_metadata.get("model")
if isinstance(model_key, str) and model_key:
return model_key
return None
def _extract_usage(last_message: object) -> dict[str, int] | None:
"""Extract token usage from common provider/LangChain metadata shapes."""
# Primary LangChain shape.
usage_metadata = getattr(last_message, "usage_metadata", None)
usage = _normalize_usage_dict(usage_metadata)
if usage:
logger.info("[TokenUsageMiddleware] usage source=usage_metadata")
return usage
# Common provider shape on AIMessage.response_metadata.usage
response_metadata = getattr(last_message, "response_metadata", None)
if isinstance(response_metadata, dict):
usage = _normalize_usage_dict(response_metadata.get("usage"))
if usage:
logger.info("[TokenUsageMiddleware] usage source=response_metadata.usage")
return usage
usage = _normalize_usage_dict(response_metadata.get("token_usage"))
if usage:
logger.info("[TokenUsageMiddleware] usage source=response_metadata.token_usage")
return usage
# Some providers attach usage-like payloads to additional_kwargs.
additional_kwargs = getattr(last_message, "additional_kwargs", None)
if isinstance(additional_kwargs, dict):
usage = _normalize_usage_dict(additional_kwargs.get("usage"))
if usage:
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.usage")
return usage
usage = _normalize_usage_dict(additional_kwargs.get("token_usage"))
if usage:
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.token_usage")
return usage
return None
def _extract_usage_from_messages(messages: list) -> dict[str, int] | None:
"""Find token usage from the most recent message that contains it."""
for msg in reversed(messages):
usage = _extract_usage(msg)
if usage:
return usage
return None
def _log_message_diagnostics(last_message: object) -> None:
"""Log lightweight diagnostics for message metadata shape when usage is missing."""
response_metadata = getattr(last_message, "response_metadata", None)
additional_kwargs = getattr(last_message, "additional_kwargs", None)
logger.info(
"[TokenUsageMiddleware] diagnostics: message_type=%s has_usage_metadata=%s response_metadata_keys=%s additional_kwargs_keys=%s",
type(last_message).__name__,
getattr(last_message, "usage_metadata", None) is not None,
sorted(response_metadata.keys()) if isinstance(response_metadata, dict) else [],
sorted(additional_kwargs.keys()) if isinstance(additional_kwargs, dict) else [],
)
def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None:
"""Normalize token usage keys to input_tokens/output_tokens/total_tokens."""
if not isinstance(raw_usage, dict):
return None
input_tokens = raw_usage.get("input_tokens")
if input_tokens is None:
input_tokens = raw_usage.get("prompt_tokens")
output_tokens = raw_usage.get("output_tokens")
if output_tokens is None:
output_tokens = raw_usage.get("completion_tokens")
total_tokens = raw_usage.get("total_tokens")
if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int):
total_tokens = input_tokens + output_tokens
if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)):
return None
return {
"input_tokens": int(input_tokens or 0),
"output_tokens": int(output_tokens or 0),
"total_tokens": int(total_tokens or 0),
}