Merge remote-tracking branch 'origin/git-main' into feat/git-main-frondend-intergretion-oldhash-20260408-165134
This commit is contained in:
commit
7d579e695c
|
|
@ -284,6 +284,11 @@ async def start_run(
|
||||||
graph_input = normalize_input(body.input)
|
graph_input = normalize_input(body.input)
|
||||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||||
|
|
||||||
|
if "configurable" in config and isinstance(config["configurable"], dict):
|
||||||
|
config["configurable"].setdefault("run_id", record.run_id)
|
||||||
|
if "context" in config and isinstance(config["context"], dict):
|
||||||
|
config["context"].setdefault("run_id", record.run_id)
|
||||||
|
|
||||||
# Merge DeerFlow-specific context overrides into configurable.
|
# Merge DeerFlow-specific context overrides into configurable.
|
||||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||||
|
|
|
||||||
|
|
@ -294,6 +294,45 @@ title:
|
||||||
max_words: 6
|
max_words: 6
|
||||||
max_chars: 60
|
max_chars: 60
|
||||||
model_name: null # Use first model in list
|
model_name: null # Use first model in list
|
||||||
|
|
||||||
|
### Billing Reservation/Finalization
|
||||||
|
|
||||||
|
External billing can reserve before each model call and finalize after completion.
|
||||||
|
This is independent from `token_usage` reporting.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
billing:
|
||||||
|
enabled: false
|
||||||
|
include_subagents: false
|
||||||
|
fail_closed: true
|
||||||
|
block_only_specific_reserve_codes: true
|
||||||
|
blocking_reserve_codes: [-1104, -1106]
|
||||||
|
frozen_type: 1
|
||||||
|
reserve_url: http://localhost:19001/accountFrozen/frozen
|
||||||
|
finalize_url: http://localhost:19001/accountFrozen/release
|
||||||
|
timeout_seconds: 10
|
||||||
|
default_expire_seconds: 1800
|
||||||
|
# default_estimated_output_tokens: 4096
|
||||||
|
# headers:
|
||||||
|
# Authorization: Bearer your-secret-token
|
||||||
|
```
|
||||||
|
|
||||||
|
For `frozen_type=1` (token billing):
|
||||||
|
- Reserve request sends `estimatedInputTokens` and `estimatedOutputTokens`.
|
||||||
|
- `estimatedInputTokens` is estimated with a simple string-length rule from the latest user input.
|
||||||
|
- `estimatedOutputTokens` is resolved from model `max_tokens`.
|
||||||
|
- Finalize request keeps `finalAmount=0`; billing platform computes final cost from
|
||||||
|
`usageInputTokens`/`usageOutputTokens`/`usageTotalTokens`.
|
||||||
|
|
||||||
|
Reserve blocking policy:
|
||||||
|
- With `block_only_specific_reserve_codes=true` (recommended), model calls are blocked
|
||||||
|
only when reserve API returns a code in `blocking_reserve_codes` (default `[-1104, -1106]`).
|
||||||
|
- For all other failures (reserve/finalize HTTP failure, 5xx, invalid reserve response),
|
||||||
|
DeerFlow logs warnings and continues model calls.
|
||||||
|
- Set `block_only_specific_reserve_codes=false` to restore legacy `fail_closed` behavior.
|
||||||
|
|
||||||
|
If model `max_tokens` is unavailable, DeerFlow uses `default_estimated_output_tokens`
|
||||||
|
when configured.
|
||||||
```
|
```
|
||||||
|
|
||||||
### GitHub API Token (Optional for GitHub Deep Research Skill)
|
### GitHub API Token (Optional for GitHub Deep Research Skill)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,629 @@
|
||||||
|
"""Middleware for external billing reservation/finalization per model call."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, override
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from langchain.agents import AgentState
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
from langgraph.errors import GraphBubbleUp
|
||||||
|
|
||||||
|
from deerflow.config.app_config import get_app_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SUCCESS_STATUS_CODES = {200, 1000}
|
||||||
|
_INSUFFICIENT_BALANCE_CODE = -1106
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ReserveContext:
|
||||||
|
frozen_id: str
|
||||||
|
call_id: str
|
||||||
|
session_id: str | None
|
||||||
|
model_name: str | None
|
||||||
|
estimated_input_tokens: int
|
||||||
|
estimated_output_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class BillingMiddleware(AgentMiddleware[AgentState]):
|
||||||
|
"""Reserve before model call and finalize after completion."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def wrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
if not cfg.enabled:
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
reserve_ctx, block_result = _reserve_sync(request)
|
||||||
|
if block_result is not None:
|
||||||
|
return block_result
|
||||||
|
|
||||||
|
response: ModelCallResult | None = None
|
||||||
|
finalize_reason = "success"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = handler(request)
|
||||||
|
return response
|
||||||
|
except GraphBubbleUp:
|
||||||
|
finalize_reason = "cancel"
|
||||||
|
raise
|
||||||
|
except TimeoutError:
|
||||||
|
finalize_reason = "timeout"
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
finalize_reason = "error"
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if reserve_ctx is not None:
|
||||||
|
_finalize_sync(request, reserve_ctx, response, finalize_reason)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
if not cfg.enabled:
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
reserve_ctx, block_result = await _reserve_async(request)
|
||||||
|
if block_result is not None:
|
||||||
|
return block_result
|
||||||
|
|
||||||
|
response: ModelCallResult | None = None
|
||||||
|
finalize_reason = "success"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await handler(request)
|
||||||
|
return response
|
||||||
|
except GraphBubbleUp:
|
||||||
|
finalize_reason = "cancel"
|
||||||
|
raise
|
||||||
|
except TimeoutError:
|
||||||
|
finalize_reason = "timeout"
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
finalize_reason = "error"
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if reserve_ctx is not None:
|
||||||
|
await _finalize_async(request, reserve_ctx, response, finalize_reason)
|
||||||
|
|
||||||
|
|
||||||
|
def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None, str | None, int, int]:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
|
||||||
|
session_id = _extract_thread_id(request)
|
||||||
|
run_id = _extract_run_id(request)
|
||||||
|
model_key = _extract_model_key_from_runtime(request)
|
||||||
|
model_name = _resolve_model_name(model_key)
|
||||||
|
|
||||||
|
estimated_input_tokens = _estimate_input_tokens(request.messages)
|
||||||
|
estimated_output_tokens = _resolve_estimated_output_tokens(request, model_key)
|
||||||
|
|
||||||
|
call_id = run_id or str(uuid4())
|
||||||
|
if not run_id:
|
||||||
|
runtime = getattr(request, "runtime", None)
|
||||||
|
runtime_context = getattr(runtime, "context", None)
|
||||||
|
runtime_config = getattr(runtime, "config", None)
|
||||||
|
context_keys = sorted(runtime_context.keys()) if isinstance(runtime_context, dict) else []
|
||||||
|
config_keys = sorted(runtime_config.keys()) if isinstance(runtime_config, dict) else []
|
||||||
|
logger.warning(
|
||||||
|
"[BillingMiddleware] run_id missing in runtime; fallback callId=%s context_type=%s config_type=%s context_keys=%s config_keys=%s",
|
||||||
|
call_id,
|
||||||
|
type(runtime_context).__name__ if runtime_context is not None else "None",
|
||||||
|
type(runtime_config).__name__ if runtime_config is not None else "None",
|
||||||
|
context_keys,
|
||||||
|
config_keys,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[BillingMiddleware] id mapping: thread_id=%s run_id=%s call_id=%s model_name=%s",
|
||||||
|
session_id,
|
||||||
|
run_id,
|
||||||
|
call_id,
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"callId": call_id,
|
||||||
|
"modelName": model_name,
|
||||||
|
"frozenType": cfg.frozen_type,
|
||||||
|
"estimatedInputTokens": estimated_input_tokens,
|
||||||
|
"estimatedOutputTokens": estimated_output_tokens,
|
||||||
|
"expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
}
|
||||||
|
return payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001
|
||||||
|
# Primary: LangGraph injects run_id into the top-level RunnableConfig
|
||||||
|
# (langgraph_api/stream.py:218) and propagates it via var_child_runnable_config
|
||||||
|
# throughout graph node execution.
|
||||||
|
try:
|
||||||
|
from langchain_core.runnables.config import var_child_runnable_config
|
||||||
|
|
||||||
|
lc_config = var_child_runnable_config.get()
|
||||||
|
if isinstance(lc_config, dict):
|
||||||
|
run_id = lc_config.get("run_id")
|
||||||
|
if run_id is not None:
|
||||||
|
return str(run_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: LangGraph API worker sets run_id via set_logging_context() before
|
||||||
|
# astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139).
|
||||||
|
try:
|
||||||
|
from langgraph_api.logging import worker_config as lg_worker_config
|
||||||
|
|
||||||
|
worker_ctx = lg_worker_config.get()
|
||||||
|
if isinstance(worker_ctx, dict):
|
||||||
|
run_id = worker_ctx.get("run_id")
|
||||||
|
if isinstance(run_id, str) and run_id:
|
||||||
|
return run_id
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _reserve_failure_message(status_code: int | None) -> str:
|
||||||
|
if status_code in _blocking_reserve_code_set():
|
||||||
|
return "The account balance is insufficient for this model call."
|
||||||
|
return "Billing reservation failed. Please try again later."
|
||||||
|
|
||||||
|
|
||||||
|
def _blocking_reserve_code_set() -> set[int]:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
return {int(code) for code in cfg.blocking_reserve_codes}
|
||||||
|
|
||||||
|
|
||||||
|
def _should_block_reserve_failure(status_code: int | None) -> bool:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
if cfg.block_only_specific_reserve_codes:
|
||||||
|
return status_code in _blocking_reserve_code_set()
|
||||||
|
return cfg.fail_closed
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_frozen_id(payload: dict[str, Any]) -> str | None:
|
||||||
|
data = payload.get("data")
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
frozen_id = data.get("frozenId")
|
||||||
|
if isinstance(frozen_id, str) and frozen_id:
|
||||||
|
return frozen_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_response_status(payload: dict[str, Any]) -> int | None:
|
||||||
|
status = payload.get("status")
|
||||||
|
if isinstance(status, int):
|
||||||
|
return status
|
||||||
|
|
||||||
|
# Backward compatibility with old response schema
|
||||||
|
code = payload.get("code")
|
||||||
|
if isinstance(code, int):
|
||||||
|
return code
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_success_payload(payload: dict[str, Any]) -> bool:
|
||||||
|
status = _extract_response_status(payload)
|
||||||
|
if isinstance(status, int) and status in _SUCCESS_STATUS_CODES:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Backward compatibility with old response schema
|
||||||
|
success = payload.get("success")
|
||||||
|
if success is True:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _reserve_async(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
if not cfg.reserve_url:
|
||||||
|
logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty")
|
||||||
|
if _should_block_reserve_failure(None):
|
||||||
|
return None, AIMessage(content="Billing reservation endpoint is not configured.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc)
|
||||||
|
if _should_block_reserve_failure(None):
|
||||||
|
return None, AIMessage(content=str(exc))
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload)
|
||||||
|
response = await _post_async(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds)
|
||||||
|
logger.info("[BillingMiddleware] reserve response: %s", response)
|
||||||
|
if response is None:
|
||||||
|
if _should_block_reserve_failure(None):
|
||||||
|
return None, AIMessage(content="Billing reservation request failed.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not _is_success_payload(response):
|
||||||
|
status_code = _extract_response_status(response)
|
||||||
|
logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response)
|
||||||
|
if _should_block_reserve_failure(status_code):
|
||||||
|
return None, AIMessage(content=_reserve_failure_message(status_code))
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
frozen_id = _extract_frozen_id(response)
|
||||||
|
if not frozen_id:
|
||||||
|
logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response)
|
||||||
|
if _should_block_reserve_failure(None):
|
||||||
|
return None, AIMessage(content="Billing reservation response is invalid.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
call_id = payload["callId"]
|
||||||
|
return (
|
||||||
|
_ReserveContext(
|
||||||
|
frozen_id=frozen_id,
|
||||||
|
call_id=call_id,
|
||||||
|
session_id=session_id,
|
||||||
|
model_name=model_name,
|
||||||
|
estimated_input_tokens=estimated_input_tokens,
|
||||||
|
estimated_output_tokens=estimated_output_tokens,
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _reserve_sync(request: ModelRequest) -> tuple[_ReserveContext | None, AIMessage | None]:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
if not cfg.reserve_url:
|
||||||
|
logger.warning("[BillingMiddleware] skip reserve: reserve_url is empty")
|
||||||
|
if _should_block_reserve_failure(None):
|
||||||
|
return None, AIMessage(content="Billing reservation endpoint is not configured.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload, session_id, model_name, estimated_input_tokens, estimated_output_tokens = _reserve_payload(request)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("[BillingMiddleware] reserve payload invalid: %s", exc)
|
||||||
|
if _should_block_reserve_failure(None):
|
||||||
|
return None, AIMessage(content=str(exc))
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
logger.info("[BillingMiddleware] reserve request: url=%s payload=%s", cfg.reserve_url, payload)
|
||||||
|
response = _post_sync(cfg.reserve_url, cfg.headers, payload, cfg.timeout_seconds)
|
||||||
|
logger.info("[BillingMiddleware] reserve response: %s", response)
|
||||||
|
if response is None:
|
||||||
|
if _should_block_reserve_failure(None):
|
||||||
|
return None, AIMessage(content="Billing reservation request failed.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not _is_success_payload(response):
|
||||||
|
status_code = _extract_response_status(response)
|
||||||
|
logger.warning("[BillingMiddleware] reserve rejected: status=%s payload=%s", status_code, response)
|
||||||
|
if _should_block_reserve_failure(status_code):
|
||||||
|
return None, AIMessage(content=_reserve_failure_message(status_code))
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
frozen_id = _extract_frozen_id(response)
|
||||||
|
if not frozen_id:
|
||||||
|
logger.warning("[BillingMiddleware] reserve response missing frozenId: %s", response)
|
||||||
|
if _should_block_reserve_failure(None):
|
||||||
|
return None, AIMessage(content="Billing reservation response is invalid.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
call_id = payload["callId"]
|
||||||
|
return (
|
||||||
|
_ReserveContext(
|
||||||
|
frozen_id=frozen_id,
|
||||||
|
call_id=call_id,
|
||||||
|
session_id=session_id,
|
||||||
|
model_name=model_name,
|
||||||
|
estimated_input_tokens=estimated_input_tokens,
|
||||||
|
estimated_output_tokens=estimated_output_tokens,
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_finalize_payload(
|
||||||
|
request: ModelRequest,
|
||||||
|
reserve_ctx: _ReserveContext,
|
||||||
|
response: ModelCallResult | None,
|
||||||
|
finalize_reason: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
usage = _extract_usage(request, response)
|
||||||
|
return {
|
||||||
|
"frozenId": reserve_ctx.frozen_id,
|
||||||
|
"finalAmount": 0,
|
||||||
|
"usageInputTokens": usage.get("input_tokens") if usage else 0,
|
||||||
|
"usageOutputTokens": usage.get("output_tokens") if usage else 0,
|
||||||
|
"usageTotalTokens": usage.get("total_tokens") if usage else 0,
|
||||||
|
"finalizeReason": finalize_reason,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _finalize_async(
|
||||||
|
request: ModelRequest,
|
||||||
|
reserve_ctx: _ReserveContext,
|
||||||
|
response: ModelCallResult | None,
|
||||||
|
finalize_reason: str,
|
||||||
|
) -> None:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
if not cfg.finalize_url:
|
||||||
|
logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty")
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason)
|
||||||
|
logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload)
|
||||||
|
result = await _post_async(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds)
|
||||||
|
logger.info("[BillingMiddleware] finalize response: %s", result)
|
||||||
|
if result is None:
|
||||||
|
logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id)
|
||||||
|
return
|
||||||
|
if not _is_success_payload(result):
|
||||||
|
logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result)
|
||||||
|
|
||||||
|
|
||||||
|
def _finalize_sync(
|
||||||
|
request: ModelRequest,
|
||||||
|
reserve_ctx: _ReserveContext,
|
||||||
|
response: ModelCallResult | None,
|
||||||
|
finalize_reason: str,
|
||||||
|
) -> None:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
if not cfg.finalize_url:
|
||||||
|
logger.warning("[BillingMiddleware] skip finalize: finalize_url is empty")
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = _build_finalize_payload(request, reserve_ctx, response, finalize_reason)
|
||||||
|
logger.info("[BillingMiddleware] finalize request: url=%s payload=%s", cfg.finalize_url, payload)
|
||||||
|
result = _post_sync(cfg.finalize_url, cfg.headers, payload, cfg.timeout_seconds)
|
||||||
|
logger.info("[BillingMiddleware] finalize response: %s", result)
|
||||||
|
if result is None:
|
||||||
|
logger.warning("[BillingMiddleware] finalize failed without response: frozenId=%s", reserve_ctx.frozen_id)
|
||||||
|
return
|
||||||
|
if not _is_success_payload(result):
|
||||||
|
logger.warning("[BillingMiddleware] finalize rejected: frozenId=%s payload=%s", reserve_ctx.frozen_id, result)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_thread_id(request: ModelRequest) -> str | None:
|
||||||
|
context = getattr(request.runtime, "context", None)
|
||||||
|
thread_id = getattr(context, "thread_id", None)
|
||||||
|
if isinstance(thread_id, str) and thread_id:
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
if isinstance(context, dict):
|
||||||
|
thread_id = context.get("thread_id")
|
||||||
|
if isinstance(thread_id, str) and thread_id:
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
config = getattr(request.runtime, "config", None)
|
||||||
|
configurable = getattr(config, "configurable", None)
|
||||||
|
thread_id = getattr(configurable, "thread_id", None)
|
||||||
|
if isinstance(thread_id, str) and thread_id:
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
if isinstance(config, dict):
|
||||||
|
thread_id = config.get("configurable", {}).get("thread_id")
|
||||||
|
if isinstance(thread_id, str) and thread_id:
|
||||||
|
return thread_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_model_key_from_runtime(request: ModelRequest) -> str | None:
|
||||||
|
config = getattr(request.runtime, "config", None)
|
||||||
|
configurable = getattr(config, "configurable", None)
|
||||||
|
model_key = getattr(configurable, "model", None) or getattr(configurable, "model_name", None)
|
||||||
|
if isinstance(model_key, str) and model_key:
|
||||||
|
return model_key
|
||||||
|
|
||||||
|
if isinstance(config, dict):
|
||||||
|
configurable = config.get("configurable", {})
|
||||||
|
model_key = configurable.get("model") or configurable.get("model_name")
|
||||||
|
if isinstance(model_key, str) and model_key:
|
||||||
|
return model_key
|
||||||
|
# Fall back to the model instance's own identifier
|
||||||
|
model_name = getattr(request.model, "model_name", None)
|
||||||
|
if isinstance(model_name, str) and model_name:
|
||||||
|
return model_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_model_name(model_key: str | None) -> str | None:
|
||||||
|
if not model_key:
|
||||||
|
return None
|
||||||
|
model_cfg = get_app_config().get_model_config(model_key)
|
||||||
|
if model_cfg and model_cfg.display_name:
|
||||||
|
return model_cfg.display_name
|
||||||
|
return model_key
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_estimated_output_tokens(request: ModelRequest, model_key: str | None) -> int:
|
||||||
|
cfg = get_app_config().billing
|
||||||
|
|
||||||
|
if model_key:
|
||||||
|
model_cfg = get_app_config().get_model_config(model_key)
|
||||||
|
if model_cfg is not None:
|
||||||
|
max_tokens = model_cfg.model_extra.get("max_tokens") if model_cfg.model_extra else None
|
||||||
|
if isinstance(max_tokens, int) and max_tokens > 0:
|
||||||
|
return max_tokens
|
||||||
|
|
||||||
|
max_tokens_from_request = request.model_settings.get("max_tokens")
|
||||||
|
if isinstance(max_tokens_from_request, int) and max_tokens_from_request > 0:
|
||||||
|
return max_tokens_from_request
|
||||||
|
|
||||||
|
# Fall back to the model instance's own max_tokens attribute
|
||||||
|
max_tokens_from_model = getattr(request.model, "max_tokens", None)
|
||||||
|
if isinstance(max_tokens_from_model, int) and max_tokens_from_model > 0:
|
||||||
|
return max_tokens_from_model
|
||||||
|
|
||||||
|
if cfg.default_estimated_output_tokens is not None:
|
||||||
|
return cfg.default_estimated_output_tokens
|
||||||
|
|
||||||
|
raise ValueError("Unable to resolve estimatedOutputTokens from model max_tokens.")
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_input_tokens(messages: list[Any]) -> int:
|
||||||
|
latest_text = _extract_latest_user_text(messages)
|
||||||
|
if not latest_text:
|
||||||
|
return 0
|
||||||
|
# Product requirement: use simple string-length estimation for input tokens.
|
||||||
|
return len(latest_text)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_latest_user_text(messages: list[Any]) -> str:
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
content = getattr(msg, "content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, str):
|
||||||
|
parts.append(part)
|
||||||
|
elif isinstance(part, dict):
|
||||||
|
text = part.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "\n".join(p for p in parts if p)
|
||||||
|
return str(content)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_usage(request: ModelRequest, response: ModelCallResult | None) -> dict[str, int] | None:
|
||||||
|
if response is None:
|
||||||
|
usage = None
|
||||||
|
else:
|
||||||
|
usage = _extract_usage_from_obj(response)
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
|
||||||
|
messages = getattr(response, "messages", None)
|
||||||
|
usage = _extract_usage_from_messages(messages)
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
|
||||||
|
state = getattr(request, "state", None)
|
||||||
|
if isinstance(state, dict):
|
||||||
|
usage = _extract_usage_from_messages(state.get("messages"))
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
|
||||||
|
runtime_context = getattr(request.runtime, "context", None)
|
||||||
|
if isinstance(runtime_context, dict):
|
||||||
|
usage = _extract_usage_from_messages(runtime_context.get("messages"))
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_usage_from_messages(messages: object) -> dict[str, int] | None:
|
||||||
|
if not isinstance(messages, list):
|
||||||
|
return None
|
||||||
|
|
||||||
|
for msg in reversed(messages):
|
||||||
|
usage = _extract_usage_from_obj(msg)
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_usage_from_obj(obj: object) -> dict[str, int] | None:
|
||||||
|
usage_metadata = getattr(obj, "usage_metadata", None)
|
||||||
|
usage = _normalize_usage_dict(usage_metadata)
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
|
||||||
|
response_metadata = getattr(obj, "response_metadata", None)
|
||||||
|
if isinstance(response_metadata, dict):
|
||||||
|
usage = _normalize_usage_dict(response_metadata.get("usage"))
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
usage = _normalize_usage_dict(response_metadata.get("token_usage"))
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
|
||||||
|
additional_kwargs = getattr(obj, "additional_kwargs", None)
|
||||||
|
if isinstance(additional_kwargs, dict):
|
||||||
|
usage = _normalize_usage_dict(additional_kwargs.get("usage"))
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
usage = _normalize_usage_dict(additional_kwargs.get("token_usage"))
|
||||||
|
if usage:
|
||||||
|
return usage
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_usage_dict(raw_usage: object) -> dict[str, int] | None:
|
||||||
|
if not isinstance(raw_usage, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
input_tokens = raw_usage.get("input_tokens")
|
||||||
|
if input_tokens is None:
|
||||||
|
input_tokens = raw_usage.get("prompt_tokens")
|
||||||
|
|
||||||
|
output_tokens = raw_usage.get("output_tokens")
|
||||||
|
if output_tokens is None:
|
||||||
|
output_tokens = raw_usage.get("completion_tokens")
|
||||||
|
|
||||||
|
total_tokens = raw_usage.get("total_tokens")
|
||||||
|
if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int):
|
||||||
|
total_tokens = input_tokens + output_tokens
|
||||||
|
|
||||||
|
if not any(isinstance(v, int) for v in (input_tokens, output_tokens, total_tokens)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_tokens": int(input_tokens or 0),
|
||||||
|
"output_tokens": int(output_tokens or 0),
|
||||||
|
"total_tokens": int(total_tokens or 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _post_async(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=timeout_seconds) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _post_sync(url: str, headers: dict[str, str], payload: dict[str, Any], timeout_seconds: float) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
with httpx.Client(timeout=timeout_seconds) as client:
|
||||||
|
response = client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("[BillingMiddleware] HTTP request failed: url=%s err=%s", url, exc)
|
||||||
|
return None
|
||||||
|
|
@ -91,6 +91,14 @@ def _build_runtime_middlewares(
|
||||||
|
|
||||||
middlewares.append(DanglingToolCallMiddleware())
|
middlewares.append(DanglingToolCallMiddleware())
|
||||||
|
|
||||||
|
from deerflow.config.app_config import get_app_config
|
||||||
|
|
||||||
|
billing_cfg = get_app_config().billing
|
||||||
|
if billing_cfg.enabled and (include_uploads or billing_cfg.include_subagents):
|
||||||
|
from deerflow.agents.middlewares.billing_middleware import BillingMiddleware
|
||||||
|
|
||||||
|
middlewares.append(BillingMiddleware())
|
||||||
|
|
||||||
middlewares.append(LLMErrorHandlingMiddleware())
|
middlewares.append(LLMErrorHandlingMiddleware())
|
||||||
|
|
||||||
# Guardrail middleware (if configured)
|
# Guardrail middleware (if configured)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from .app_config import get_app_config
|
from .app_config import get_app_config
|
||||||
|
from .billing_config import BillingConfig
|
||||||
from .extensions_config import ExtensionsConfig, get_extensions_config
|
from .extensions_config import ExtensionsConfig, get_extensions_config
|
||||||
from .memory_config import MemoryConfig, get_memory_config
|
from .memory_config import MemoryConfig, get_memory_config
|
||||||
from .paths import Paths, get_paths
|
from .paths import Paths, get_paths
|
||||||
|
|
@ -13,6 +14,7 @@ from .tracing_config import (
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_app_config",
|
"get_app_config",
|
||||||
|
"BillingConfig",
|
||||||
"Paths",
|
"Paths",
|
||||||
"get_paths",
|
"get_paths",
|
||||||
"SkillsConfig",
|
"SkillsConfig",
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||||
|
from deerflow.config.billing_config import BillingConfig
|
||||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig
|
from deerflow.config.extensions_config import ExtensionsConfig
|
||||||
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
||||||
|
|
@ -40,6 +41,7 @@ class AppConfig(BaseModel):
|
||||||
"""Config for the DeerFlow application"""
|
"""Config for the DeerFlow application"""
|
||||||
|
|
||||||
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
|
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
|
||||||
|
billing: BillingConfig = Field(default_factory=BillingConfig, description="External billing reservation/finalization configuration")
|
||||||
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
|
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
|
||||||
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
|
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
|
||||||
sandbox: SandboxConfig = Field(description="Sandbox configuration")
|
sandbox: SandboxConfig = Field(description="Sandbox configuration")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""Configuration for reservation/finalization billing integration."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class BillingConfig(BaseModel):
|
||||||
|
"""Configuration for external billing reservation/finalization calls."""
|
||||||
|
|
||||||
|
enabled: bool = Field(default=False, description="Enable external billing middleware.")
|
||||||
|
include_subagents: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether billing applies to subagent model calls as well.",
|
||||||
|
)
|
||||||
|
fail_closed: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Block model calls when reserve request fails or balance is insufficient.",
|
||||||
|
)
|
||||||
|
block_only_specific_reserve_codes: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"When true, only reserve responses with codes in blocking_reserve_codes block model calls. "
|
||||||
|
"When false, fallback to fail_closed behavior for all reserve failures."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
blocking_reserve_codes: list[int] = Field(
|
||||||
|
default_factory=lambda: [-1104, -1106],
|
||||||
|
description="Reserve response codes that should block model calls when block_only_specific_reserve_codes is enabled.",
|
||||||
|
)
|
||||||
|
frozen_type: int = Field(
|
||||||
|
default=1,
|
||||||
|
ge=1,
|
||||||
|
description="Frozen type sent to the platform. Current flow uses 1 for token billing.",
|
||||||
|
)
|
||||||
|
reserve_url: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="HTTP(S) endpoint for creating frozen reservations.",
|
||||||
|
)
|
||||||
|
finalize_url: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="HTTP(S) endpoint for finalizing frozen reservations.",
|
||||||
|
)
|
||||||
|
headers: dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Extra HTTP headers included in reserve/finalize requests.",
|
||||||
|
)
|
||||||
|
timeout_seconds: float = Field(
|
||||||
|
default=10.0,
|
||||||
|
gt=0,
|
||||||
|
le=120,
|
||||||
|
description="HTTP request timeout for reserve/finalize calls.",
|
||||||
|
)
|
||||||
|
default_expire_seconds: int = Field(
|
||||||
|
default=1800,
|
||||||
|
ge=60,
|
||||||
|
le=86400,
|
||||||
|
description="Default reservation expiration seconds when expireAt is included.",
|
||||||
|
)
|
||||||
|
default_estimated_output_tokens: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
description="Fallback estimatedOutputTokens when model max_tokens is unavailable.",
|
||||||
|
)
|
||||||
|
|
@ -89,12 +89,13 @@ async def run_agent(
|
||||||
|
|
||||||
# Inject runtime context so middlewares can access thread_id
|
# Inject runtime context so middlewares can access thread_id
|
||||||
# (langgraph-cli does this automatically; we must do it manually)
|
# (langgraph-cli does this automatically; we must do it manually)
|
||||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
|
||||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
||||||
# prefers it over ``configurable`` for thread-level data), make
|
# prefers it over ``configurable`` for thread-level data), make
|
||||||
# sure ``thread_id`` is available there too.
|
# sure ``thread_id`` is available there too.
|
||||||
if "context" in config and isinstance(config["context"], dict):
|
if "context" in config and isinstance(config["context"], dict):
|
||||||
config["context"].setdefault("thread_id", thread_id)
|
config["context"].setdefault("thread_id", thread_id)
|
||||||
|
config["context"].setdefault("run_id", run_id)
|
||||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||||
|
|
||||||
runnable_config = RunnableConfig(**config)
|
runnable_config = RunnableConfig(**config)
|
||||||
|
|
|
||||||
|
|
@ -226,15 +226,18 @@ class SubagentExecutor:
|
||||||
try:
|
try:
|
||||||
agent = self._create_agent()
|
agent = self._create_agent()
|
||||||
state = self._build_initial_state(task)
|
state = self._build_initial_state(task)
|
||||||
|
subagent_model_name = _get_model_name(self.config, self.parent_model)
|
||||||
|
|
||||||
# Build config with thread_id for sandbox access and recursion limit
|
# Build config with thread_id for sandbox access and recursion limit
|
||||||
run_config: RunnableConfig = {
|
run_config: RunnableConfig = {
|
||||||
"recursion_limit": self.config.max_turns,
|
"recursion_limit": self.config.max_turns,
|
||||||
}
|
}
|
||||||
context = {}
|
context = {}
|
||||||
|
configurable: dict[str, Any] = {"model_name": subagent_model_name}
|
||||||
if self.thread_id:
|
if self.thread_id:
|
||||||
run_config["configurable"] = {"thread_id": self.thread_id}
|
configurable["thread_id"] = self.thread_id
|
||||||
context["thread_id"] = self.thread_id
|
context["thread_id"] = self.thread_id
|
||||||
|
run_config["configurable"] = configurable
|
||||||
|
|
||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution with max_turns={self.config.max_turns}")
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution with max_turns={self.config.max_turns}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,241 @@
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares.billing_middleware import BillingMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True):
|
||||||
|
billing = SimpleNamespace(
|
||||||
|
enabled=enabled,
|
||||||
|
include_subagents=include_subagents,
|
||||||
|
fail_closed=True,
|
||||||
|
block_only_specific_reserve_codes=True,
|
||||||
|
blocking_reserve_codes=[-1104, -1106],
|
||||||
|
frozen_type=1,
|
||||||
|
reserve_url="http://billing.local/accountFrozen/frozen",
|
||||||
|
finalize_url="http://billing.local/accountFrozen/release",
|
||||||
|
headers={"Authorization": "Bearer x"},
|
||||||
|
timeout_seconds=3.0,
|
||||||
|
default_expire_seconds=1800,
|
||||||
|
default_estimated_output_tokens=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096})
|
||||||
|
return SimpleNamespace(
|
||||||
|
billing=billing,
|
||||||
|
get_model_config=lambda name: model_cfg if name == "gpt-4" else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _request_with_latest_user_text(text: str):
|
||||||
|
request = MagicMock()
|
||||||
|
request.messages = [HumanMessage(content="old"), HumanMessage(content=text)]
|
||||||
|
request.model_settings = {}
|
||||||
|
request.runtime = SimpleNamespace(
|
||||||
|
config={"configurable": {"thread_id": "thread-1", "model_name": "gpt-4"}},
|
||||||
|
context={"thread_id": "thread-1"},
|
||||||
|
)
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awrap_model_call_uses_estimated_tokens_and_finalizes(monkeypatch):
|
||||||
|
from langchain_core.runnables.config import var_child_runnable_config
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares import billing_middleware as bm
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
||||||
|
|
||||||
|
seen_payloads = []
|
||||||
|
|
||||||
|
async def fake_post(url, headers, payload, timeout_seconds):
|
||||||
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
||||||
|
if url.endswith("/frozen"):
|
||||||
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
||||||
|
return {"status": 1000, "message": "ok", "data": {}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
||||||
|
|
||||||
|
middleware = BillingMiddleware()
|
||||||
|
request = _request_with_latest_user_text("hello world")
|
||||||
|
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 11, "output_tokens": 22, "total_tokens": 33}))
|
||||||
|
|
||||||
|
token = var_child_runnable_config.set({"run_id": "run-1"})
|
||||||
|
try:
|
||||||
|
result = await middleware.awrap_model_call(request, handler)
|
||||||
|
finally:
|
||||||
|
var_child_runnable_config.reset(token)
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
assert len(seen_payloads) == 2
|
||||||
|
|
||||||
|
reserve_payload = seen_payloads[0][2]
|
||||||
|
assert reserve_payload["callId"] == "run-1"
|
||||||
|
assert reserve_payload["frozenType"] == 1
|
||||||
|
assert reserve_payload["estimatedInputTokens"] == len("hello world")
|
||||||
|
assert reserve_payload["estimatedOutputTokens"] == 4096
|
||||||
|
assert "frozenAmount" not in reserve_payload
|
||||||
|
|
||||||
|
finalize_payload = seen_payloads[1][2]
|
||||||
|
assert finalize_payload["frozenId"] == "frozen-123"
|
||||||
|
assert finalize_payload["finalAmount"] == 0
|
||||||
|
assert finalize_payload["usageInputTokens"] == 11
|
||||||
|
assert finalize_payload["usageOutputTokens"] == 22
|
||||||
|
assert finalize_payload["usageTotalTokens"] == 33
|
||||||
|
assert finalize_payload["finalizeReason"] == "success"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awrap_model_call_fail_closed_on_insufficient_balance(monkeypatch):
|
||||||
|
from deerflow.agents.middlewares import billing_middleware as bm
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
||||||
|
|
||||||
|
async def fake_post(url, headers, payload, timeout_seconds):
|
||||||
|
return {"status": -1106, "message": "insufficient balance", "data": {}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
||||||
|
|
||||||
|
middleware = BillingMiddleware()
|
||||||
|
request = _request_with_latest_user_text("question")
|
||||||
|
handler = AsyncMock(return_value=AIMessage(content="should not run"))
|
||||||
|
|
||||||
|
result = await middleware.awrap_model_call(request, handler)
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
assert "insufficient" in str(result.content).lower()
|
||||||
|
handler.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awrap_model_call_finalize_uses_state_messages_usage_when_response_missing_usage(monkeypatch):
|
||||||
|
from deerflow.agents.middlewares import billing_middleware as bm
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
||||||
|
|
||||||
|
seen_payloads = []
|
||||||
|
|
||||||
|
async def fake_post(url, headers, payload, timeout_seconds):
|
||||||
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
||||||
|
if url.endswith("/frozen"):
|
||||||
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
||||||
|
return {"status": 1000, "message": "ok", "data": {}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
||||||
|
|
||||||
|
middleware = BillingMiddleware()
|
||||||
|
request = _request_with_latest_user_text("hello world")
|
||||||
|
request.state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(content="hello world"),
|
||||||
|
AIMessage(content="ok", usage_metadata={"input_tokens": 101, "output_tokens": 202, "total_tokens": 303}),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
handler = AsyncMock(return_value=AIMessage(content="ok"))
|
||||||
|
|
||||||
|
result = await middleware.awrap_model_call(request, handler)
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
assert len(seen_payloads) == 2
|
||||||
|
|
||||||
|
finalize_payload = seen_payloads[1][2]
|
||||||
|
assert finalize_payload["frozenId"] == "frozen-123"
|
||||||
|
assert finalize_payload["usageInputTokens"] == 101
|
||||||
|
assert finalize_payload["usageOutputTokens"] == 202
|
||||||
|
assert finalize_payload["usageTotalTokens"] == 303
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awrap_model_call_does_not_block_on_non_blocking_reserve_code(monkeypatch):
|
||||||
|
from deerflow.agents.middlewares import billing_middleware as bm
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
||||||
|
|
||||||
|
async def fake_post(url, headers, payload, timeout_seconds):
|
||||||
|
if url.endswith("/frozen"):
|
||||||
|
return {"status": 5001, "message": "platform busy", "data": {}}
|
||||||
|
return {"status": 1000, "message": "ok", "data": {}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
||||||
|
|
||||||
|
middleware = BillingMiddleware()
|
||||||
|
request = _request_with_latest_user_text("question")
|
||||||
|
handler = AsyncMock(return_value=AIMessage(content="model-ran"))
|
||||||
|
|
||||||
|
result = await middleware.awrap_model_call(request, handler)
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
assert result.content == "model-ran"
|
||||||
|
handler.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awrap_model_call_uses_runnable_config_run_id(monkeypatch):
|
||||||
|
"""run_id is sourced from var_child_runnable_config, which LangGraph populates
|
||||||
|
via langgraph_api/stream.py during graph node execution."""
|
||||||
|
from langchain_core.runnables.config import var_child_runnable_config
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares import billing_middleware as bm
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
||||||
|
|
||||||
|
seen_payloads = []
|
||||||
|
|
||||||
|
async def fake_post(url, headers, payload, timeout_seconds):
|
||||||
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
||||||
|
if url.endswith("/frozen"):
|
||||||
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
||||||
|
return {"status": 1000, "message": "ok", "data": {}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
||||||
|
|
||||||
|
middleware = BillingMiddleware()
|
||||||
|
request = _request_with_latest_user_text("hello world")
|
||||||
|
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
|
||||||
|
|
||||||
|
token = var_child_runnable_config.set({"run_id": "run-from-ctx"})
|
||||||
|
try:
|
||||||
|
result = await middleware.awrap_model_call(request, handler)
|
||||||
|
finally:
|
||||||
|
var_child_runnable_config.reset(token)
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
reserve_payload = seen_payloads[0][2]
|
||||||
|
assert reserve_payload["callId"] == "run-from-ctx"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awrap_model_call_uses_worker_config_fallback_run_id(monkeypatch):
|
||||||
|
"""Fallback: run_id from langgraph_api.logging.worker_config when var_child_runnable_config is unset."""
|
||||||
|
from deerflow.agents.middlewares import billing_middleware as bm
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
||||||
|
|
||||||
|
seen_payloads = []
|
||||||
|
|
||||||
|
async def fake_post(url, headers, payload, timeout_seconds):
|
||||||
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
||||||
|
if url.endswith("/frozen"):
|
||||||
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
||||||
|
return {"status": 1000, "message": "ok", "data": {}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
||||||
|
|
||||||
|
import langgraph_api.logging as lg_logging
|
||||||
|
|
||||||
|
middleware = BillingMiddleware()
|
||||||
|
request = _request_with_latest_user_text("hello world")
|
||||||
|
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
|
||||||
|
|
||||||
|
token = lg_logging.worker_config.set({"run_id": "run-from-worker"})
|
||||||
|
try:
|
||||||
|
result = await middleware.awrap_model_call(request, handler)
|
||||||
|
finally:
|
||||||
|
lg_logging.worker_config.reset(token)
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
reserve_payload = seen_payloads[0][2]
|
||||||
|
assert reserve_payload["callId"] == "run-from-worker"
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Bump this number when the config schema changes.
|
# Bump this number when the config schema changes.
|
||||||
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
||||||
config_version: 5
|
config_version: 7
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Logging
|
# Logging
|
||||||
|
|
@ -20,6 +20,35 @@ config_version: 5
|
||||||
# Log level for deerflow modules (debug/info/warning/error)
|
# Log level for deerflow modules (debug/info/warning/error)
|
||||||
log_level: info
|
log_level: info
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Billing Reservation/Finalization
|
||||||
|
# ============================================================================
|
||||||
|
# Reserve before each LLM call and finalize after call completion.
|
||||||
|
# Keep this independent from token_usage reporting.
|
||||||
|
|
||||||
|
billing:
|
||||||
|
enabled: false
|
||||||
|
include_subagents: false
|
||||||
|
fail_closed: true
|
||||||
|
# true: only block when reserve returns a code in blocking_reserve_codes
|
||||||
|
# false: fallback to fail_closed behavior for all reserve failures
|
||||||
|
block_only_specific_reserve_codes: true
|
||||||
|
blocking_reserve_codes: [-1104, -1106]
|
||||||
|
frozen_type: 1
|
||||||
|
timeout_seconds: 10
|
||||||
|
default_expire_seconds: 1800
|
||||||
|
|
||||||
|
# When model config has no max_tokens, this fallback is used for
|
||||||
|
# estimatedOutputTokens. If unset and fail_closed=true, billing blocks calls.
|
||||||
|
# default_estimated_output_tokens: 4096
|
||||||
|
|
||||||
|
# reserve_url: "http://localhost:19001/accountFrozen/frozen"
|
||||||
|
# finalize_url: "http://localhost:19001/accountFrozen/release"
|
||||||
|
|
||||||
|
# headers:
|
||||||
|
# Authorization: "Bearer your-secret-token"
|
||||||
|
# X-App-Id: "deer-flow"
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Token Usage Tracking
|
# Token Usage Tracking
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue