"""Universal third-party API proxy router with integrated billing. Endpoint: ANY /api/proxy/{provider}/{path...} The caller (a sandbox skill script) should set: X-Thread-Id: — used for billing reservation (injected via THREAD_ID env var) X-Idempotency-Key: — optional; deduplicates submit calls The gateway automatically: 1. Injects the provider's API key from the configured env var. 2. For *submit* routes: reserves billing, forwards, records task state. 3. For *query* routes: forwards, detects terminal status, finalizes billing once. 4. For all other routes: transparent passthrough, no billing side-effects. """ from __future__ import annotations import json import logging from typing import Any from fastapi import APIRouter, HTTPException, Request from fastapi.responses import JSONResponse, Response from app.gateway.third_party_proxy import billing, proxy from app.gateway.third_party_proxy.ledger import CallRecord, get_ledger logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/proxy", tags=["third-party-proxy"]) # --------------------------------------------------------------------------- # Main entry point # --------------------------------------------------------------------------- @router.api_route("/{provider}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def proxy_request(provider: str, path: str, request: Request) -> Response: """Universal proxy endpoint for third-party API calls with billing integration.""" provider_config = proxy.get_provider_config(provider) if provider_config is None: raise HTTPException( status_code=404, detail=f"Provider '{provider}' is not configured or the proxy is disabled.", ) method = request.method # Normalise: ensure leading slash so patterns like /openapi/v2/** match correctly path = "/" + path.lstrip("/") thread_id = request.headers.get("x-thread-id") idempotency_key = request.headers.get("x-idempotency-key") body = await request.body() request_json: dict[str, Any] | None = _try_parse_json(body) submit_route = proxy.match_submit_route(provider_config, method, path) query_route = proxy.match_query_route(provider_config, method, path) logger.info("[ThirdPartyProxy] route=%s provider=%s method=%s path=%s", "submit" if submit_route else "query" if query_route else "passthrough", provider, method, path) if submit_route: return await _handle_submit( provider=provider, provider_config=provider_config, method=method, path=path, request=request, body=body, thread_id=thread_id, idempotency_key=idempotency_key, task_id_jsonpath=submit_route.task_id_jsonpath, route_frozen_amount=submit_route.frozen_amount, route_frozen_type=submit_route.frozen_type, ) if query_route: return await _handle_query( provider=provider, provider_config=provider_config, method=method, path=path, request=request, body=body, request_json=request_json, query_route=query_route, ) # Pure passthrough — no billing, no state return await _passthrough( provider_config=provider_config, method=method, path=path, request=request, body=body, ) # --------------------------------------------------------------------------- # Submit handler # --------------------------------------------------------------------------- async def _handle_submit( *, provider: str, provider_config, method: str, path: str, request: Request, body: bytes, thread_id: str | None, idempotency_key: str | None, task_id_jsonpath: str, route_frozen_amount: float | None, route_frozen_type: int | None, ) -> Response: ledger = get_ledger() # Idempotency: if we've already handled this exact submit, return the cached response if idempotency_key: existing = ledger.get_by_idempotency_key(provider, idempotency_key) if existing is not None and existing.last_response is not None: logger.info("[ThirdPartyProxy] idempotent submit: proxy_call_id=%s", existing.proxy_call_id) return _proxy_response(existing.last_response, existing.proxy_call_id) record = ledger.create(provider, thread_id, idempotency_key) # Reserve billing before touching the provider reserve_frozen_amount = route_frozen_amount if route_frozen_amount is not None else provider_config.frozen_amount reserve_frozen_type = route_frozen_type if route_frozen_type is not None else provider_config.frozen_type frozen_id = await billing.reserve( thread_id=thread_id, call_id=record.call_id, provider=provider, operation=path, frozen_amount=reserve_frozen_amount, frozen_type=reserve_frozen_type, ) if frozen_id: ledger.set_reserved(record.proxy_call_id, frozen_id) # Forward to provider try: status_code, resp_headers, resp_body = await proxy.forward_request( provider_config=provider_config, method=method, path=path, headers=dict(request.headers), body=body, query_params=str(request.query_params), ) except Exception as exc: await _finalize_zero(frozen_id, record.proxy_call_id, "error exception") raise HTTPException(status_code=502, detail=f"Provider unreachable: {exc}") from exc resp_json = _try_parse_json(resp_body) # HTTP-level failure if status_code >= 400: reason = f"error_http_{status_code}" await _finalize_zero(frozen_id, record.proxy_call_id, reason) if resp_json is not None: ledger.update_response(record.proxy_call_id, resp_json) return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json") # Extract task_id from response; no task_id means provider rejected at business level provider_task_id: str | None = None if resp_json is not None: raw = proxy.jsonpath_get(resp_json, task_id_jsonpath) if raw is not None: provider_task_id = str(raw) if provider_task_id: ledger.set_running(record.proxy_call_id, provider_task_id) else: # No async task ID usually means provider-side business rejection. # Propagate errorCode (if present) into finalize_reason. error_code = None if resp_json is not None: raw_error_code = resp_json.get("errorCode") if raw_error_code is None: raw_error_code = resp_json.get("code") if raw_error_code is not None: error_code = str(raw_error_code) finalize_reason = error_code or "no_task_id" await _finalize_zero(frozen_id, record.proxy_call_id, finalize_reason) if resp_json is not None: ledger.update_response(record.proxy_call_id, resp_json) return _proxy_response(resp_json or {}, record.proxy_call_id, status_code, resp_headers) # --------------------------------------------------------------------------- # Query handler # --------------------------------------------------------------------------- async def _handle_query( *, provider: str, provider_config, method: str, path: str, request: Request, body: bytes, request_json: dict[str, Any] | None, query_route, ) -> Response: ledger = get_ledger() # Locate the call record by provider_task_id embedded in the request body provider_task_id: str | None = None if request_json: raw = proxy.jsonpath_get(request_json, query_route.request_task_id_jsonpath) if raw is not None: provider_task_id = str(raw) record: CallRecord | None = None if provider_task_id: record = ledger.get_by_task_id(provider, provider_task_id) # Already at terminal state — return cached result without calling the provider again if record is not None and ledger.is_finalized(record.proxy_call_id) and record.last_response is not None: logger.info("[ThirdPartyProxy] query already finalized, returning cache: proxy_call_id=%s", record.proxy_call_id) return _proxy_response(record.last_response, record.proxy_call_id) # Forward query to provider try: status_code, resp_headers, resp_body = await proxy.forward_request( provider_config=provider_config, method=method, path=path, headers=dict(request.headers), body=body, query_params=str(request.query_params), ) except Exception as exc: raise HTTPException(status_code=502, detail=f"Provider query failed: {exc}") from exc resp_json = _try_parse_json(resp_body) if status_code >= 400 or resp_json is None: return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json") # Detect terminal status in the response status_value = proxy.jsonpath_get(resp_json, query_route.status_jsonpath) status_str = str(status_value) if status_value is not None else None is_success = status_str in query_route.success_values is_failure = status_str in query_route.failure_values logger.debug( "[ThirdPartyProxy] query terminal check: provider=%s task_id=%s status=%s is_success=%s is_failure=%s", provider, provider_task_id, status_str, is_success, is_failure, ) if record is not None and (is_success or is_failure): logger.info( "[ThirdPartyProxy] finalize candidate: proxy_call_id=%s provider_task_id=%s terminal_status=%s", record.proxy_call_id, provider_task_id, status_str, ) # Atomically claim finalize rights — only one concurrent query wins if ledger.try_claim_finalize(record.proxy_call_id): logger.info( "[ThirdPartyProxy] finalize claimed: proxy_call_id=%s", record.proxy_call_id, ) final_amount: float = 0.0 if is_success and query_route.usage_jsonpath: raw_amount = proxy.jsonpath_get(resp_json, query_route.usage_jsonpath) try: final_amount = float(raw_amount) if raw_amount is not None else 0.0 except (TypeError, ValueError): final_amount = 0.0 logger.debug( "[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s final_amount=%s usage_path=%s", record.proxy_call_id, final_amount, query_route.usage_jsonpath, ) task_state = "SUCCESS" if is_success else "FAILED" finalize_reason = "success" if is_success else "error" logger.info( "[ThirdPartyProxy] finalize start: proxy_call_id=%s reason=%s task_state=%s has_frozen_id=%s", record.proxy_call_id, finalize_reason, task_state, bool(record.frozen_id), ) if record.frozen_id: ok = await billing.finalize( frozen_id=record.frozen_id, final_amount=final_amount, finalize_reason=finalize_reason, ) logger.info( "[ThirdPartyProxy] finalize result: proxy_call_id=%s ok=%s", record.proxy_call_id, ok, ) if ok: ledger.set_finalized(record.proxy_call_id, task_state) else: ledger.set_finalize_failed(record.proxy_call_id, task_state) else: logger.info( "[ThirdPartyProxy] finalize skipped billing call (no frozen_id): proxy_call_id=%s", record.proxy_call_id, ) ledger.set_finalized(record.proxy_call_id, task_state) ledger.update_response(record.proxy_call_id, resp_json) else: logger.info( "[ThirdPartyProxy] finalize claim denied (already processed): proxy_call_id=%s", record.proxy_call_id, ) proxy_call_id = record.proxy_call_id if record else None return _proxy_response(resp_json, proxy_call_id, status_code, resp_headers) # --------------------------------------------------------------------------- # Passthrough handler # --------------------------------------------------------------------------- async def _passthrough(*, provider_config, method: str, path: str, request: Request, body: bytes) -> Response: try: status_code, resp_headers, resp_body = await proxy.forward_request( provider_config=provider_config, method=method, path=path, headers=dict(request.headers), body=body, query_params=str(request.query_params), ) except Exception as exc: raise HTTPException(status_code=502, detail=f"Provider request failed: {exc}") from exc return Response(content=resp_body, status_code=status_code, headers=resp_headers) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- async def _finalize_zero(frozen_id: str | None, proxy_call_id: str, reason: str) -> None: """Finalize with amount=0 when billing was reserved but the call failed.""" ledger = get_ledger() logger.info( "[ThirdPartyProxy] finalize_zero requested: proxy_call_id=%s reason=%s has_frozen_id=%s", proxy_call_id, reason, bool(frozen_id), ) if frozen_id and ledger.try_claim_finalize(proxy_call_id): logger.info("[ThirdPartyProxy] finalize_zero claimed: proxy_call_id=%s", proxy_call_id) ok = await billing.finalize(frozen_id=frozen_id, final_amount=0, finalize_reason=reason) logger.info("[ThirdPartyProxy] finalize_zero result: proxy_call_id=%s ok=%s", proxy_call_id, ok) task_state = "SUCCESS" if reason == "success" else "FAILED" if ok: ledger.set_finalized(proxy_call_id, task_state) else: ledger.set_finalize_failed(proxy_call_id, task_state) elif not frozen_id: logger.debug("[ThirdPartyProxy] finalize_zero skipped: no frozen_id proxy_call_id=%s", proxy_call_id) else: logger.info("[ThirdPartyProxy] finalize_zero claim denied: proxy_call_id=%s", proxy_call_id) def _try_parse_json(data: bytes) -> dict[str, Any] | None: if not data: return None try: parsed = json.loads(data) return parsed if isinstance(parsed, dict) else None except (json.JSONDecodeError, ValueError): return None def _proxy_response( data: dict[str, Any], proxy_call_id: str | None, status_code: int = 200, extra_headers: dict[str, str] | None = None, ) -> JSONResponse: headers: dict[str, str] = dict(extra_headers or {}) if proxy_call_id: headers["X-Proxy-Call-Id"] = proxy_call_id return JSONResponse(content=data, status_code=status_code, headers=headers)