"""Middleware for external billing reservation/finalization per model call.""" from __future__ import annotations import logging from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, override from uuid import uuid4 from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse from langchain_core.messages import AIMessage, HumanMessage from langgraph.errors import GraphBubbleUp from deerflow.config.app_config import get_app_config logger = logging.getLogger(__name__) _SUCCESS_STATUS_CODES = {200, 1000} _INSUFFICIENT_BALANCE_CODE = -1106 @dataclass class _ReserveContext: frozen_id: str call_id: str session_id: str | None model_name: str | None estimated_input_tokens: int estimated_output_tokens: int class BillingMiddleware(AgentMiddleware[AgentState]): """Reserve before model call and finalize after completion.""" @override def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: cfg = get_app_config().billing if not cfg.enabled: return handler(request) reserve_ctx, block_result = _reserve_sync(request) if block_result is not None: return block_result response: ModelCallResult | None = None finalize_reason = "success" try: response = handler(request) return response except GraphBubbleUp: finalize_reason = "cancel" raise except TimeoutError: finalize_reason = "timeout" raise except Exception: finalize_reason = "error" raise finally: if reserve_ctx is not None: _finalize_sync(request, reserve_ctx, response, finalize_reason) @override async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: cfg = get_app_config().billing if not cfg.enabled: return await handler(request) reserve_ctx, block_result = await _reserve_async(request) if block_result is not None: return block_result response: ModelCallResult | None = None finalize_reason = "success" try: response = await handler(request) return response except GraphBubbleUp: finalize_reason = "cancel" raise except TimeoutError: finalize_reason = "timeout" raise except Exception: finalize_reason = "error" raise finally: if reserve_ctx is not None: await _finalize_async(request, reserve_ctx, response, finalize_reason) def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None, str | None, int, int]: cfg = get_app_config().billing session_id = _extract_thread_id(request) run_id = _extract_run_id(request) model_key = _extract_model_key_from_runtime(request) model_name = _resolve_model_name(model_key) estimated_input_tokens = _estimate_input_tokens(request.messages) estimated_output_tokens = _resolve_estimated_output_tokens(request, model_key) question = _extract_latest_question(request.messages) call_id = run_id or str(uuid4()) expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds) payload: dict[str, Any] = { "sessionId": session_id, "callId": call_id, "modelName": model_name, "question": question, "frozenType": cfg.frozen_type, "estimatedInputTokens": estimated_input_tokens, "estimatedOutputTokens": estimated_output_tokens, "expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"), } return payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001 # Primary: use LangGraph's public runtime API to access the current RunnableConfig. # This matches the official guidance for code that needs config inside runtime-bound # execution, while middleware itself only receives ModelRequest(runtime=Runtime). try: from langgraph.config import get_config config = get_config() if isinstance(config, dict): # Depending on LangGraph API variant, run_id may live at different levels. run_id = config.get("run_id") if run_id is None: metadata = config.get("metadata") if isinstance(metadata, dict): run_id = metadata.get("run_id") if run_id is None: configurable = config.get("configurable") if isinstance(configurable, dict): run_id = configurable.get("run_id") if run_id is not None: return str(run_id) except RuntimeError: pass except Exception as exc: logger.warning("[BillingMiddleware] failed to read run_id from get_config(): %s", exc) # Fallback: LangGraph API worker sets run_id via set_logging_context() before # astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139). try: from langgraph_api.logging import worker_config as lg_worker_config worker_ctx = lg_worker_config.get() if isinstance(worker_ctx, dict): run_id = worker_ctx.get("run_id") if isinstance(run_id, str) and run_id: return run_id except Exception: pass return None def _reserve_failure_message(status_code: int | None) -> str: if status_code in _blocking_reserve_code_set(): return "The account balance is insufficient for this model call." return "Billing reservation failed. Please try again later." def _blocking_reserve_code_set() -> set[int]: cfg = get_app_config().billing return {int(code) for code in cfg.blocking_reserve_codes} def _should_block_reserve_failure(status_code: int | None) -> bool: cfg = get_app_config().billing if cfg.block_only_specific_reserve_codes: return status_code in _blocking_reserve_code_set() return cfg.fail_closed def _extract_frozen_id(payload: dict[str, Any]) -> str | None: data = payload.get("data") if not isinstance(data, dict): return None frozen_id = data.get("frozenId") if isinstance(frozen_id, str) and frozen_id: return frozen_id return None def _extract_response_status(payload: dict[str, Any]) -> int | None: status = payload.get("status") if isinstance(status, int): return status # Backward compatibility with old response schema code = payload.get("code") if isinstance(code, int): return code return None def _is_success_payload(payload: dict[str, Any]) -> bool: status = _extract_response_status(payload) if isinstance(status, int) and status in _SUCCESS_STATUS_CODES: return True # Backward compatibility with old response schema success = payload.get("success") if success is True: return True return False async def _reserve_async(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]: cfg = get_app_config().billing if not cfg.reserve_url: logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty") if _should_block_reserve_failure(None): return None, AIMessage(content="Billing reservation endpoint is not configured.") return None, None try: payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request) except ValueError as exc: logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc) if _should_block_reserve_failure(None): return None, AIMessage(content=str(exc)) return None, None logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload) response = await _post_async(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds) logger.info("[BillingMiddleware] reserve response: %s", response) if response is None: if _should_block_reserve_failure(None): return None, AIMessage(content="Billing reservation request failed.") return None, None if not _is_success_payload(response): status_code = _extract_response_status(response) logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response) if _should_block_reserve_failure(status_code): return None, AIMessage(content=_reserve_failure_message(status_code)) return None, None frozen_id = _extract_frozen_id(response) if not frozen_id: logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response) if _should_block_reserve_failure(None): return None, AIMessage(content="Billing reservation response is invalid.") return None, None call_id = payload["callId"] return ( _ReserveContext( frozen_id=frozen_id, call_id=call_id, session_id=session_id, model_name=model_name, estimated_input_tokens=estimated_input_tokens, estimated_output_tokens=estimated_output_tokens, ), None, ) def _reserve_sync(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]: cfg = get_app_config().billing if not cfg.reserve_url: logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty") if _should_block_reserve_failure(None): return None, AIMessage(content="Billing reservation endpoint is not configured.") return None, None try: payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request) except ValueError as exc: logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc) if _should_block_reserve_failure(None): return None, AIMessage(content=str(exc)) return None, None logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload) response = _post_sync(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds) logger.info("[BillingMiddleware] reserve response: %s", response) if response is None: if _should_block_reserve_failure(None): return None, AIMessage(content="Billing reservation request failed.") return None, None if not _is_success_payload(response): status_code = _extract_response_status(response) logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response) if _should_block_reserve_failure(status_code): return None, AIMessage(content=_reserve_failure_message(status_code)) return None, None frozen_id = _extract_frozen_id(response) if not frozen_id: logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response) if _should_block_reserve_failure(None): return None, AIMessage(content="Billing reservation response is invalid.") return None, None call_id = payload["callId"] return ( _ReserveContext( frozen_id=frozen_id, call_id=call_id, session_id=session_id, model_name=model_name, estimated_input_tokens=estimated_input_tokens, estimated_output_tokens=estimated_output_tokens, ), None, ) def _build_finalize_payload( request: ModelRequest, reserve_ctx: _ReserveContext, response: ModelCallResult | None, finalize_reason: str, ) -> dict[str, Any]: usage = _extract_usage(request, response) return { "frozenId": reserve_ctx.frozen_id, "finalAmount": 0, "usageInputTokens": usage.get("input_tokens") if usage else 0, "usageOutputTokens": usage.get("output_tokens") if usage else 0, "usageTotalTokens": usage.get("total_tokens") if usage else 0, "finalizeReason": finalize_reason, } async def _finalize_async( request: ModelRequest, reserve_ctx: _ReserveContext, response: ModelCallResult | None, finalize_reason: str, ) -> None: cfg = get_app_config().billing if not cfg.finalize_url: logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty") return payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason) logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload) result = await _post_async(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds) logger.info("[BillingMiddleware] finalize response: %s", result) if result is None: logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id) return if not _is_success_payload(result): logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result) def _finalize_sync( request: ModelRequest, reserve_ctx: _ReserveContext, response: ModelCallResult | None, finalize_reason: str, ) -> None: cfg = get_app_config().billing if not cfg.finalize_url: logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty") return payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason) logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload) result = _post_sync(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds) logger.info("[BillingMiddleware] finalize response: %s", result) if result is None: logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id) return if not _is_success_payload(result): logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result) def _extract_thread_id(request: ModelRequest) -> str | None: context = getattr(request.runtime, "context", None) thread_id = getattr(context, "thread_id", None) if isinstance(thread_id, str) and thread_id: return thread_id if isinstance(context, dict): thread_id = context.get("thread_id") if isinstance(thread_id, str) and thread_id: return thread_id config = getattr(request.runtime, "config", None) configurable = getattr(config, "configurable", None) thread_id = getattr(configurable, "thread_id", None) if isinstance(thread_id, str) and thread_id: return thread_id if isinstance(config, dict): thread_id = config.get("configurable", {}).get("thread_id") if isinstance(thread_id, str) and thread_id: return thread_id return None def _extract_model_key_from_runtime(request: ModelRequest) -> str | None: config = getattr(request.runtime, "config", None) configurable = getattr(config, "configurable", None) model_key = getattr(configurable, "model", None) or getattr(configurable, "model_name", None) if isinstance(model_key, str) and model_key: return model_key if isinstance(config, dict): configurable = config.get("configurable", {}) model_key = configurable.get("model") or configurable.get("model_name") if isinstance(model_key, str) and model_key: return model_key # Fall back to the model instance's own identifier model_name = getattr(request.model, "model_name", None) if isinstance(model_name, str) and model_name: return model_name return None def _resolve_model_name(model_key: str | None) -> str | None: if not model_key: return None model_cfg = get_app_config().get_model_config(model_key) if model_cfg and model_cfg.model: return model_cfg.model return model_key def _resolve_estimated_output_tokens(request: ModelRequest, model_key: str | None) -> int: cfg = get_app_config().billing if model_key: model_cfg = get_app_config().get_model_config(model_key) if model_cfg is not None: max_tokens = model_cfg.model_extra.get("max_tokens") if model_cfg.model_extra else None if isinstance(max_tokens, int) and max_tokens > 0: return max_tokens max_tokens_from_request = request.model_settings.get("max_tokens") if isinstance(max_tokens_from_request, int) and max_tokens_from_request > 0: return max_tokens_from_request # Fall back to the model instance's own max_tokens attribute max_tokens_from_model = getattr(request.model, "max_tokens", None) if isinstance(max_tokens_from_model, int) and max_tokens_from_model > 0: return max_tokens_from_model if cfg.default_estimated_output_tokens is not None: return cfg.default_estimated_output_tokens raise ValueError("Unable to resolve estimatedOutputTokens from model max_tokens.") def _estimate_input_tokens(messages: list[Any]) -> int: latest_text = _extract_latest_user_text(messages) if not latest_text: return 0 # Product requirement: use simple string-length estimation for input tokens. return len(latest_text) def _extract_latest_user_text(messages: list[Any]) -> str: for msg in reversed(messages): if isinstance(msg, HumanMessage): content = getattr(msg, "content", "") if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for part in content: if isinstance(part, str): parts.append(part) elif isinstance(part, dict): text = part.get("text") if isinstance(text, str): parts.append(text) return "\n".join(p for p in parts if p) return str(content) return "" def _extract_latest_question(messages: list[Any]) -> str: question = _extract_latest_user_text(messages) if isinstance(question, str) and len(question) > 27: return question[:27] + "。。。" return question def _extract_usage(request: ModelRequest, response: ModelCallResult | None) -> dict[str, int] | None: if response is None: usage = None else: usage = _extract_usage_from_obj(response) if usage: return usage messages = getattr(response, "messages", None) usage = _extract_usage_from_messages(messages) if usage: return usage state = getattr(request, "state", None) if isinstance(state, dict): usage = _extract_usage_from_messages(state.get("messages")) if usage: return usage runtime_context = getattr(request.runtime, "context", None) if isinstance(runtime_context, dict): usage = _extract_usage_from_messages(runtime_context.get("messages")) if usage: return usage return None def _extract_usage_from_messages(messages: object) -> dict[str, int] | None: if not isinstance(messages, list): return None for msg in reversed(messages): usage = _extract_usage_from_obj(msg) if usage: return usage return None def _extract_usage_from_obj(obj: object) -> dict[str, int] | None: usage_metadata = getattr(obj, "usage_metadata", None) usage = _normalize_usage_dict(usage_metadata) if usage: return usage response_metadata = getattr(obj, "response_metadata", None) if isinstance(response_metadata, dict): usage = _normalize_usage_dict(response_metadata.get("usage")) if usage: return usage usage = _normalize_usage_dict(response_metadata.get("token_usage")) if usage: return usage additional_kwargs = getattr(obj, "additional_kwargs", None) if isinstance(additional_kwargs, dict): usage = _normalize_usage_dict(additional_kwargs.get("usage")) if usage: return usage usage = _normalize_usage_dict(additional_kwargs.get("token_usage")) if usage: return usage return None def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None: if not isinstance(raw_usage, dict): return None input_tokens = raw_usage.get("input_tokens") if input_tokens is None: input_tokens = raw_usage.get("prompt_tokens") output_tokens = raw_usage.get("output_tokens") if output_tokens is None: output_tokens = raw_usage.get("completion_tokens") total_tokens = raw_usage.get("total_tokens") if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int): total_tokens = input_tokens + output_tokens if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)): return None return { "input_tokens": int(input_tokens or 0), "output_tokens": int(output_tokens or 0), "total_tokens": int(total_tokens or 0), } async def _post_async(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None: try: import httpx async with httpx.AsyncClient(timeout=timeout_seconds) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() if isinstance(data, dict): return data return None except Exception as exc: logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc) return None def _post_sync(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None: try: import httpx with httpx.Client(timeout=timeout_seconds) as client: response = client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() if isinstance(data, dict): return data return None except Exception as exc: logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc) return None