feat: enhance billing integration with usage token extraction and API key handling

This commit is contained in:
Titan 2026-04-30 10:54:08 +08:00
parent 169332ab29
commit f584c3e53b
7 changed files with 272 additions and 39 deletions

View File

@ -67,11 +67,13 @@ async def proxy_request(provider: str, path: str, request: Request) -> Response:
path=path,
request=request,
body=body,
request_json=request_json,
thread_id=thread_id,
idempotency_key=idempotency_key,
task_id_jsonpath=submit_route.task_id_jsonpath,
route_frozen_amount=submit_route.frozen_amount,
route_frozen_type=submit_route.frozen_type,
route_frozen_token=submit_route.frozen_token,
)
if query_route:
@ -109,11 +111,13 @@ async def _handle_submit(
path: str,
request: Request,
body: bytes,
request_json: dict[str, Any] | None,
thread_id: str | None,
idempotency_key: str | None,
task_id_jsonpath: str,
route_frozen_amount: float | None,
route_frozen_type: int | None,
route_frozen_token: int | None,
) -> Response:
ledger = get_ledger()
@ -129,6 +133,7 @@ async def _handle_submit(
# Reserve billing before touching the provider
reserve_frozen_amount = route_frozen_amount if route_frozen_amount is not None else provider_config.frozen_amount
reserve_frozen_type = route_frozen_type if route_frozen_type is not None else provider_config.frozen_type
reserve_frozen_token = route_frozen_token if route_frozen_token is not None else provider_config.frozen_token
frozen_id = await billing.reserve(
thread_id=thread_id,
call_id=record.call_id,
@ -136,9 +141,11 @@ async def _handle_submit(
operation=path,
frozen_amount=reserve_frozen_amount,
frozen_type=reserve_frozen_type,
frozen_token=reserve_frozen_token,
request_payload=request_json,
)
if frozen_id:
ledger.set_reserved(record.proxy_call_id, frozen_id)
ledger.set_reserved(record.proxy_call_id, frozen_id, reserve_frozen_type)
# Forward to provider
try:
@ -156,6 +163,32 @@ async def _handle_submit(
resp_json = _try_parse_json(resp_body)
if resp_json is None:
if frozen_id and reserve_frozen_type == 1:
usage_input_tokens, usage_output_tokens = _extract_usage_tokens_from_submit_stream(resp_body)
logger.debug(
"[ThirdPartyProxy] submit stream usage resolved: proxy_call_id=%s usage_input_tokens=%s usage_output_tokens=%s",
record.proxy_call_id,
usage_input_tokens,
usage_output_tokens,
)
if ledger.try_claim_finalize(record.proxy_call_id):
ok = await billing.finalize(
frozen_id=frozen_id,
final_amount=0.0,
finalize_reason="success",
usage_input_tokens=usage_input_tokens,
usage_output_tokens=usage_output_tokens,
)
if ok:
ledger.set_finalized(record.proxy_call_id, "SUCCESS")
else:
ledger.set_finalize_failed(record.proxy_call_id, "FAILED")
media_type = resp_headers.get("content-type")
return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type=media_type)
# HTTP-level failure
if status_code >= 400:
reason = f"error_http_{status_code}"
@ -272,17 +305,30 @@ async def _handle_query(
"[ThirdPartyProxy] finalize claimed: proxy_call_id=%s",
record.proxy_call_id,
)
final_amount: float = 0.0
resolved_frozen_type = (
record.frozen_type if record.frozen_type is not None else provider_config.frozen_type
)
usage_input_tokens = 0
usage_output_tokens = 0
usage_paths = list(query_route.usage_jsonpaths or [])
if not usage_paths and query_route.usage_jsonpath:
usage_paths = [query_route.usage_jsonpath]
final_amount: float = 0.0
if is_success:
if resolved_frozen_type == 1:
usage_input_tokens, usage_output_tokens = _extract_usage_tokens(resp_json)
else:
final_amount = _resolve_final_amount(resp_json, query_route)
logger.debug(
"[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s final_amount=%s usage_paths=%s legacy_path=%s",
"[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s frozen_type=%s final_amount=%s usage_input_tokens=%s usage_output_tokens=%s usage_paths=%s legacy_path=%s",
record.proxy_call_id,
resolved_frozen_type,
final_amount,
usage_input_tokens,
usage_output_tokens,
usage_paths,
query_route.usage_jsonpath,
)
@ -303,6 +349,8 @@ async def _handle_query(
frozen_id=record.frozen_id,
final_amount=final_amount,
finalize_reason=finalize_reason,
usage_input_tokens=usage_input_tokens,
usage_output_tokens=usage_output_tokens,
)
logger.info(
"[ThirdPartyProxy] finalize result: proxy_call_id=%s ok=%s",
@ -415,6 +463,61 @@ def _resolve_final_amount(resp_json: dict[str, Any], query_route) -> float:
return total
def _extract_usage_tokens(resp_json: dict[str, Any]) -> tuple[int, int]:
usage = resp_json.get("usage")
if not isinstance(usage, dict):
return 0, 0
input_tokens = _as_int(usage.get("input_tokens"))
if input_tokens == 0:
input_tokens = _as_int(usage.get("prompt_tokens"))
output_tokens = _as_int(usage.get("output_tokens"))
if output_tokens == 0:
output_tokens = _as_int(usage.get("completion_tokens"))
return input_tokens, output_tokens
def _extract_usage_tokens_from_submit_stream(resp_body: bytes) -> tuple[int, int]:
"""Extract usage tokens from the final SSE chunk in a submit stream response."""
if not resp_body:
return 0, 0
input_tokens = 0
output_tokens = 0
for raw_line in resp_body.splitlines():
line = raw_line.decode("utf-8", errors="replace").strip()
if not line.startswith("data:"):
continue
payload_str = line[5:].strip()
if not payload_str or payload_str == "[DONE]":
continue
try:
payload = json.loads(payload_str)
except (json.JSONDecodeError, ValueError):
continue
if isinstance(payload, dict):
in_tokens, out_tokens = _extract_usage_tokens(payload)
if in_tokens or out_tokens:
input_tokens, output_tokens = in_tokens, out_tokens
return input_tokens, output_tokens
def _as_int(value: Any) -> int:
if isinstance(value, int):
return value
if isinstance(value, float):
return int(value)
if isinstance(value, str):
try:
return int(float(value))
except ValueError:
return 0
return 0
def _proxy_response(
data: dict[str, Any],
proxy_call_id: str | None,

View File

@ -10,6 +10,7 @@ from __future__ import annotations
import logging
from datetime import datetime, timedelta
from typing import Any
import httpx
@ -28,6 +29,8 @@ async def reserve(
operation: str,
frozen_amount: float,
frozen_type: int | None,
frozen_token: int = 0,
request_payload: dict[str, Any] | None = None,
) -> str | None:
"""Reserve billing before forwarding a submit call.
@ -44,19 +47,25 @@ async def reserve(
)
return None
resolved_frozen_type = frozen_type if frozen_type is not None else cfg.frozen_type
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
payload = {
payload: dict[str, Any] = {
"sessionId": thread_id,
"callId": call_id,
"modelName": provider,
"modelName": _extract_model_name(request_payload) or provider,
"question": f"skill invokes {operation.split('/')[-1]}",
"frozenAmount": frozen_amount,
"frozenType": frozen_type if frozen_type is not None else cfg.frozen_type,
"estimatedInputTokens": 0,
"estimatedOutputTokens": 0,
"frozenType": resolved_frozen_type,
"expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"),
}
if resolved_frozen_type == 1:
payload["estimatedInputTokens"] = int(frozen_token)
payload["estimatedOutputTokens"] = int(frozen_token)
else:
payload["frozenAmount"] = frozen_amount
payload["estimatedInputTokens"] = 0
payload["estimatedOutputTokens"] = 0
logger.info(
"[ThirdPartyProxy][Billing] reserve request: url=%s call_id=%s provider=%s thread_id=%s",
cfg.reserve_url,
@ -114,6 +123,8 @@ async def finalize(
frozen_id: str,
final_amount: float,
finalize_reason: str,
usage_input_tokens: int = 0,
usage_output_tokens: int = 0,
) -> bool:
"""Finalize billing after a third-party call reaches a terminal state.
@ -135,9 +146,9 @@ async def finalize(
payload = {
"frozenId": frozen_id,
"finalAmount": final_amount,
"usageInputTokens": 0,
"usageOutputTokens": 0,
"usageTotalTokens": 0,
"usageInputTokens": usage_input_tokens,
"usageOutputTokens": usage_output_tokens,
"usageTotalTokens": usage_input_tokens + usage_output_tokens,
"finalizeReason": finalize_reason,
}
@ -188,3 +199,12 @@ def _is_success(data: dict) -> bool:
if isinstance(status, int) and status in _SUCCESS_STATUS_CODES:
return True
return data.get("success") is True
def _extract_model_name(request_payload: dict[str, Any] | None) -> str | None:
if not isinstance(request_payload, dict):
return None
model = request_payload.get("model")
if isinstance(model, str) and model:
return model
return None

View File

@ -27,6 +27,7 @@ class CallRecord:
# call_id is sent to the billing platform (callId in reserve payload)
call_id: str
frozen_id: str | None = None
frozen_type: int | None = None
provider_task_id: str | None = None
billing_state: BillingState = "UNRESERVED"
task_state: TaskState = "PENDING"
@ -109,16 +110,18 @@ class CallLedger:
def get_by_idempotency_key(self, provider: str, idempotency_key: str) -> CallRecord | None:
return self._get_by_idem_key_locked(provider, idempotency_key)
def set_reserved(self, proxy_call_id: str, frozen_id: str) -> None:
def set_reserved(self, proxy_call_id: str, frozen_id: str, frozen_type: int | None = None) -> None:
with self._lock:
record = self._records.get(proxy_call_id)
if record:
record.frozen_id = frozen_id
record.frozen_type = frozen_type
record.billing_state = "RESERVED"
logger.info(
"[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s",
"[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s frozen_type=%s",
proxy_call_id,
frozen_id,
frozen_type,
)
# logger.debug(
# "[ThirdPartyProxy][Ledger] reserve state: call_id=%s provider=%s task_state=%s",

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import json
import logging
import os
from typing import Any
@ -17,16 +18,7 @@ from deerflow.config.third_party_proxy_config import (
logger = logging.getLogger(__name__)
_SENSITIVE_HEADERS = frozenset(
[
"authorization",
"proxy-authorization",
"x-api-key",
"api-key",
"cookie",
"set-cookie",
]
)
API_KEY_MARKER = "__API_KEY_MARKER__"
# ---------------------------------------------------------------------------
# Provider config lookup
@ -154,17 +146,6 @@ _STRIP_RESPONSE_HEADERS = frozenset(
)
def _sanitize_headers(headers: dict[str, str]) -> dict[str, str]:
"""Return a copy of headers with sensitive values redacted."""
sanitized: dict[str, str] = {}
for key, value in headers.items():
if key.lower() in _SENSITIVE_HEADERS:
sanitized[key] = "***"
else:
sanitized[key] = value
return sanitized
def _preview_body(data: bytes, limit: int = 2048) -> str:
"""Return a safe textual preview of body bytes for debugging logs."""
if not data:
@ -176,6 +157,53 @@ def _preview_body(data: bytes, limit: int = 2048) -> str:
return text
def _replace_api_key_marker_in_headers(headers: dict[str, str], api_key: str) -> dict[str, str]:
"""Replace API key marker placeholders in header values."""
replaced: dict[str, str] = {}
for key, value in headers.items():
if isinstance(value, str) and API_KEY_MARKER in value:
replaced[key] = value.replace(API_KEY_MARKER, api_key)
else:
replaced[key] = value
return replaced
def _header_value(headers: dict[str, str], key: str) -> str | None:
target = key.lower()
for h_key, h_val in headers.items():
if h_key.lower() == target:
return h_val
return None
def _replace_api_key_marker_in_json(data: Any, api_key: str) -> Any:
if isinstance(data, str):
return data.replace(API_KEY_MARKER, api_key)
if isinstance(data, list):
return [_replace_api_key_marker_in_json(item, api_key) for item in data]
if isinstance(data, dict):
return {k: _replace_api_key_marker_in_json(v, api_key) for k, v in data.items()}
return data
def _replace_api_key_marker_in_body(headers: dict[str, str], body: bytes, api_key: str) -> bytes:
"""Replace API key marker in JSON body payloads only."""
if not body:
return body
content_type = _header_value(headers, "content-type") or ""
if "application/json" not in content_type.lower():
return body
try:
parsed = json.loads(body)
except (json.JSONDecodeError, ValueError):
return body
replaced = _replace_api_key_marker_in_json(parsed, api_key)
return json.dumps(replaced, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
async def forward_request(
*,
provider_config: ThirdPartyProviderConfig,
@ -202,6 +230,9 @@ async def forward_request(
if provider_config.api_key_env:
api_key = os.getenv(provider_config.api_key_env)
if api_key:
# Dependency-injection style: replace marker placeholders first.
forward_headers = _replace_api_key_marker_in_headers(forward_headers, api_key)
body = _replace_api_key_marker_in_body(forward_headers, body, api_key)
forward_headers[provider_config.api_key_header] = provider_config.api_key_prefix + api_key
else:
logger.warning(
@ -212,7 +243,7 @@ async def forward_request(
logger.info("[ThirdPartyProxy] → %s %s", method, target_url)
logger.debug(
"[ThirdPartyProxy] request headers=%s",
_sanitize_headers(forward_headers)
forward_headers,
)
logger.debug(
"[ThirdPartyProxy] request body(%dB)=%s",
@ -236,7 +267,7 @@ async def forward_request(
logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code)
logger.debug(
"[ThirdPartyProxy] response headers=%s",
_sanitize_headers(response_headers)
response_headers,
)
logger.debug(
"[ThirdPartyProxy] response body(%dB)=%s",

View File

@ -28,6 +28,11 @@ class SubmitRouteConfig(BaseModel):
default=None,
description="Optional route-level override for billing reserve payload frozenType",
)
frozen_token: int | None = Field(
default=None,
ge=0,
description="Optional route-level override for billing reserve payload estimatedInputTokens/estimatedOutputTokens when frozenType=1",
)
class QueryRouteConfig(BaseModel):
@ -96,6 +101,11 @@ class ThirdPartyProviderConfig(BaseModel):
default=None,
description="Billing frozen type for this provider (frozenType). If omitted, falls back to billing.frozen_type",
)
frozen_token: int = Field(
default=0,
ge=0,
description="Estimated token amount used for reserve payload when frozenType=1",
)
submit_routes: list[SubmitRouteConfig] = Field(
default_factory=list,
description="Route patterns that identify submit (task-create) requests",

View File

@ -3,9 +3,16 @@
from __future__ import annotations
from app.gateway.third_party_proxy.ledger import CallLedger
from app.gateway.routers.third_party import _resolve_final_amount
from app.gateway.routers.third_party import (
_extract_usage_tokens,
_extract_usage_tokens_from_submit_stream,
_resolve_final_amount,
)
from app.gateway.third_party_proxy.proxy import (
API_KEY_MARKER,
_path_matches,
_replace_api_key_marker_in_body,
_replace_api_key_marker_in_headers,
jsonpath_get,
match_query_route,
match_submit_route,
@ -225,3 +232,61 @@ class TestResolveFinalAmount:
resp_json = {"usage": {"thirdPartyConsumeMoney": "1.5"}}
amount = _resolve_final_amount(resp_json, route)
assert amount == 1.5
class TestExtractUsageTokens:
def test_prefers_openai_usage_keys(self):
resp_json = {
"usage": {
"prompt_tokens": 123,
"completion_tokens": 45,
}
}
input_tokens, output_tokens = _extract_usage_tokens(resp_json)
assert input_tokens == 123
assert output_tokens == 45
def test_supports_generic_usage_keys(self):
resp_json = {
"usage": {
"input_tokens": "88",
"output_tokens": "12",
}
}
input_tokens, output_tokens = _extract_usage_tokens(resp_json)
assert input_tokens == 88
assert output_tokens == 12
class TestExtractUsageTokensFromSubmitStream:
def test_extracts_usage_from_final_sse_chunk(self):
body = (
b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n'
b'data: {"id":"x","choices":[],"usage":{"prompt_tokens":22,"completion_tokens":17}}\n\n'
b'data: [DONE]\n\n'
)
input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body)
assert input_tokens == 22
assert output_tokens == 17
def test_returns_zero_when_no_usage_found(self):
body = b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n'
input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body)
assert input_tokens == 0
assert output_tokens == 0
class TestApiKeyMarkerReplacement:
def test_replace_marker_in_headers(self):
headers = {"Authorization": f"Bearer {API_KEY_MARKER}", "Content-Type": "application/json"}
replaced = _replace_api_key_marker_in_headers(headers, "real-key")
assert replaced["Authorization"] == "Bearer real-key"
def test_replace_marker_in_json_body(self):
headers = {"Content-Type": "application/json"}
body = (
b'{"apiKey":"__API_KEY_MARKER__","nested":{"token":"Bearer __API_KEY_MARKER__"}}'
)
replaced = _replace_api_key_marker_in_body(headers, body, "real-key")
assert b'"apiKey":"real-key"' in replaced
assert b'"token":"Bearer real-key"' in replaced

View File

@ -105,6 +105,7 @@ third_party_proxy:
api_key_header: Authorization
api_key_prefix: "Bearer "
timeout_seconds: 60.0
frozen_token: 32768
submit_routes:
- path_pattern: "/compatible-mode/v1/chat/completions"
task_id_jsonpath: "id"