deerflow2/backend/app/gateway/routers/third_party.py

531 lines
20 KiB
Python

"""Universal third-party API proxy router with integrated billing.
Endpoint: ANY /api/proxy/{provider}/{path...}
The caller (a sandbox skill script) should set:
X-Thread-Id: <thread_id> — used for billing reservation (injected via THREAD_ID env var)
X-Idempotency-Key: <uuid> — optional; deduplicates submit calls
The gateway automatically:
1. Injects the provider's API key from the configured env var.
2. For *submit* routes: reserves billing, forwards, records task state.
3. For *query* routes: forwards, detects terminal status, finalizes billing once.
4. For all other routes: transparent passthrough, no billing side-effects.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from app.gateway.third_party_proxy import billing, proxy
from app.gateway.third_party_proxy.ledger import CallRecord, get_ledger
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/proxy", tags=["third-party-proxy"])
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
@router.api_route("/{provider}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_request(provider: str, path: str, request: Request) -> Response:
"""Universal proxy endpoint for third-party API calls with billing integration."""
provider_config = proxy.get_provider_config(provider)
if provider_config is None:
raise HTTPException(
status_code=404,
detail=f"Provider '{provider}' is not configured or the proxy is disabled.",
)
method = request.method
# Normalise: ensure leading slash so patterns like /openapi/v2/** match correctly
path = "/" + path.lstrip("/")
thread_id = request.headers.get("x-thread-id")
idempotency_key = request.headers.get("x-idempotency-key")
body = await request.body()
request_json: dict[str, Any] | None = _try_parse_json(body)
submit_route = proxy.match_submit_route(provider_config, method, path)
query_route = proxy.match_query_route(provider_config, method, path)
logger.info("[ThirdPartyProxy] route=%s provider=%s method=%s path=%s", "submit" if submit_route else "query" if query_route else "passthrough", provider, method, path)
if submit_route:
return await _handle_submit(
provider=provider,
provider_config=provider_config,
method=method,
path=path,
request=request,
body=body,
request_json=request_json,
thread_id=thread_id,
idempotency_key=idempotency_key,
task_id_jsonpath=submit_route.task_id_jsonpath,
route_frozen_amount=submit_route.frozen_amount,
route_frozen_type=submit_route.frozen_type,
route_frozen_token=submit_route.frozen_token,
)
if query_route:
return await _handle_query(
provider=provider,
provider_config=provider_config,
method=method,
path=path,
request=request,
body=body,
request_json=request_json,
query_route=query_route,
)
# Pure passthrough — no billing, no state
return await _passthrough(
provider_config=provider_config,
method=method,
path=path,
request=request,
body=body,
)
# ---------------------------------------------------------------------------
# Submit handler
# ---------------------------------------------------------------------------
async def _handle_submit(
*,
provider: str,
provider_config,
method: str,
path: str,
request: Request,
body: bytes,
request_json: dict[str, Any] | None,
thread_id: str | None,
idempotency_key: str | None,
task_id_jsonpath: str,
route_frozen_amount: float | None,
route_frozen_type: int | None,
route_frozen_token: int | None,
) -> Response:
ledger = get_ledger()
# Idempotency: if we've already handled this exact submit, return the cached response
if idempotency_key:
existing = ledger.get_by_idempotency_key(provider, idempotency_key)
if existing is not None and existing.last_response is not None:
logger.info("[ThirdPartyProxy] idempotent submit: proxy_call_id=%s", existing.proxy_call_id)
return _proxy_response(existing.last_response, existing.proxy_call_id)
record = ledger.create(provider, thread_id, idempotency_key)
# Reserve billing before touching the provider
reserve_frozen_amount = route_frozen_amount if route_frozen_amount is not None else provider_config.frozen_amount
reserve_frozen_type = route_frozen_type if route_frozen_type is not None else provider_config.frozen_type
reserve_frozen_token = route_frozen_token if route_frozen_token is not None else provider_config.frozen_token
frozen_id = await billing.reserve(
thread_id=thread_id,
call_id=record.call_id,
provider=provider,
operation=path,
frozen_amount=reserve_frozen_amount,
frozen_type=reserve_frozen_type,
frozen_token=reserve_frozen_token,
request_payload=request_json,
)
if frozen_id:
ledger.set_reserved(record.proxy_call_id, frozen_id, reserve_frozen_type)
# Forward to provider
try:
status_code, resp_headers, resp_body = await proxy.forward_request(
provider_config=provider_config,
method=method,
path=path,
headers=dict(request.headers),
body=body,
query_params=str(request.query_params),
)
except Exception as exc:
await _finalize_zero(frozen_id, record.proxy_call_id, "error exception")
raise HTTPException(status_code=502, detail=f"Provider unreachable: {exc}") from exc
resp_json = _try_parse_json(resp_body)
if resp_json is None:
if frozen_id and reserve_frozen_type == 1:
usage_input_tokens, usage_output_tokens = _extract_usage_tokens_from_submit_stream(resp_body)
logger.debug(
"[ThirdPartyProxy] submit stream usage resolved: proxy_call_id=%s usage_input_tokens=%s usage_output_tokens=%s",
record.proxy_call_id,
usage_input_tokens,
usage_output_tokens,
)
if ledger.try_claim_finalize(record.proxy_call_id):
ok = await billing.finalize(
frozen_id=frozen_id,
final_amount=0.0,
finalize_reason="success",
usage_input_tokens=usage_input_tokens,
usage_output_tokens=usage_output_tokens,
)
if ok:
ledger.set_finalized(record.proxy_call_id, "SUCCESS")
else:
ledger.set_finalize_failed(record.proxy_call_id, "FAILED")
media_type = resp_headers.get("content-type")
return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type=media_type)
# HTTP-level failure
if status_code >= 400:
reason = f"error_http_{status_code}"
await _finalize_zero(frozen_id, record.proxy_call_id, reason)
if resp_json is not None:
ledger.update_response(record.proxy_call_id, resp_json)
return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json")
# Extract task_id from response; no task_id means provider rejected at business level
provider_task_id: str | None = None
if resp_json is not None:
raw = proxy.jsonpath_get(resp_json, task_id_jsonpath)
if raw is not None:
provider_task_id = str(raw)
if provider_task_id:
ledger.set_running(record.proxy_call_id, provider_task_id)
else:
# No async task ID usually means provider-side business rejection.
# Propagate errorCode (if present) into finalize_reason.
error_code = None
if resp_json is not None:
raw_error_code = resp_json.get("errorCode")
if raw_error_code is None:
raw_error_code = resp_json.get("code")
if raw_error_code is not None:
error_code = str(raw_error_code)
finalize_reason = error_code or "no_task_id"
await _finalize_zero(frozen_id, record.proxy_call_id, finalize_reason)
if resp_json is not None:
ledger.update_response(record.proxy_call_id, resp_json)
return _proxy_response(resp_json or {}, record.proxy_call_id, status_code, resp_headers)
# ---------------------------------------------------------------------------
# Query handler
# ---------------------------------------------------------------------------
async def _handle_query(
*,
provider: str,
provider_config,
method: str,
path: str,
request: Request,
body: bytes,
request_json: dict[str, Any] | None,
query_route,
) -> Response:
ledger = get_ledger()
# Locate the call record by provider_task_id embedded in the request body
provider_task_id: str | None = None
if request_json:
raw = proxy.jsonpath_get(request_json, query_route.request_task_id_jsonpath)
if raw is not None:
provider_task_id = str(raw)
record: CallRecord | None = None
if provider_task_id:
record = ledger.get_by_task_id(provider, provider_task_id)
# Already at terminal state — return cached result without calling the provider again
if record is not None and ledger.is_finalized(record.proxy_call_id) and record.last_response is not None:
logger.info("[ThirdPartyProxy] query already finalized, returning cache: proxy_call_id=%s", record.proxy_call_id)
return _proxy_response(record.last_response, record.proxy_call_id)
# Forward query to provider
try:
status_code, resp_headers, resp_body = await proxy.forward_request(
provider_config=provider_config,
method=method,
path=path,
headers=dict(request.headers),
body=body,
query_params=str(request.query_params),
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Provider query failed: {exc}") from exc
resp_json = _try_parse_json(resp_body)
if status_code >= 400 or resp_json is None:
return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json")
# Detect terminal status in the response
status_value = proxy.jsonpath_get(resp_json, query_route.status_jsonpath)
status_str = str(status_value) if status_value is not None else None
is_success = status_str in query_route.success_values
is_failure = status_str in query_route.failure_values
logger.debug(
"[ThirdPartyProxy] query terminal check: provider=%s task_id=%s status=%s is_success=%s is_failure=%s",
provider,
provider_task_id,
status_str,
is_success,
is_failure,
)
if record is not None and (is_success or is_failure):
logger.info(
"[ThirdPartyProxy] finalize candidate: proxy_call_id=%s provider_task_id=%s terminal_status=%s",
record.proxy_call_id,
provider_task_id,
status_str,
)
# Atomically claim finalize rights — only one concurrent query wins
if ledger.try_claim_finalize(record.proxy_call_id):
logger.info(
"[ThirdPartyProxy] finalize claimed: proxy_call_id=%s",
record.proxy_call_id,
)
resolved_frozen_type = (
record.frozen_type if record.frozen_type is not None else provider_config.frozen_type
)
usage_input_tokens = 0
usage_output_tokens = 0
usage_paths = list(query_route.usage_jsonpaths or [])
if not usage_paths and query_route.usage_jsonpath:
usage_paths = [query_route.usage_jsonpath]
final_amount: float = 0.0
if is_success:
if resolved_frozen_type == 1:
usage_input_tokens, usage_output_tokens = _extract_usage_tokens(resp_json)
else:
final_amount = _resolve_final_amount(resp_json, query_route)
logger.debug(
"[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s frozen_type=%s final_amount=%s usage_input_tokens=%s usage_output_tokens=%s usage_paths=%s legacy_path=%s",
record.proxy_call_id,
resolved_frozen_type,
final_amount,
usage_input_tokens,
usage_output_tokens,
usage_paths,
query_route.usage_jsonpath,
)
task_state = "SUCCESS" if is_success else "FAILED"
finalize_reason = "success" if is_success else "error"
logger.info(
"[ThirdPartyProxy] finalize start: proxy_call_id=%s reason=%s task_state=%s has_frozen_id=%s",
record.proxy_call_id,
finalize_reason,
task_state,
bool(record.frozen_id),
)
if record.frozen_id:
ok = await billing.finalize(
frozen_id=record.frozen_id,
final_amount=final_amount,
finalize_reason=finalize_reason,
usage_input_tokens=usage_input_tokens,
usage_output_tokens=usage_output_tokens,
)
logger.info(
"[ThirdPartyProxy] finalize result: proxy_call_id=%s ok=%s",
record.proxy_call_id,
ok,
)
if ok:
ledger.set_finalized(record.proxy_call_id, task_state)
else:
ledger.set_finalize_failed(record.proxy_call_id, task_state)
else:
logger.info(
"[ThirdPartyProxy] finalize skipped billing call (no frozen_id): proxy_call_id=%s",
record.proxy_call_id,
)
ledger.set_finalized(record.proxy_call_id, task_state)
ledger.update_response(record.proxy_call_id, resp_json)
else:
logger.info(
"[ThirdPartyProxy] finalize claim denied (already processed): proxy_call_id=%s",
record.proxy_call_id,
)
proxy_call_id = record.proxy_call_id if record else None
return _proxy_response(resp_json, proxy_call_id, status_code, resp_headers)
# ---------------------------------------------------------------------------
# Passthrough handler
# ---------------------------------------------------------------------------
async def _passthrough(*, provider_config, method: str, path: str, request: Request, body: bytes) -> Response:
try:
status_code, resp_headers, resp_body = await proxy.forward_request(
provider_config=provider_config,
method=method,
path=path,
headers=dict(request.headers),
body=body,
query_params=str(request.query_params),
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Provider request failed: {exc}") from exc
return Response(content=resp_body, status_code=status_code, headers=resp_headers)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _finalize_zero(frozen_id: str | None, proxy_call_id: str, reason: str) -> None:
"""Finalize with amount=0 when billing was reserved but the call failed."""
ledger = get_ledger()
logger.info(
"[ThirdPartyProxy] finalize_zero requested: proxy_call_id=%s reason=%s has_frozen_id=%s",
proxy_call_id,
reason,
bool(frozen_id),
)
if frozen_id and ledger.try_claim_finalize(proxy_call_id):
logger.info("[ThirdPartyProxy] finalize_zero claimed: proxy_call_id=%s", proxy_call_id)
ok = await billing.finalize(frozen_id=frozen_id, final_amount=0, finalize_reason=reason)
logger.info("[ThirdPartyProxy] finalize_zero result: proxy_call_id=%s ok=%s", proxy_call_id, ok)
task_state = "SUCCESS" if reason == "success" else "FAILED"
if ok:
ledger.set_finalized(proxy_call_id, task_state)
else:
ledger.set_finalize_failed(proxy_call_id, task_state)
elif not frozen_id:
logger.debug("[ThirdPartyProxy] finalize_zero skipped: no frozen_id proxy_call_id=%s", proxy_call_id)
else:
logger.info("[ThirdPartyProxy] finalize_zero claim denied: proxy_call_id=%s", proxy_call_id)
def _try_parse_json(data: bytes) -> dict[str, Any] | None:
if not data:
return None
try:
parsed = json.loads(data)
return parsed if isinstance(parsed, dict) else None
except (json.JSONDecodeError, ValueError):
return None
def _resolve_final_amount(resp_json: dict[str, Any], query_route) -> float:
"""Resolve final billing amount from configured usage paths.
Priority:
1) `usage_jsonpaths` (sum all valid numeric values)
2) legacy `usage_jsonpath` (single value)
"""
usage_paths = list(query_route.usage_jsonpaths or [])
if not usage_paths and query_route.usage_jsonpath:
usage_paths = [query_route.usage_jsonpath]
total = 0.0
for path in usage_paths:
raw = proxy.jsonpath_get(resp_json, path)
if raw is None:
continue
try:
total += float(raw)
except (TypeError, ValueError):
continue
return total
def _extract_usage_tokens(resp_json: dict[str, Any]) -> tuple[int, int]:
usage = resp_json.get("usage")
if not isinstance(usage, dict):
return 0, 0
input_tokens = _as_int(usage.get("input_tokens"))
if input_tokens == 0:
input_tokens = _as_int(usage.get("prompt_tokens"))
output_tokens = _as_int(usage.get("output_tokens"))
if output_tokens == 0:
output_tokens = _as_int(usage.get("completion_tokens"))
return input_tokens, output_tokens
def _extract_usage_tokens_from_submit_stream(resp_body: bytes) -> tuple[int, int]:
"""Extract usage tokens from the final SSE chunk in a submit stream response."""
if not resp_body:
return 0, 0
input_tokens = 0
output_tokens = 0
for raw_line in resp_body.splitlines():
line = raw_line.decode("utf-8", errors="replace").strip()
if not line.startswith("data:"):
continue
payload_str = line[5:].strip()
if not payload_str or payload_str == "[DONE]":
continue
try:
payload = json.loads(payload_str)
except (json.JSONDecodeError, ValueError):
continue
if isinstance(payload, dict):
in_tokens, out_tokens = _extract_usage_tokens(payload)
if in_tokens or out_tokens:
input_tokens, output_tokens = in_tokens, out_tokens
return input_tokens, output_tokens
def _as_int(value: Any) -> int:
if isinstance(value, int):
return value
if isinstance(value, float):
return int(value)
if isinstance(value, str):
try:
return int(float(value))
except ValueError:
return 0
return 0
def _proxy_response(
data: dict[str, Any],
proxy_call_id: str | None,
status_code: int = 200,
extra_headers: dict[str, str] | None = None,
) -> JSONResponse:
headers: dict[str, str] = dict(extra_headers or {})
if proxy_call_id:
headers["X-Proxy-Call-Id"] = proxy_call_id
return JSONResponse(content=data, status_code=status_code, headers=headers)