From f584c3e53b5476da0e132a0d940922287594774d Mon Sep 17 00:00:00 2001 From: Titan Date: Thu, 30 Apr 2026 10:54:08 +0800 Subject: [PATCH] feat: enhance billing integration with usage token extraction and API key handling --- backend/app/gateway/routers/third_party.py | 111 +++++++++++++++++- .../app/gateway/third_party_proxy/billing.py | 38 ++++-- .../app/gateway/third_party_proxy/ledger.py | 7 +- .../app/gateway/third_party_proxy/proxy.py | 77 ++++++++---- .../config/third_party_proxy_config.py | 10 ++ backend/tests/test_third_party_proxy.py | 67 ++++++++++- config.example.yaml | 1 + 7 files changed, 272 insertions(+), 39 deletions(-) diff --git a/backend/app/gateway/routers/third_party.py b/backend/app/gateway/routers/third_party.py index fbc14f8a..3bda4d7e 100644 --- a/backend/app/gateway/routers/third_party.py +++ b/backend/app/gateway/routers/third_party.py @@ -67,11 +67,13 @@ async def proxy_request(provider: str, path: str, request: Request) -> Response: path=path, request=request, body=body, + request_json=request_json, thread_id=thread_id, idempotency_key=idempotency_key, task_id_jsonpath=submit_route.task_id_jsonpath, route_frozen_amount=submit_route.frozen_amount, route_frozen_type=submit_route.frozen_type, + route_frozen_token=submit_route.frozen_token, ) if query_route: @@ -109,11 +111,13 @@ async def _handle_submit( path: str, request: Request, body: bytes, + request_json: dict[str, Any] | None, thread_id: str | None, idempotency_key: str | None, task_id_jsonpath: str, route_frozen_amount: float | None, route_frozen_type: int | None, + route_frozen_token: int | None, ) -> Response: ledger = get_ledger() @@ -129,6 +133,7 @@ async def _handle_submit( # Reserve billing before touching the provider reserve_frozen_amount = route_frozen_amount if route_frozen_amount is not None else provider_config.frozen_amount reserve_frozen_type = route_frozen_type if route_frozen_type is not None else provider_config.frozen_type + reserve_frozen_token = route_frozen_token if route_frozen_token is not None else provider_config.frozen_token frozen_id = await billing.reserve( thread_id=thread_id, call_id=record.call_id, @@ -136,9 +141,11 @@ async def _handle_submit( operation=path, frozen_amount=reserve_frozen_amount, frozen_type=reserve_frozen_type, + frozen_token=reserve_frozen_token, + request_payload=request_json, ) if frozen_id: - ledger.set_reserved(record.proxy_call_id, frozen_id) + ledger.set_reserved(record.proxy_call_id, frozen_id, reserve_frozen_type) # Forward to provider try: @@ -156,6 +163,32 @@ async def _handle_submit( resp_json = _try_parse_json(resp_body) + if resp_json is None: + if frozen_id and reserve_frozen_type == 1: + usage_input_tokens, usage_output_tokens = _extract_usage_tokens_from_submit_stream(resp_body) + logger.debug( + "[ThirdPartyProxy] submit stream usage resolved: proxy_call_id=%s usage_input_tokens=%s usage_output_tokens=%s", + record.proxy_call_id, + usage_input_tokens, + usage_output_tokens, + ) + + if ledger.try_claim_finalize(record.proxy_call_id): + ok = await billing.finalize( + frozen_id=frozen_id, + final_amount=0.0, + finalize_reason="success", + usage_input_tokens=usage_input_tokens, + usage_output_tokens=usage_output_tokens, + ) + if ok: + ledger.set_finalized(record.proxy_call_id, "SUCCESS") + else: + ledger.set_finalize_failed(record.proxy_call_id, "FAILED") + + media_type = resp_headers.get("content-type") + return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type=media_type) + # HTTP-level failure if status_code >= 400: reason = f"error_http_{status_code}" @@ -272,17 +305,30 @@ async def _handle_query( "[ThirdPartyProxy] finalize claimed: proxy_call_id=%s", record.proxy_call_id, ) - final_amount: float = 0.0 + resolved_frozen_type = ( + record.frozen_type if record.frozen_type is not None else provider_config.frozen_type + ) + + usage_input_tokens = 0 + usage_output_tokens = 0 usage_paths = list(query_route.usage_jsonpaths or []) if not usage_paths and query_route.usage_jsonpath: usage_paths = [query_route.usage_jsonpath] + + final_amount: float = 0.0 if is_success: - final_amount = _resolve_final_amount(resp_json, query_route) + if resolved_frozen_type == 1: + usage_input_tokens, usage_output_tokens = _extract_usage_tokens(resp_json) + else: + final_amount = _resolve_final_amount(resp_json, query_route) logger.debug( - "[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s final_amount=%s usage_paths=%s legacy_path=%s", + "[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s frozen_type=%s final_amount=%s usage_input_tokens=%s usage_output_tokens=%s usage_paths=%s legacy_path=%s", record.proxy_call_id, + resolved_frozen_type, final_amount, + usage_input_tokens, + usage_output_tokens, usage_paths, query_route.usage_jsonpath, ) @@ -303,6 +349,8 @@ async def _handle_query( frozen_id=record.frozen_id, final_amount=final_amount, finalize_reason=finalize_reason, + usage_input_tokens=usage_input_tokens, + usage_output_tokens=usage_output_tokens, ) logger.info( "[ThirdPartyProxy] finalize result: proxy_call_id=%s ok=%s", @@ -415,6 +463,61 @@ def _resolve_final_amount(resp_json: dict[str, Any], query_route) -> float: return total +def _extract_usage_tokens(resp_json: dict[str, Any]) -> tuple[int, int]: + usage = resp_json.get("usage") + if not isinstance(usage, dict): + return 0, 0 + + input_tokens = _as_int(usage.get("input_tokens")) + if input_tokens == 0: + input_tokens = _as_int(usage.get("prompt_tokens")) + + output_tokens = _as_int(usage.get("output_tokens")) + if output_tokens == 0: + output_tokens = _as_int(usage.get("completion_tokens")) + + return input_tokens, output_tokens + + +def _extract_usage_tokens_from_submit_stream(resp_body: bytes) -> tuple[int, int]: + """Extract usage tokens from the final SSE chunk in a submit stream response.""" + if not resp_body: + return 0, 0 + + input_tokens = 0 + output_tokens = 0 + for raw_line in resp_body.splitlines(): + line = raw_line.decode("utf-8", errors="replace").strip() + if not line.startswith("data:"): + continue + payload_str = line[5:].strip() + if not payload_str or payload_str == "[DONE]": + continue + try: + payload = json.loads(payload_str) + except (json.JSONDecodeError, ValueError): + continue + if isinstance(payload, dict): + in_tokens, out_tokens = _extract_usage_tokens(payload) + if in_tokens or out_tokens: + input_tokens, output_tokens = in_tokens, out_tokens + + return input_tokens, output_tokens + + +def _as_int(value: Any) -> int: + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + if isinstance(value, str): + try: + return int(float(value)) + except ValueError: + return 0 + return 0 + + def _proxy_response( data: dict[str, Any], proxy_call_id: str | None, diff --git a/backend/app/gateway/third_party_proxy/billing.py b/backend/app/gateway/third_party_proxy/billing.py index 0c863670..59efd2e2 100644 --- a/backend/app/gateway/third_party_proxy/billing.py +++ b/backend/app/gateway/third_party_proxy/billing.py @@ -10,6 +10,7 @@ from __future__ import annotations import logging from datetime import datetime, timedelta +from typing import Any import httpx @@ -28,6 +29,8 @@ async def reserve( operation: str, frozen_amount: float, frozen_type: int | None, + frozen_token: int = 0, + request_payload: dict[str, Any] | None = None, ) -> str | None: """Reserve billing before forwarding a submit call. @@ -44,19 +47,25 @@ async def reserve( ) return None + resolved_frozen_type = frozen_type if frozen_type is not None else cfg.frozen_type expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds) - payload = { + payload: dict[str, Any] = { "sessionId": thread_id, "callId": call_id, - "modelName": provider, + "modelName": _extract_model_name(request_payload) or provider, "question": f"skill invokes {operation.split('/')[-1]}", - "frozenAmount": frozen_amount, - "frozenType": frozen_type if frozen_type is not None else cfg.frozen_type, - "estimatedInputTokens": 0, - "estimatedOutputTokens": 0, + "frozenType": resolved_frozen_type, "expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"), } + if resolved_frozen_type == 1: + payload["estimatedInputTokens"] = int(frozen_token) + payload["estimatedOutputTokens"] = int(frozen_token) + else: + payload["frozenAmount"] = frozen_amount + payload["estimatedInputTokens"] = 0 + payload["estimatedOutputTokens"] = 0 + logger.info( "[ThirdPartyProxy][Billing] reserve request: url=%s call_id=%s provider=%s thread_id=%s", cfg.reserve_url, @@ -114,6 +123,8 @@ async def finalize( frozen_id: str, final_amount: float, finalize_reason: str, + usage_input_tokens: int = 0, + usage_output_tokens: int = 0, ) -> bool: """Finalize billing after a third-party call reaches a terminal state. @@ -135,9 +146,9 @@ async def finalize( payload = { "frozenId": frozen_id, "finalAmount": final_amount, - "usageInputTokens": 0, - "usageOutputTokens": 0, - "usageTotalTokens": 0, + "usageInputTokens": usage_input_tokens, + "usageOutputTokens": usage_output_tokens, + "usageTotalTokens": usage_input_tokens + usage_output_tokens, "finalizeReason": finalize_reason, } @@ -188,3 +199,12 @@ def _is_success(data: dict) -> bool: if isinstance(status, int) and status in _SUCCESS_STATUS_CODES: return True return data.get("success") is True + + +def _extract_model_name(request_payload: dict[str, Any] | None) -> str | None: + if not isinstance(request_payload, dict): + return None + model = request_payload.get("model") + if isinstance(model, str) and model: + return model + return None diff --git a/backend/app/gateway/third_party_proxy/ledger.py b/backend/app/gateway/third_party_proxy/ledger.py index 42a02861..c8fcf569 100644 --- a/backend/app/gateway/third_party_proxy/ledger.py +++ b/backend/app/gateway/third_party_proxy/ledger.py @@ -27,6 +27,7 @@ class CallRecord: # call_id is sent to the billing platform (callId in reserve payload) call_id: str frozen_id: str | None = None + frozen_type: int | None = None provider_task_id: str | None = None billing_state: BillingState = "UNRESERVED" task_state: TaskState = "PENDING" @@ -109,16 +110,18 @@ class CallLedger: def get_by_idempotency_key(self, provider: str, idempotency_key: str) -> CallRecord | None: return self._get_by_idem_key_locked(provider, idempotency_key) - def set_reserved(self, proxy_call_id: str, frozen_id: str) -> None: + def set_reserved(self, proxy_call_id: str, frozen_id: str, frozen_type: int | None = None) -> None: with self._lock: record = self._records.get(proxy_call_id) if record: record.frozen_id = frozen_id + record.frozen_type = frozen_type record.billing_state = "RESERVED" logger.info( - "[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s", + "[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s frozen_type=%s", proxy_call_id, frozen_id, + frozen_type, ) # logger.debug( # "[ThirdPartyProxy][Ledger] reserve state: call_id=%s provider=%s task_state=%s", diff --git a/backend/app/gateway/third_party_proxy/proxy.py b/backend/app/gateway/third_party_proxy/proxy.py index 912814e6..573e9751 100644 --- a/backend/app/gateway/third_party_proxy/proxy.py +++ b/backend/app/gateway/third_party_proxy/proxy.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import logging import os from typing import Any @@ -17,16 +18,7 @@ from deerflow.config.third_party_proxy_config import ( logger = logging.getLogger(__name__) -_SENSITIVE_HEADERS = frozenset( - [ - "authorization", - "proxy-authorization", - "x-api-key", - "api-key", - "cookie", - "set-cookie", - ] -) +API_KEY_MARKER = "__API_KEY_MARKER__" # --------------------------------------------------------------------------- # Provider config lookup @@ -154,17 +146,6 @@ _STRIP_RESPONSE_HEADERS = frozenset( ) -def _sanitize_headers(headers: dict[str, str]) -> dict[str, str]: - """Return a copy of headers with sensitive values redacted.""" - sanitized: dict[str, str] = {} - for key, value in headers.items(): - if key.lower() in _SENSITIVE_HEADERS: - sanitized[key] = "***" - else: - sanitized[key] = value - return sanitized - - def _preview_body(data: bytes, limit: int = 2048) -> str: """Return a safe textual preview of body bytes for debugging logs.""" if not data: @@ -176,6 +157,53 @@ def _preview_body(data: bytes, limit: int = 2048) -> str: return text +def _replace_api_key_marker_in_headers(headers: dict[str, str], api_key: str) -> dict[str, str]: + """Replace API key marker placeholders in header values.""" + replaced: dict[str, str] = {} + for key, value in headers.items(): + if isinstance(value, str) and API_KEY_MARKER in value: + replaced[key] = value.replace(API_KEY_MARKER, api_key) + else: + replaced[key] = value + return replaced + + +def _header_value(headers: dict[str, str], key: str) -> str | None: + target = key.lower() + for h_key, h_val in headers.items(): + if h_key.lower() == target: + return h_val + return None + + +def _replace_api_key_marker_in_json(data: Any, api_key: str) -> Any: + if isinstance(data, str): + return data.replace(API_KEY_MARKER, api_key) + if isinstance(data, list): + return [_replace_api_key_marker_in_json(item, api_key) for item in data] + if isinstance(data, dict): + return {k: _replace_api_key_marker_in_json(v, api_key) for k, v in data.items()} + return data + + +def _replace_api_key_marker_in_body(headers: dict[str, str], body: bytes, api_key: str) -> bytes: + """Replace API key marker in JSON body payloads only.""" + if not body: + return body + + content_type = _header_value(headers, "content-type") or "" + if "application/json" not in content_type.lower(): + return body + + try: + parsed = json.loads(body) + except (json.JSONDecodeError, ValueError): + return body + + replaced = _replace_api_key_marker_in_json(parsed, api_key) + return json.dumps(replaced, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + + async def forward_request( *, provider_config: ThirdPartyProviderConfig, @@ -202,6 +230,9 @@ async def forward_request( if provider_config.api_key_env: api_key = os.getenv(provider_config.api_key_env) if api_key: + # Dependency-injection style: replace marker placeholders first. + forward_headers = _replace_api_key_marker_in_headers(forward_headers, api_key) + body = _replace_api_key_marker_in_body(forward_headers, body, api_key) forward_headers[provider_config.api_key_header] = provider_config.api_key_prefix + api_key else: logger.warning( @@ -212,7 +243,7 @@ async def forward_request( logger.info("[ThirdPartyProxy] → %s %s", method, target_url) logger.debug( "[ThirdPartyProxy] request headers=%s", - _sanitize_headers(forward_headers) + forward_headers, ) logger.debug( "[ThirdPartyProxy] request body(%dB)=%s", @@ -236,7 +267,7 @@ async def forward_request( logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code) logger.debug( "[ThirdPartyProxy] response headers=%s", - _sanitize_headers(response_headers) + response_headers, ) logger.debug( "[ThirdPartyProxy] response body(%dB)=%s", diff --git a/backend/packages/harness/deerflow/config/third_party_proxy_config.py b/backend/packages/harness/deerflow/config/third_party_proxy_config.py index 81f9d3a5..890c2036 100644 --- a/backend/packages/harness/deerflow/config/third_party_proxy_config.py +++ b/backend/packages/harness/deerflow/config/third_party_proxy_config.py @@ -28,6 +28,11 @@ class SubmitRouteConfig(BaseModel): default=None, description="Optional route-level override for billing reserve payload frozenType", ) + frozen_token: int | None = Field( + default=None, + ge=0, + description="Optional route-level override for billing reserve payload estimatedInputTokens/estimatedOutputTokens when frozenType=1", + ) class QueryRouteConfig(BaseModel): @@ -96,6 +101,11 @@ class ThirdPartyProviderConfig(BaseModel): default=None, description="Billing frozen type for this provider (frozenType). If omitted, falls back to billing.frozen_type", ) + frozen_token: int = Field( + default=0, + ge=0, + description="Estimated token amount used for reserve payload when frozenType=1", + ) submit_routes: list[SubmitRouteConfig] = Field( default_factory=list, description="Route patterns that identify submit (task-create) requests", diff --git a/backend/tests/test_third_party_proxy.py b/backend/tests/test_third_party_proxy.py index b0da77fd..c5a9ba86 100644 --- a/backend/tests/test_third_party_proxy.py +++ b/backend/tests/test_third_party_proxy.py @@ -3,9 +3,16 @@ from __future__ import annotations from app.gateway.third_party_proxy.ledger import CallLedger -from app.gateway.routers.third_party import _resolve_final_amount +from app.gateway.routers.third_party import ( + _extract_usage_tokens, + _extract_usage_tokens_from_submit_stream, + _resolve_final_amount, +) from app.gateway.third_party_proxy.proxy import ( + API_KEY_MARKER, _path_matches, + _replace_api_key_marker_in_body, + _replace_api_key_marker_in_headers, jsonpath_get, match_query_route, match_submit_route, @@ -225,3 +232,61 @@ class TestResolveFinalAmount: resp_json = {"usage": {"thirdPartyConsumeMoney": "1.5"}} amount = _resolve_final_amount(resp_json, route) assert amount == 1.5 + + +class TestExtractUsageTokens: + def test_prefers_openai_usage_keys(self): + resp_json = { + "usage": { + "prompt_tokens": 123, + "completion_tokens": 45, + } + } + input_tokens, output_tokens = _extract_usage_tokens(resp_json) + assert input_tokens == 123 + assert output_tokens == 45 + + def test_supports_generic_usage_keys(self): + resp_json = { + "usage": { + "input_tokens": "88", + "output_tokens": "12", + } + } + input_tokens, output_tokens = _extract_usage_tokens(resp_json) + assert input_tokens == 88 + assert output_tokens == 12 + + +class TestExtractUsageTokensFromSubmitStream: + def test_extracts_usage_from_final_sse_chunk(self): + body = ( + b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n' + b'data: {"id":"x","choices":[],"usage":{"prompt_tokens":22,"completion_tokens":17}}\n\n' + b'data: [DONE]\n\n' + ) + input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body) + assert input_tokens == 22 + assert output_tokens == 17 + + def test_returns_zero_when_no_usage_found(self): + body = b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n' + input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body) + assert input_tokens == 0 + assert output_tokens == 0 + + +class TestApiKeyMarkerReplacement: + def test_replace_marker_in_headers(self): + headers = {"Authorization": f"Bearer {API_KEY_MARKER}", "Content-Type": "application/json"} + replaced = _replace_api_key_marker_in_headers(headers, "real-key") + assert replaced["Authorization"] == "Bearer real-key" + + def test_replace_marker_in_json_body(self): + headers = {"Content-Type": "application/json"} + body = ( + b'{"apiKey":"__API_KEY_MARKER__","nested":{"token":"Bearer __API_KEY_MARKER__"}}' + ) + replaced = _replace_api_key_marker_in_body(headers, body, "real-key") + assert b'"apiKey":"real-key"' in replaced + assert b'"token":"Bearer real-key"' in replaced diff --git a/config.example.yaml b/config.example.yaml index f05cc496..0e912161 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -105,6 +105,7 @@ third_party_proxy: api_key_header: Authorization api_key_prefix: "Bearer " timeout_seconds: 60.0 + frozen_token: 32768 submit_routes: - path_pattern: "/compatible-mode/v1/chat/completions" task_id_jsonpath: "id"