feat(token-usage): Rewrite middleware to support external billing API reporting
This commit is contained in:
parent
6b900ccb60
commit
3bfe2e0203
|
|
@ -1,32 +1,56 @@
|
||||||
"""Middleware for logging LLM token usage."""
|
"""Middleware for logging LLM token usage and optionally reporting to an external billing API."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import override
|
from typing import override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from deerflow.config.app_config import get_app_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TokenUsageMiddleware(AgentMiddleware):
|
class TokenUsageMiddleware(AgentMiddleware):
|
||||||
"""Logs token usage from model response usage_metadata."""
|
"""Logs token usage from model response usage_metadata and optionally reports to an external billing API."""
|
||||||
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
return self._log_usage(state)
|
cfg = get_app_config().token_usage
|
||||||
|
logger.info(
|
||||||
|
"[TokenUsageMiddleware] after_model triggered: enabled=%s report_enabled=%s",
|
||||||
|
getattr(cfg, "enabled", False),
|
||||||
|
getattr(cfg, "report_enabled", False),
|
||||||
|
)
|
||||||
|
if getattr(cfg, "enabled", False):
|
||||||
|
self._log_usage(state)
|
||||||
|
return None
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
return self._log_usage(state)
|
cfg = get_app_config().token_usage
|
||||||
|
logger.info(
|
||||||
|
"[TokenUsageMiddleware] aafter_model triggered: enabled=%s report_enabled=%s",
|
||||||
|
getattr(cfg, "enabled", False),
|
||||||
|
getattr(cfg, "report_enabled", False),
|
||||||
|
)
|
||||||
|
if getattr(cfg, "enabled", False):
|
||||||
|
self._log_usage(state)
|
||||||
|
if getattr(cfg, "report_enabled", False):
|
||||||
|
_schedule_usage_report(state, runtime)
|
||||||
|
return None
|
||||||
|
|
||||||
def _log_usage(self, state: AgentState) -> None:
|
def _log_usage(self, state: AgentState) -> None:
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
|
logger.info("[TokenUsageMiddleware] _log_usage messages_count=%s", len(messages))
|
||||||
if not messages:
|
if not messages:
|
||||||
|
logger.info("[TokenUsageMiddleware] _log_usage skip: no messages")
|
||||||
return None
|
return None
|
||||||
last = messages[-1]
|
usage = _extract_usage_from_messages(messages)
|
||||||
usage = getattr(last, "usage_metadata", None)
|
|
||||||
if usage:
|
if usage:
|
||||||
logger.info(
|
logger.info(
|
||||||
"LLM token usage: input=%s output=%s total=%s",
|
"LLM token usage: input=%s output=%s total=%s",
|
||||||
|
|
@ -35,3 +59,238 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||||
usage.get("total_tokens", "?"),
|
usage.get("total_tokens", "?"),
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _schedule_usage_report(state: AgentState, runtime: Runtime) -> None:
|
||||||
|
"""Fire-and-forget: schedule a billing report request after each model call."""
|
||||||
|
cfg = get_app_config().token_usage
|
||||||
|
logger.info("[TokenUsageMiddleware] _schedule_usage_report entered")
|
||||||
|
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
logger.info("[TokenUsageMiddleware] skip report: no messages")
|
||||||
|
return
|
||||||
|
last = messages[-1]
|
||||||
|
usage = _extract_usage_from_messages(messages)
|
||||||
|
if not usage:
|
||||||
|
logger.info("[TokenUsageMiddleware] skip report: no token usage found")
|
||||||
|
_log_message_diagnostics(last)
|
||||||
|
return
|
||||||
|
|
||||||
|
thread_id = _extract_thread_id(runtime)
|
||||||
|
model_key = _extract_model_key(runtime, last)
|
||||||
|
model_name = model_key
|
||||||
|
if model_key:
|
||||||
|
model_cfg = get_app_config().get_model_config(model_key)
|
||||||
|
if model_cfg and model_cfg.display_name:
|
||||||
|
model_name = model_cfg.display_name
|
||||||
|
|
||||||
|
transaction_no = _extract_transaction_no(last)
|
||||||
|
if not transaction_no:
|
||||||
|
transaction_no = 0
|
||||||
|
|
||||||
|
question = _extract_latest_question(messages)
|
||||||
|
if isinstance(question, str) and len(question) > 27:
|
||||||
|
question = question[:27] + "。。。"
|
||||||
|
payload: dict = {
|
||||||
|
"sessionId": thread_id,
|
||||||
|
"inputToken": usage.get("input_tokens"),
|
||||||
|
"outputToken": usage.get("output_tokens"),
|
||||||
|
"totalTokens": usage.get("total_tokens"),
|
||||||
|
"modelName": model_name,
|
||||||
|
"question": question,
|
||||||
|
"transactionNo": transaction_no,
|
||||||
|
"remark": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Token billing payload: %s", payload)
|
||||||
|
|
||||||
|
if not cfg.report_url:
|
||||||
|
logger.info("[TokenUsageMiddleware] skip report: report_url is empty")
|
||||||
|
return
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
_post_usage_report(cfg.report_url, cfg.report_headers, payload),
|
||||||
|
name="token_usage_report",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _post_usage_report(url: str, headers: dict[str, str], payload: dict) -> None:
|
||||||
|
"""POST billing payload to the configured external endpoint."""
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
response = await client.post(url, json=payload, headers=headers)
|
||||||
|
if response.status_code >= 400:
|
||||||
|
logger.warning(
|
||||||
|
"Token billing report HTTP %s from %s",
|
||||||
|
response.status_code,
|
||||||
|
url,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to report token billing to %s: %s", url, exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_latest_question(messages: list) -> str:
|
||||||
|
"""Extract latest user question text from message history."""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
content = getattr(msg, "content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, str):
|
||||||
|
parts.append(part)
|
||||||
|
elif isinstance(part, dict):
|
||||||
|
text = part.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "\n".join(p for p in parts if p)
|
||||||
|
return str(content)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_transaction_no(last_message: object) -> str | None:
|
||||||
|
"""Get provider response id from AI message for billing transactionNo."""
|
||||||
|
msg_id = getattr(last_message, "id", None)
|
||||||
|
if isinstance(msg_id, str) and msg_id:
|
||||||
|
return msg_id
|
||||||
|
|
||||||
|
response_metadata = getattr(last_message, "response_metadata", None)
|
||||||
|
if isinstance(response_metadata, dict):
|
||||||
|
meta_id = response_metadata.get("id") or response_metadata.get("response_id")
|
||||||
|
if isinstance(meta_id, str) and meta_id:
|
||||||
|
return meta_id
|
||||||
|
|
||||||
|
additional_kwargs = getattr(last_message, "additional_kwargs", None)
|
||||||
|
if isinstance(additional_kwargs, dict):
|
||||||
|
kw_id = additional_kwargs.get("id")
|
||||||
|
if isinstance(kw_id, str) and kw_id:
|
||||||
|
return kw_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_thread_id(runtime: Runtime) -> str | None:
|
||||||
|
"""Extract thread_id safely across LangGraph runtime variants."""
|
||||||
|
context = getattr(runtime, "context", None)
|
||||||
|
if isinstance(context, dict):
|
||||||
|
thread_id = context.get("thread_id")
|
||||||
|
if isinstance(thread_id, str) and thread_id:
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
config = getattr(runtime, "config", None)
|
||||||
|
if isinstance(config, dict):
|
||||||
|
thread_id = config.get("configurable", {}).get("thread_id")
|
||||||
|
if isinstance(thread_id, str) and thread_id:
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_model_key(runtime: Runtime, last_message: object) -> str | None:
|
||||||
|
"""Extract model key/name safely from runtime config or message metadata."""
|
||||||
|
config = getattr(runtime, "config", None)
|
||||||
|
if isinstance(config, dict):
|
||||||
|
configurable = config.get("configurable", {})
|
||||||
|
model_key = configurable.get("model") or configurable.get("model_name")
|
||||||
|
if isinstance(model_key, str) and model_key:
|
||||||
|
return model_key
|
||||||
|
|
||||||
|
response_metadata = getattr(last_message, "response_metadata", None)
|
||||||
|
if isinstance(response_metadata, dict):
|
||||||
|
model_key = response_metadata.get("model_name") or response_metadata.get("model")
|
||||||
|
if isinstance(model_key, str) and model_key:
|
||||||
|
return model_key
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_usage(last_message: object) -> dict[str, int] | None:
|
||||||
|
"""Extract token usage from common provider/LangChain metadata shapes."""
|
||||||
|
# Primary LangChain shape.
|
||||||
|
usage_metadata = getattr(last_message, "usage_metadata", None)
|
||||||
|
usage = _normalize_usage_dict(usage_metadata)
|
||||||
|
if usage:
|
||||||
|
logger.info("[TokenUsageMiddleware] usage source=usage_metadata")
|
||||||
|
return usage
|
||||||
|
|
||||||
|
# Common provider shape on AIMessage.response_metadata.usage
|
||||||
|
response_metadata = getattr(last_message, "response_metadata", None)
|
||||||
|
if isinstance(response_metadata, dict):
|
||||||
|
usage = _normalize_usage_dict(response_metadata.get("usage"))
|
||||||
|
if usage:
|
||||||
|
logger.info("[TokenUsageMiddleware] usage source=response_metadata.usage")
|
||||||
|
return usage
|
||||||
|
|
||||||
|
usage = _normalize_usage_dict(response_metadata.get("token_usage"))
|
||||||
|
if usage:
|
||||||
|
logger.info("[TokenUsageMiddleware] usage source=response_metadata.token_usage")
|
||||||
|
return usage
|
||||||
|
|
||||||
|
# Some providers attach usage-like payloads to additional_kwargs.
|
||||||
|
additional_kwargs = getattr(last_message, "additional_kwargs", None)
|
||||||
|
if isinstance(additional_kwargs, dict):
|
||||||
|
usage = _normalize_usage_dict(additional_kwargs.get("usage"))
|
||||||
|
if usage:
|
||||||
|
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.usage")
|
||||||
|
return usage
|
||||||
|
|
||||||
|
usage = _normalize_usage_dict(additional_kwargs.get("token_usage"))
|
||||||
|
if usage:
|
||||||
|
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.token_usage")
|
||||||
|
return usage
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_usage_from_messages(messages: list) -> dict[str, int] | None:
|
||||||
|
"""Find token usage from the most recent message that contains it."""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
usage = _extract_usage(msg)
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _log_message_diagnostics(last_message: object) -> None:
|
||||||
|
"""Log lightweight diagnostics for message metadata shape when usage is missing."""
|
||||||
|
response_metadata = getattr(last_message, "response_metadata", None)
|
||||||
|
additional_kwargs = getattr(last_message, "additional_kwargs", None)
|
||||||
|
logger.info(
|
||||||
|
"[TokenUsageMiddleware] diagnostics: message_type=%s has_usage_metadata=%s response_metadata_keys=%s additional_kwargs_keys=%s",
|
||||||
|
type(last_message).__name__,
|
||||||
|
getattr(last_message, "usage_metadata", None) is not None,
|
||||||
|
sorted(response_metadata.keys()) if isinstance(response_metadata, dict) else [],
|
||||||
|
sorted(additional_kwargs.keys()) if isinstance(additional_kwargs, dict) else [],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None:
|
||||||
|
"""Normalize token usage keys to input_tokens/output_tokens/total_tokens."""
|
||||||
|
if not isinstance(raw_usage, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
input_tokens = raw_usage.get("input_tokens")
|
||||||
|
if input_tokens is None:
|
||||||
|
input_tokens = raw_usage.get("prompt_tokens")
|
||||||
|
|
||||||
|
output_tokens = raw_usage.get("output_tokens")
|
||||||
|
if output_tokens is None:
|
||||||
|
output_tokens = raw_usage.get("completion_tokens")
|
||||||
|
|
||||||
|
total_tokens = raw_usage.get("total_tokens")
|
||||||
|
if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int):
|
||||||
|
total_tokens = input_tokens + output_tokens
|
||||||
|
|
||||||
|
if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_tokens": int(input_tokens or 0),
|
||||||
|
"output_tokens": int(output_tokens or 0),
|
||||||
|
"total_tokens": int(total_tokens or 0),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,3 +5,12 @@ class TokenUsageConfig(BaseModel):
|
||||||
"""Configuration for token usage tracking."""
|
"""Configuration for token usage tracking."""
|
||||||
|
|
||||||
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
|
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
|
||||||
|
report_enabled: bool = Field(default=False, description="Enable reporting of token usage to external billing API")
|
||||||
|
report_url: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="HTTP(S) endpoint to POST token billing requests to. If unset, external billing reporting is disabled.",
|
||||||
|
)
|
||||||
|
report_headers: dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Extra HTTP headers included in each billing request (e.g. Authorization: Bearer <token>).",
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,32 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||||
elif "reasoning_effort" not in model_settings_from_config:
|
elif "reasoning_effort" not in model_settings_from_config:
|
||||||
model_settings_from_config["reasoning_effort"] = "medium"
|
model_settings_from_config["reasoning_effort"] = "medium"
|
||||||
|
|
||||||
|
# For ChatOpenAI-compatible providers, request usage in streaming responses
|
||||||
|
# so token middleware can reliably read usage metadata.
|
||||||
|
try:
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
if issubclass(model_class, ChatOpenAI):
|
||||||
|
has_stream_usage = "stream_usage" in model_settings_from_config or "stream_usage" in kwargs
|
||||||
|
if not has_stream_usage:
|
||||||
|
model_settings_from_config["stream_usage"] = True
|
||||||
|
|
||||||
|
# Some OpenAI-compatible providers only return usage in streaming mode
|
||||||
|
# when stream_options.include_usage is explicitly enabled.
|
||||||
|
stream_options_source = "kwargs" if "stream_options" in kwargs else "config"
|
||||||
|
stream_options = kwargs.get("stream_options") if stream_options_source == "kwargs" else model_settings_from_config.get("stream_options")
|
||||||
|
if stream_options is None:
|
||||||
|
model_settings_from_config["stream_options"] = {"include_usage": True}
|
||||||
|
elif isinstance(stream_options, dict) and "include_usage" not in stream_options:
|
||||||
|
patched_stream_options = {**stream_options, "include_usage": True}
|
||||||
|
if stream_options_source == "kwargs":
|
||||||
|
kwargs["stream_options"] = patched_stream_options
|
||||||
|
else:
|
||||||
|
model_settings_from_config["stream_options"] = patched_stream_options
|
||||||
|
except Exception:
|
||||||
|
# Keep model creation robust when langchain_openai isn't available.
|
||||||
|
pass
|
||||||
|
|
||||||
model_instance = model_class(**kwargs, **model_settings_from_config)
|
model_instance = model_class(**kwargs, **model_settings_from_config)
|
||||||
|
|
||||||
callbacks = build_tracing_callbacks()
|
callbacks = build_tracing_callbacks()
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,23 @@ log_level: info
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Track LLM token usage per model call (input/output/total tokens)
|
# Track LLM token usage per model call (input/output/total tokens)
|
||||||
# Logs at info level via TokenUsageMiddleware
|
# Logs at info level via TokenUsageMiddleware
|
||||||
|
|
||||||
token_usage:
|
token_usage:
|
||||||
enabled: false
|
enabled: false # Whether to log token usage to local logs
|
||||||
|
report_enabled: false # Whether to report token usage to external billing API
|
||||||
|
|
||||||
|
# Optional: POST billing records to an external HTTP API after every model call.
|
||||||
|
# Payload schema follows the billing endpoint contract:
|
||||||
|
# sessionId, inputToken, outputToken, totalTokens, modelName, question,
|
||||||
|
# transactionNo, remark.
|
||||||
|
# transactionNo comes from model response id (AIMessage.id). If missing,
|
||||||
|
# middleware will skip reporting to avoid fake billing transaction IDs.
|
||||||
|
# report_url: "http://localhost:19001/api/account/reduceBalanceWithToken"
|
||||||
|
|
||||||
|
# Optional: HTTP headers sent with every billing request (e.g. auth).
|
||||||
|
# report_headers:
|
||||||
|
# Authorization: "Bearer your-secret-token"
|
||||||
|
# X-App-Id: "deer-flow"
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Models Configuration
|
# Models Configuration
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue