feat(token-usage): Rewrite middleware to support external billing API reporting
This commit is contained in:
parent
6b900ccb60
commit
3bfe2e0203
|
|
@ -1,32 +1,56 @@
|
|||
"""Middleware for logging LLM token usage."""
|
||||
"""Middleware for logging LLM token usage and optionally reporting to an external billing API."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenUsageMiddleware(AgentMiddleware):
|
||||
"""Logs token usage from model response usage_metadata."""
|
||||
"""Logs token usage from model response usage_metadata and optionally reports to an external billing API."""
|
||||
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
cfg = get_app_config().token_usage
|
||||
logger.info(
|
||||
"[TokenUsageMiddleware] after_model triggered: enabled=%s report_enabled=%s",
|
||||
getattr(cfg, "enabled", False),
|
||||
getattr(cfg, "report_enabled", False),
|
||||
)
|
||||
if getattr(cfg, "enabled", False):
|
||||
self._log_usage(state)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
cfg = get_app_config().token_usage
|
||||
logger.info(
|
||||
"[TokenUsageMiddleware] aafter_model triggered: enabled=%s report_enabled=%s",
|
||||
getattr(cfg, "enabled", False),
|
||||
getattr(cfg, "report_enabled", False),
|
||||
)
|
||||
if getattr(cfg, "enabled", False):
|
||||
self._log_usage(state)
|
||||
if getattr(cfg, "report_enabled", False):
|
||||
_schedule_usage_report(state, runtime)
|
||||
return None
|
||||
|
||||
def _log_usage(self, state: AgentState) -> None:
|
||||
messages = state.get("messages", [])
|
||||
logger.info("[TokenUsageMiddleware] _log_usage messages_count=%s", len(messages))
|
||||
if not messages:
|
||||
logger.info("[TokenUsageMiddleware] _log_usage skip: no messages")
|
||||
return None
|
||||
last = messages[-1]
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
usage = _extract_usage_from_messages(messages)
|
||||
if usage:
|
||||
logger.info(
|
||||
"LLM token usage: input=%s output=%s total=%s",
|
||||
|
|
@ -35,3 +59,238 @@ class TokenUsageMiddleware(AgentMiddleware):
|
|||
usage.get("total_tokens", "?"),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _schedule_usage_report(state: AgentState, runtime: Runtime) -> None:
|
||||
"""Fire-and-forget: schedule a billing report request after each model call."""
|
||||
cfg = get_app_config().token_usage
|
||||
logger.info("[TokenUsageMiddleware] _schedule_usage_report entered")
|
||||
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
logger.info("[TokenUsageMiddleware] skip report: no messages")
|
||||
return
|
||||
last = messages[-1]
|
||||
usage = _extract_usage_from_messages(messages)
|
||||
if not usage:
|
||||
logger.info("[TokenUsageMiddleware] skip report: no token usage found")
|
||||
_log_message_diagnostics(last)
|
||||
return
|
||||
|
||||
thread_id = _extract_thread_id(runtime)
|
||||
model_key = _extract_model_key(runtime, last)
|
||||
model_name = model_key
|
||||
if model_key:
|
||||
model_cfg = get_app_config().get_model_config(model_key)
|
||||
if model_cfg and model_cfg.display_name:
|
||||
model_name = model_cfg.display_name
|
||||
|
||||
transaction_no = _extract_transaction_no(last)
|
||||
if not transaction_no:
|
||||
transaction_no = 0
|
||||
|
||||
question = _extract_latest_question(messages)
|
||||
if isinstance(question, str) and len(question) > 27:
|
||||
question = question[:27] + "。。。"
|
||||
payload: dict = {
|
||||
"sessionId": thread_id,
|
||||
"inputToken": usage.get("input_tokens"),
|
||||
"outputToken": usage.get("output_tokens"),
|
||||
"totalTokens": usage.get("total_tokens"),
|
||||
"modelName": model_name,
|
||||
"question": question,
|
||||
"transactionNo": transaction_no,
|
||||
"remark": "",
|
||||
}
|
||||
|
||||
logger.info("Token billing payload: %s", payload)
|
||||
|
||||
if not cfg.report_url:
|
||||
logger.info("[TokenUsageMiddleware] skip report: report_url is empty")
|
||||
return
|
||||
|
||||
asyncio.create_task(
|
||||
_post_usage_report(cfg.report_url, cfg.report_headers, payload),
|
||||
name="token_usage_report",
|
||||
)
|
||||
|
||||
|
||||
async def _post_usage_report(url: str, headers: dict[str, str], payload: dict) -> None:
|
||||
"""POST billing payload to the configured external endpoint."""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code >= 400:
|
||||
logger.warning(
|
||||
"Token billing report HTTP %s from %s",
|
||||
response.status_code,
|
||||
url,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to report token billing to %s: %s", url, exc)
|
||||
|
||||
|
||||
def _extract_latest_question(messages: list) -> str:
|
||||
"""Extract latest user question text from message history."""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
parts.append(part)
|
||||
elif isinstance(part, dict):
|
||||
text = part.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "\n".join(p for p in parts if p)
|
||||
return str(content)
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_transaction_no(last_message: object) -> str | None:
|
||||
"""Get provider response id from AI message for billing transactionNo."""
|
||||
msg_id = getattr(last_message, "id", None)
|
||||
if isinstance(msg_id, str) and msg_id:
|
||||
return msg_id
|
||||
|
||||
response_metadata = getattr(last_message, "response_metadata", None)
|
||||
if isinstance(response_metadata, dict):
|
||||
meta_id = response_metadata.get("id") or response_metadata.get("response_id")
|
||||
if isinstance(meta_id, str) and meta_id:
|
||||
return meta_id
|
||||
|
||||
additional_kwargs = getattr(last_message, "additional_kwargs", None)
|
||||
if isinstance(additional_kwargs, dict):
|
||||
kw_id = additional_kwargs.get("id")
|
||||
if isinstance(kw_id, str) and kw_id:
|
||||
return kw_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_thread_id(runtime: Runtime) -> str | None:
|
||||
"""Extract thread_id safely across LangGraph runtime variants."""
|
||||
context = getattr(runtime, "context", None)
|
||||
if isinstance(context, dict):
|
||||
thread_id = context.get("thread_id")
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
return thread_id
|
||||
|
||||
config = getattr(runtime, "config", None)
|
||||
if isinstance(config, dict):
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
return thread_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_model_key(runtime: Runtime, last_message: object) -> str | None:
|
||||
"""Extract model key/name safely from runtime config or message metadata."""
|
||||
config = getattr(runtime, "config", None)
|
||||
if isinstance(config, dict):
|
||||
configurable = config.get("configurable", {})
|
||||
model_key = configurable.get("model") or configurable.get("model_name")
|
||||
if isinstance(model_key, str) and model_key:
|
||||
return model_key
|
||||
|
||||
response_metadata = getattr(last_message, "response_metadata", None)
|
||||
if isinstance(response_metadata, dict):
|
||||
model_key = response_metadata.get("model_name") or response_metadata.get("model")
|
||||
if isinstance(model_key, str) and model_key:
|
||||
return model_key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_usage(last_message: object) -> dict[str, int] | None:
|
||||
"""Extract token usage from common provider/LangChain metadata shapes."""
|
||||
# Primary LangChain shape.
|
||||
usage_metadata = getattr(last_message, "usage_metadata", None)
|
||||
usage = _normalize_usage_dict(usage_metadata)
|
||||
if usage:
|
||||
logger.info("[TokenUsageMiddleware] usage source=usage_metadata")
|
||||
return usage
|
||||
|
||||
# Common provider shape on AIMessage.response_metadata.usage
|
||||
response_metadata = getattr(last_message, "response_metadata", None)
|
||||
if isinstance(response_metadata, dict):
|
||||
usage = _normalize_usage_dict(response_metadata.get("usage"))
|
||||
if usage:
|
||||
logger.info("[TokenUsageMiddleware] usage source=response_metadata.usage")
|
||||
return usage
|
||||
|
||||
usage = _normalize_usage_dict(response_metadata.get("token_usage"))
|
||||
if usage:
|
||||
logger.info("[TokenUsageMiddleware] usage source=response_metadata.token_usage")
|
||||
return usage
|
||||
|
||||
# Some providers attach usage-like payloads to additional_kwargs.
|
||||
additional_kwargs = getattr(last_message, "additional_kwargs", None)
|
||||
if isinstance(additional_kwargs, dict):
|
||||
usage = _normalize_usage_dict(additional_kwargs.get("usage"))
|
||||
if usage:
|
||||
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.usage")
|
||||
return usage
|
||||
|
||||
usage = _normalize_usage_dict(additional_kwargs.get("token_usage"))
|
||||
if usage:
|
||||
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.token_usage")
|
||||
return usage
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_usage_from_messages(messages: list) -> dict[str, int] | None:
|
||||
"""Find token usage from the most recent message that contains it."""
|
||||
for msg in reversed(messages):
|
||||
usage = _extract_usage(msg)
|
||||
if usage:
|
||||
return usage
|
||||
return None
|
||||
|
||||
|
||||
def _log_message_diagnostics(last_message: object) -> None:
|
||||
"""Log lightweight diagnostics for message metadata shape when usage is missing."""
|
||||
response_metadata = getattr(last_message, "response_metadata", None)
|
||||
additional_kwargs = getattr(last_message, "additional_kwargs", None)
|
||||
logger.info(
|
||||
"[TokenUsageMiddleware] diagnostics: message_type=%s has_usage_metadata=%s response_metadata_keys=%s additional_kwargs_keys=%s",
|
||||
type(last_message).__name__,
|
||||
getattr(last_message, "usage_metadata", None) is not None,
|
||||
sorted(response_metadata.keys()) if isinstance(response_metadata, dict) else [],
|
||||
sorted(additional_kwargs.keys()) if isinstance(additional_kwargs, dict) else [],
|
||||
)
|
||||
|
||||
|
||||
def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None:
|
||||
"""Normalize token usage keys to input_tokens/output_tokens/total_tokens."""
|
||||
if not isinstance(raw_usage, dict):
|
||||
return None
|
||||
|
||||
input_tokens = raw_usage.get("input_tokens")
|
||||
if input_tokens is None:
|
||||
input_tokens = raw_usage.get("prompt_tokens")
|
||||
|
||||
output_tokens = raw_usage.get("output_tokens")
|
||||
if output_tokens is None:
|
||||
output_tokens = raw_usage.get("completion_tokens")
|
||||
|
||||
total_tokens = raw_usage.get("total_tokens")
|
||||
if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int):
|
||||
total_tokens = input_tokens + output_tokens
|
||||
|
||||
if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)):
|
||||
return None
|
||||
|
||||
return {
|
||||
"input_tokens": int(input_tokens or 0),
|
||||
"output_tokens": int(output_tokens or 0),
|
||||
"total_tokens": int(total_tokens or 0),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,3 +5,12 @@ class TokenUsageConfig(BaseModel):
|
|||
"""Configuration for token usage tracking."""
|
||||
|
||||
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
|
||||
report_enabled: bool = Field(default=False, description="Enable reporting of token usage to external billing API")
|
||||
report_url: str | None = Field(
|
||||
default=None,
|
||||
description="HTTP(S) endpoint to POST token billing requests to. If unset, external billing reporting is disabled.",
|
||||
)
|
||||
report_headers: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Extra HTTP headers included in each billing request (e.g. Authorization: Bearer <token>).",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -78,6 +78,32 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
|||
elif "reasoning_effort" not in model_settings_from_config:
|
||||
model_settings_from_config["reasoning_effort"] = "medium"
|
||||
|
||||
# For ChatOpenAI-compatible providers, request usage in streaming responses
|
||||
# so token middleware can reliably read usage metadata.
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
if issubclass(model_class, ChatOpenAI):
|
||||
has_stream_usage = "stream_usage" in model_settings_from_config or "stream_usage" in kwargs
|
||||
if not has_stream_usage:
|
||||
model_settings_from_config["stream_usage"] = True
|
||||
|
||||
# Some OpenAI-compatible providers only return usage in streaming mode
|
||||
# when stream_options.include_usage is explicitly enabled.
|
||||
stream_options_source = "kwargs" if "stream_options" in kwargs else "config"
|
||||
stream_options = kwargs.get("stream_options") if stream_options_source == "kwargs" else model_settings_from_config.get("stream_options")
|
||||
if stream_options is None:
|
||||
model_settings_from_config["stream_options"] = {"include_usage": True}
|
||||
elif isinstance(stream_options, dict) and "include_usage" not in stream_options:
|
||||
patched_stream_options = {**stream_options, "include_usage": True}
|
||||
if stream_options_source == "kwargs":
|
||||
kwargs["stream_options"] = patched_stream_options
|
||||
else:
|
||||
model_settings_from_config["stream_options"] = patched_stream_options
|
||||
except Exception:
|
||||
# Keep model creation robust when langchain_openai isn't available.
|
||||
pass
|
||||
|
||||
model_instance = model_class(**kwargs, **model_settings_from_config)
|
||||
|
||||
callbacks = build_tracing_callbacks()
|
||||
|
|
|
|||
|
|
@ -25,8 +25,23 @@ log_level: info
|
|||
# ============================================================================
|
||||
# Track LLM token usage per model call (input/output/total tokens)
|
||||
# Logs at info level via TokenUsageMiddleware
|
||||
|
||||
token_usage:
|
||||
enabled: false
|
||||
enabled: false # Whether to log token usage to local logs
|
||||
report_enabled: false # Whether to report token usage to external billing API
|
||||
|
||||
# Optional: POST billing records to an external HTTP API after every model call.
|
||||
# Payload schema follows the billing endpoint contract:
|
||||
# sessionId, inputToken, outputToken, totalTokens, modelName, question,
|
||||
# transactionNo, remark.
|
||||
# transactionNo comes from model response id (AIMessage.id). If missing,
|
||||
# middleware will skip reporting to avoid fake billing transaction IDs.
|
||||
# report_url: "http://localhost:19001/api/account/reduceBalanceWithToken"
|
||||
|
||||
# Optional: HTTP headers sent with every billing request (e.g. auth).
|
||||
# report_headers:
|
||||
# Authorization: "Bearer your-secret-token"
|
||||
# X-App-Id: "deer-flow"
|
||||
|
||||
# ============================================================================
|
||||
# Models Configuration
|
||||
|
|
|
|||
Loading…
Reference in New Issue