feat(token-usage): Rewrite middleware to support external billing API reporting

This commit is contained in:
Titan 2026-04-08 15:19:32 +08:00
parent 6b900ccb60
commit 3bfe2e0203
4 changed files with 316 additions and 7 deletions

View File

@ -1,32 +1,56 @@
"""Middleware for logging LLM token usage.""" """Middleware for logging LLM token usage and optionally reporting to an external billing API."""
import asyncio
import logging import logging
from typing import override from typing import override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.app_config import get_app_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TokenUsageMiddleware(AgentMiddleware): class TokenUsageMiddleware(AgentMiddleware):
"""Logs token usage from model response usage_metadata.""" """Logs token usage from model response usage_metadata and optionally reports to an external billing API."""
@override @override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state) cfg = get_app_config().token_usage
logger.info(
"[TokenUsageMiddleware] after_model triggered: enabled=%s report_enabled=%s",
getattr(cfg, "enabled", False),
getattr(cfg, "report_enabled", False),
)
if getattr(cfg, "enabled", False):
self._log_usage(state)
return None
@override @override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state) cfg = get_app_config().token_usage
logger.info(
"[TokenUsageMiddleware] aafter_model triggered: enabled=%s report_enabled=%s",
getattr(cfg, "enabled", False),
getattr(cfg, "report_enabled", False),
)
if getattr(cfg, "enabled", False):
self._log_usage(state)
if getattr(cfg, "report_enabled", False):
_schedule_usage_report(state, runtime)
return None
def _log_usage(self, state: AgentState) -> None: def _log_usage(self, state: AgentState) -> None:
messages = state.get("messages", []) messages = state.get("messages", [])
logger.info("[TokenUsageMiddleware] _log_usage messages_count=%s", len(messages))
if not messages: if not messages:
logger.info("[TokenUsageMiddleware] _log_usage skip: no messages")
return None return None
last = messages[-1] usage = _extract_usage_from_messages(messages)
usage = getattr(last, "usage_metadata", None)
if usage: if usage:
logger.info( logger.info(
"LLM token usage: input=%s output=%s total=%s", "LLM token usage: input=%s output=%s total=%s",
@ -35,3 +59,238 @@ class TokenUsageMiddleware(AgentMiddleware):
usage.get("total_tokens", "?"), usage.get("total_tokens", "?"),
) )
return None return None
def _schedule_usage_report(state: AgentState, runtime: Runtime) -> None:
"""Fire-and-forget: schedule a billing report request after each model call."""
cfg = get_app_config().token_usage
logger.info("[TokenUsageMiddleware] _schedule_usage_report entered")
messages = state.get("messages", [])
if not messages:
logger.info("[TokenUsageMiddleware] skip report: no messages")
return
last = messages[-1]
usage = _extract_usage_from_messages(messages)
if not usage:
logger.info("[TokenUsageMiddleware] skip report: no token usage found")
_log_message_diagnostics(last)
return
thread_id = _extract_thread_id(runtime)
model_key = _extract_model_key(runtime, last)
model_name = model_key
if model_key:
model_cfg = get_app_config().get_model_config(model_key)
if model_cfg and model_cfg.display_name:
model_name = model_cfg.display_name
transaction_no = _extract_transaction_no(last)
if not transaction_no:
transaction_no = 0
question = _extract_latest_question(messages)
if isinstance(question, str) and len(question) > 27:
question = question[:27] + "。。。"
payload: dict = {
"sessionId": thread_id,
"inputToken": usage.get("input_tokens"),
"outputToken": usage.get("output_tokens"),
"totalTokens": usage.get("total_tokens"),
"modelName": model_name,
"question": question,
"transactionNo": transaction_no,
"remark": "",
}
logger.info("Token billing payload: %s", payload)
if not cfg.report_url:
logger.info("[TokenUsageMiddleware] skip report: report_url is empty")
return
asyncio.create_task(
_post_usage_report(cfg.report_url, cfg.report_headers, payload),
name="token_usage_report",
)
async def _post_usage_report(url: str, headers: dict[str, str], payload: dict) -> None:
"""POST billing payload to the configured external endpoint."""
try:
import httpx
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(url, json=payload, headers=headers)
if response.status_code >= 400:
logger.warning(
"Token billing report HTTP %s from %s",
response.status_code,
url,
)
except Exception as exc:
logger.warning("Failed to report token billing to %s: %s", url, exc)
def _extract_latest_question(messages: list) -> str:
"""Extract latest user question text from message history."""
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
content = getattr(msg, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
elif isinstance(part, dict):
text = part.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(p for p in parts if p)
return str(content)
return ""
def _extract_transaction_no(last_message: object) -> str | None:
"""Get provider response id from AI message for billing transactionNo."""
msg_id = getattr(last_message, "id", None)
if isinstance(msg_id, str) and msg_id:
return msg_id
response_metadata = getattr(last_message, "response_metadata", None)
if isinstance(response_metadata, dict):
meta_id = response_metadata.get("id") or response_metadata.get("response_id")
if isinstance(meta_id, str) and meta_id:
return meta_id
additional_kwargs = getattr(last_message, "additional_kwargs", None)
if isinstance(additional_kwargs, dict):
kw_id = additional_kwargs.get("id")
if isinstance(kw_id, str) and kw_id:
return kw_id
return None
def _extract_thread_id(runtime: Runtime) -> str | None:
"""Extract thread_id safely across LangGraph runtime variants."""
context = getattr(runtime, "context", None)
if isinstance(context, dict):
thread_id = context.get("thread_id")
if isinstance(thread_id, str) and thread_id:
return thread_id
config = getattr(runtime, "config", None)
if isinstance(config, dict):
thread_id = config.get("configurable", {}).get("thread_id")
if isinstance(thread_id, str) and thread_id:
return thread_id
return None
def _extract_model_key(runtime: Runtime, last_message: object) -> str | None:
"""Extract model key/name safely from runtime config or message metadata."""
config = getattr(runtime, "config", None)
if isinstance(config, dict):
configurable = config.get("configurable", {})
model_key = configurable.get("model") or configurable.get("model_name")
if isinstance(model_key, str) and model_key:
return model_key
response_metadata = getattr(last_message, "response_metadata", None)
if isinstance(response_metadata, dict):
model_key = response_metadata.get("model_name") or response_metadata.get("model")
if isinstance(model_key, str) and model_key:
return model_key
return None
def _extract_usage(last_message: object) -> dict[str, int] | None:
"""Extract token usage from common provider/LangChain metadata shapes."""
# Primary LangChain shape.
usage_metadata = getattr(last_message, "usage_metadata", None)
usage = _normalize_usage_dict(usage_metadata)
if usage:
logger.info("[TokenUsageMiddleware] usage source=usage_metadata")
return usage
# Common provider shape on AIMessage.response_metadata.usage
response_metadata = getattr(last_message, "response_metadata", None)
if isinstance(response_metadata, dict):
usage = _normalize_usage_dict(response_metadata.get("usage"))
if usage:
logger.info("[TokenUsageMiddleware] usage source=response_metadata.usage")
return usage
usage = _normalize_usage_dict(response_metadata.get("token_usage"))
if usage:
logger.info("[TokenUsageMiddleware] usage source=response_metadata.token_usage")
return usage
# Some providers attach usage-like payloads to additional_kwargs.
additional_kwargs = getattr(last_message, "additional_kwargs", None)
if isinstance(additional_kwargs, dict):
usage = _normalize_usage_dict(additional_kwargs.get("usage"))
if usage:
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.usage")
return usage
usage = _normalize_usage_dict(additional_kwargs.get("token_usage"))
if usage:
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.token_usage")
return usage
return None
def _extract_usage_from_messages(messages: list) -> dict[str, int] | None:
"""Find token usage from the most recent message that contains it."""
for msg in reversed(messages):
usage = _extract_usage(msg)
if usage:
return usage
return None
def _log_message_diagnostics(last_message: object) -> None:
"""Log lightweight diagnostics for message metadata shape when usage is missing."""
response_metadata = getattr(last_message, "response_metadata", None)
additional_kwargs = getattr(last_message, "additional_kwargs", None)
logger.info(
"[TokenUsageMiddleware] diagnostics: message_type=%s has_usage_metadata=%s response_metadata_keys=%s additional_kwargs_keys=%s",
type(last_message).__name__,
getattr(last_message, "usage_metadata", None) is not None,
sorted(response_metadata.keys()) if isinstance(response_metadata, dict) else [],
sorted(additional_kwargs.keys()) if isinstance(additional_kwargs, dict) else [],
)
def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None:
"""Normalize token usage keys to input_tokens/output_tokens/total_tokens."""
if not isinstance(raw_usage, dict):
return None
input_tokens = raw_usage.get("input_tokens")
if input_tokens is None:
input_tokens = raw_usage.get("prompt_tokens")
output_tokens = raw_usage.get("output_tokens")
if output_tokens is None:
output_tokens = raw_usage.get("completion_tokens")
total_tokens = raw_usage.get("total_tokens")
if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int):
total_tokens = input_tokens + output_tokens
if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)):
return None
return {
"input_tokens": int(input_tokens or 0),
"output_tokens": int(output_tokens or 0),
"total_tokens": int(total_tokens or 0),
}

View File

@ -5,3 +5,12 @@ class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking.""" """Configuration for token usage tracking."""
enabled: bool = Field(default=False, description="Enable token usage tracking middleware") enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
report_enabled: bool = Field(default=False, description="Enable reporting of token usage to external billing API")
report_url: str | None = Field(
default=None,
description="HTTP(S) endpoint to POST token billing requests to. If unset, external billing reporting is disabled.",
)
report_headers: dict[str, str] = Field(
default_factory=dict,
description="Extra HTTP headers included in each billing request (e.g. Authorization: Bearer <token>).",
)

View File

@ -78,6 +78,32 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif "reasoning_effort" not in model_settings_from_config: elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium" model_settings_from_config["reasoning_effort"] = "medium"
# For ChatOpenAI-compatible providers, request usage in streaming responses
# so token middleware can reliably read usage metadata.
try:
from langchain_openai import ChatOpenAI
if issubclass(model_class, ChatOpenAI):
has_stream_usage = "stream_usage" in model_settings_from_config or "stream_usage" in kwargs
if not has_stream_usage:
model_settings_from_config["stream_usage"] = True
# Some OpenAI-compatible providers only return usage in streaming mode
# when stream_options.include_usage is explicitly enabled.
stream_options_source = "kwargs" if "stream_options" in kwargs else "config"
stream_options = kwargs.get("stream_options") if stream_options_source == "kwargs" else model_settings_from_config.get("stream_options")
if stream_options is None:
model_settings_from_config["stream_options"] = {"include_usage": True}
elif isinstance(stream_options, dict) and "include_usage" not in stream_options:
patched_stream_options = {**stream_options, "include_usage": True}
if stream_options_source == "kwargs":
kwargs["stream_options"] = patched_stream_options
else:
model_settings_from_config["stream_options"] = patched_stream_options
except Exception:
# Keep model creation robust when langchain_openai isn't available.
pass
model_instance = model_class(**kwargs, **model_settings_from_config) model_instance = model_class(**kwargs, **model_settings_from_config)
callbacks = build_tracing_callbacks() callbacks = build_tracing_callbacks()

View File

@ -25,8 +25,23 @@ log_level: info
# ============================================================================ # ============================================================================
# Track LLM token usage per model call (input/output/total tokens) # Track LLM token usage per model call (input/output/total tokens)
# Logs at info level via TokenUsageMiddleware # Logs at info level via TokenUsageMiddleware
token_usage: token_usage:
enabled: false enabled: false # Whether to log token usage to local logs
report_enabled: false # Whether to report token usage to external billing API
# Optional: POST billing records to an external HTTP API after every model call.
# Payload schema follows the billing endpoint contract:
# sessionId, inputToken, outputToken, totalTokens, modelName, question,
# transactionNo, remark.
# transactionNo comes from model response id (AIMessage.id). If missing,
# middleware will skip reporting to avoid fake billing transaction IDs.
# report_url: "http://localhost:19001/api/account/reduceBalanceWithToken"
# Optional: HTTP headers sent with every billing request (e.g. auth).
# report_headers:
# Authorization: "Bearer your-secret-token"
# X-App-Id: "deer-flow"
# ============================================================================ # ============================================================================
# Models Configuration # Models Configuration