diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 272801b6..1756901a 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -284,6 +284,11 @@ async def start_run( graph_input = normalize_input(body.input) config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) + if "configurable" in config and isinstance(config["configurable"], dict): + config["configurable"].setdefault("run_id", record.run_id) + if "context" in config and isinstance(config["context"], dict): + config["context"].setdefault("run_id", record.run_id) + # Merge DeerFlow-specific context overrides into configurable. # The ``context`` field is a custom extension for the langgraph-compat layer # that carries agent configuration (model_name, thinking_enabled, etc.). diff --git a/backend/docs/CONFIGURATION.md b/backend/docs/CONFIGURATION.md index 63791b82..d1d47593 100644 --- a/backend/docs/CONFIGURATION.md +++ b/backend/docs/CONFIGURATION.md @@ -294,6 +294,45 @@ title: max_words: 6 max_chars: 60 model_name: null # Use first model in list + +### Billing Reservation/Finalization + +External billing can reserve before each model call and finalize after completion. +This is independent from `token_usage` reporting. + +```yaml +billing: + enabled: false + include_subagents: false + fail_closed: true + block_only_specific_reserve_codes: true + blocking_reserve_codes: [-1104, -1106] + frozen_type: 1 + reserve_url: http://localhost:19001/accountFrozen/frozen + finalize_url: http://localhost:19001/accountFrozen/release + timeout_seconds: 10 + default_expire_seconds: 1800 + # default_estimated_output_tokens: 4096 + # headers: + # Authorization: Bearer your-secret-token +``` + +For `frozen_type=1` (token billing): +- Reserve request sends `estimatedInputTokens` and `estimatedOutputTokens`. +- `estimatedInputTokens` is estimated with a simple string-length rule from the latest user input. +- `estimatedOutputTokens` is resolved from model `max_tokens`. +- Finalize request keeps `finalAmount=0`; billing platform computes final cost from + `usageInputTokens`/`usageOutputTokens`/`usageTotalTokens`. + +Reserve blocking policy: +- With `block_only_specific_reserve_codes=true` (recommended), model calls are blocked + only when reserve API returns a code in `blocking_reserve_codes` (default `[-1104, -1106]`). +- For all other failures (reserve/finalize HTTP failure, 5xx, invalid reserve response), + DeerFlow logs warnings and continues model calls. +- Set `block_only_specific_reserve_codes=false` to restore legacy `fail_closed` behavior. + +If model `max_tokens` is unavailable, DeerFlow uses `default_estimated_output_tokens` +when configured. ``` ### GitHub API Token (Optional for GitHub Deep Research Skill) diff --git a/backend/packages/harness/deerflow/agents/middlewares/billing_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/billing_middleware.py new file mode 100644 index 00000000..c2896b77 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/billing_middleware.py @@ -0,0 +1,629 @@ +"""Middleware for external billing reservation/finalization per model call.""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, override +from uuid import uuid4 + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse +from langchain_core.messages import AIMessage, HumanMessage +from langgraph.errors import GraphBubbleUp + +from deerflow.config.app_config import get_app_config + +logger = logging.getLogger(__name__) + +_SUCCESS_STATUS_CODES = {200, 1000} +_INSUFFICIENT_BALANCE_CODE = -1106 + + +@dataclass +class _ReserveContext: + frozen_id: str + call_id: str + session_id: str | None + model_name: str | None + estimated_input_tokens: int + estimated_output_tokens: int + + +class BillingMiddleware(AgentMiddleware[AgentState]): + """Reserve before model call and finalize after completion.""" + + @override + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + cfg = get_app_config().billing + if not cfg.enabled: + return handler(request) + + reserve_ctx, block_result = _reserve_sync(request) + if block_result is not None: + return block_result + + response: ModelCallResult | None = None + finalize_reason = "success" + + try: + response = handler(request) + return response + except GraphBubbleUp: + finalize_reason = "cancel" + raise + except TimeoutError: + finalize_reason = "timeout" + raise + except Exception: + finalize_reason = "error" + raise + finally: + if reserve_ctx is not None: + _finalize_sync(request, reserve_ctx, response, finalize_reason) + + @override + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + cfg = get_app_config().billing + if not cfg.enabled: + return await handler(request) + + reserve_ctx, block_result = await _reserve_async(request) + if block_result is not None: + return block_result + + response: ModelCallResult | None = None + finalize_reason = "success" + + try: + response = await handler(request) + return response + except GraphBubbleUp: + finalize_reason = "cancel" + raise + except TimeoutError: + finalize_reason = "timeout" + raise + except Exception: + finalize_reason = "error" + raise + finally: + if reserve_ctx is not None: + await _finalize_async(request, reserve_ctx, response, finalize_reason) + + +def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None, str | None, int, int]: + cfg = get_app_config().billing + + session_id = _extract_thread_id(request) + run_id = _extract_run_id(request) + model_key = _extract_model_key_from_runtime(request) + model_name = _resolve_model_name(model_key) + + estimated_input_tokens = _estimate_input_tokens(request.messages) + estimated_output_tokens = _resolve_estimated_output_tokens(request, model_key) + + call_id = run_id or str(uuid4()) + if not run_id: + runtime = getattr(request, "runtime", None) + runtime_context = getattr(runtime, "context", None) + runtime_config = getattr(runtime, "config", None) + context_keys = sorted(runtime_context.keys()) if isinstance(runtime_context, dict) else [] + config_keys = sorted(runtime_config.keys()) if isinstance(runtime_config, dict) else [] + logger.warning( + "[BillingMiddleware] run_id missing in runtime; fallback callId=%s context_type=%s config_type=%s context_keys=%s config_keys=%s", + call_id, + type(runtime_context).__name__ if runtime_context is not None else "None", + type(runtime_config).__name__ if runtime_config is not None else "None", + context_keys, + config_keys, + ) + logger.info( + "[BillingMiddleware] id mapping: thread_id=%s run_id=%s call_id=%s model_name=%s", + session_id, + run_id, + call_id, + model_name, + ) + expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds) + payload: dict[str, Any] = { + "sessionId": session_id, + "callId": call_id, + "modelName": model_name, + "frozenType": cfg.frozen_type, + "estimatedInputTokens": estimated_input_tokens, + "estimatedOutputTokens": estimated_output_tokens, + "expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"), + } + return payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens + + +def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001 + # Primary: LangGraph injects run_id into the top-level RunnableConfig + # (langgraph_api/stream.py:218) and propagates it via var_child_runnable_config + # throughout graph node execution. + try: + from langchain_core.runnables.config import var_child_runnable_config + + lc_config = var_child_runnable_config.get() + if isinstance(lc_config, dict): + run_id = lc_config.get("run_id") + if run_id is not None: + return str(run_id) + except Exception: + pass + + # Fallback: LangGraph API worker sets run_id via set_logging_context() before + # astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139). + try: + from langgraph_api.logging import worker_config as lg_worker_config + + worker_ctx = lg_worker_config.get() + if isinstance(worker_ctx, dict): + run_id = worker_ctx.get("run_id") + if isinstance(run_id, str) and run_id: + return run_id + except Exception: + pass + + return None + + +def _reserve_failure_message(status_code: int | None) -> str: + if status_code in _blocking_reserve_code_set(): + return "The account balance is insufficient for this model call." + return "Billing reservation failed. Please try again later." + + +def _blocking_reserve_code_set() -> set[int]: + cfg = get_app_config().billing + return {int(code) for code in cfg.blocking_reserve_codes} + + +def _should_block_reserve_failure(status_code: int | None) -> bool: + cfg = get_app_config().billing + if cfg.block_only_specific_reserve_codes: + return status_code in _blocking_reserve_code_set() + return cfg.fail_closed + + +def _extract_frozen_id(payload: dict[str, Any]) -> str | None: + data = payload.get("data") + if not isinstance(data, dict): + return None + frozen_id = data.get("frozenId") + if isinstance(frozen_id, str) and frozen_id: + return frozen_id + return None + + +def _extract_response_status(payload: dict[str, Any]) -> int | None: + status = payload.get("status") + if isinstance(status, int): + return status + + # Backward compatibility with old response schema + code = payload.get("code") + if isinstance(code, int): + return code + + return None + + +def _is_success_payload(payload: dict[str, Any]) -> bool: + status = _extract_response_status(payload) + if isinstance(status, int) and status in _SUCCESS_STATUS_CODES: + return True + + # Backward compatibility with old response schema + success = payload.get("success") + if success is True: + return True + + return False + + +async def _reserve_async(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]: + cfg = get_app_config().billing + if not cfg.reserve_url: + logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty") + if _should_block_reserve_failure(None): + return None, AIMessage(content="Billing reservation endpoint is not configured.") + return None, None + + try: + payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request) + except ValueError as exc: + logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc) + if _should_block_reserve_failure(None): + return None, AIMessage(content=str(exc)) + return None, None + + logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload) + response = await _post_async(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds) + logger.info("[BillingMiddleware] reserve response: %s", response) + if response is None: + if _should_block_reserve_failure(None): + return None, AIMessage(content="Billing reservation request failed.") + return None, None + + if not _is_success_payload(response): + status_code = _extract_response_status(response) + logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response) + if _should_block_reserve_failure(status_code): + return None, AIMessage(content=_reserve_failure_message(status_code)) + return None, None + + frozen_id = _extract_frozen_id(response) + if not frozen_id: + logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response) + if _should_block_reserve_failure(None): + return None, AIMessage(content="Billing reservation response is invalid.") + return None, None + + call_id = payload["callId"] + return ( + _ReserveContext( + frozen_id=frozen_id, + call_id=call_id, + session_id=session_id, + model_name=model_name, + estimated_input_tokens=estimated_input_tokens, + estimated_output_tokens=estimated_output_tokens, + ), + None, + ) + + +def _reserve_sync(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]: + cfg = get_app_config().billing + if not cfg.reserve_url: + logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty") + if _should_block_reserve_failure(None): + return None, AIMessage(content="Billing reservation endpoint is not configured.") + return None, None + + try: + payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request) + except ValueError as exc: + logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc) + if _should_block_reserve_failure(None): + return None, AIMessage(content=str(exc)) + return None, None + + logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload) + response = _post_sync(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds) + logger.info("[BillingMiddleware] reserve response: %s", response) + if response is None: + if _should_block_reserve_failure(None): + return None, AIMessage(content="Billing reservation request failed.") + return None, None + + if not _is_success_payload(response): + status_code = _extract_response_status(response) + logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response) + if _should_block_reserve_failure(status_code): + return None, AIMessage(content=_reserve_failure_message(status_code)) + return None, None + + frozen_id = _extract_frozen_id(response) + if not frozen_id: + logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response) + if _should_block_reserve_failure(None): + return None, AIMessage(content="Billing reservation response is invalid.") + return None, None + + call_id = payload["callId"] + return ( + _ReserveContext( + frozen_id=frozen_id, + call_id=call_id, + session_id=session_id, + model_name=model_name, + estimated_input_tokens=estimated_input_tokens, + estimated_output_tokens=estimated_output_tokens, + ), + None, + ) + + +def _build_finalize_payload( + request: ModelRequest, + reserve_ctx: _ReserveContext, + response: ModelCallResult | None, + finalize_reason: str, +) -> dict[str, Any]: + usage = _extract_usage(request, response) + return { + "frozenId": reserve_ctx.frozen_id, + "finalAmount": 0, + "usageInputTokens": usage.get("input_tokens") if usage else 0, + "usageOutputTokens": usage.get("output_tokens") if usage else 0, + "usageTotalTokens": usage.get("total_tokens") if usage else 0, + "finalizeReason": finalize_reason, + } + + +async def _finalize_async( + request: ModelRequest, + reserve_ctx: _ReserveContext, + response: ModelCallResult | None, + finalize_reason: str, +) -> None: + cfg = get_app_config().billing + if not cfg.finalize_url: + logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty") + return + + payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason) + logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload) + result = await _post_async(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds) + logger.info("[BillingMiddleware] finalize response: %s", result) + if result is None: + logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id) + return + if not _is_success_payload(result): + logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result) + + +def _finalize_sync( + request: ModelRequest, + reserve_ctx: _ReserveContext, + response: ModelCallResult | None, + finalize_reason: str, +) -> None: + cfg = get_app_config().billing + if not cfg.finalize_url: + logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty") + return + + payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason) + logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload) + result = _post_sync(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds) + logger.info("[BillingMiddleware] finalize response: %s", result) + if result is None: + logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id) + return + if not _is_success_payload(result): + logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result) + + +def _extract_thread_id(request: ModelRequest) -> str | None: + context = getattr(request.runtime, "context", None) + thread_id = getattr(context, "thread_id", None) + if isinstance(thread_id, str) and thread_id: + return thread_id + + if isinstance(context, dict): + thread_id = context.get("thread_id") + if isinstance(thread_id, str) and thread_id: + return thread_id + + config = getattr(request.runtime, "config", None) + configurable = getattr(config, "configurable", None) + thread_id = getattr(configurable, "thread_id", None) + if isinstance(thread_id, str) and thread_id: + return thread_id + + if isinstance(config, dict): + thread_id = config.get("configurable", {}).get("thread_id") + if isinstance(thread_id, str) and thread_id: + return thread_id + return None + + +def _extract_model_key_from_runtime(request: ModelRequest) -> str | None: + config = getattr(request.runtime, "config", None) + configurable = getattr(config, "configurable", None) + model_key = getattr(configurable, "model", None) or getattr(configurable, "model_name", None) + if isinstance(model_key, str) and model_key: + return model_key + + if isinstance(config, dict): + configurable = config.get("configurable", {}) + model_key = configurable.get("model") or configurable.get("model_name") + if isinstance(model_key, str) and model_key: + return model_key + # Fall back to the model instance's own identifier + model_name = getattr(request.model, "model_name", None) + if isinstance(model_name, str) and model_name: + return model_name + return None + + +def _resolve_model_name(model_key: str | None) -> str | None: + if not model_key: + return None + model_cfg = get_app_config().get_model_config(model_key) + if model_cfg and model_cfg.display_name: + return model_cfg.display_name + return model_key + + +def _resolve_estimated_output_tokens(request: ModelRequest, model_key: str | None) -> int: + cfg = get_app_config().billing + + if model_key: + model_cfg = get_app_config().get_model_config(model_key) + if model_cfg is not None: + max_tokens = model_cfg.model_extra.get("max_tokens") if model_cfg.model_extra else None + if isinstance(max_tokens, int) and max_tokens > 0: + return max_tokens + + max_tokens_from_request = request.model_settings.get("max_tokens") + if isinstance(max_tokens_from_request, int) and max_tokens_from_request > 0: + return max_tokens_from_request + + # Fall back to the model instance's own max_tokens attribute + max_tokens_from_model = getattr(request.model, "max_tokens", None) + if isinstance(max_tokens_from_model, int) and max_tokens_from_model > 0: + return max_tokens_from_model + + if cfg.default_estimated_output_tokens is not None: + return cfg.default_estimated_output_tokens + + raise ValueError("Unable to resolve estimatedOutputTokens from model max_tokens.") + + +def _estimate_input_tokens(messages: list[Any]) -> int: + latest_text = _extract_latest_user_text(messages) + if not latest_text: + return 0 + # Product requirement: use simple string-length estimation for input tokens. + return len(latest_text) + + +def _extract_latest_user_text(messages: list[Any]) -> str: + for msg in reversed(messages): + if isinstance(msg, HumanMessage): + content = getattr(msg, "content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if isinstance(part, str): + parts.append(part) + elif isinstance(part, dict): + text = part.get("text") + if isinstance(text, str): + parts.append(text) + return "\n".join(p for p in parts if p) + return str(content) + return "" + + +def _extract_usage(request: ModelRequest, response: ModelCallResult | None) -> dict[str, int] | None: + if response is None: + usage = None + else: + usage = _extract_usage_from_obj(response) + if usage: + return usage + + messages = getattr(response, "messages", None) + usage = _extract_usage_from_messages(messages) + if usage: + return usage + + state = getattr(request, "state", None) + if isinstance(state, dict): + usage = _extract_usage_from_messages(state.get("messages")) + if usage: + return usage + + runtime_context = getattr(request.runtime, "context", None) + if isinstance(runtime_context, dict): + usage = _extract_usage_from_messages(runtime_context.get("messages")) + if usage: + return usage + + return None + + +def _extract_usage_from_messages(messages: object) -> dict[str, int] | None: + if not isinstance(messages, list): + return None + + for msg in reversed(messages): + usage = _extract_usage_from_obj(msg) + if usage: + return usage + + return None + + +def _extract_usage_from_obj(obj: object) -> dict[str, int] | None: + usage_metadata = getattr(obj, "usage_metadata", None) + usage = _normalize_usage_dict(usage_metadata) + if usage: + return usage + + response_metadata = getattr(obj, "response_metadata", None) + if isinstance(response_metadata, dict): + usage = _normalize_usage_dict(response_metadata.get("usage")) + if usage: + return usage + usage = _normalize_usage_dict(response_metadata.get("token_usage")) + if usage: + return usage + + additional_kwargs = getattr(obj, "additional_kwargs", None) + if isinstance(additional_kwargs, dict): + usage = _normalize_usage_dict(additional_kwargs.get("usage")) + if usage: + return usage + usage = _normalize_usage_dict(additional_kwargs.get("token_usage")) + if usage: + return usage + + return None + + +def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None: + if not isinstance(raw_usage, dict): + return None + + input_tokens = raw_usage.get("input_tokens") + if input_tokens is None: + input_tokens = raw_usage.get("prompt_tokens") + + output_tokens = raw_usage.get("output_tokens") + if output_tokens is None: + output_tokens = raw_usage.get("completion_tokens") + + total_tokens = raw_usage.get("total_tokens") + if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int): + total_tokens = input_tokens + output_tokens + + if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)): + return None + + return { + "input_tokens": int(input_tokens or 0), + "output_tokens": int(output_tokens or 0), + "total_tokens": int(total_tokens or 0), + } + + +async def _post_async(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None: + try: + import httpx + + async with httpx.AsyncClient(timeout=timeout_seconds) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + if isinstance(data, dict): + return data + return None + except Exception as exc: + logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc) + return None + + +def _post_sync(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None: + try: + import httpx + + with httpx.Client(timeout=timeout_seconds) as client: + response = client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + if isinstance(data, dict): + return data + return None + except Exception as exc: + logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc) + return None diff --git a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py index 52be28bf..096e6dac 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py @@ -91,6 +91,14 @@ def _build_runtime_middlewares( middlewares.append(DanglingToolCallMiddleware()) + from deerflow.config.app_config import get_app_config + + billing_cfg = get_app_config().billing + if billing_cfg.enabled and (include_uploads or billing_cfg.include_subagents): + from deerflow.agents.middlewares.billing_middleware import BillingMiddleware + + middlewares.append(BillingMiddleware()) + middlewares.append(LLMErrorHandlingMiddleware()) # Guardrail middleware (if configured) diff --git a/backend/packages/harness/deerflow/config/__init__.py b/backend/packages/harness/deerflow/config/__init__.py index aa379f2a..c41be373 100644 --- a/backend/packages/harness/deerflow/config/__init__.py +++ b/backend/packages/harness/deerflow/config/__init__.py @@ -1,4 +1,5 @@ from .app_config import get_app_config +from .billing_config import BillingConfig from .extensions_config import ExtensionsConfig, get_extensions_config from .memory_config import MemoryConfig, get_memory_config from .paths import Paths, get_paths @@ -13,6 +14,7 @@ from .tracing_config import ( __all__ = [ "get_app_config", + "BillingConfig", "Paths", "get_paths", "SkillsConfig", diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index cd233623..f15e0304 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -9,6 +9,7 @@ from dotenv import load_dotenv from pydantic import BaseModel, ConfigDict, Field from deerflow.config.acp_config import load_acp_config_from_dict +from deerflow.config.billing_config import BillingConfig from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict @@ -40,6 +41,7 @@ class AppConfig(BaseModel): """Config for the DeerFlow application""" log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)") + billing: BillingConfig = Field(default_factory=BillingConfig, description="External billing reservation/finalization configuration") token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration") models: list[ModelConfig] = Field(default_factory=list, description="Available models") sandbox: SandboxConfig = Field(description="Sandbox configuration") diff --git a/backend/packages/harness/deerflow/config/billing_config.py b/backend/packages/harness/deerflow/config/billing_config.py new file mode 100644 index 00000000..99f9c98f --- /dev/null +++ b/backend/packages/harness/deerflow/config/billing_config.py @@ -0,0 +1,62 @@ +"""Configuration for reservation/finalization billing integration.""" + +from pydantic import BaseModel, Field + + +class BillingConfig(BaseModel): + """Configuration for external billing reservation/finalization calls.""" + + enabled: bool = Field(default=False, description="Enable external billing middleware.") + include_subagents: bool = Field( + default=False, + description="Whether billing applies to subagent model calls as well.", + ) + fail_closed: bool = Field( + default=True, + description="Block model calls when reserve request fails or balance is insufficient.", + ) + block_only_specific_reserve_codes: bool = Field( + default=True, + description=( + "When true, only reserve responses with codes in blocking_reserve_codes block model calls. " + "When false, fallback to fail_closed behavior for all reserve failures." + ), + ) + blocking_reserve_codes: list[int] = Field( + default_factory=lambda: [-1104, -1106], + description="Reserve response codes that should block model calls when block_only_specific_reserve_codes is enabled.", + ) + frozen_type: int = Field( + default=1, + ge=1, + description="Frozen type sent to the platform. Current flow uses 1 for token billing.", + ) + reserve_url: str | None = Field( + default=None, + description="HTTP(S) endpoint for creating frozen reservations.", + ) + finalize_url: str | None = Field( + default=None, + description="HTTP(S) endpoint for finalizing frozen reservations.", + ) + headers: dict[str, str] = Field( + default_factory=dict, + description="Extra HTTP headers included in reserve/finalize requests.", + ) + timeout_seconds: float = Field( + default=10.0, + gt=0, + le=120, + description="HTTP request timeout for reserve/finalize calls.", + ) + default_expire_seconds: int = Field( + default=1800, + ge=60, + le=86400, + description="Default reservation expiration seconds when expireAt is included.", + ) + default_estimated_output_tokens: int | None = Field( + default=None, + ge=1, + description="Fallback estimatedOutputTokens when model max_tokens is unavailable.", + ) diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 2d67ecb2..c6c57e60 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -89,12 +89,13 @@ async def run_agent( # Inject runtime context so middlewares can access thread_id # (langgraph-cli does this automatically; we must do it manually) - runtime = Runtime(context={"thread_id": thread_id}, store=store) + runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store) # If the caller already set a ``context`` key (LangGraph >= 0.6.0 # prefers it over ``configurable`` for thread-level data), make # sure ``thread_id`` is available there too. if "context" in config and isinstance(config["context"], dict): config["context"].setdefault("thread_id", thread_id) + config["context"].setdefault("run_id", run_id) config.setdefault("configurable", {})["__pregel_runtime"] = runtime runnable_config = RunnableConfig(**config) diff --git a/backend/packages/harness/deerflow/subagents/executor.py b/backend/packages/harness/deerflow/subagents/executor.py index 8e1b1513..536ae024 100644 --- a/backend/packages/harness/deerflow/subagents/executor.py +++ b/backend/packages/harness/deerflow/subagents/executor.py @@ -226,15 +226,18 @@ class SubagentExecutor: try: agent = self._create_agent() state = self._build_initial_state(task) + subagent_model_name = _get_model_name(self.config, self.parent_model) # Build config with thread_id for sandbox access and recursion limit run_config: RunnableConfig = { "recursion_limit": self.config.max_turns, } context = {} + configurable: dict[str, Any] = {"model_name": subagent_model_name} if self.thread_id: - run_config["configurable"] = {"thread_id": self.thread_id} + configurable["thread_id"] = self.thread_id context["thread_id"] = self.thread_id + run_config["configurable"] = configurable logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution with max_turns={self.config.max_turns}") diff --git a/backend/tests/test_billing_middleware.py b/backend/tests/test_billing_middleware.py new file mode 100644 index 00000000..193117ee --- /dev/null +++ b/backend/tests/test_billing_middleware.py @@ -0,0 +1,241 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from deerflow.agents.middlewares.billing_middleware import BillingMiddleware + + +def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True): + billing = SimpleNamespace( + enabled=enabled, + include_subagents=include_subagents, + fail_closed=True, + block_only_specific_reserve_codes=True, + blocking_reserve_codes=[-1104, -1106], + frozen_type=1, + reserve_url="http://billing.local/accountFrozen/frozen", + finalize_url="http://billing.local/accountFrozen/release", + headers={"Authorization": "Bearer x"}, + timeout_seconds=3.0, + default_expire_seconds=1800, + default_estimated_output_tokens=None, + ) + + model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096}) + return SimpleNamespace( + billing=billing, + get_model_config=lambda name: model_cfg if name == "gpt-4" else None, + ) + + +def _request_with_latest_user_text(text: str): + request = MagicMock() + request.messages = [HumanMessage(content="old"), HumanMessage(content=text)] + request.model_settings = {} + request.runtime = SimpleNamespace( + config={"configurable": {"thread_id": "thread-1", "model_name": "gpt-4"}}, + context={"thread_id": "thread-1"}, + ) + return request + + +@pytest.mark.anyio +async def test_awrap_model_call_uses_estimated_tokens_and_finalizes(monkeypatch): + from langchain_core.runnables.config import var_child_runnable_config + + from deerflow.agents.middlewares import billing_middleware as bm + + monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) + + seen_payloads = [] + + async def fake_post(url, headers, payload, timeout_seconds): + seen_payloads.append((url, headers, payload, timeout_seconds)) + if url.endswith("/frozen"): + return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} + return {"status": 1000, "message": "ok", "data": {}} + + monkeypatch.setattr(bm, "_post_async", fake_post) + + middleware = BillingMiddleware() + request = _request_with_latest_user_text("hello world") + handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 11, "output_tokens": 22, "total_tokens": 33})) + + token = var_child_runnable_config.set({"run_id": "run-1"}) + try: + result = await middleware.awrap_model_call(request, handler) + finally: + var_child_runnable_config.reset(token) + + assert isinstance(result, AIMessage) + assert len(seen_payloads) == 2 + + reserve_payload = seen_payloads[0][2] + assert reserve_payload["callId"] == "run-1" + assert reserve_payload["frozenType"] == 1 + assert reserve_payload["estimatedInputTokens"] == len("hello world") + assert reserve_payload["estimatedOutputTokens"] == 4096 + assert "frozenAmount" not in reserve_payload + + finalize_payload = seen_payloads[1][2] + assert finalize_payload["frozenId"] == "frozen-123" + assert finalize_payload["finalAmount"] == 0 + assert finalize_payload["usageInputTokens"] == 11 + assert finalize_payload["usageOutputTokens"] == 22 + assert finalize_payload["usageTotalTokens"] == 33 + assert finalize_payload["finalizeReason"] == "success" + + +@pytest.mark.anyio +async def test_awrap_model_call_fail_closed_on_insufficient_balance(monkeypatch): + from deerflow.agents.middlewares import billing_middleware as bm + + monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) + + async def fake_post(url, headers, payload, timeout_seconds): + return {"status": -1106, "message": "insufficient balance", "data": {}} + + monkeypatch.setattr(bm, "_post_async", fake_post) + + middleware = BillingMiddleware() + request = _request_with_latest_user_text("question") + handler = AsyncMock(return_value=AIMessage(content="should not run")) + + result = await middleware.awrap_model_call(request, handler) + + assert isinstance(result, AIMessage) + assert "insufficient" in str(result.content).lower() + handler.assert_not_awaited() + + +@pytest.mark.anyio +async def test_awrap_model_call_finalize_uses_state_messages_usage_when_response_missing_usage(monkeypatch): + from deerflow.agents.middlewares import billing_middleware as bm + + monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) + + seen_payloads = [] + + async def fake_post(url, headers, payload, timeout_seconds): + seen_payloads.append((url, headers, payload, timeout_seconds)) + if url.endswith("/frozen"): + return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} + return {"status": 1000, "message": "ok", "data": {}} + + monkeypatch.setattr(bm, "_post_async", fake_post) + + middleware = BillingMiddleware() + request = _request_with_latest_user_text("hello world") + request.state = { + "messages": [ + HumanMessage(content="hello world"), + AIMessage(content="ok", usage_metadata={"input_tokens": 101, "output_tokens": 202, "total_tokens": 303}), + ] + } + handler = AsyncMock(return_value=AIMessage(content="ok")) + + result = await middleware.awrap_model_call(request, handler) + + assert isinstance(result, AIMessage) + assert len(seen_payloads) == 2 + + finalize_payload = seen_payloads[1][2] + assert finalize_payload["frozenId"] == "frozen-123" + assert finalize_payload["usageInputTokens"] == 101 + assert finalize_payload["usageOutputTokens"] == 202 + assert finalize_payload["usageTotalTokens"] == 303 + + +@pytest.mark.anyio +async def test_awrap_model_call_does_not_block_on_non_blocking_reserve_code(monkeypatch): + from deerflow.agents.middlewares import billing_middleware as bm + + monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) + + async def fake_post(url, headers, payload, timeout_seconds): + if url.endswith("/frozen"): + return {"status": 5001, "message": "platform busy", "data": {}} + return {"status": 1000, "message": "ok", "data": {}} + + monkeypatch.setattr(bm, "_post_async", fake_post) + + middleware = BillingMiddleware() + request = _request_with_latest_user_text("question") + handler = AsyncMock(return_value=AIMessage(content="model-ran")) + + result = await middleware.awrap_model_call(request, handler) + + assert isinstance(result, AIMessage) + assert result.content == "model-ran" + handler.assert_awaited_once() + + +@pytest.mark.anyio +async def test_awrap_model_call_uses_runnable_config_run_id(monkeypatch): + """run_id is sourced from var_child_runnable_config, which LangGraph populates + via langgraph_api/stream.py during graph node execution.""" + from langchain_core.runnables.config import var_child_runnable_config + + from deerflow.agents.middlewares import billing_middleware as bm + + monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) + + seen_payloads = [] + + async def fake_post(url, headers, payload, timeout_seconds): + seen_payloads.append((url, headers, payload, timeout_seconds)) + if url.endswith("/frozen"): + return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} + return {"status": 1000, "message": "ok", "data": {}} + + monkeypatch.setattr(bm, "_post_async", fake_post) + + middleware = BillingMiddleware() + request = _request_with_latest_user_text("hello world") + handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3})) + + token = var_child_runnable_config.set({"run_id": "run-from-ctx"}) + try: + result = await middleware.awrap_model_call(request, handler) + finally: + var_child_runnable_config.reset(token) + + assert isinstance(result, AIMessage) + reserve_payload = seen_payloads[0][2] + assert reserve_payload["callId"] == "run-from-ctx" + + +@pytest.mark.anyio +async def test_awrap_model_call_uses_worker_config_fallback_run_id(monkeypatch): + """Fallback: run_id from langgraph_api.logging.worker_config when var_child_runnable_config is unset.""" + from deerflow.agents.middlewares import billing_middleware as bm + + monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) + + seen_payloads = [] + + async def fake_post(url, headers, payload, timeout_seconds): + seen_payloads.append((url, headers, payload, timeout_seconds)) + if url.endswith("/frozen"): + return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} + return {"status": 1000, "message": "ok", "data": {}} + + monkeypatch.setattr(bm, "_post_async", fake_post) + + import langgraph_api.logging as lg_logging + + middleware = BillingMiddleware() + request = _request_with_latest_user_text("hello world") + handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3})) + + token = lg_logging.worker_config.set({"run_id": "run-from-worker"}) + try: + result = await middleware.awrap_model_call(request, handler) + finally: + lg_logging.worker_config.reset(token) + + assert isinstance(result, AIMessage) + reserve_payload = seen_payloads[0][2] + assert reserve_payload["callId"] == "run-from-worker" diff --git a/config.example.yaml b/config.example.yaml index d1deda46..31e90a0c 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -12,7 +12,7 @@ # ============================================================================ # Bump this number when the config schema changes. # Run `make config-upgrade` to merge new fields into your local config.yaml. -config_version: 5 +config_version: 7 # ============================================================================ # Logging @@ -20,6 +20,35 @@ config_version: 5 # Log level for deerflow modules (debug/info/warning/error) log_level: info +# ============================================================================ +# Billing Reservation/Finalization +# ============================================================================ +# Reserve before each LLM call and finalize after call completion. +# Keep this independent from token_usage reporting. + +billing: + enabled: false + include_subagents: false + fail_closed: true + # true: only block when reserve returns a code in blocking_reserve_codes + # false: fallback to fail_closed behavior for all reserve failures + block_only_specific_reserve_codes: true + blocking_reserve_codes: [-1104, -1106] + frozen_type: 1 + timeout_seconds: 10 + default_expire_seconds: 1800 + + # When model config has no max_tokens, this fallback is used for + # estimatedOutputTokens. If unset and fail_closed=true, billing blocks calls. + # default_estimated_output_tokens: 4096 + + # reserve_url: "http://localhost:19001/accountFrozen/frozen" + # finalize_url: "http://localhost:19001/accountFrozen/release" + + # headers: + # Authorization: "Bearer your-secret-token" + # X-App-Id: "deer-flow" + # ============================================================================ # Token Usage Tracking # ============================================================================