428 lines
16 KiB
Python
428 lines
16 KiB
Python
"""Universal third-party API proxy router with integrated billing.
|
|
|
|
Endpoint: ANY /api/proxy/{provider}/{path...}
|
|
|
|
The caller (a sandbox skill script) should set:
|
|
X-Thread-Id: <thread_id> — used for billing reservation (injected via THREAD_ID env var)
|
|
X-Idempotency-Key: <uuid> — optional; deduplicates submit calls
|
|
|
|
The gateway automatically:
|
|
1. Injects the provider's API key from the configured env var.
|
|
2. For *submit* routes: reserves billing, forwards, records task state.
|
|
3. For *query* routes: forwards, detects terminal status, finalizes billing once.
|
|
4. For all other routes: transparent passthrough, no billing side-effects.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from fastapi.responses import JSONResponse, Response
|
|
|
|
from app.gateway.third_party_proxy import billing, proxy
|
|
from app.gateway.third_party_proxy.ledger import CallRecord, get_ledger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/proxy", tags=["third-party-proxy"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.api_route("/{provider}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
|
async def proxy_request(provider: str, path: str, request: Request) -> Response:
|
|
"""Universal proxy endpoint for third-party API calls with billing integration."""
|
|
provider_config = proxy.get_provider_config(provider)
|
|
if provider_config is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Provider '{provider}' is not configured or the proxy is disabled.",
|
|
)
|
|
|
|
method = request.method
|
|
# Normalise: ensure leading slash so patterns like /openapi/v2/** match correctly
|
|
path = "/" + path.lstrip("/")
|
|
|
|
thread_id = request.headers.get("x-thread-id")
|
|
idempotency_key = request.headers.get("x-idempotency-key")
|
|
|
|
body = await request.body()
|
|
request_json: dict[str, Any] | None = _try_parse_json(body)
|
|
|
|
submit_route = proxy.match_submit_route(provider_config, method, path)
|
|
query_route = proxy.match_query_route(provider_config, method, path)
|
|
logger.info("[ThirdPartyProxy] route=%s provider=%s method=%s path=%s", "submit" if submit_route else "query" if query_route else "passthrough", provider, method, path)
|
|
|
|
if submit_route:
|
|
return await _handle_submit(
|
|
provider=provider,
|
|
provider_config=provider_config,
|
|
method=method,
|
|
path=path,
|
|
request=request,
|
|
body=body,
|
|
thread_id=thread_id,
|
|
idempotency_key=idempotency_key,
|
|
task_id_jsonpath=submit_route.task_id_jsonpath,
|
|
route_frozen_amount=submit_route.frozen_amount,
|
|
route_frozen_type=submit_route.frozen_type,
|
|
)
|
|
|
|
if query_route:
|
|
return await _handle_query(
|
|
provider=provider,
|
|
provider_config=provider_config,
|
|
method=method,
|
|
path=path,
|
|
request=request,
|
|
body=body,
|
|
request_json=request_json,
|
|
query_route=query_route,
|
|
)
|
|
|
|
# Pure passthrough — no billing, no state
|
|
return await _passthrough(
|
|
provider_config=provider_config,
|
|
method=method,
|
|
path=path,
|
|
request=request,
|
|
body=body,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Submit handler
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def _handle_submit(
|
|
*,
|
|
provider: str,
|
|
provider_config,
|
|
method: str,
|
|
path: str,
|
|
request: Request,
|
|
body: bytes,
|
|
thread_id: str | None,
|
|
idempotency_key: str | None,
|
|
task_id_jsonpath: str,
|
|
route_frozen_amount: float | None,
|
|
route_frozen_type: int | None,
|
|
) -> Response:
|
|
ledger = get_ledger()
|
|
|
|
# Idempotency: if we've already handled this exact submit, return the cached response
|
|
if idempotency_key:
|
|
existing = ledger.get_by_idempotency_key(provider, idempotency_key)
|
|
if existing is not None and existing.last_response is not None:
|
|
logger.info("[ThirdPartyProxy] idempotent submit: proxy_call_id=%s", existing.proxy_call_id)
|
|
return _proxy_response(existing.last_response, existing.proxy_call_id)
|
|
|
|
record = ledger.create(provider, thread_id, idempotency_key)
|
|
|
|
# Reserve billing before touching the provider
|
|
reserve_frozen_amount = route_frozen_amount if route_frozen_amount is not None else provider_config.frozen_amount
|
|
reserve_frozen_type = route_frozen_type if route_frozen_type is not None else provider_config.frozen_type
|
|
frozen_id = await billing.reserve(
|
|
thread_id=thread_id,
|
|
call_id=record.call_id,
|
|
provider=provider,
|
|
operation=path,
|
|
frozen_amount=reserve_frozen_amount,
|
|
frozen_type=reserve_frozen_type,
|
|
)
|
|
if frozen_id:
|
|
ledger.set_reserved(record.proxy_call_id, frozen_id)
|
|
|
|
# Forward to provider
|
|
try:
|
|
status_code, resp_headers, resp_body = await proxy.forward_request(
|
|
provider_config=provider_config,
|
|
method=method,
|
|
path=path,
|
|
headers=dict(request.headers),
|
|
body=body,
|
|
query_params=str(request.query_params),
|
|
)
|
|
except Exception as exc:
|
|
await _finalize_zero(frozen_id, record.proxy_call_id, "error exception")
|
|
raise HTTPException(status_code=502, detail=f"Provider unreachable: {exc}") from exc
|
|
|
|
resp_json = _try_parse_json(resp_body)
|
|
|
|
# HTTP-level failure
|
|
if status_code >= 400:
|
|
reason = f"error_http_{status_code}"
|
|
await _finalize_zero(frozen_id, record.proxy_call_id, reason)
|
|
if resp_json is not None:
|
|
ledger.update_response(record.proxy_call_id, resp_json)
|
|
return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json")
|
|
|
|
# Extract task_id from response; no task_id means provider rejected at business level
|
|
provider_task_id: str | None = None
|
|
if resp_json is not None:
|
|
raw = proxy.jsonpath_get(resp_json, task_id_jsonpath)
|
|
if raw is not None:
|
|
provider_task_id = str(raw)
|
|
|
|
if provider_task_id:
|
|
ledger.set_running(record.proxy_call_id, provider_task_id)
|
|
else:
|
|
# No async task ID usually means provider-side business rejection.
|
|
# Propagate errorCode (if present) into finalize_reason.
|
|
error_code = None
|
|
if resp_json is not None:
|
|
raw_error_code = resp_json.get("errorCode")
|
|
if raw_error_code is None:
|
|
raw_error_code = resp_json.get("code")
|
|
if raw_error_code is not None:
|
|
error_code = str(raw_error_code)
|
|
|
|
finalize_reason = error_code or "no_task_id"
|
|
await _finalize_zero(frozen_id, record.proxy_call_id, finalize_reason)
|
|
|
|
if resp_json is not None:
|
|
ledger.update_response(record.proxy_call_id, resp_json)
|
|
|
|
return _proxy_response(resp_json or {}, record.proxy_call_id, status_code, resp_headers)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Query handler
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def _handle_query(
|
|
*,
|
|
provider: str,
|
|
provider_config,
|
|
method: str,
|
|
path: str,
|
|
request: Request,
|
|
body: bytes,
|
|
request_json: dict[str, Any] | None,
|
|
query_route,
|
|
) -> Response:
|
|
ledger = get_ledger()
|
|
|
|
# Locate the call record by provider_task_id embedded in the request body
|
|
provider_task_id: str | None = None
|
|
if request_json:
|
|
raw = proxy.jsonpath_get(request_json, query_route.request_task_id_jsonpath)
|
|
if raw is not None:
|
|
provider_task_id = str(raw)
|
|
|
|
record: CallRecord | None = None
|
|
if provider_task_id:
|
|
record = ledger.get_by_task_id(provider, provider_task_id)
|
|
|
|
# Already at terminal state — return cached result without calling the provider again
|
|
if record is not None and ledger.is_finalized(record.proxy_call_id) and record.last_response is not None:
|
|
logger.info("[ThirdPartyProxy] query already finalized, returning cache: proxy_call_id=%s", record.proxy_call_id)
|
|
return _proxy_response(record.last_response, record.proxy_call_id)
|
|
|
|
# Forward query to provider
|
|
try:
|
|
status_code, resp_headers, resp_body = await proxy.forward_request(
|
|
provider_config=provider_config,
|
|
method=method,
|
|
path=path,
|
|
headers=dict(request.headers),
|
|
body=body,
|
|
query_params=str(request.query_params),
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=502, detail=f"Provider query failed: {exc}") from exc
|
|
|
|
resp_json = _try_parse_json(resp_body)
|
|
if status_code >= 400 or resp_json is None:
|
|
return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json")
|
|
|
|
# Detect terminal status in the response
|
|
status_value = proxy.jsonpath_get(resp_json, query_route.status_jsonpath)
|
|
status_str = str(status_value) if status_value is not None else None
|
|
is_success = status_str in query_route.success_values
|
|
is_failure = status_str in query_route.failure_values
|
|
|
|
logger.debug(
|
|
"[ThirdPartyProxy] query terminal check: provider=%s task_id=%s status=%s is_success=%s is_failure=%s",
|
|
provider,
|
|
provider_task_id,
|
|
status_str,
|
|
is_success,
|
|
is_failure,
|
|
)
|
|
|
|
if record is not None and (is_success or is_failure):
|
|
logger.info(
|
|
"[ThirdPartyProxy] finalize candidate: proxy_call_id=%s provider_task_id=%s terminal_status=%s",
|
|
record.proxy_call_id,
|
|
provider_task_id,
|
|
status_str,
|
|
)
|
|
# Atomically claim finalize rights — only one concurrent query wins
|
|
if ledger.try_claim_finalize(record.proxy_call_id):
|
|
logger.info(
|
|
"[ThirdPartyProxy] finalize claimed: proxy_call_id=%s",
|
|
record.proxy_call_id,
|
|
)
|
|
final_amount: float = 0.0
|
|
usage_paths = list(query_route.usage_jsonpaths or [])
|
|
if not usage_paths and query_route.usage_jsonpath:
|
|
usage_paths = [query_route.usage_jsonpath]
|
|
if is_success:
|
|
final_amount = _resolve_final_amount(resp_json, query_route)
|
|
|
|
logger.debug(
|
|
"[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s final_amount=%s usage_paths=%s legacy_path=%s",
|
|
record.proxy_call_id,
|
|
final_amount,
|
|
usage_paths,
|
|
query_route.usage_jsonpath,
|
|
)
|
|
|
|
task_state = "SUCCESS" if is_success else "FAILED"
|
|
finalize_reason = "success" if is_success else "error"
|
|
|
|
logger.info(
|
|
"[ThirdPartyProxy] finalize start: proxy_call_id=%s reason=%s task_state=%s has_frozen_id=%s",
|
|
record.proxy_call_id,
|
|
finalize_reason,
|
|
task_state,
|
|
bool(record.frozen_id),
|
|
)
|
|
|
|
if record.frozen_id:
|
|
ok = await billing.finalize(
|
|
frozen_id=record.frozen_id,
|
|
final_amount=final_amount,
|
|
finalize_reason=finalize_reason,
|
|
)
|
|
logger.info(
|
|
"[ThirdPartyProxy] finalize result: proxy_call_id=%s ok=%s",
|
|
record.proxy_call_id,
|
|
ok,
|
|
)
|
|
if ok:
|
|
ledger.set_finalized(record.proxy_call_id, task_state)
|
|
else:
|
|
ledger.set_finalize_failed(record.proxy_call_id, task_state)
|
|
else:
|
|
logger.info(
|
|
"[ThirdPartyProxy] finalize skipped billing call (no frozen_id): proxy_call_id=%s",
|
|
record.proxy_call_id,
|
|
)
|
|
ledger.set_finalized(record.proxy_call_id, task_state)
|
|
|
|
ledger.update_response(record.proxy_call_id, resp_json)
|
|
else:
|
|
logger.info(
|
|
"[ThirdPartyProxy] finalize claim denied (already processed): proxy_call_id=%s",
|
|
record.proxy_call_id,
|
|
)
|
|
|
|
proxy_call_id = record.proxy_call_id if record else None
|
|
return _proxy_response(resp_json, proxy_call_id, status_code, resp_headers)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Passthrough handler
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def _passthrough(*, provider_config, method: str, path: str, request: Request, body: bytes) -> Response:
|
|
try:
|
|
status_code, resp_headers, resp_body = await proxy.forward_request(
|
|
provider_config=provider_config,
|
|
method=method,
|
|
path=path,
|
|
headers=dict(request.headers),
|
|
body=body,
|
|
query_params=str(request.query_params),
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=502, detail=f"Provider request failed: {exc}") from exc
|
|
|
|
return Response(content=resp_body, status_code=status_code, headers=resp_headers)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def _finalize_zero(frozen_id: str | None, proxy_call_id: str, reason: str) -> None:
|
|
"""Finalize with amount=0 when billing was reserved but the call failed."""
|
|
ledger = get_ledger()
|
|
logger.info(
|
|
"[ThirdPartyProxy] finalize_zero requested: proxy_call_id=%s reason=%s has_frozen_id=%s",
|
|
proxy_call_id,
|
|
reason,
|
|
bool(frozen_id),
|
|
)
|
|
if frozen_id and ledger.try_claim_finalize(proxy_call_id):
|
|
logger.info("[ThirdPartyProxy] finalize_zero claimed: proxy_call_id=%s", proxy_call_id)
|
|
ok = await billing.finalize(frozen_id=frozen_id, final_amount=0, finalize_reason=reason)
|
|
logger.info("[ThirdPartyProxy] finalize_zero result: proxy_call_id=%s ok=%s", proxy_call_id, ok)
|
|
task_state = "SUCCESS" if reason == "success" else "FAILED"
|
|
if ok:
|
|
ledger.set_finalized(proxy_call_id, task_state)
|
|
else:
|
|
ledger.set_finalize_failed(proxy_call_id, task_state)
|
|
elif not frozen_id:
|
|
logger.debug("[ThirdPartyProxy] finalize_zero skipped: no frozen_id proxy_call_id=%s", proxy_call_id)
|
|
else:
|
|
logger.info("[ThirdPartyProxy] finalize_zero claim denied: proxy_call_id=%s", proxy_call_id)
|
|
|
|
|
|
def _try_parse_json(data: bytes) -> dict[str, Any] | None:
|
|
if not data:
|
|
return None
|
|
try:
|
|
parsed = json.loads(data)
|
|
return parsed if isinstance(parsed, dict) else None
|
|
except (json.JSONDecodeError, ValueError):
|
|
return None
|
|
|
|
|
|
def _resolve_final_amount(resp_json: dict[str, Any], query_route) -> float:
|
|
"""Resolve final billing amount from configured usage paths.
|
|
|
|
Priority:
|
|
1) `usage_jsonpaths` (sum all valid numeric values)
|
|
2) legacy `usage_jsonpath` (single value)
|
|
"""
|
|
usage_paths = list(query_route.usage_jsonpaths or [])
|
|
if not usage_paths and query_route.usage_jsonpath:
|
|
usage_paths = [query_route.usage_jsonpath]
|
|
|
|
total = 0.0
|
|
for path in usage_paths:
|
|
raw = proxy.jsonpath_get(resp_json, path)
|
|
if raw is None:
|
|
continue
|
|
try:
|
|
total += float(raw)
|
|
except (TypeError, ValueError):
|
|
continue
|
|
|
|
return total
|
|
|
|
|
|
def _proxy_response(
|
|
data: dict[str, Any],
|
|
proxy_call_id: str | None,
|
|
status_code: int = 200,
|
|
extra_headers: dict[str, str] | None = None,
|
|
) -> JSONResponse:
|
|
headers: dict[str, str] = dict(extra_headers or {})
|
|
if proxy_call_id:
|
|
headers["X-Proxy-Call-Id"] = proxy_call_id
|
|
return JSONResponse(content=data, status_code=status_code, headers=headers)
|