feat: add billing reservation and finalization middleware with configuration (pre + call_id)

This commit is contained in:
Titan 2026-04-12 15:33:37 +08:00
parent 97247c3f28
commit a5cf6c87e5
11 changed files with 1024 additions and 3 deletions

View File

@ -284,6 +284,11 @@ async def start_run(
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
if "configurable" in config and isinstance(config["configurable"], dict):
config["configurable"].setdefault("run_id", record.run_id)
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("run_id", record.run_id)
# Merge DeerFlow-specific context overrides into configurable.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).

View File

@ -294,6 +294,45 @@ title:
max_words: 6
max_chars: 60
model_name: null # Use first model in list
### Billing Reservation/Finalization
External billing can reserve before each model call and finalize after completion.
This is independent from `token_usage` reporting.
```yaml
billing:
enabled: false
include_subagents: false
fail_closed: true
block_only_specific_reserve_codes: true
blocking_reserve_codes: [-1104, -1106]
frozen_type: 1
reserve_url: http://localhost:19001/accountFrozen/frozen
finalize_url: http://localhost:19001/accountFrozen/release
timeout_seconds: 10
default_expire_seconds: 1800
# default_estimated_output_tokens: 4096
# headers:
# Authorization: Bearer your-secret-token
```
For `frozen_type=1` (token billing):
- Reserve request sends `estimatedInputTokens` and `estimatedOutputTokens`.
- `estimatedInputTokens` is estimated with a simple string-length rule from the latest user input.
- `estimatedOutputTokens` is resolved from model `max_tokens`.
- Finalize request keeps `finalAmount=0`; billing platform computes final cost from
`usageInputTokens`/`usageOutputTokens`/`usageTotalTokens`.
Reserve blocking policy:
- With `block_only_specific_reserve_codes=true` (recommended), model calls are blocked
only when reserve API returns a code in `blocking_reserve_codes` (default `[-1104, -1106]`).
- For all other failures (reserve/finalize HTTP failure, 5xx, invalid reserve response),
DeerFlow logs warnings and continues model calls.
- Set `block_only_specific_reserve_codes=false` to restore legacy `fail_closed` behavior.
If model `max_tokens` is unavailable, DeerFlow uses `default_estimated_output_tokens`
when configured.
```
### GitHub API Token (Optional for GitHub Deep Research Skill)

View File

@ -0,0 +1,629 @@
"""Middleware for external billing reservation/finalization per model call."""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, override
from uuid import uuid4
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.errors import GraphBubbleUp
from deerflow.config.app_config import get_app_config
logger = logging.getLogger(__name__)
_SUCCESS_STATUS_CODES = {200, 1000}
_INSUFFICIENT_BALANCE_CODE = -1106
@dataclass
class _ReserveContext:
frozen_id: str
call_id: str
session_id: str | None
model_name: str | None
estimated_input_tokens: int
estimated_output_tokens: int
class BillingMiddleware(AgentMiddleware[AgentState]):
"""Reserve before model call and finalize after completion."""
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
cfg = get_app_config().billing
if not cfg.enabled:
return handler(request)
reserve_ctx, block_result = _reserve_sync(request)
if block_result is not None:
return block_result
response: ModelCallResult | None = None
finalize_reason = "success"
try:
response = handler(request)
return response
except GraphBubbleUp:
finalize_reason = "cancel"
raise
except TimeoutError:
finalize_reason = "timeout"
raise
except Exception:
finalize_reason = "error"
raise
finally:
if reserve_ctx is not None:
_finalize_sync(request, reserve_ctx, response, finalize_reason)
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
cfg = get_app_config().billing
if not cfg.enabled:
return await handler(request)
reserve_ctx, block_result = await _reserve_async(request)
if block_result is not None:
return block_result
response: ModelCallResult | None = None
finalize_reason = "success"
try:
response = await handler(request)
return response
except GraphBubbleUp:
finalize_reason = "cancel"
raise
except TimeoutError:
finalize_reason = "timeout"
raise
except Exception:
finalize_reason = "error"
raise
finally:
if reserve_ctx is not None:
await _finalize_async(request, reserve_ctx, response, finalize_reason)
def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None, str | None, int, int]:
cfg = get_app_config().billing
session_id = _extract_thread_id(request)
run_id = _extract_run_id(request)
model_key = _extract_model_key_from_runtime(request)
model_name = _resolve_model_name(model_key)
estimated_input_tokens = _estimate_input_tokens(request.messages)
estimated_output_tokens = _resolve_estimated_output_tokens(request, model_key)
call_id = run_id or str(uuid4())
if not run_id:
runtime = getattr(request, "runtime", None)
runtime_context = getattr(runtime, "context", None)
runtime_config = getattr(runtime, "config", None)
context_keys = sorted(runtime_context.keys()) if isinstance(runtime_context, dict) else []
config_keys = sorted(runtime_config.keys()) if isinstance(runtime_config, dict) else []
logger.warning(
"[BillingMiddleware] run_id missing in runtime; fallback callId=%s context_type=%s config_type=%s context_keys=%s config_keys=%s",
call_id,
type(runtime_context).__name__ if runtime_context is not None else "None",
type(runtime_config).__name__ if runtime_config is not None else "None",
context_keys,
config_keys,
)
logger.info(
"[BillingMiddleware] id mapping: thread_id=%s run_id=%s call_id=%s model_name=%s",
session_id,
run_id,
call_id,
model_name,
)
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
payload: dict[str, Any] = {
"sessionId": session_id,
"callId": call_id,
"modelName": model_name,
"frozenType": cfg.frozen_type,
"estimatedInputTokens": estimated_input_tokens,
"estimatedOutputTokens": estimated_output_tokens,
"expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"),
}
return payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens
def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001
# Primary: LangGraph injects run_id into the top-level RunnableConfig
# (langgraph_api/stream.py:218) and propagates it via var_child_runnable_config
# throughout graph node execution.
try:
from langchain_core.runnables.config import var_child_runnable_config
lc_config = var_child_runnable_config.get()
if isinstance(lc_config, dict):
run_id = lc_config.get("run_id")
if run_id is not None:
return str(run_id)
except Exception:
pass
# Fallback: LangGraph API worker sets run_id via set_logging_context() before
# astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139).
try:
from langgraph_api.logging import worker_config as lg_worker_config
worker_ctx = lg_worker_config.get()
if isinstance(worker_ctx, dict):
run_id = worker_ctx.get("run_id")
if isinstance(run_id, str) and run_id:
return run_id
except Exception:
pass
return None
def _reserve_failure_message(status_code: int | None) -> str:
if status_code in _blocking_reserve_code_set():
return "The account balance is insufficient for this model call."
return "Billing reservation failed. Please try again later."
def _blocking_reserve_code_set() -> set[int]:
cfg = get_app_config().billing
return {int(code) for code in cfg.blocking_reserve_codes}
def _should_block_reserve_failure(status_code: int | None) -> bool:
cfg = get_app_config().billing
if cfg.block_only_specific_reserve_codes:
return status_code in _blocking_reserve_code_set()
return cfg.fail_closed
def _extract_frozen_id(payload: dict[str, Any]) -> str | None:
data = payload.get("data")
if not isinstance(data, dict):
return None
frozen_id = data.get("frozenId")
if isinstance(frozen_id, str) and frozen_id:
return frozen_id
return None
def _extract_response_status(payload: dict[str, Any]) -> int | None:
status = payload.get("status")
if isinstance(status, int):
return status
# Backward compatibility with old response schema
code = payload.get("code")
if isinstance(code, int):
return code
return None
def _is_success_payload(payload: dict[str, Any]) -> bool:
status = _extract_response_status(payload)
if isinstance(status, int) and status in _SUCCESS_STATUS_CODES:
return True
# Backward compatibility with old response schema
success = payload.get("success")
if success is True:
return True
return False
async def _reserve_async(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]:
cfg = get_app_config().billing
if not cfg.reserve_url:
logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty")
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation endpoint is not configured.")
return None, None
try:
payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request)
except ValueError as exc:
logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc)
if _should_block_reserve_failure(None):
return None, AIMessage(content=str(exc))
return None, None
logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload)
response = await _post_async(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds)
logger.info("[BillingMiddleware] reserve response: %s", response)
if response is None:
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation request failed.")
return None, None
if not _is_success_payload(response):
status_code = _extract_response_status(response)
logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response)
if _should_block_reserve_failure(status_code):
return None, AIMessage(content=_reserve_failure_message(status_code))
return None, None
frozen_id = _extract_frozen_id(response)
if not frozen_id:
logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response)
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation response is invalid.")
return None, None
call_id = payload["callId"]
return (
_ReserveContext(
frozen_id=frozen_id,
call_id=call_id,
session_id=session_id,
model_name=model_name,
estimated_input_tokens=estimated_input_tokens,
estimated_output_tokens=estimated_output_tokens,
),
None,
)
def _reserve_sync(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]:
cfg = get_app_config().billing
if not cfg.reserve_url:
logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty")
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation endpoint is not configured.")
return None, None
try:
payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request)
except ValueError as exc:
logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc)
if _should_block_reserve_failure(None):
return None, AIMessage(content=str(exc))
return None, None
logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload)
response = _post_sync(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds)
logger.info("[BillingMiddleware] reserve response: %s", response)
if response is None:
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation request failed.")
return None, None
if not _is_success_payload(response):
status_code = _extract_response_status(response)
logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response)
if _should_block_reserve_failure(status_code):
return None, AIMessage(content=_reserve_failure_message(status_code))
return None, None
frozen_id = _extract_frozen_id(response)
if not frozen_id:
logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response)
if _should_block_reserve_failure(None):
return None, AIMessage(content="Billing reservation response is invalid.")
return None, None
call_id = payload["callId"]
return (
_ReserveContext(
frozen_id=frozen_id,
call_id=call_id,
session_id=session_id,
model_name=model_name,
estimated_input_tokens=estimated_input_tokens,
estimated_output_tokens=estimated_output_tokens,
),
None,
)
def _build_finalize_payload(
request: ModelRequest,
reserve_ctx: _ReserveContext,
response: ModelCallResult | None,
finalize_reason: str,
) -> dict[str, Any]:
usage = _extract_usage(request, response)
return {
"frozenId": reserve_ctx.frozen_id,
"finalAmount": 0,
"usageInputTokens": usage.get("input_tokens") if usage else 0,
"usageOutputTokens": usage.get("output_tokens") if usage else 0,
"usageTotalTokens": usage.get("total_tokens") if usage else 0,
"finalizeReason": finalize_reason,
}
async def _finalize_async(
request: ModelRequest,
reserve_ctx: _ReserveContext,
response: ModelCallResult | None,
finalize_reason: str,
) -> None:
cfg = get_app_config().billing
if not cfg.finalize_url:
logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty")
return
payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason)
logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload)
result = await _post_async(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds)
logger.info("[BillingMiddleware] finalize response: %s", result)
if result is None:
logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id)
return
if not _is_success_payload(result):
logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result)
def _finalize_sync(
request: ModelRequest,
reserve_ctx: _ReserveContext,
response: ModelCallResult | None,
finalize_reason: str,
) -> None:
cfg = get_app_config().billing
if not cfg.finalize_url:
logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty")
return
payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason)
logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload)
result = _post_sync(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds)
logger.info("[BillingMiddleware] finalize response: %s", result)
if result is None:
logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id)
return
if not _is_success_payload(result):
logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result)
def _extract_thread_id(request: ModelRequest) -> str | None:
context = getattr(request.runtime, "context", None)
thread_id = getattr(context, "thread_id", None)
if isinstance(thread_id, str) and thread_id:
return thread_id
if isinstance(context, dict):
thread_id = context.get("thread_id")
if isinstance(thread_id, str) and thread_id:
return thread_id
config = getattr(request.runtime, "config", None)
configurable = getattr(config, "configurable", None)
thread_id = getattr(configurable, "thread_id", None)
if isinstance(thread_id, str) and thread_id:
return thread_id
if isinstance(config, dict):
thread_id = config.get("configurable", {}).get("thread_id")
if isinstance(thread_id, str) and thread_id:
return thread_id
return None
def _extract_model_key_from_runtime(request: ModelRequest) -> str | None:
config = getattr(request.runtime, "config", None)
configurable = getattr(config, "configurable", None)
model_key = getattr(configurable, "model", None) or getattr(configurable, "model_name", None)
if isinstance(model_key, str) and model_key:
return model_key
if isinstance(config, dict):
configurable = config.get("configurable", {})
model_key = configurable.get("model") or configurable.get("model_name")
if isinstance(model_key, str) and model_key:
return model_key
# Fall back to the model instance's own identifier
model_name = getattr(request.model, "model_name", None)
if isinstance(model_name, str) and model_name:
return model_name
return None
def _resolve_model_name(model_key: str | None) -> str | None:
if not model_key:
return None
model_cfg = get_app_config().get_model_config(model_key)
if model_cfg and model_cfg.display_name:
return model_cfg.display_name
return model_key
def _resolve_estimated_output_tokens(request: ModelRequest, model_key: str | None) -> int:
cfg = get_app_config().billing
if model_key:
model_cfg = get_app_config().get_model_config(model_key)
if model_cfg is not None:
max_tokens = model_cfg.model_extra.get("max_tokens") if model_cfg.model_extra else None
if isinstance(max_tokens, int) and max_tokens > 0:
return max_tokens
max_tokens_from_request = request.model_settings.get("max_tokens")
if isinstance(max_tokens_from_request, int) and max_tokens_from_request > 0:
return max_tokens_from_request
# Fall back to the model instance's own max_tokens attribute
max_tokens_from_model = getattr(request.model, "max_tokens", None)
if isinstance(max_tokens_from_model, int) and max_tokens_from_model > 0:
return max_tokens_from_model
if cfg.default_estimated_output_tokens is not None:
return cfg.default_estimated_output_tokens
raise ValueError("Unable to resolve estimatedOutputTokens from model max_tokens.")
def _estimate_input_tokens(messages: list[Any]) -> int:
latest_text = _extract_latest_user_text(messages)
if not latest_text:
return 0
# Product requirement: use simple string-length estimation for input tokens.
return len(latest_text)
def _extract_latest_user_text(messages: list[Any]) -> str:
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
content = getattr(msg, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
elif isinstance(part, dict):
text = part.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(p for p in parts if p)
return str(content)
return ""
def _extract_usage(request: ModelRequest, response: ModelCallResult | None) -> dict[str, int] | None:
if response is None:
usage = None
else:
usage = _extract_usage_from_obj(response)
if usage:
return usage
messages = getattr(response, "messages", None)
usage = _extract_usage_from_messages(messages)
if usage:
return usage
state = getattr(request, "state", None)
if isinstance(state, dict):
usage = _extract_usage_from_messages(state.get("messages"))
if usage:
return usage
runtime_context = getattr(request.runtime, "context", None)
if isinstance(runtime_context, dict):
usage = _extract_usage_from_messages(runtime_context.get("messages"))
if usage:
return usage
return None
def _extract_usage_from_messages(messages: object) -> dict[str, int] | None:
if not isinstance(messages, list):
return None
for msg in reversed(messages):
usage = _extract_usage_from_obj(msg)
if usage:
return usage
return None
def _extract_usage_from_obj(obj: object) -> dict[str, int] | None:
usage_metadata = getattr(obj, "usage_metadata", None)
usage = _normalize_usage_dict(usage_metadata)
if usage:
return usage
response_metadata = getattr(obj, "response_metadata", None)
if isinstance(response_metadata, dict):
usage = _normalize_usage_dict(response_metadata.get("usage"))
if usage:
return usage
usage = _normalize_usage_dict(response_metadata.get("token_usage"))
if usage:
return usage
additional_kwargs = getattr(obj, "additional_kwargs", None)
if isinstance(additional_kwargs, dict):
usage = _normalize_usage_dict(additional_kwargs.get("usage"))
if usage:
return usage
usage = _normalize_usage_dict(additional_kwargs.get("token_usage"))
if usage:
return usage
return None
def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None:
if not isinstance(raw_usage, dict):
return None
input_tokens = raw_usage.get("input_tokens")
if input_tokens is None:
input_tokens = raw_usage.get("prompt_tokens")
output_tokens = raw_usage.get("output_tokens")
if output_tokens is None:
output_tokens = raw_usage.get("completion_tokens")
total_tokens = raw_usage.get("total_tokens")
if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int):
total_tokens = input_tokens + output_tokens
if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)):
return None
return {
"input_tokens": int(input_tokens or 0),
"output_tokens": int(output_tokens or 0),
"total_tokens": int(total_tokens or 0),
}
async def _post_async(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None:
try:
import httpx
async with httpx.AsyncClient(timeout=timeout_seconds) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
if isinstance(data, dict):
return data
return None
except Exception as exc:
logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc)
return None
def _post_sync(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None:
try:
import httpx
with httpx.Client(timeout=timeout_seconds) as client:
response = client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
if isinstance(data, dict):
return data
return None
except Exception as exc:
logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc)
return None

View File

@ -91,6 +91,14 @@ def _build_runtime_middlewares(
middlewares.append(DanglingToolCallMiddleware())
from deerflow.config.app_config import get_app_config
billing_cfg = get_app_config().billing
if billing_cfg.enabled and (include_uploads or billing_cfg.include_subagents):
from deerflow.agents.middlewares.billing_middleware import BillingMiddleware
middlewares.append(BillingMiddleware())
middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured)

View File

@ -1,4 +1,5 @@
from .app_config import get_app_config
from .billing_config import BillingConfig
from .extensions_config import ExtensionsConfig, get_extensions_config
from .memory_config import MemoryConfig, get_memory_config
from .paths import Paths, get_paths
@ -13,6 +14,7 @@ from .tracing_config import (
__all__ = [
"get_app_config",
"BillingConfig",
"Paths",
"get_paths",
"SkillsConfig",

View File

@ -9,6 +9,7 @@ from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.billing_config import BillingConfig
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
@ -40,6 +41,7 @@ class AppConfig(BaseModel):
"""Config for the DeerFlow application"""
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
billing: BillingConfig = Field(default_factory=BillingConfig, description="External billing reservation/finalization configuration")
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
sandbox: SandboxConfig = Field(description="Sandbox configuration")

View File

@ -0,0 +1,62 @@
"""Configuration for reservation/finalization billing integration."""
from pydantic import BaseModel, Field
class BillingConfig(BaseModel):
"""Configuration for external billing reservation/finalization calls."""
enabled: bool = Field(default=False, description="Enable external billing middleware.")
include_subagents: bool = Field(
default=False,
description="Whether billing applies to subagent model calls as well.",
)
fail_closed: bool = Field(
default=True,
description="Block model calls when reserve request fails or balance is insufficient.",
)
block_only_specific_reserve_codes: bool = Field(
default=True,
description=(
"When true, only reserve responses with codes in blocking_reserve_codes block model calls. "
"When false, fallback to fail_closed behavior for all reserve failures."
),
)
blocking_reserve_codes: list[int] = Field(
default_factory=lambda: [-1104, -1106],
description="Reserve response codes that should block model calls when block_only_specific_reserve_codes is enabled.",
)
frozen_type: int = Field(
default=1,
ge=1,
description="Frozen type sent to the platform. Current flow uses 1 for token billing.",
)
reserve_url: str | None = Field(
default=None,
description="HTTP(S) endpoint for creating frozen reservations.",
)
finalize_url: str | None = Field(
default=None,
description="HTTP(S) endpoint for finalizing frozen reservations.",
)
headers: dict[str, str] = Field(
default_factory=dict,
description="Extra HTTP headers included in reserve/finalize requests.",
)
timeout_seconds: float = Field(
default=10.0,
gt=0,
le=120,
description="HTTP request timeout for reserve/finalize calls.",
)
default_expire_seconds: int = Field(
default=1800,
ge=60,
le=86400,
description="Default reservation expiration seconds when expireAt is included.",
)
default_estimated_output_tokens: int | None = Field(
default=None,
ge=1,
description="Fallback estimatedOutputTokens when model max_tokens is unavailable.",
)

View File

@ -89,12 +89,13 @@ async def run_agent(
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id}, store=store)
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config["context"].setdefault("run_id", run_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
runnable_config = RunnableConfig(**config)

View File

@ -226,15 +226,18 @@ class SubagentExecutor:
try:
agent = self._create_agent()
state = self._build_initial_state(task)
subagent_model_name = _get_model_name(self.config, self.parent_model)
# Build config with thread_id for sandbox access and recursion limit
run_config: RunnableConfig = {
"recursion_limit": self.config.max_turns,
}
context = {}
configurable: dict[str, Any] = {"model_name": subagent_model_name}
if self.thread_id:
run_config["configurable"] = {"thread_id": self.thread_id}
configurable["thread_id"] = self.thread_id
context["thread_id"] = self.thread_id
run_config["configurable"] = configurable
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution with max_turns={self.config.max_turns}")

View File

@ -0,0 +1,241 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.billing_middleware import BillingMiddleware
def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True):
billing = SimpleNamespace(
enabled=enabled,
include_subagents=include_subagents,
fail_closed=True,
block_only_specific_reserve_codes=True,
blocking_reserve_codes=[-1104, -1106],
frozen_type=1,
reserve_url="http://billing.local/accountFrozen/frozen",
finalize_url="http://billing.local/accountFrozen/release",
headers={"Authorization": "Bearer x"},
timeout_seconds=3.0,
default_expire_seconds=1800,
default_estimated_output_tokens=None,
)
model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096})
return SimpleNamespace(
billing=billing,
get_model_config=lambda name: model_cfg if name == "gpt-4" else None,
)
def _request_with_latest_user_text(text: str):
request = MagicMock()
request.messages = [HumanMessage(content="old"), HumanMessage(content=text)]
request.model_settings = {}
request.runtime = SimpleNamespace(
config={"configurable": {"thread_id": "thread-1", "model_name": "gpt-4"}},
context={"thread_id": "thread-1"},
)
return request
@pytest.mark.anyio
async def test_awrap_model_call_uses_estimated_tokens_and_finalizes(monkeypatch):
from langchain_core.runnables.config import var_child_runnable_config
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 11, "output_tokens": 22, "total_tokens": 33}))
token = var_child_runnable_config.set({"run_id": "run-1"})
try:
result = await middleware.awrap_model_call(request, handler)
finally:
var_child_runnable_config.reset(token)
assert isinstance(result, AIMessage)
assert len(seen_payloads) == 2
reserve_payload = seen_payloads[0][2]
assert reserve_payload["callId"] == "run-1"
assert reserve_payload["frozenType"] == 1
assert reserve_payload["estimatedInputTokens"] == len("hello world")
assert reserve_payload["estimatedOutputTokens"] == 4096
assert "frozenAmount" not in reserve_payload
finalize_payload = seen_payloads[1][2]
assert finalize_payload["frozenId"] == "frozen-123"
assert finalize_payload["finalAmount"] == 0
assert finalize_payload["usageInputTokens"] == 11
assert finalize_payload["usageOutputTokens"] == 22
assert finalize_payload["usageTotalTokens"] == 33
assert finalize_payload["finalizeReason"] == "success"
@pytest.mark.anyio
async def test_awrap_model_call_fail_closed_on_insufficient_balance(monkeypatch):
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
async def fake_post(url, headers, payload, timeout_seconds):
return {"status": -1106, "message": "insufficient balance", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("question")
handler = AsyncMock(return_value=AIMessage(content="should not run"))
result = await middleware.awrap_model_call(request, handler)
assert isinstance(result, AIMessage)
assert "insufficient" in str(result.content).lower()
handler.assert_not_awaited()
@pytest.mark.anyio
async def test_awrap_model_call_finalize_uses_state_messages_usage_when_response_missing_usage(monkeypatch):
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
request.state = {
"messages": [
HumanMessage(content="hello world"),
AIMessage(content="ok", usage_metadata={"input_tokens": 101, "output_tokens": 202, "total_tokens": 303}),
]
}
handler = AsyncMock(return_value=AIMessage(content="ok"))
result = await middleware.awrap_model_call(request, handler)
assert isinstance(result, AIMessage)
assert len(seen_payloads) == 2
finalize_payload = seen_payloads[1][2]
assert finalize_payload["frozenId"] == "frozen-123"
assert finalize_payload["usageInputTokens"] == 101
assert finalize_payload["usageOutputTokens"] == 202
assert finalize_payload["usageTotalTokens"] == 303
@pytest.mark.anyio
async def test_awrap_model_call_does_not_block_on_non_blocking_reserve_code(monkeypatch):
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
async def fake_post(url, headers, payload, timeout_seconds):
if url.endswith("/frozen"):
return {"status": 5001, "message": "platform busy", "data": {}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("question")
handler = AsyncMock(return_value=AIMessage(content="model-ran"))
result = await middleware.awrap_model_call(request, handler)
assert isinstance(result, AIMessage)
assert result.content == "model-ran"
handler.assert_awaited_once()
@pytest.mark.anyio
async def test_awrap_model_call_uses_runnable_config_run_id(monkeypatch):
"""run_id is sourced from var_child_runnable_config, which LangGraph populates
via langgraph_api/stream.py during graph node execution."""
from langchain_core.runnables.config import var_child_runnable_config
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
token = var_child_runnable_config.set({"run_id": "run-from-ctx"})
try:
result = await middleware.awrap_model_call(request, handler)
finally:
var_child_runnable_config.reset(token)
assert isinstance(result, AIMessage)
reserve_payload = seen_payloads[0][2]
assert reserve_payload["callId"] == "run-from-ctx"
@pytest.mark.anyio
async def test_awrap_model_call_uses_worker_config_fallback_run_id(monkeypatch):
"""Fallback: run_id from langgraph_api.logging.worker_config when var_child_runnable_config is unset."""
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
import langgraph_api.logging as lg_logging
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
token = lg_logging.worker_config.set({"run_id": "run-from-worker"})
try:
result = await middleware.awrap_model_call(request, handler)
finally:
lg_logging.worker_config.reset(token)
assert isinstance(result, AIMessage)
reserve_payload = seen_payloads[0][2]
assert reserve_payload["callId"] == "run-from-worker"

View File

@ -12,7 +12,7 @@
# ============================================================================
# Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 5
config_version: 7
# ============================================================================
# Logging
@ -20,6 +20,35 @@ config_version: 5
# Log level for deerflow modules (debug/info/warning/error)
log_level: info
# ============================================================================
# Billing Reservation/Finalization
# ============================================================================
# Reserve before each LLM call and finalize after call completion.
# Keep this independent from token_usage reporting.
billing:
enabled: false
include_subagents: false
fail_closed: true
# true: only block when reserve returns a code in blocking_reserve_codes
# false: fallback to fail_closed behavior for all reserve failures
block_only_specific_reserve_codes: true
blocking_reserve_codes: [-1104, -1106]
frozen_type: 1
timeout_seconds: 10
default_expire_seconds: 1800
# When model config has no max_tokens, this fallback is used for
# estimatedOutputTokens. If unset and fail_closed=true, billing blocks calls.
# default_estimated_output_tokens: 4096
# reserve_url: "http://localhost:19001/accountFrozen/frozen"
# finalize_url: "http://localhost:19001/accountFrozen/release"
# headers:
# Authorization: "Bearer your-secret-token"
# X-App-Id: "deer-flow"
# ============================================================================
# Token Usage Tracking
# ============================================================================