feat: enhance billing integration with usage token extraction and API key handling
This commit is contained in:
parent
169332ab29
commit
f584c3e53b
|
|
@ -67,11 +67,13 @@ async def proxy_request(provider: str, path: str, request: Request) -> Response:
|
||||||
path=path,
|
path=path,
|
||||||
request=request,
|
request=request,
|
||||||
body=body,
|
body=body,
|
||||||
|
request_json=request_json,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
idempotency_key=idempotency_key,
|
idempotency_key=idempotency_key,
|
||||||
task_id_jsonpath=submit_route.task_id_jsonpath,
|
task_id_jsonpath=submit_route.task_id_jsonpath,
|
||||||
route_frozen_amount=submit_route.frozen_amount,
|
route_frozen_amount=submit_route.frozen_amount,
|
||||||
route_frozen_type=submit_route.frozen_type,
|
route_frozen_type=submit_route.frozen_type,
|
||||||
|
route_frozen_token=submit_route.frozen_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_route:
|
if query_route:
|
||||||
|
|
@ -109,11 +111,13 @@ async def _handle_submit(
|
||||||
path: str,
|
path: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
body: bytes,
|
body: bytes,
|
||||||
|
request_json: dict[str, Any] | None,
|
||||||
thread_id: str | None,
|
thread_id: str | None,
|
||||||
idempotency_key: str | None,
|
idempotency_key: str | None,
|
||||||
task_id_jsonpath: str,
|
task_id_jsonpath: str,
|
||||||
route_frozen_amount: float | None,
|
route_frozen_amount: float | None,
|
||||||
route_frozen_type: int | None,
|
route_frozen_type: int | None,
|
||||||
|
route_frozen_token: int | None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
ledger = get_ledger()
|
ledger = get_ledger()
|
||||||
|
|
||||||
|
|
@ -129,6 +133,7 @@ async def _handle_submit(
|
||||||
# Reserve billing before touching the provider
|
# Reserve billing before touching the provider
|
||||||
reserve_frozen_amount = route_frozen_amount if route_frozen_amount is not None else provider_config.frozen_amount
|
reserve_frozen_amount = route_frozen_amount if route_frozen_amount is not None else provider_config.frozen_amount
|
||||||
reserve_frozen_type = route_frozen_type if route_frozen_type is not None else provider_config.frozen_type
|
reserve_frozen_type = route_frozen_type if route_frozen_type is not None else provider_config.frozen_type
|
||||||
|
reserve_frozen_token = route_frozen_token if route_frozen_token is not None else provider_config.frozen_token
|
||||||
frozen_id = await billing.reserve(
|
frozen_id = await billing.reserve(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
call_id=record.call_id,
|
call_id=record.call_id,
|
||||||
|
|
@ -136,9 +141,11 @@ async def _handle_submit(
|
||||||
operation=path,
|
operation=path,
|
||||||
frozen_amount=reserve_frozen_amount,
|
frozen_amount=reserve_frozen_amount,
|
||||||
frozen_type=reserve_frozen_type,
|
frozen_type=reserve_frozen_type,
|
||||||
|
frozen_token=reserve_frozen_token,
|
||||||
|
request_payload=request_json,
|
||||||
)
|
)
|
||||||
if frozen_id:
|
if frozen_id:
|
||||||
ledger.set_reserved(record.proxy_call_id, frozen_id)
|
ledger.set_reserved(record.proxy_call_id, frozen_id, reserve_frozen_type)
|
||||||
|
|
||||||
# Forward to provider
|
# Forward to provider
|
||||||
try:
|
try:
|
||||||
|
|
@ -156,6 +163,32 @@ async def _handle_submit(
|
||||||
|
|
||||||
resp_json = _try_parse_json(resp_body)
|
resp_json = _try_parse_json(resp_body)
|
||||||
|
|
||||||
|
if resp_json is None:
|
||||||
|
if frozen_id and reserve_frozen_type == 1:
|
||||||
|
usage_input_tokens, usage_output_tokens = _extract_usage_tokens_from_submit_stream(resp_body)
|
||||||
|
logger.debug(
|
||||||
|
"[ThirdPartyProxy] submit stream usage resolved: proxy_call_id=%s usage_input_tokens=%s usage_output_tokens=%s",
|
||||||
|
record.proxy_call_id,
|
||||||
|
usage_input_tokens,
|
||||||
|
usage_output_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ledger.try_claim_finalize(record.proxy_call_id):
|
||||||
|
ok = await billing.finalize(
|
||||||
|
frozen_id=frozen_id,
|
||||||
|
final_amount=0.0,
|
||||||
|
finalize_reason="success",
|
||||||
|
usage_input_tokens=usage_input_tokens,
|
||||||
|
usage_output_tokens=usage_output_tokens,
|
||||||
|
)
|
||||||
|
if ok:
|
||||||
|
ledger.set_finalized(record.proxy_call_id, "SUCCESS")
|
||||||
|
else:
|
||||||
|
ledger.set_finalize_failed(record.proxy_call_id, "FAILED")
|
||||||
|
|
||||||
|
media_type = resp_headers.get("content-type")
|
||||||
|
return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type=media_type)
|
||||||
|
|
||||||
# HTTP-level failure
|
# HTTP-level failure
|
||||||
if status_code >= 400:
|
if status_code >= 400:
|
||||||
reason = f"error_http_{status_code}"
|
reason = f"error_http_{status_code}"
|
||||||
|
|
@ -272,17 +305,30 @@ async def _handle_query(
|
||||||
"[ThirdPartyProxy] finalize claimed: proxy_call_id=%s",
|
"[ThirdPartyProxy] finalize claimed: proxy_call_id=%s",
|
||||||
record.proxy_call_id,
|
record.proxy_call_id,
|
||||||
)
|
)
|
||||||
final_amount: float = 0.0
|
resolved_frozen_type = (
|
||||||
|
record.frozen_type if record.frozen_type is not None else provider_config.frozen_type
|
||||||
|
)
|
||||||
|
|
||||||
|
usage_input_tokens = 0
|
||||||
|
usage_output_tokens = 0
|
||||||
usage_paths = list(query_route.usage_jsonpaths or [])
|
usage_paths = list(query_route.usage_jsonpaths or [])
|
||||||
if not usage_paths and query_route.usage_jsonpath:
|
if not usage_paths and query_route.usage_jsonpath:
|
||||||
usage_paths = [query_route.usage_jsonpath]
|
usage_paths = [query_route.usage_jsonpath]
|
||||||
|
|
||||||
|
final_amount: float = 0.0
|
||||||
if is_success:
|
if is_success:
|
||||||
final_amount = _resolve_final_amount(resp_json, query_route)
|
if resolved_frozen_type == 1:
|
||||||
|
usage_input_tokens, usage_output_tokens = _extract_usage_tokens(resp_json)
|
||||||
|
else:
|
||||||
|
final_amount = _resolve_final_amount(resp_json, query_route)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s final_amount=%s usage_paths=%s legacy_path=%s",
|
"[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s frozen_type=%s final_amount=%s usage_input_tokens=%s usage_output_tokens=%s usage_paths=%s legacy_path=%s",
|
||||||
record.proxy_call_id,
|
record.proxy_call_id,
|
||||||
|
resolved_frozen_type,
|
||||||
final_amount,
|
final_amount,
|
||||||
|
usage_input_tokens,
|
||||||
|
usage_output_tokens,
|
||||||
usage_paths,
|
usage_paths,
|
||||||
query_route.usage_jsonpath,
|
query_route.usage_jsonpath,
|
||||||
)
|
)
|
||||||
|
|
@ -303,6 +349,8 @@ async def _handle_query(
|
||||||
frozen_id=record.frozen_id,
|
frozen_id=record.frozen_id,
|
||||||
final_amount=final_amount,
|
final_amount=final_amount,
|
||||||
finalize_reason=finalize_reason,
|
finalize_reason=finalize_reason,
|
||||||
|
usage_input_tokens=usage_input_tokens,
|
||||||
|
usage_output_tokens=usage_output_tokens,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"[ThirdPartyProxy] finalize result: proxy_call_id=%s ok=%s",
|
"[ThirdPartyProxy] finalize result: proxy_call_id=%s ok=%s",
|
||||||
|
|
@ -415,6 +463,61 @@ def _resolve_final_amount(resp_json: dict[str, Any], query_route) -> float:
|
||||||
return total
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_usage_tokens(resp_json: dict[str, Any]) -> tuple[int, int]:
|
||||||
|
usage = resp_json.get("usage")
|
||||||
|
if not isinstance(usage, dict):
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
input_tokens = _as_int(usage.get("input_tokens"))
|
||||||
|
if input_tokens == 0:
|
||||||
|
input_tokens = _as_int(usage.get("prompt_tokens"))
|
||||||
|
|
||||||
|
output_tokens = _as_int(usage.get("output_tokens"))
|
||||||
|
if output_tokens == 0:
|
||||||
|
output_tokens = _as_int(usage.get("completion_tokens"))
|
||||||
|
|
||||||
|
return input_tokens, output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_usage_tokens_from_submit_stream(resp_body: bytes) -> tuple[int, int]:
|
||||||
|
"""Extract usage tokens from the final SSE chunk in a submit stream response."""
|
||||||
|
if not resp_body:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
input_tokens = 0
|
||||||
|
output_tokens = 0
|
||||||
|
for raw_line in resp_body.splitlines():
|
||||||
|
line = raw_line.decode("utf-8", errors="replace").strip()
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
payload_str = line[5:].strip()
|
||||||
|
if not payload_str or payload_str == "[DONE]":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload = json.loads(payload_str)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
continue
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
in_tokens, out_tokens = _extract_usage_tokens(payload)
|
||||||
|
if in_tokens or out_tokens:
|
||||||
|
input_tokens, output_tokens = in_tokens, out_tokens
|
||||||
|
|
||||||
|
return input_tokens, output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _as_int(value: Any) -> int:
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
if isinstance(value, float):
|
||||||
|
return int(value)
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return int(float(value))
|
||||||
|
except ValueError:
|
||||||
|
return 0
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def _proxy_response(
|
def _proxy_response(
|
||||||
data: dict[str, Any],
|
data: dict[str, Any],
|
||||||
proxy_call_id: str | None,
|
proxy_call_id: str | None,
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -28,6 +29,8 @@ async def reserve(
|
||||||
operation: str,
|
operation: str,
|
||||||
frozen_amount: float,
|
frozen_amount: float,
|
||||||
frozen_type: int | None,
|
frozen_type: int | None,
|
||||||
|
frozen_token: int = 0,
|
||||||
|
request_payload: dict[str, Any] | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Reserve billing before forwarding a submit call.
|
"""Reserve billing before forwarding a submit call.
|
||||||
|
|
||||||
|
|
@ -44,19 +47,25 @@ async def reserve(
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
resolved_frozen_type = frozen_type if frozen_type is not None else cfg.frozen_type
|
||||||
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
|
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
|
||||||
payload = {
|
payload: dict[str, Any] = {
|
||||||
"sessionId": thread_id,
|
"sessionId": thread_id,
|
||||||
"callId": call_id,
|
"callId": call_id,
|
||||||
"modelName": provider,
|
"modelName": _extract_model_name(request_payload) or provider,
|
||||||
"question": f"skill invokes {operation.split('/')[-1]}",
|
"question": f"skill invokes {operation.split('/')[-1]}",
|
||||||
"frozenAmount": frozen_amount,
|
"frozenType": resolved_frozen_type,
|
||||||
"frozenType": frozen_type if frozen_type is not None else cfg.frozen_type,
|
|
||||||
"estimatedInputTokens": 0,
|
|
||||||
"estimatedOutputTokens": 0,
|
|
||||||
"expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"),
|
"expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resolved_frozen_type == 1:
|
||||||
|
payload["estimatedInputTokens"] = int(frozen_token)
|
||||||
|
payload["estimatedOutputTokens"] = int(frozen_token)
|
||||||
|
else:
|
||||||
|
payload["frozenAmount"] = frozen_amount
|
||||||
|
payload["estimatedInputTokens"] = 0
|
||||||
|
payload["estimatedOutputTokens"] = 0
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[ThirdPartyProxy][Billing] reserve request: url=%s call_id=%s provider=%s thread_id=%s",
|
"[ThirdPartyProxy][Billing] reserve request: url=%s call_id=%s provider=%s thread_id=%s",
|
||||||
cfg.reserve_url,
|
cfg.reserve_url,
|
||||||
|
|
@ -114,6 +123,8 @@ async def finalize(
|
||||||
frozen_id: str,
|
frozen_id: str,
|
||||||
final_amount: float,
|
final_amount: float,
|
||||||
finalize_reason: str,
|
finalize_reason: str,
|
||||||
|
usage_input_tokens: int = 0,
|
||||||
|
usage_output_tokens: int = 0,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Finalize billing after a third-party call reaches a terminal state.
|
"""Finalize billing after a third-party call reaches a terminal state.
|
||||||
|
|
||||||
|
|
@ -135,9 +146,9 @@ async def finalize(
|
||||||
payload = {
|
payload = {
|
||||||
"frozenId": frozen_id,
|
"frozenId": frozen_id,
|
||||||
"finalAmount": final_amount,
|
"finalAmount": final_amount,
|
||||||
"usageInputTokens": 0,
|
"usageInputTokens": usage_input_tokens,
|
||||||
"usageOutputTokens": 0,
|
"usageOutputTokens": usage_output_tokens,
|
||||||
"usageTotalTokens": 0,
|
"usageTotalTokens": usage_input_tokens + usage_output_tokens,
|
||||||
"finalizeReason": finalize_reason,
|
"finalizeReason": finalize_reason,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -188,3 +199,12 @@ def _is_success(data: dict) -> bool:
|
||||||
if isinstance(status, int) and status in _SUCCESS_STATUS_CODES:
|
if isinstance(status, int) and status in _SUCCESS_STATUS_CODES:
|
||||||
return True
|
return True
|
||||||
return data.get("success") is True
|
return data.get("success") is True
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_model_name(request_payload: dict[str, Any] | None) -> str | None:
|
||||||
|
if not isinstance(request_payload, dict):
|
||||||
|
return None
|
||||||
|
model = request_payload.get("model")
|
||||||
|
if isinstance(model, str) and model:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ class CallRecord:
|
||||||
# call_id is sent to the billing platform (callId in reserve payload)
|
# call_id is sent to the billing platform (callId in reserve payload)
|
||||||
call_id: str
|
call_id: str
|
||||||
frozen_id: str | None = None
|
frozen_id: str | None = None
|
||||||
|
frozen_type: int | None = None
|
||||||
provider_task_id: str | None = None
|
provider_task_id: str | None = None
|
||||||
billing_state: BillingState = "UNRESERVED"
|
billing_state: BillingState = "UNRESERVED"
|
||||||
task_state: TaskState = "PENDING"
|
task_state: TaskState = "PENDING"
|
||||||
|
|
@ -109,16 +110,18 @@ class CallLedger:
|
||||||
def get_by_idempotency_key(self, provider: str, idempotency_key: str) -> CallRecord | None:
|
def get_by_idempotency_key(self, provider: str, idempotency_key: str) -> CallRecord | None:
|
||||||
return self._get_by_idem_key_locked(provider, idempotency_key)
|
return self._get_by_idem_key_locked(provider, idempotency_key)
|
||||||
|
|
||||||
def set_reserved(self, proxy_call_id: str, frozen_id: str) -> None:
|
def set_reserved(self, proxy_call_id: str, frozen_id: str, frozen_type: int | None = None) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
record = self._records.get(proxy_call_id)
|
record = self._records.get(proxy_call_id)
|
||||||
if record:
|
if record:
|
||||||
record.frozen_id = frozen_id
|
record.frozen_id = frozen_id
|
||||||
|
record.frozen_type = frozen_type
|
||||||
record.billing_state = "RESERVED"
|
record.billing_state = "RESERVED"
|
||||||
logger.info(
|
logger.info(
|
||||||
"[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s",
|
"[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s frozen_type=%s",
|
||||||
proxy_call_id,
|
proxy_call_id,
|
||||||
frozen_id,
|
frozen_id,
|
||||||
|
frozen_type,
|
||||||
)
|
)
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# "[ThirdPartyProxy][Ledger] reserve state: call_id=%s provider=%s task_state=%s",
|
# "[ThirdPartyProxy][Ledger] reserve state: call_id=%s provider=%s task_state=%s",
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -17,16 +18,7 @@ from deerflow.config.third_party_proxy_config import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_SENSITIVE_HEADERS = frozenset(
|
API_KEY_MARKER = "__API_KEY_MARKER__"
|
||||||
[
|
|
||||||
"authorization",
|
|
||||||
"proxy-authorization",
|
|
||||||
"x-api-key",
|
|
||||||
"api-key",
|
|
||||||
"cookie",
|
|
||||||
"set-cookie",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Provider config lookup
|
# Provider config lookup
|
||||||
|
|
@ -154,17 +146,6 @@ _STRIP_RESPONSE_HEADERS = frozenset(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_headers(headers: dict[str, str]) -> dict[str, str]:
|
|
||||||
"""Return a copy of headers with sensitive values redacted."""
|
|
||||||
sanitized: dict[str, str] = {}
|
|
||||||
for key, value in headers.items():
|
|
||||||
if key.lower() in _SENSITIVE_HEADERS:
|
|
||||||
sanitized[key] = "***"
|
|
||||||
else:
|
|
||||||
sanitized[key] = value
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
|
|
||||||
def _preview_body(data: bytes, limit: int = 2048) -> str:
|
def _preview_body(data: bytes, limit: int = 2048) -> str:
|
||||||
"""Return a safe textual preview of body bytes for debugging logs."""
|
"""Return a safe textual preview of body bytes for debugging logs."""
|
||||||
if not data:
|
if not data:
|
||||||
|
|
@ -176,6 +157,53 @@ def _preview_body(data: bytes, limit: int = 2048) -> str:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_api_key_marker_in_headers(headers: dict[str, str], api_key: str) -> dict[str, str]:
|
||||||
|
"""Replace API key marker placeholders in header values."""
|
||||||
|
replaced: dict[str, str] = {}
|
||||||
|
for key, value in headers.items():
|
||||||
|
if isinstance(value, str) and API_KEY_MARKER in value:
|
||||||
|
replaced[key] = value.replace(API_KEY_MARKER, api_key)
|
||||||
|
else:
|
||||||
|
replaced[key] = value
|
||||||
|
return replaced
|
||||||
|
|
||||||
|
|
||||||
|
def _header_value(headers: dict[str, str], key: str) -> str | None:
|
||||||
|
target = key.lower()
|
||||||
|
for h_key, h_val in headers.items():
|
||||||
|
if h_key.lower() == target:
|
||||||
|
return h_val
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_api_key_marker_in_json(data: Any, api_key: str) -> Any:
|
||||||
|
if isinstance(data, str):
|
||||||
|
return data.replace(API_KEY_MARKER, api_key)
|
||||||
|
if isinstance(data, list):
|
||||||
|
return [_replace_api_key_marker_in_json(item, api_key) for item in data]
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return {k: _replace_api_key_marker_in_json(v, api_key) for k, v in data.items()}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_api_key_marker_in_body(headers: dict[str, str], body: bytes, api_key: str) -> bytes:
|
||||||
|
"""Replace API key marker in JSON body payloads only."""
|
||||||
|
if not body:
|
||||||
|
return body
|
||||||
|
|
||||||
|
content_type = _header_value(headers, "content-type") or ""
|
||||||
|
if "application/json" not in content_type.lower():
|
||||||
|
return body
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = json.loads(body)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return body
|
||||||
|
|
||||||
|
replaced = _replace_api_key_marker_in_json(parsed, api_key)
|
||||||
|
return json.dumps(replaced, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
async def forward_request(
|
async def forward_request(
|
||||||
*,
|
*,
|
||||||
provider_config: ThirdPartyProviderConfig,
|
provider_config: ThirdPartyProviderConfig,
|
||||||
|
|
@ -202,6 +230,9 @@ async def forward_request(
|
||||||
if provider_config.api_key_env:
|
if provider_config.api_key_env:
|
||||||
api_key = os.getenv(provider_config.api_key_env)
|
api_key = os.getenv(provider_config.api_key_env)
|
||||||
if api_key:
|
if api_key:
|
||||||
|
# Dependency-injection style: replace marker placeholders first.
|
||||||
|
forward_headers = _replace_api_key_marker_in_headers(forward_headers, api_key)
|
||||||
|
body = _replace_api_key_marker_in_body(forward_headers, body, api_key)
|
||||||
forward_headers[provider_config.api_key_header] = provider_config.api_key_prefix + api_key
|
forward_headers[provider_config.api_key_header] = provider_config.api_key_prefix + api_key
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -212,7 +243,7 @@ async def forward_request(
|
||||||
logger.info("[ThirdPartyProxy] → %s %s", method, target_url)
|
logger.info("[ThirdPartyProxy] → %s %s", method, target_url)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[ThirdPartyProxy] request headers=%s",
|
"[ThirdPartyProxy] request headers=%s",
|
||||||
_sanitize_headers(forward_headers)
|
forward_headers,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[ThirdPartyProxy] request body(%dB)=%s",
|
"[ThirdPartyProxy] request body(%dB)=%s",
|
||||||
|
|
@ -236,7 +267,7 @@ async def forward_request(
|
||||||
logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code)
|
logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[ThirdPartyProxy] response headers=%s",
|
"[ThirdPartyProxy] response headers=%s",
|
||||||
_sanitize_headers(response_headers)
|
response_headers,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[ThirdPartyProxy] response body(%dB)=%s",
|
"[ThirdPartyProxy] response body(%dB)=%s",
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,11 @@ class SubmitRouteConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional route-level override for billing reserve payload frozenType",
|
description="Optional route-level override for billing reserve payload frozenType",
|
||||||
)
|
)
|
||||||
|
frozen_token: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
ge=0,
|
||||||
|
description="Optional route-level override for billing reserve payload estimatedInputTokens/estimatedOutputTokens when frozenType=1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QueryRouteConfig(BaseModel):
|
class QueryRouteConfig(BaseModel):
|
||||||
|
|
@ -96,6 +101,11 @@ class ThirdPartyProviderConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Billing frozen type for this provider (frozenType). If omitted, falls back to billing.frozen_type",
|
description="Billing frozen type for this provider (frozenType). If omitted, falls back to billing.frozen_type",
|
||||||
)
|
)
|
||||||
|
frozen_token: int = Field(
|
||||||
|
default=0,
|
||||||
|
ge=0,
|
||||||
|
description="Estimated token amount used for reserve payload when frozenType=1",
|
||||||
|
)
|
||||||
submit_routes: list[SubmitRouteConfig] = Field(
|
submit_routes: list[SubmitRouteConfig] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Route patterns that identify submit (task-create) requests",
|
description="Route patterns that identify submit (task-create) requests",
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,16 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from app.gateway.third_party_proxy.ledger import CallLedger
|
from app.gateway.third_party_proxy.ledger import CallLedger
|
||||||
from app.gateway.routers.third_party import _resolve_final_amount
|
from app.gateway.routers.third_party import (
|
||||||
|
_extract_usage_tokens,
|
||||||
|
_extract_usage_tokens_from_submit_stream,
|
||||||
|
_resolve_final_amount,
|
||||||
|
)
|
||||||
from app.gateway.third_party_proxy.proxy import (
|
from app.gateway.third_party_proxy.proxy import (
|
||||||
|
API_KEY_MARKER,
|
||||||
_path_matches,
|
_path_matches,
|
||||||
|
_replace_api_key_marker_in_body,
|
||||||
|
_replace_api_key_marker_in_headers,
|
||||||
jsonpath_get,
|
jsonpath_get,
|
||||||
match_query_route,
|
match_query_route,
|
||||||
match_submit_route,
|
match_submit_route,
|
||||||
|
|
@ -225,3 +232,61 @@ class TestResolveFinalAmount:
|
||||||
resp_json = {"usage": {"thirdPartyConsumeMoney": "1.5"}}
|
resp_json = {"usage": {"thirdPartyConsumeMoney": "1.5"}}
|
||||||
amount = _resolve_final_amount(resp_json, route)
|
amount = _resolve_final_amount(resp_json, route)
|
||||||
assert amount == 1.5
|
assert amount == 1.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractUsageTokens:
|
||||||
|
def test_prefers_openai_usage_keys(self):
|
||||||
|
resp_json = {
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 123,
|
||||||
|
"completion_tokens": 45,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_tokens, output_tokens = _extract_usage_tokens(resp_json)
|
||||||
|
assert input_tokens == 123
|
||||||
|
assert output_tokens == 45
|
||||||
|
|
||||||
|
def test_supports_generic_usage_keys(self):
|
||||||
|
resp_json = {
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": "88",
|
||||||
|
"output_tokens": "12",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_tokens, output_tokens = _extract_usage_tokens(resp_json)
|
||||||
|
assert input_tokens == 88
|
||||||
|
assert output_tokens == 12
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractUsageTokensFromSubmitStream:
|
||||||
|
def test_extracts_usage_from_final_sse_chunk(self):
|
||||||
|
body = (
|
||||||
|
b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n'
|
||||||
|
b'data: {"id":"x","choices":[],"usage":{"prompt_tokens":22,"completion_tokens":17}}\n\n'
|
||||||
|
b'data: [DONE]\n\n'
|
||||||
|
)
|
||||||
|
input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body)
|
||||||
|
assert input_tokens == 22
|
||||||
|
assert output_tokens == 17
|
||||||
|
|
||||||
|
def test_returns_zero_when_no_usage_found(self):
|
||||||
|
body = b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n'
|
||||||
|
input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body)
|
||||||
|
assert input_tokens == 0
|
||||||
|
assert output_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiKeyMarkerReplacement:
|
||||||
|
def test_replace_marker_in_headers(self):
|
||||||
|
headers = {"Authorization": f"Bearer {API_KEY_MARKER}", "Content-Type": "application/json"}
|
||||||
|
replaced = _replace_api_key_marker_in_headers(headers, "real-key")
|
||||||
|
assert replaced["Authorization"] == "Bearer real-key"
|
||||||
|
|
||||||
|
def test_replace_marker_in_json_body(self):
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
body = (
|
||||||
|
b'{"apiKey":"__API_KEY_MARKER__","nested":{"token":"Bearer __API_KEY_MARKER__"}}'
|
||||||
|
)
|
||||||
|
replaced = _replace_api_key_marker_in_body(headers, body, "real-key")
|
||||||
|
assert b'"apiKey":"real-key"' in replaced
|
||||||
|
assert b'"token":"Bearer real-key"' in replaced
|
||||||
|
|
|
||||||
|
|
@ -105,6 +105,7 @@ third_party_proxy:
|
||||||
api_key_header: Authorization
|
api_key_header: Authorization
|
||||||
api_key_prefix: "Bearer "
|
api_key_prefix: "Bearer "
|
||||||
timeout_seconds: 60.0
|
timeout_seconds: 60.0
|
||||||
|
frozen_token: 32768
|
||||||
submit_routes:
|
submit_routes:
|
||||||
- path_pattern: "/compatible-mode/v1/chat/completions"
|
- path_pattern: "/compatible-mode/v1/chat/completions"
|
||||||
task_id_jsonpath: "id"
|
task_id_jsonpath: "id"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue