deerflow2/backend/packages/harness/deerflow/agents/middlewares/billing_middleware.py

630 lines
22 KiB
Python

"""Middleware for external billing reservation/finalization per model call."""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, override
from uuid import uuid4
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.errors import GraphBubbleUp
from deerflow.config.app_config import get_app_config
logger = logging.getLogger(__name__)
_SUCCESS_STATUS_CODES = {200, 1000}
_INSUFFICIENT_BALANCE_CODE = -1106
@dataclass
class _ReserveContext:
frozen_id: str
call_id: str
session_id: str | None
model_name: str | None
estimated_input_tokens: int
estimated_output_tokens: int
class BillingMiddleware(AgentMiddleware[AgentState]):
"""Reserve before model call and finalize after completion."""
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
cfg = get_app_config().billing
if not cfg.enabled:
return handler(request)
reserve_ctx, block_result = _reserve_sync(request)
if block_result is not None:
return block_result
response: ModelCallResult | None = None
finalize_reason = "success"
try:
response = handler(request)
return response
except GraphBubbleUp:
finalize_reason = "cancel"
raise
except TimeoutError:
finalize_reason = "timeout"
raise
except Exception:
finalize_reason = "error"
raise
finally:
if reserve_ctx is not None:
_finalize_sync(request, reserve_ctx, response, finalize_reason)
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
cfg = get_app_config().billing
if not cfg.enabled:
return await handler(request)
reserve_ctx, block_result = await _reserve_async(request)
if block_result is not None:
return block_result
response: ModelCallResult | None = None
finalize_reason = "success"
try:
response = await handler(request)
return response
except GraphBubbleUp:
finalize_reason = "cancel"
raise
except TimeoutError:
finalize_reason = "timeout"
raise
except Exception:
finalize_reason = "error"
raise
finally:
if reserve_ctx is not None:
await _finalize_async(request, reserve_ctx, response, finalize_reason)
def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None, str | None, int, int]:
cfg = get_app_config().billing
session_id = _extract_thread_id(request)
run_id = _extract_run_id(request)
model_key = _extract_model_key_from_runtime(request)
model_name = _resolve_model_name(model_key)
estimated_input_tokens = _estimate_input_tokens(request.messages)
estimated_output_tokens = _resolve_estimated_output_tokens(request, model_key)
call_id = run_id or str(uuid4())
if not run_id:
runtime = getattr(request, "runtime", None)
runtime_context = getattr(runtime, "context", None)
runtime_config = getattr(runtime, "config", None)
context_keys = sorted(runtime_context.keys()) if isinstance(runtime_context, dict) else []
config_keys = sorted(runtime_config.keys()) if isinstance(runtime_config, dict) else []
logger.warning(
"[BillingMiddleware] run_id missing in runtime; fallback callId=%s context_type=%s config_type=%s context_keys=%s config_keys=%s",
call_id,
type(runtime_context).__name__ if runtime_context is not None else "None",
type(runtime_config).__name__ if runtime_config is not None else "None",
context_keys,
config_keys,
)
logger.info(
"[BillingMiddleware] id mapping: thread_id=%s run_id=%s call_id=%s model_name=%s",
session_id,
run_id,
call_id,
model_name,
)
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
payload: dict[str, Any] = {
"sessionId": session_id,
"callId": call_id,
"modelName": model_name,
"frozenType": cfg.frozen_type,
"estimatedInputTokens": estimated_input_tokens,
"estimatedOutputTokens": estimated_output_tokens,
"expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"),
}
return payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens
def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001
# Primary: LangGraph injects run_id into the top-level RunnableConfig
# (langgraph_api/stream.py:218) and propagates it via var_child_runnable_config
# throughout graph node execution.
try:
from langchain_core.runnables.config import var_child_runnable_config
lc_config = var_child_runnable_config.get()
if isinstance(lc_config, dict):
run_id = lc_config.get("run_id")
if run_id is not None:
return str(run_id)
except Exception:
pass
# Fallback: LangGraph API worker sets run_id via set_logging_context() before
# astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139).
try:
from langgraph_api.logging import worker_config as lg_worker_config
worker_ctx = lg_worker_config.get()
if isinstance(worker_ctx, dict):
run_id = worker_ctx.get("run_id")
if isinstance(run_id, str) and run_id:
return run_id
except Exception:
pass
return None
def _reserve_failure_message(status_code: int | None) -> str:
if status_code in _blocking_reserve_code_set():
return "The account balance is insufficient for this model call."
return "Billing reservation failed. Please try again later."
def _blocking_reserve_code_set() -> set[int]:
cfg = get_app_config().billing
return {int(code) for code in cfg.blocking_reserve_codes}
def _should_block_reserve_failure(status_code: int | None) -> bool:
cfg = get_app_config().billing
if cfg.block_only_specific_reserve_codes:
return status_code in _blocking_reserve_code_set()
return cfg.fail_closed
def _extract_frozen_id(payload: dict[str, Any]) -> str | None:
data = payload.get("data")
if not isinstance(data, dict):
return None
frozen_id = data.get("frozenId")
if isinstance(frozen_id, str) and frozen_id:
return frozen_id
return None
def _extract_response_status(payload: dict[str, Any]) -> int | None:
status = payload.get("status")
if isinstance(status, int):
return status
# Backward compatibility with old response schema
code = payload.get("code")
if isinstance(code, int):
return code
return None
def _is_success_payload(payload: dict[str, Any]) -> bool:
status = _extract_response_status(payload)
if isinstance(status, int) and status in _SUCCESS_STATUS_CODES:
return True
# Backward compatibility with old response schema
success = payload.get("success")
if success is True:
return True
return False
async def _reserve_async(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]:
cfg = get_app_config().billing
if not cfg.reserve_url:
logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty")
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation endpoint is not configured.")
return None, None
try:
payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request)
except ValueError as exc:
logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc)
if _should_block_reserve_failure(None):
return None, AIMessage(content=str(exc))
return None, None
logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload)
response = await _post_async(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds)
logger.info("[BillingMiddleware] reserve response: %s", response)
if response is None:
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation request failed.")
return None, None
if not _is_success_payload(response):
status_code = _extract_response_status(response)
logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response)
if _should_block_reserve_failure(status_code):
return None, AIMessage(content=_reserve_failure_message(status_code))
return None, None
frozen_id = _extract_frozen_id(response)
if not frozen_id:
logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response)
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation response is invalid.")
return None, None
call_id = payload["callId"]
return (
_ReserveContext(
frozen_id=frozen_id,
call_id=call_id,
session_id=session_id,
model_name=model_name,
estimated_input_tokens=estimated_input_tokens,
estimated_output_tokens=estimated_output_tokens,
),
None,
)
def _reserve_sync(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]:
cfg = get_app_config().billing
if not cfg.reserve_url:
logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty")
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation endpoint is not configured.")
return None, None
try:
payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request)
except ValueError as exc:
logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc)
if _should_block_reserve_failure(None):
return None, AIMessage(content=str(exc))
return None, None
logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload)
response = _post_sync(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds)
logger.info("[BillingMiddleware] reserve response: %s", response)
if response is None:
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation request failed.")
return None, None
if not _is_success_payload(response):
status_code = _extract_response_status(response)
logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response)
if _should_block_reserve_failure(status_code):
return None, AIMessage(content=_reserve_failure_message(status_code))
return None, None
frozen_id = _extract_frozen_id(response)
if not frozen_id:
logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response)
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation response is invalid.")
return None, None
call_id = payload["callId"]
return (
_ReserveContext(
frozen_id=frozen_id,
call_id=call_id,
session_id=session_id,
model_name=model_name,
estimated_input_tokens=estimated_input_tokens,
estimated_output_tokens=estimated_output_tokens,
),
None,
)
def _build_finalize_payload(
request: ModelRequest,
reserve_ctx: _ReserveContext,
response: ModelCallResult | None,
finalize_reason: str,
) -> dict[str, Any]:
usage = _extract_usage(request, response)
return {
"frozenId": reserve_ctx.frozen_id,
"finalAmount": 0,
"usageInputTokens": usage.get("input_tokens") if usage else 0,
"usageOutputTokens": usage.get("output_tokens") if usage else 0,
"usageTotalTokens": usage.get("total_tokens") if usage else 0,
"finalizeReason": finalize_reason,
}
async def _finalize_async(
request: ModelRequest,
reserve_ctx: _ReserveContext,
response: ModelCallResult | None,
finalize_reason: str,
) -> None:
cfg = get_app_config().billing
if not cfg.finalize_url:
logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty")
return
payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason)
logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload)
result = await _post_async(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds)
logger.info("[BillingMiddleware] finalize response: %s", result)
if result is None:
logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id)
return
if not _is_success_payload(result):
logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result)
def _finalize_sync(
request: ModelRequest,
reserve_ctx: _ReserveContext,
response: ModelCallResult | None,
finalize_reason: str,
) -> None:
cfg = get_app_config().billing
if not cfg.finalize_url:
logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty")
return
payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason)
logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload)
result = _post_sync(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds)
logger.info("[BillingMiddleware] finalize response: %s", result)
if result is None:
logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id)
return
if not _is_success_payload(result):
logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result)
def _extract_thread_id(request: ModelRequest) -> str | None:
context = getattr(request.runtime, "context", None)
thread_id = getattr(context, "thread_id", None)
if isinstance(thread_id, str) and thread_id:
return thread_id
if isinstance(context, dict):
thread_id = context.get("thread_id")
if isinstance(thread_id, str) and thread_id:
return thread_id
config = getattr(request.runtime, "config", None)
configurable = getattr(config, "configurable", None)
thread_id = getattr(configurable, "thread_id", None)
if isinstance(thread_id, str) and thread_id:
return thread_id
if isinstance(config, dict):
thread_id = config.get("configurable", {}).get("thread_id")
if isinstance(thread_id, str) and thread_id:
return thread_id
return None
def _extract_model_key_from_runtime(request: ModelRequest) -> str | None:
config = getattr(request.runtime, "config", None)
configurable = getattr(config, "configurable", None)
model_key = getattr(configurable, "model", None) or getattr(configurable, "model_name", None)
if isinstance(model_key, str) and model_key:
return model_key
if isinstance(config, dict):
configurable = config.get("configurable", {})
model_key = configurable.get("model") or configurable.get("model_name")
if isinstance(model_key, str) and model_key:
return model_key
# Fall back to the model instance's own identifier
model_name = getattr(request.model, "model_name", None)
if isinstance(model_name, str) and model_name:
return model_name
return None
def _resolve_model_name(model_key: str | None) -> str | None:
if not model_key:
return None
model_cfg = get_app_config().get_model_config(model_key)
if model_cfg and model_cfg.display_name:
return model_cfg.display_name
return model_key
def _resolve_estimated_output_tokens(request: ModelRequest, model_key: str | None) -> int:
cfg = get_app_config().billing
if model_key:
model_cfg = get_app_config().get_model_config(model_key)
if model_cfg is not None:
max_tokens = model_cfg.model_extra.get("max_tokens") if model_cfg.model_extra else None
if isinstance(max_tokens, int) and max_tokens > 0:
return max_tokens
max_tokens_from_request = request.model_settings.get("max_tokens")
if isinstance(max_tokens_from_request, int) and max_tokens_from_request > 0:
return max_tokens_from_request
# Fall back to the model instance's own max_tokens attribute
max_tokens_from_model = getattr(request.model, "max_tokens", None)
if isinstance(max_tokens_from_model, int) and max_tokens_from_model > 0:
return max_tokens_from_model
if cfg.default_estimated_output_tokens is not None:
return cfg.default_estimated_output_tokens
raise ValueError("Unable to resolve estimatedOutputTokens from model max_tokens.")
def _estimate_input_tokens(messages: list[Any]) -> int:
latest_text = _extract_latest_user_text(messages)
if not latest_text:
return 0
# Product requirement: use simple string-length estimation for input tokens.
return len(latest_text)
def _extract_latest_user_text(messages: list[Any]) -> str:
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
content = getattr(msg, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
elif isinstance(part, dict):
text = part.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(p for p in parts if p)
return str(content)
return ""
def _extract_usage(request: ModelRequest, response: ModelCallResult | None) -> dict[str, int] | None:
if response is None:
usage = None
else:
usage = _extract_usage_from_obj(response)
if usage:
return usage
messages = getattr(response, "messages", None)
usage = _extract_usage_from_messages(messages)
if usage:
return usage
state = getattr(request, "state", None)
if isinstance(state, dict):
usage = _extract_usage_from_messages(state.get("messages"))
if usage:
return usage
runtime_context = getattr(request.runtime, "context", None)
if isinstance(runtime_context, dict):
usage = _extract_usage_from_messages(runtime_context.get("messages"))
if usage:
return usage
return None
def _extract_usage_from_messages(messages: object) -> dict[str, int] | None:
if not isinstance(messages, list):
return None
for msg in reversed(messages):
usage = _extract_usage_from_obj(msg)
if usage:
return usage
return None
def _extract_usage_from_obj(obj: object) -> dict[str, int] | None:
usage_metadata = getattr(obj, "usage_metadata", None)
usage = _normalize_usage_dict(usage_metadata)
if usage:
return usage
response_metadata = getattr(obj, "response_metadata", None)
if isinstance(response_metadata, dict):
usage = _normalize_usage_dict(response_metadata.get("usage"))
if usage:
return usage
usage = _normalize_usage_dict(response_metadata.get("token_usage"))
if usage:
return usage
additional_kwargs = getattr(obj, "additional_kwargs", None)
if isinstance(additional_kwargs, dict):
usage = _normalize_usage_dict(additional_kwargs.get("usage"))
if usage:
return usage
usage = _normalize_usage_dict(additional_kwargs.get("token_usage"))
if usage:
return usage
return None
def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None:
if not isinstance(raw_usage, dict):
return None
input_tokens = raw_usage.get("input_tokens")
if input_tokens is None:
input_tokens = raw_usage.get("prompt_tokens")
output_tokens = raw_usage.get("output_tokens")
if output_tokens is None:
output_tokens = raw_usage.get("completion_tokens")
total_tokens = raw_usage.get("total_tokens")
if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int):
total_tokens = input_tokens + output_tokens
if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)):
return None
return {
"input_tokens": int(input_tokens or 0),
"output_tokens": int(output_tokens or 0),
"total_tokens": int(total_tokens or 0),
}
async def _post_async(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None:
try:
import httpx
async with httpx.AsyncClient(timeout=timeout_seconds) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
if isinstance(data, dict):
return data
return None
except Exception as exc:
logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc)
return None
def _post_sync(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None:
try:
import httpx
with httpx.Client(timeout=timeout_seconds) as client:
response = client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
if isinstance(data, dict):
return data
return None
except Exception as exc:
logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc)
return None