297 lines
11 KiB
Python
297 lines
11 KiB
Python
"""Middleware for logging LLM token usage and optionally reporting to an external billing API."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import override
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain_core.messages import HumanMessage
|
|
from langgraph.runtime import Runtime
|
|
|
|
from deerflow.config.app_config import get_app_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TokenUsageMiddleware(AgentMiddleware):
|
|
"""Logs token usage from model response usage_metadata and optionally reports to an external billing API."""
|
|
|
|
|
|
@override
|
|
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
cfg = get_app_config().token_usage
|
|
logger.info(
|
|
"[TokenUsageMiddleware] after_model triggered: enabled=%s report_enabled=%s",
|
|
getattr(cfg, "enabled", False),
|
|
getattr(cfg, "report_enabled", False),
|
|
)
|
|
if getattr(cfg, "enabled", False):
|
|
self._log_usage(state)
|
|
return None
|
|
|
|
@override
|
|
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
cfg = get_app_config().token_usage
|
|
logger.info(
|
|
"[TokenUsageMiddleware] aafter_model triggered: enabled=%s report_enabled=%s",
|
|
getattr(cfg, "enabled", False),
|
|
getattr(cfg, "report_enabled", False),
|
|
)
|
|
if getattr(cfg, "enabled", False):
|
|
self._log_usage(state)
|
|
if getattr(cfg, "report_enabled", False):
|
|
_schedule_usage_report(state, runtime)
|
|
return None
|
|
|
|
def _log_usage(self, state: AgentState) -> None:
|
|
messages = state.get("messages", [])
|
|
logger.info("[TokenUsageMiddleware] _log_usage messages_count=%s", len(messages))
|
|
if not messages:
|
|
logger.info("[TokenUsageMiddleware] _log_usage skip: no messages")
|
|
return None
|
|
usage = _extract_usage_from_messages(messages)
|
|
if usage:
|
|
logger.info(
|
|
"LLM token usage: input=%s output=%s total=%s",
|
|
usage.get("input_tokens", "?"),
|
|
usage.get("output_tokens", "?"),
|
|
usage.get("total_tokens", "?"),
|
|
)
|
|
return None
|
|
|
|
|
|
def _schedule_usage_report(state: AgentState, runtime: Runtime) -> None:
|
|
"""Fire-and-forget: schedule a billing report request after each model call."""
|
|
cfg = get_app_config().token_usage
|
|
logger.info("[TokenUsageMiddleware] _schedule_usage_report entered")
|
|
|
|
messages = state.get("messages", [])
|
|
if not messages:
|
|
logger.info("[TokenUsageMiddleware] skip report: no messages")
|
|
return
|
|
last = messages[-1]
|
|
usage = _extract_usage_from_messages(messages)
|
|
if not usage:
|
|
logger.info("[TokenUsageMiddleware] skip report: no token usage found")
|
|
_log_message_diagnostics(last)
|
|
return
|
|
|
|
thread_id = _extract_thread_id(runtime)
|
|
model_key = _extract_model_key(runtime, last)
|
|
model_name = model_key
|
|
if model_key:
|
|
model_cfg = get_app_config().get_model_config(model_key)
|
|
if model_cfg and model_cfg.display_name:
|
|
model_name = model_cfg.display_name
|
|
|
|
transaction_no = _extract_transaction_no(last)
|
|
if not transaction_no:
|
|
transaction_no = 0
|
|
|
|
question = _extract_latest_question(messages)
|
|
if isinstance(question, str) and len(question) > 27:
|
|
question = question[:27] + "。。。"
|
|
payload: dict = {
|
|
"sessionId": thread_id,
|
|
"inputToken": usage.get("input_tokens"),
|
|
"outputToken": usage.get("output_tokens"),
|
|
"totalTokens": usage.get("total_tokens"),
|
|
"modelName": model_name,
|
|
"question": question,
|
|
"transactionNo": transaction_no,
|
|
"remark": "",
|
|
}
|
|
|
|
logger.info("Token billing payload: %s", payload)
|
|
|
|
if not cfg.report_url:
|
|
logger.info("[TokenUsageMiddleware] skip report: report_url is empty")
|
|
return
|
|
|
|
asyncio.create_task(
|
|
_post_usage_report(cfg.report_url, cfg.report_headers, payload),
|
|
name="token_usage_report",
|
|
)
|
|
|
|
|
|
async def _post_usage_report(url: str, headers: dict[str, str], payload: dict) -> None:
|
|
"""POST billing payload to the configured external endpoint."""
|
|
try:
|
|
import httpx
|
|
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
response = await client.post(url, json=payload, headers=headers)
|
|
if response.status_code >= 400:
|
|
logger.warning(
|
|
"Token billing report HTTP %s from %s",
|
|
response.status_code,
|
|
url,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("Failed to report token billing to %s: %s", url, exc)
|
|
|
|
|
|
def _extract_latest_question(messages: list) -> str:
|
|
"""Extract latest user question text from message history."""
|
|
for msg in reversed(messages):
|
|
if isinstance(msg, HumanMessage):
|
|
content = getattr(msg, "content", "")
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for part in content:
|
|
if isinstance(part, str):
|
|
parts.append(part)
|
|
elif isinstance(part, dict):
|
|
text = part.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "\n".join(p for p in parts if p)
|
|
return str(content)
|
|
return ""
|
|
|
|
|
|
def _extract_transaction_no(last_message: object) -> str | None:
|
|
"""Get provider response id from AI message for billing transactionNo."""
|
|
msg_id = getattr(last_message, "id", None)
|
|
if isinstance(msg_id, str) and msg_id:
|
|
return msg_id
|
|
|
|
response_metadata = getattr(last_message, "response_metadata", None)
|
|
if isinstance(response_metadata, dict):
|
|
meta_id = response_metadata.get("id") or response_metadata.get("response_id")
|
|
if isinstance(meta_id, str) and meta_id:
|
|
return meta_id
|
|
|
|
additional_kwargs = getattr(last_message, "additional_kwargs", None)
|
|
if isinstance(additional_kwargs, dict):
|
|
kw_id = additional_kwargs.get("id")
|
|
if isinstance(kw_id, str) and kw_id:
|
|
return kw_id
|
|
|
|
return None
|
|
|
|
|
|
def _extract_thread_id(runtime: Runtime) -> str | None:
|
|
"""Extract thread_id safely across LangGraph runtime variants."""
|
|
context = getattr(runtime, "context", None)
|
|
if isinstance(context, dict):
|
|
thread_id = context.get("thread_id")
|
|
if isinstance(thread_id, str) and thread_id:
|
|
return thread_id
|
|
|
|
config = getattr(runtime, "config", None)
|
|
if isinstance(config, dict):
|
|
thread_id = config.get("configurable", {}).get("thread_id")
|
|
if isinstance(thread_id, str) and thread_id:
|
|
return thread_id
|
|
|
|
return None
|
|
|
|
|
|
def _extract_model_key(runtime: Runtime, last_message: object) -> str | None:
|
|
"""Extract model key/name safely from runtime config or message metadata."""
|
|
config = getattr(runtime, "config", None)
|
|
if isinstance(config, dict):
|
|
configurable = config.get("configurable", {})
|
|
model_key = configurable.get("model") or configurable.get("model_name")
|
|
if isinstance(model_key, str) and model_key:
|
|
return model_key
|
|
|
|
response_metadata = getattr(last_message, "response_metadata", None)
|
|
if isinstance(response_metadata, dict):
|
|
model_key = response_metadata.get("model_name") or response_metadata.get("model")
|
|
if isinstance(model_key, str) and model_key:
|
|
return model_key
|
|
|
|
return None
|
|
|
|
|
|
def _extract_usage(last_message: object) -> dict[str, int] | None:
|
|
"""Extract token usage from common provider/LangChain metadata shapes."""
|
|
# Primary LangChain shape.
|
|
usage_metadata = getattr(last_message, "usage_metadata", None)
|
|
usage = _normalize_usage_dict(usage_metadata)
|
|
if usage:
|
|
logger.info("[TokenUsageMiddleware] usage source=usage_metadata")
|
|
return usage
|
|
|
|
# Common provider shape on AIMessage.response_metadata.usage
|
|
response_metadata = getattr(last_message, "response_metadata", None)
|
|
if isinstance(response_metadata, dict):
|
|
usage = _normalize_usage_dict(response_metadata.get("usage"))
|
|
if usage:
|
|
logger.info("[TokenUsageMiddleware] usage source=response_metadata.usage")
|
|
return usage
|
|
|
|
usage = _normalize_usage_dict(response_metadata.get("token_usage"))
|
|
if usage:
|
|
logger.info("[TokenUsageMiddleware] usage source=response_metadata.token_usage")
|
|
return usage
|
|
|
|
# Some providers attach usage-like payloads to additional_kwargs.
|
|
additional_kwargs = getattr(last_message, "additional_kwargs", None)
|
|
if isinstance(additional_kwargs, dict):
|
|
usage = _normalize_usage_dict(additional_kwargs.get("usage"))
|
|
if usage:
|
|
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.usage")
|
|
return usage
|
|
|
|
usage = _normalize_usage_dict(additional_kwargs.get("token_usage"))
|
|
if usage:
|
|
logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.token_usage")
|
|
return usage
|
|
|
|
return None
|
|
|
|
|
|
def _extract_usage_from_messages(messages: list) -> dict[str, int] | None:
|
|
"""Find token usage from the most recent message that contains it."""
|
|
for msg in reversed(messages):
|
|
usage = _extract_usage(msg)
|
|
if usage:
|
|
return usage
|
|
return None
|
|
|
|
|
|
def _log_message_diagnostics(last_message: object) -> None:
|
|
"""Log lightweight diagnostics for message metadata shape when usage is missing."""
|
|
response_metadata = getattr(last_message, "response_metadata", None)
|
|
additional_kwargs = getattr(last_message, "additional_kwargs", None)
|
|
logger.info(
|
|
"[TokenUsageMiddleware] diagnostics: message_type=%s has_usage_metadata=%s response_metadata_keys=%s additional_kwargs_keys=%s",
|
|
type(last_message).__name__,
|
|
getattr(last_message, "usage_metadata", None) is not None,
|
|
sorted(response_metadata.keys()) if isinstance(response_metadata, dict) else [],
|
|
sorted(additional_kwargs.keys()) if isinstance(additional_kwargs, dict) else [],
|
|
)
|
|
|
|
|
|
def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None:
|
|
"""Normalize token usage keys to input_tokens/output_tokens/total_tokens."""
|
|
if not isinstance(raw_usage, dict):
|
|
return None
|
|
|
|
input_tokens = raw_usage.get("input_tokens")
|
|
if input_tokens is None:
|
|
input_tokens = raw_usage.get("prompt_tokens")
|
|
|
|
output_tokens = raw_usage.get("output_tokens")
|
|
if output_tokens is None:
|
|
output_tokens = raw_usage.get("completion_tokens")
|
|
|
|
total_tokens = raw_usage.get("total_tokens")
|
|
if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int):
|
|
total_tokens = input_tokens + output_tokens
|
|
|
|
if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)):
|
|
return None
|
|
|
|
return {
|
|
"input_tokens": int(input_tokens or 0),
|
|
"output_tokens": int(output_tokens or 0),
|
|
"total_tokens": int(total_tokens or 0),
|
|
}
|