diff --git a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py index 59c3423d..ba0b8514 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py @@ -1,32 +1,56 @@ -"""Middleware for logging LLM token usage.""" +"""Middleware for logging LLM token usage and optionally reporting to an external billing API.""" +import asyncio import logging from typing import override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import HumanMessage from langgraph.runtime import Runtime +from deerflow.config.app_config import get_app_config + logger = logging.getLogger(__name__) class TokenUsageMiddleware(AgentMiddleware): - """Logs token usage from model response usage_metadata.""" + """Logs token usage from model response usage_metadata and optionally reports to an external billing API.""" + @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: - return self._log_usage(state) + cfg = get_app_config().token_usage + logger.info( + "[TokenUsageMiddleware] after_model triggered: enabled=%s report_enabled=%s", + getattr(cfg, "enabled", False), + getattr(cfg, "report_enabled", False), + ) + if getattr(cfg, "enabled", False): + self._log_usage(state) + return None @override async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: - return self._log_usage(state) + cfg = get_app_config().token_usage + logger.info( + "[TokenUsageMiddleware] aafter_model triggered: enabled=%s report_enabled=%s", + getattr(cfg, "enabled", False), + getattr(cfg, "report_enabled", False), + ) + if getattr(cfg, "enabled", False): + self._log_usage(state) + if getattr(cfg, "report_enabled", False): + _schedule_usage_report(state, runtime) + return None def _log_usage(self, state: AgentState) -> None: messages = state.get("messages", []) + logger.info("[TokenUsageMiddleware] _log_usage messages_count=%s", len(messages)) if not messages: + logger.info("[TokenUsageMiddleware] _log_usage skip: no messages") return None - last = messages[-1] - usage = getattr(last, "usage_metadata", None) + usage = _extract_usage_from_messages(messages) if usage: logger.info( "LLM token usage: input=%s output=%s total=%s", @@ -35,3 +59,238 @@ class TokenUsageMiddleware(AgentMiddleware): usage.get("total_tokens", "?"), ) return None + + +def _schedule_usage_report(state: AgentState, runtime: Runtime) -> None: + """Fire-and-forget: schedule a billing report request after each model call.""" + cfg = get_app_config().token_usage + logger.info("[TokenUsageMiddleware] _schedule_usage_report entered") + + messages = state.get("messages", []) + if not messages: + logger.info("[TokenUsageMiddleware] skip report: no messages") + return + last = messages[-1] + usage = _extract_usage_from_messages(messages) + if not usage: + logger.info("[TokenUsageMiddleware] skip report: no token usage found") + _log_message_diagnostics(last) + return + + thread_id = _extract_thread_id(runtime) + model_key = _extract_model_key(runtime, last) + model_name = model_key + if model_key: + model_cfg = get_app_config().get_model_config(model_key) + if model_cfg and model_cfg.display_name: + model_name = model_cfg.display_name + + transaction_no = _extract_transaction_no(last) + if not transaction_no: + transaction_no = 0 + + question = _extract_latest_question(messages) + if isinstance(question, str) and len(question) > 27: + question = question[:27] + "。。。" + payload: dict = { + "sessionId": thread_id, + "inputToken": usage.get("input_tokens"), + "outputToken": usage.get("output_tokens"), + "totalTokens": usage.get("total_tokens"), + "modelName": model_name, + "question": question, + "transactionNo": transaction_no, + "remark": "", + } + + logger.info("Token billing payload: %s", payload) + + if not cfg.report_url: + logger.info("[TokenUsageMiddleware] skip report: report_url is empty") + return + + asyncio.create_task( + _post_usage_report(cfg.report_url, cfg.report_headers, payload), + name="token_usage_report", + ) + + +async def _post_usage_report(url: str, headers: dict[str, str], payload: dict) -> None: + """POST billing payload to the configured external endpoint.""" + try: + import httpx + + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.post(url, json=payload, headers=headers) + if response.status_code >= 400: + logger.warning( + "Token billing report HTTP %s from %s", + response.status_code, + url, + ) + except Exception as exc: + logger.warning("Failed to report token billing to %s: %s", url, exc) + + +def _extract_latest_question(messages: list) -> str: + """Extract latest user question text from message history.""" + for msg in reversed(messages): + if isinstance(msg, HumanMessage): + content = getattr(msg, "content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if isinstance(part, str): + parts.append(part) + elif isinstance(part, dict): + text = part.get("text") + if isinstance(text, str): + parts.append(text) + return "\n".join(p for p in parts if p) + return str(content) + return "" + + +def _extract_transaction_no(last_message: object) -> str | None: + """Get provider response id from AI message for billing transactionNo.""" + msg_id = getattr(last_message, "id", None) + if isinstance(msg_id, str) and msg_id: + return msg_id + + response_metadata = getattr(last_message, "response_metadata", None) + if isinstance(response_metadata, dict): + meta_id = response_metadata.get("id") or response_metadata.get("response_id") + if isinstance(meta_id, str) and meta_id: + return meta_id + + additional_kwargs = getattr(last_message, "additional_kwargs", None) + if isinstance(additional_kwargs, dict): + kw_id = additional_kwargs.get("id") + if isinstance(kw_id, str) and kw_id: + return kw_id + + return None + + +def _extract_thread_id(runtime: Runtime) -> str | None: + """Extract thread_id safely across LangGraph runtime variants.""" + context = getattr(runtime, "context", None) + if isinstance(context, dict): + thread_id = context.get("thread_id") + if isinstance(thread_id, str) and thread_id: + return thread_id + + config = getattr(runtime, "config", None) + if isinstance(config, dict): + thread_id = config.get("configurable", {}).get("thread_id") + if isinstance(thread_id, str) and thread_id: + return thread_id + + return None + + +def _extract_model_key(runtime: Runtime, last_message: object) -> str | None: + """Extract model key/name safely from runtime config or message metadata.""" + config = getattr(runtime, "config", None) + if isinstance(config, dict): + configurable = config.get("configurable", {}) + model_key = configurable.get("model") or configurable.get("model_name") + if isinstance(model_key, str) and model_key: + return model_key + + response_metadata = getattr(last_message, "response_metadata", None) + if isinstance(response_metadata, dict): + model_key = response_metadata.get("model_name") or response_metadata.get("model") + if isinstance(model_key, str) and model_key: + return model_key + + return None + + +def _extract_usage(last_message: object) -> dict[str, int] | None: + """Extract token usage from common provider/LangChain metadata shapes.""" + # Primary LangChain shape. + usage_metadata = getattr(last_message, "usage_metadata", None) + usage = _normalize_usage_dict(usage_metadata) + if usage: + logger.info("[TokenUsageMiddleware] usage source=usage_metadata") + return usage + + # Common provider shape on AIMessage.response_metadata.usage + response_metadata = getattr(last_message, "response_metadata", None) + if isinstance(response_metadata, dict): + usage = _normalize_usage_dict(response_metadata.get("usage")) + if usage: + logger.info("[TokenUsageMiddleware] usage source=response_metadata.usage") + return usage + + usage = _normalize_usage_dict(response_metadata.get("token_usage")) + if usage: + logger.info("[TokenUsageMiddleware] usage source=response_metadata.token_usage") + return usage + + # Some providers attach usage-like payloads to additional_kwargs. + additional_kwargs = getattr(last_message, "additional_kwargs", None) + if isinstance(additional_kwargs, dict): + usage = _normalize_usage_dict(additional_kwargs.get("usage")) + if usage: + logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.usage") + return usage + + usage = _normalize_usage_dict(additional_kwargs.get("token_usage")) + if usage: + logger.info("[TokenUsageMiddleware] usage source=additional_kwargs.token_usage") + return usage + + return None + + +def _extract_usage_from_messages(messages: list) -> dict[str, int] | None: + """Find token usage from the most recent message that contains it.""" + for msg in reversed(messages): + usage = _extract_usage(msg) + if usage: + return usage + return None + + +def _log_message_diagnostics(last_message: object) -> None: + """Log lightweight diagnostics for message metadata shape when usage is missing.""" + response_metadata = getattr(last_message, "response_metadata", None) + additional_kwargs = getattr(last_message, "additional_kwargs", None) + logger.info( + "[TokenUsageMiddleware] diagnostics: message_type=%s has_usage_metadata=%s response_metadata_keys=%s additional_kwargs_keys=%s", + type(last_message).__name__, + getattr(last_message, "usage_metadata", None) is not None, + sorted(response_metadata.keys()) if isinstance(response_metadata, dict) else [], + sorted(additional_kwargs.keys()) if isinstance(additional_kwargs, dict) else [], + ) + + +def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None: + """Normalize token usage keys to input_tokens/output_tokens/total_tokens.""" + if not isinstance(raw_usage, dict): + return None + + input_tokens = raw_usage.get("input_tokens") + if input_tokens is None: + input_tokens = raw_usage.get("prompt_tokens") + + output_tokens = raw_usage.get("output_tokens") + if output_tokens is None: + output_tokens = raw_usage.get("completion_tokens") + + total_tokens = raw_usage.get("total_tokens") + if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int): + total_tokens = input_tokens + output_tokens + + if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)): + return None + + return { + "input_tokens": int(input_tokens or 0), + "output_tokens": int(output_tokens or 0), + "total_tokens": int(total_tokens or 0), + } diff --git a/backend/packages/harness/deerflow/config/token_usage_config.py b/backend/packages/harness/deerflow/config/token_usage_config.py index ab1e2629..15834c14 100644 --- a/backend/packages/harness/deerflow/config/token_usage_config.py +++ b/backend/packages/harness/deerflow/config/token_usage_config.py @@ -5,3 +5,12 @@ class TokenUsageConfig(BaseModel): """Configuration for token usage tracking.""" enabled: bool = Field(default=False, description="Enable token usage tracking middleware") + report_enabled: bool = Field(default=False, description="Enable reporting of token usage to external billing API") + report_url: str | None = Field( + default=None, + description="HTTP(S) endpoint to POST token billing requests to. If unset, external billing reporting is disabled.", + ) + report_headers: dict[str, str] = Field( + default_factory=dict, + description="Extra HTTP headers included in each billing request (e.g. Authorization: Bearer ).", + ) diff --git a/backend/packages/harness/deerflow/models/factory.py b/backend/packages/harness/deerflow/models/factory.py index 51332c5e..b17f4577 100644 --- a/backend/packages/harness/deerflow/models/factory.py +++ b/backend/packages/harness/deerflow/models/factory.py @@ -78,6 +78,32 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, * elif "reasoning_effort" not in model_settings_from_config: model_settings_from_config["reasoning_effort"] = "medium" + # For ChatOpenAI-compatible providers, request usage in streaming responses + # so token middleware can reliably read usage metadata. + try: + from langchain_openai import ChatOpenAI + + if issubclass(model_class, ChatOpenAI): + has_stream_usage = "stream_usage" in model_settings_from_config or "stream_usage" in kwargs + if not has_stream_usage: + model_settings_from_config["stream_usage"] = True + + # Some OpenAI-compatible providers only return usage in streaming mode + # when stream_options.include_usage is explicitly enabled. + stream_options_source = "kwargs" if "stream_options" in kwargs else "config" + stream_options = kwargs.get("stream_options") if stream_options_source == "kwargs" else model_settings_from_config.get("stream_options") + if stream_options is None: + model_settings_from_config["stream_options"] = {"include_usage": True} + elif isinstance(stream_options, dict) and "include_usage" not in stream_options: + patched_stream_options = {**stream_options, "include_usage": True} + if stream_options_source == "kwargs": + kwargs["stream_options"] = patched_stream_options + else: + model_settings_from_config["stream_options"] = patched_stream_options + except Exception: + # Keep model creation robust when langchain_openai isn't available. + pass + model_instance = model_class(**kwargs, **model_settings_from_config) callbacks = build_tracing_callbacks() diff --git a/config.example.yaml b/config.example.yaml index d6f38259..d1deda46 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -25,8 +25,23 @@ log_level: info # ============================================================================ # Track LLM token usage per model call (input/output/total tokens) # Logs at info level via TokenUsageMiddleware + token_usage: - enabled: false + enabled: false # Whether to log token usage to local logs + report_enabled: false # Whether to report token usage to external billing API + + # Optional: POST billing records to an external HTTP API after every model call. + # Payload schema follows the billing endpoint contract: + # sessionId, inputToken, outputToken, totalTokens, modelName, question, + # transactionNo, remark. + # transactionNo comes from model response id (AIMessage.id). If missing, + # middleware will skip reporting to avoid fake billing transaction IDs. + # report_url: "http://localhost:19001/api/account/reduceBalanceWithToken" + + # Optional: HTTP headers sent with every billing request (e.g. auth). + # report_headers: + # Authorization: "Bearer your-secret-token" + # X-App-Id: "deer-flow" # ============================================================================ # Models Configuration