feat(proxy): add third-party proxy module with billing integration
- Introduced a new third-party proxy package for handling async task APIs. - Implemented billing client with reserve and finalize functionalities. - Created an in-memory ledger to track call states and ensure idempotency. - Added route classification for submit and query requests. - Configured third-party provider settings and routes in the application config. - Updated local backend to support Docker networking for sandbox containers.
This commit is contained in:
parent
8d5b01a59b
commit
dabe529cc7
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
|
@ -17,21 +18,39 @@ from app.gateway.routers import (
|
|||
runs,
|
||||
skills,
|
||||
suggestions,
|
||||
third_party,
|
||||
thread_runs,
|
||||
threads,
|
||||
uploads,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
# Configure logging with env override
|
||||
import os
|
||||
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
# Configure logging (prefer config.yaml log_level, fallback to LOG_LEVEL env)
|
||||
env_log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
log_level = env_log_level
|
||||
try:
|
||||
configured_log_level = get_app_config().log_level.upper()
|
||||
if configured_log_level:
|
||||
log_level = configured_log_level
|
||||
except Exception:
|
||||
# Keep startup resilient even if config is temporarily invalid/unavailable.
|
||||
log_level = env_log_level
|
||||
|
||||
resolved_log_level = getattr(logging, log_level, logging.INFO)
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level, logging.INFO),
|
||||
level=resolved_log_level,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
# Uvicorn installs logging handlers before app import; force reconfigure so
|
||||
# config.yaml log_level reliably takes effect.
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Ensure package loggers inherit the intended level even under custom handlers.
|
||||
logging.getLogger().setLevel(resolved_log_level)
|
||||
logging.getLogger("app").setLevel(resolved_log_level)
|
||||
logging.getLogger("deerflow").setLevel(resolved_log_level)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -162,6 +181,10 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
|||
"name": "health",
|
||||
"description": "Health check and system status endpoints",
|
||||
},
|
||||
{
|
||||
"name": "third-party-proxy",
|
||||
"description": "Universal third-party API proxy with billing integration (/api/proxy/{provider}/...)",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -207,6 +230,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
|||
# Stateless Runs API (stream/wait without a pre-existing thread)
|
||||
app.include_router(runs.router)
|
||||
|
||||
# Third-party API proxy with billing integration
|
||||
app.include_router(third_party.router)
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health_check() -> dict:
|
||||
"""Health check endpoint.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
||||
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, third_party, thread_runs, threads, uploads
|
||||
|
||||
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
||||
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "third_party", "threads", "thread_runs", "uploads"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,403 @@
|
|||
"""Universal third-party API proxy router with integrated billing.
|
||||
|
||||
Endpoint: ANY /api/proxy/{provider}/{path...}
|
||||
|
||||
The caller (a sandbox skill script) should set:
|
||||
X-Thread-Id: <thread_id> — used for billing reservation (injected via THREAD_ID env var)
|
||||
X-Idempotency-Key: <uuid> — optional; deduplicates submit calls
|
||||
|
||||
The gateway automatically:
|
||||
1. Injects the provider's API key from the configured env var.
|
||||
2. For *submit* routes: reserves billing, forwards, records task state.
|
||||
3. For *query* routes: forwards, detects terminal status, finalizes billing once.
|
||||
4. For all other routes: transparent passthrough, no billing side-effects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from app.gateway.third_party_proxy import billing, proxy
|
||||
from app.gateway.third_party_proxy.ledger import CallRecord, get_ledger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/proxy", tags=["third-party-proxy"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.api_route("/{provider}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||||
async def proxy_request(provider: str, path: str, request: Request) -> Response:
|
||||
"""Universal proxy endpoint for third-party API calls with billing integration."""
|
||||
provider_config = proxy.get_provider_config(provider)
|
||||
if provider_config is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Provider '{provider}' is not configured or the proxy is disabled.",
|
||||
)
|
||||
|
||||
method = request.method
|
||||
# Normalise: ensure leading slash so patterns like /openapi/v2/** match correctly
|
||||
path = "/" + path.lstrip("/")
|
||||
|
||||
thread_id = request.headers.get("x-thread-id")
|
||||
idempotency_key = request.headers.get("x-idempotency-key")
|
||||
|
||||
body = await request.body()
|
||||
request_json: dict[str, Any] | None = _try_parse_json(body)
|
||||
|
||||
submit_route = proxy.match_submit_route(provider_config, method, path)
|
||||
query_route = proxy.match_query_route(provider_config, method, path)
|
||||
logger.info("[ThirdPartyProxy] route=%s provider=%s method=%s path=%s", "submit" if submit_route else "query" if query_route else "passthrough", provider, method, path)
|
||||
|
||||
if submit_route:
|
||||
return await _handle_submit(
|
||||
provider=provider,
|
||||
provider_config=provider_config,
|
||||
method=method,
|
||||
path=path,
|
||||
request=request,
|
||||
body=body,
|
||||
thread_id=thread_id,
|
||||
idempotency_key=idempotency_key,
|
||||
task_id_jsonpath=submit_route.task_id_jsonpath,
|
||||
route_frozen_amount=submit_route.frozen_amount,
|
||||
route_frozen_type=submit_route.frozen_type,
|
||||
)
|
||||
|
||||
if query_route:
|
||||
return await _handle_query(
|
||||
provider=provider,
|
||||
provider_config=provider_config,
|
||||
method=method,
|
||||
path=path,
|
||||
request=request,
|
||||
body=body,
|
||||
request_json=request_json,
|
||||
query_route=query_route,
|
||||
)
|
||||
|
||||
# Pure passthrough — no billing, no state
|
||||
return await _passthrough(
|
||||
provider_config=provider_config,
|
||||
method=method,
|
||||
path=path,
|
||||
request=request,
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Submit handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _handle_submit(
|
||||
*,
|
||||
provider: str,
|
||||
provider_config,
|
||||
method: str,
|
||||
path: str,
|
||||
request: Request,
|
||||
body: bytes,
|
||||
thread_id: str | None,
|
||||
idempotency_key: str | None,
|
||||
task_id_jsonpath: str,
|
||||
route_frozen_amount: float | None,
|
||||
route_frozen_type: int | None,
|
||||
) -> Response:
|
||||
ledger = get_ledger()
|
||||
|
||||
# Idempotency: if we've already handled this exact submit, return the cached response
|
||||
if idempotency_key:
|
||||
existing = ledger.get_by_idempotency_key(provider, idempotency_key)
|
||||
if existing is not None and existing.last_response is not None:
|
||||
logger.info("[ThirdPartyProxy] idempotent submit: proxy_call_id=%s", existing.proxy_call_id)
|
||||
return _proxy_response(existing.last_response, existing.proxy_call_id)
|
||||
|
||||
record = ledger.create(provider, thread_id, idempotency_key)
|
||||
|
||||
# Reserve billing before touching the provider
|
||||
reserve_frozen_amount = route_frozen_amount if route_frozen_amount is not None else provider_config.frozen_amount
|
||||
reserve_frozen_type = route_frozen_type if route_frozen_type is not None else provider_config.frozen_type
|
||||
frozen_id = await billing.reserve(
|
||||
thread_id=thread_id,
|
||||
call_id=record.call_id,
|
||||
provider=provider,
|
||||
operation=path,
|
||||
frozen_amount=reserve_frozen_amount,
|
||||
frozen_type=reserve_frozen_type,
|
||||
)
|
||||
if frozen_id:
|
||||
ledger.set_reserved(record.proxy_call_id, frozen_id)
|
||||
|
||||
# Forward to provider
|
||||
try:
|
||||
status_code, resp_headers, resp_body = await proxy.forward_request(
|
||||
provider_config=provider_config,
|
||||
method=method,
|
||||
path=path,
|
||||
headers=dict(request.headers),
|
||||
body=body,
|
||||
query_params=str(request.query_params),
|
||||
)
|
||||
except Exception as exc:
|
||||
await _finalize_zero(frozen_id, record.proxy_call_id, "error exception")
|
||||
raise HTTPException(status_code=502, detail=f"Provider unreachable: {exc}") from exc
|
||||
|
||||
resp_json = _try_parse_json(resp_body)
|
||||
|
||||
# HTTP-level failure
|
||||
if status_code >= 400:
|
||||
reason = f"error_http_{status_code}"
|
||||
await _finalize_zero(frozen_id, record.proxy_call_id, reason)
|
||||
if resp_json is not None:
|
||||
ledger.update_response(record.proxy_call_id, resp_json)
|
||||
return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json")
|
||||
|
||||
# Extract task_id from response; no task_id means provider rejected at business level
|
||||
provider_task_id: str | None = None
|
||||
if resp_json is not None:
|
||||
raw = proxy.jsonpath_get(resp_json, task_id_jsonpath)
|
||||
if raw is not None:
|
||||
provider_task_id = str(raw)
|
||||
|
||||
if provider_task_id:
|
||||
ledger.set_running(record.proxy_call_id, provider_task_id)
|
||||
else:
|
||||
# No async task ID usually means provider-side business rejection.
|
||||
# Propagate errorCode (if present) into finalize_reason.
|
||||
error_code = None
|
||||
if resp_json is not None:
|
||||
raw_error_code = resp_json.get("errorCode")
|
||||
if raw_error_code is None:
|
||||
raw_error_code = resp_json.get("code")
|
||||
if raw_error_code is not None:
|
||||
error_code = str(raw_error_code)
|
||||
|
||||
finalize_reason = error_code or "no_task_id"
|
||||
await _finalize_zero(frozen_id, record.proxy_call_id, finalize_reason)
|
||||
|
||||
if resp_json is not None:
|
||||
ledger.update_response(record.proxy_call_id, resp_json)
|
||||
|
||||
return _proxy_response(resp_json or {}, record.proxy_call_id, status_code, resp_headers)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _handle_query(
|
||||
*,
|
||||
provider: str,
|
||||
provider_config,
|
||||
method: str,
|
||||
path: str,
|
||||
request: Request,
|
||||
body: bytes,
|
||||
request_json: dict[str, Any] | None,
|
||||
query_route,
|
||||
) -> Response:
|
||||
ledger = get_ledger()
|
||||
|
||||
# Locate the call record by provider_task_id embedded in the request body
|
||||
provider_task_id: str | None = None
|
||||
if request_json:
|
||||
raw = proxy.jsonpath_get(request_json, query_route.request_task_id_jsonpath)
|
||||
if raw is not None:
|
||||
provider_task_id = str(raw)
|
||||
|
||||
record: CallRecord | None = None
|
||||
if provider_task_id:
|
||||
record = ledger.get_by_task_id(provider, provider_task_id)
|
||||
|
||||
# Already at terminal state — return cached result without calling the provider again
|
||||
if record is not None and ledger.is_finalized(record.proxy_call_id) and record.last_response is not None:
|
||||
logger.info("[ThirdPartyProxy] query already finalized, returning cache: proxy_call_id=%s", record.proxy_call_id)
|
||||
return _proxy_response(record.last_response, record.proxy_call_id)
|
||||
|
||||
# Forward query to provider
|
||||
try:
|
||||
status_code, resp_headers, resp_body = await proxy.forward_request(
|
||||
provider_config=provider_config,
|
||||
method=method,
|
||||
path=path,
|
||||
headers=dict(request.headers),
|
||||
body=body,
|
||||
query_params=str(request.query_params),
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Provider query failed: {exc}") from exc
|
||||
|
||||
resp_json = _try_parse_json(resp_body)
|
||||
if status_code >= 400 or resp_json is None:
|
||||
return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json")
|
||||
|
||||
# Detect terminal status in the response
|
||||
status_value = proxy.jsonpath_get(resp_json, query_route.status_jsonpath)
|
||||
status_str = str(status_value) if status_value is not None else None
|
||||
is_success = status_str in query_route.success_values
|
||||
is_failure = status_str in query_route.failure_values
|
||||
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy] query terminal check: provider=%s task_id=%s status=%s is_success=%s is_failure=%s",
|
||||
provider,
|
||||
provider_task_id,
|
||||
status_str,
|
||||
is_success,
|
||||
is_failure,
|
||||
)
|
||||
|
||||
if record is not None and (is_success or is_failure):
|
||||
logger.info(
|
||||
"[ThirdPartyProxy] finalize candidate: proxy_call_id=%s provider_task_id=%s terminal_status=%s",
|
||||
record.proxy_call_id,
|
||||
provider_task_id,
|
||||
status_str,
|
||||
)
|
||||
# Atomically claim finalize rights — only one concurrent query wins
|
||||
if ledger.try_claim_finalize(record.proxy_call_id):
|
||||
logger.info(
|
||||
"[ThirdPartyProxy] finalize claimed: proxy_call_id=%s",
|
||||
record.proxy_call_id,
|
||||
)
|
||||
final_amount: float = 0.0
|
||||
if is_success and query_route.usage_jsonpath:
|
||||
raw_amount = proxy.jsonpath_get(resp_json, query_route.usage_jsonpath)
|
||||
try:
|
||||
final_amount = float(raw_amount) if raw_amount is not None else 0.0
|
||||
except (TypeError, ValueError):
|
||||
final_amount = 0.0
|
||||
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s final_amount=%s usage_path=%s",
|
||||
record.proxy_call_id,
|
||||
final_amount,
|
||||
query_route.usage_jsonpath,
|
||||
)
|
||||
|
||||
task_state = "SUCCESS" if is_success else "FAILED"
|
||||
finalize_reason = "success" if is_success else "error"
|
||||
|
||||
logger.info(
|
||||
"[ThirdPartyProxy] finalize start: proxy_call_id=%s reason=%s task_state=%s has_frozen_id=%s",
|
||||
record.proxy_call_id,
|
||||
finalize_reason,
|
||||
task_state,
|
||||
bool(record.frozen_id),
|
||||
)
|
||||
|
||||
if record.frozen_id:
|
||||
ok = await billing.finalize(
|
||||
frozen_id=record.frozen_id,
|
||||
final_amount=final_amount,
|
||||
finalize_reason=finalize_reason,
|
||||
)
|
||||
logger.info(
|
||||
"[ThirdPartyProxy] finalize result: proxy_call_id=%s ok=%s",
|
||||
record.proxy_call_id,
|
||||
ok,
|
||||
)
|
||||
if ok:
|
||||
ledger.set_finalized(record.proxy_call_id, task_state)
|
||||
else:
|
||||
ledger.set_finalize_failed(record.proxy_call_id, task_state)
|
||||
else:
|
||||
logger.info(
|
||||
"[ThirdPartyProxy] finalize skipped billing call (no frozen_id): proxy_call_id=%s",
|
||||
record.proxy_call_id,
|
||||
)
|
||||
ledger.set_finalized(record.proxy_call_id, task_state)
|
||||
|
||||
ledger.update_response(record.proxy_call_id, resp_json)
|
||||
else:
|
||||
logger.info(
|
||||
"[ThirdPartyProxy] finalize claim denied (already processed): proxy_call_id=%s",
|
||||
record.proxy_call_id,
|
||||
)
|
||||
|
||||
proxy_call_id = record.proxy_call_id if record else None
|
||||
return _proxy_response(resp_json, proxy_call_id, status_code, resp_headers)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Passthrough handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _passthrough(*, provider_config, method: str, path: str, request: Request, body: bytes) -> Response:
|
||||
try:
|
||||
status_code, resp_headers, resp_body = await proxy.forward_request(
|
||||
provider_config=provider_config,
|
||||
method=method,
|
||||
path=path,
|
||||
headers=dict(request.headers),
|
||||
body=body,
|
||||
query_params=str(request.query_params),
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Provider request failed: {exc}") from exc
|
||||
|
||||
return Response(content=resp_body, status_code=status_code, headers=resp_headers)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _finalize_zero(frozen_id: str | None, proxy_call_id: str, reason: str) -> None:
|
||||
"""Finalize with amount=0 when billing was reserved but the call failed."""
|
||||
ledger = get_ledger()
|
||||
logger.info(
|
||||
"[ThirdPartyProxy] finalize_zero requested: proxy_call_id=%s reason=%s has_frozen_id=%s",
|
||||
proxy_call_id,
|
||||
reason,
|
||||
bool(frozen_id),
|
||||
)
|
||||
if frozen_id and ledger.try_claim_finalize(proxy_call_id):
|
||||
logger.info("[ThirdPartyProxy] finalize_zero claimed: proxy_call_id=%s", proxy_call_id)
|
||||
ok = await billing.finalize(frozen_id=frozen_id, final_amount=0, finalize_reason=reason)
|
||||
logger.info("[ThirdPartyProxy] finalize_zero result: proxy_call_id=%s ok=%s", proxy_call_id, ok)
|
||||
task_state = "SUCCESS" if reason == "success" else "FAILED"
|
||||
if ok:
|
||||
ledger.set_finalized(proxy_call_id, task_state)
|
||||
else:
|
||||
ledger.set_finalize_failed(proxy_call_id, task_state)
|
||||
elif not frozen_id:
|
||||
logger.debug("[ThirdPartyProxy] finalize_zero skipped: no frozen_id proxy_call_id=%s", proxy_call_id)
|
||||
else:
|
||||
logger.info("[ThirdPartyProxy] finalize_zero claim denied: proxy_call_id=%s", proxy_call_id)
|
||||
|
||||
|
||||
def _try_parse_json(data: bytes) -> dict[str, Any] | None:
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _proxy_response(
|
||||
data: dict[str, Any],
|
||||
proxy_call_id: str | None,
|
||||
status_code: int = 200,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> JSONResponse:
|
||||
headers: dict[str, str] = dict(extra_headers or {})
|
||||
if proxy_call_id:
|
||||
headers["X-Proxy-Call-Id"] = proxy_call_id
|
||||
return JSONResponse(content=data, status_code=status_code, headers=headers)
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Third-party proxy package."""
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
"""Thin async billing client for the third-party proxy.
|
||||
|
||||
Calls the same reserve/finalize HTTP endpoints as BillingMiddleware,
|
||||
but with semantics appropriate for third-party task calls:
|
||||
- estimatedTokens = 0 (not applicable)
|
||||
- finalAmount = actual provider monetary charge (thirdPartyConsumeMoney)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import httpx
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SUCCESS_STATUS_CODES = {200, 1000}
|
||||
|
||||
|
||||
async def reserve(
|
||||
*,
|
||||
thread_id: str | None,
|
||||
call_id: str,
|
||||
provider: str,
|
||||
operation: str,
|
||||
frozen_amount: float,
|
||||
frozen_type: int | None,
|
||||
) -> str | None:
|
||||
"""Reserve billing before forwarding a submit call.
|
||||
|
||||
Returns the frozen_id string on success, or None if billing is disabled
|
||||
or the reserve call fails (non-blocking — proxy continues in that case).
|
||||
"""
|
||||
cfg = get_app_config().billing
|
||||
if not cfg.enabled or not cfg.reserve_url:
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Billing] reserve skipped: enabled=%s reserve_url=%s call_id=%s",
|
||||
cfg.enabled,
|
||||
cfg.reserve_url,
|
||||
call_id,
|
||||
)
|
||||
return None
|
||||
|
||||
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
|
||||
payload = {
|
||||
"sessionId": thread_id,
|
||||
"callId": call_id,
|
||||
"modelName": provider,
|
||||
"question": f"skill invokes {operation.split('/')[-1]}",
|
||||
"frozenAmount": frozen_amount,
|
||||
"frozenType": frozen_type if frozen_type is not None else cfg.frozen_type,
|
||||
"estimatedInputTokens": 0,
|
||||
"estimatedOutputTokens": 0,
|
||||
"expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Billing] reserve request: url=%s call_id=%s provider=%s thread_id=%s",
|
||||
cfg.reserve_url,
|
||||
call_id,
|
||||
provider,
|
||||
thread_id,
|
||||
)
|
||||
logger.debug("[ThirdPartyProxy][Billing] reserve payload: %s", payload)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=cfg.timeout_seconds) as client:
|
||||
resp = await client.post(cfg.reserve_url, headers=cfg.headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
data: dict = resp.json()
|
||||
except Exception as exc:
|
||||
logger.warning("[ThirdPartyProxy][Billing] reserve HTTP error: %s", exc)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Billing] reserve response: call_id=%s status_code=%s",
|
||||
call_id,
|
||||
resp.status_code,
|
||||
)
|
||||
logger.debug("[ThirdPartyProxy][Billing] reserve response body: %s", data)
|
||||
|
||||
if not _is_success(data):
|
||||
logger.warning(
|
||||
"[ThirdPartyProxy][Billing] reserve rejected: call_id=%s status=%s payload=%s",
|
||||
call_id,
|
||||
data.get("status") or data.get("code"),
|
||||
data,
|
||||
)
|
||||
return None
|
||||
|
||||
frozen_id = (data.get("data") or {}).get("frozenId")
|
||||
if not isinstance(frozen_id, str) or not frozen_id:
|
||||
logger.warning(
|
||||
"[ThirdPartyProxy][Billing] reserve response missing frozenId: call_id=%s payload=%s",
|
||||
call_id,
|
||||
data,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info("[ThirdPartyProxy][Billing] reserve ok: call_id=%s frozen_id=%s", call_id, frozen_id)
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Billing] reserve success details: provider=%s operation=%s expire_at=%s",
|
||||
provider,
|
||||
operation,
|
||||
payload["expireAt"],
|
||||
)
|
||||
return frozen_id
|
||||
|
||||
|
||||
async def finalize(
|
||||
*,
|
||||
frozen_id: str,
|
||||
final_amount: float,
|
||||
finalize_reason: str,
|
||||
) -> bool:
|
||||
"""Finalize billing after a third-party call reaches a terminal state.
|
||||
|
||||
final_amount is the actual provider charge (e.g. thirdPartyConsumeMoney from RunningHub).
|
||||
Pass 0 for failed/cancelled calls.
|
||||
Returns True on success.
|
||||
"""
|
||||
cfg = get_app_config().billing
|
||||
if not cfg.enabled or not cfg.finalize_url:
|
||||
# Billing not configured — treat as success so the caller marks the record finalized
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Billing] finalize skipped: enabled=%s finalize_url=%s frozen_id=%s",
|
||||
cfg.enabled,
|
||||
cfg.finalize_url,
|
||||
frozen_id,
|
||||
)
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"frozenId": frozen_id,
|
||||
"finalAmount": final_amount,
|
||||
"usageInputTokens": 0,
|
||||
"usageOutputTokens": 0,
|
||||
"usageTotalTokens": 0,
|
||||
"finalizeReason": finalize_reason,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Billing] finalize request: frozen_id=%s amount=%s reason=%s url=%s",
|
||||
frozen_id,
|
||||
final_amount,
|
||||
finalize_reason,
|
||||
cfg.finalize_url,
|
||||
)
|
||||
logger.debug("[ThirdPartyProxy][Billing] finalize payload: %s", payload)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=cfg.timeout_seconds) as client:
|
||||
resp = await client.post(cfg.finalize_url, headers=cfg.headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
data: dict = resp.json()
|
||||
except Exception as exc:
|
||||
logger.warning("[ThirdPartyProxy][Billing] finalize HTTP error: frozen_id=%s err=%s", frozen_id, exc)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Billing] finalize response: frozen_id=%s status_code=%s",
|
||||
frozen_id,
|
||||
resp.status_code,
|
||||
)
|
||||
logger.debug("[ThirdPartyProxy][Billing] finalize response body: %s", data)
|
||||
|
||||
if not _is_success(data):
|
||||
logger.warning(
|
||||
"[ThirdPartyProxy][Billing] finalize rejected: frozen_id=%s status=%s payload=%s",
|
||||
frozen_id,
|
||||
data.get("status") or data.get("code"),
|
||||
data,
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("[ThirdPartyProxy][Billing] finalize ok: frozen_id=%s", frozen_id)
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Billing] finalize success details: amount=%s reason=%s",
|
||||
final_amount,
|
||||
finalize_reason,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _is_success(data: dict) -> bool:
|
||||
status = data.get("status") or data.get("code")
|
||||
if isinstance(status, int) and status in _SUCCESS_STATUS_CODES:
|
||||
return True
|
||||
return data.get("success") is True
|
||||
|
|
@ -0,0 +1,289 @@
|
|||
"""In-memory call state ledger for the third-party proxy.
|
||||
|
||||
Tracks each proxied call from reserve → submit → query → finalize,
|
||||
enforcing idempotency and ensuring billing finalize runs exactly once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BillingState = Literal["UNRESERVED", "RESERVED", "FINALIZED", "FINALIZE_FAILED"]
|
||||
TaskState = Literal["PENDING", "RUNNING", "SUCCESS", "FAILED", "UNKNOWN"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallRecord:
|
||||
proxy_call_id: str
|
||||
provider: str
|
||||
thread_id: str | None
|
||||
# call_id is sent to the billing platform (callId in reserve payload)
|
||||
call_id: str
|
||||
frozen_id: str | None = None
|
||||
provider_task_id: str | None = None
|
||||
billing_state: BillingState = "UNRESERVED"
|
||||
task_state: TaskState = "PENDING"
|
||||
created_at: float = field(default_factory=time.time)
|
||||
finalized_at: float | None = None
|
||||
error: str | None = None
|
||||
idempotency_key: str | None = None
|
||||
# Cached last provider response — returned for repeat queries after finalization
|
||||
last_response: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CallLedger:
|
||||
"""Thread-safe in-memory ledger for third-party proxy call records."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._records: dict[str, CallRecord] = {} # proxy_call_id → record
|
||||
self._task_index: dict[str, str] = {} # "{provider}:{provider_task_id}" → proxy_call_id
|
||||
self._idem_index: dict[str, str] = {} # "{provider}:{idem_key}" → proxy_call_id
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def create(
|
||||
self,
|
||||
provider: str,
|
||||
thread_id: str | None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> CallRecord:
|
||||
"""Create a new call record, or return the existing one if idempotency key matches."""
|
||||
with self._lock:
|
||||
if idempotency_key:
|
||||
existing = self._get_by_idem_key_locked(provider, idempotency_key)
|
||||
if existing is not None:
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Ledger] idempotent hit: provider=%s proxy_call_id=%s idem_key=%s",
|
||||
provider,
|
||||
existing.proxy_call_id,
|
||||
idempotency_key,
|
||||
)
|
||||
# logger.debug(
|
||||
# "[ThirdPartyProxy][Ledger] existing record reused: call_id=%s task_id=%s billing_state=%s task_state=%s",
|
||||
# existing.call_id,
|
||||
# existing.provider_task_id,
|
||||
# existing.billing_state,
|
||||
# existing.task_state,
|
||||
# )
|
||||
return existing
|
||||
|
||||
record = CallRecord(
|
||||
proxy_call_id=str(uuid4()),
|
||||
provider=provider,
|
||||
thread_id=thread_id,
|
||||
call_id=str(uuid4()),
|
||||
idempotency_key=idempotency_key,
|
||||
)
|
||||
self._records[record.proxy_call_id] = record
|
||||
if idempotency_key:
|
||||
self._idem_index[f"{provider}:{idempotency_key}"] = record.proxy_call_id
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Ledger] created record: provider=%s proxy_call_id=%s call_id=%s thread_id=%s",
|
||||
provider,
|
||||
record.proxy_call_id,
|
||||
record.call_id,
|
||||
thread_id,
|
||||
)
|
||||
# logger.debug(
|
||||
# "[ThirdPartyProxy][Ledger] create details: idem_key=%s billing_state=%s task_state=%s",
|
||||
# idempotency_key,
|
||||
# record.billing_state,
|
||||
# record.task_state,
|
||||
# )
|
||||
return record
|
||||
|
||||
def get(self, proxy_call_id: str) -> CallRecord | None:
|
||||
return self._records.get(proxy_call_id)
|
||||
|
||||
def get_by_task_id(self, provider: str, provider_task_id: str) -> CallRecord | None:
|
||||
key = f"{provider}:{provider_task_id}"
|
||||
proxy_call_id = self._task_index.get(key)
|
||||
return self._records.get(proxy_call_id) if proxy_call_id else None
|
||||
|
||||
def get_by_idempotency_key(self, provider: str, idempotency_key: str) -> CallRecord | None:
|
||||
return self._get_by_idem_key_locked(provider, idempotency_key)
|
||||
|
||||
def set_reserved(self, proxy_call_id: str, frozen_id: str) -> None:
|
||||
with self._lock:
|
||||
record = self._records.get(proxy_call_id)
|
||||
if record:
|
||||
record.frozen_id = frozen_id
|
||||
record.billing_state = "RESERVED"
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s",
|
||||
proxy_call_id,
|
||||
frozen_id,
|
||||
)
|
||||
# logger.debug(
|
||||
# "[ThirdPartyProxy][Ledger] reserve state: call_id=%s provider=%s task_state=%s",
|
||||
# record.call_id,
|
||||
# record.provider,
|
||||
# record.task_state,
|
||||
# )
|
||||
else:
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] set_reserved ignored for missing record: proxy_call_id=%s",
|
||||
proxy_call_id,
|
||||
)
|
||||
|
||||
def set_running(self, proxy_call_id: str, provider_task_id: str) -> None:
|
||||
with self._lock:
|
||||
record = self._records.get(proxy_call_id)
|
||||
if record:
|
||||
record.provider_task_id = provider_task_id
|
||||
record.task_state = "RUNNING"
|
||||
self._task_index[f"{record.provider}:{provider_task_id}"] = proxy_call_id
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Ledger] running: proxy_call_id=%s provider_task_id=%s",
|
||||
proxy_call_id,
|
||||
provider_task_id,
|
||||
)
|
||||
# logger.debug(
|
||||
# "[ThirdPartyProxy][Ledger] running state: provider=%s call_id=%s billing_state=%s",
|
||||
# record.provider,
|
||||
# record.call_id,
|
||||
# record.billing_state,
|
||||
# )
|
||||
else:
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] set_running ignored for missing record: proxy_call_id=%s provider_task_id=%s",
|
||||
proxy_call_id,
|
||||
provider_task_id,
|
||||
)
|
||||
|
||||
def try_claim_finalize(self, proxy_call_id: str) -> bool:
|
||||
"""Atomically claim finalization rights. Returns True only once per record."""
|
||||
with self._lock:
|
||||
record = self._records.get(proxy_call_id)
|
||||
if record is None:
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] finalize claim denied: missing record proxy_call_id=%s",
|
||||
proxy_call_id,
|
||||
)
|
||||
return False
|
||||
if record.billing_state in ("FINALIZED", "FINALIZE_FAILED"):
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] finalize claim denied: proxy_call_id=%s billing_state=%s",
|
||||
proxy_call_id,
|
||||
record.billing_state,
|
||||
)
|
||||
return False
|
||||
# Mark as finalized immediately to prevent concurrent finalize
|
||||
record.billing_state = "FINALIZED"
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Ledger] finalize claimed: proxy_call_id=%s",
|
||||
proxy_call_id,
|
||||
)
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] finalize claim state: call_id=%s provider=%s task_state=%s frozen_id=%s",
|
||||
record.call_id,
|
||||
record.provider,
|
||||
record.task_state,
|
||||
record.frozen_id,
|
||||
)
|
||||
return True
|
||||
|
||||
def set_finalized(self, proxy_call_id: str, task_state: TaskState) -> None:
|
||||
with self._lock:
|
||||
record = self._records.get(proxy_call_id)
|
||||
if record:
|
||||
record.task_state = task_state
|
||||
record.billing_state = "FINALIZED"
|
||||
record.finalized_at = time.time()
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Ledger] finalized: proxy_call_id=%s task_state=%s",
|
||||
proxy_call_id,
|
||||
task_state,
|
||||
)
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] finalized state: provider=%s call_id=%s frozen_id=%s finalized_at=%s",
|
||||
record.provider,
|
||||
record.call_id,
|
||||
record.frozen_id,
|
||||
record.finalized_at,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] set_finalized ignored for missing record: proxy_call_id=%s task_state=%s",
|
||||
proxy_call_id,
|
||||
task_state,
|
||||
)
|
||||
|
||||
def set_finalize_failed(self, proxy_call_id: str, task_state: TaskState) -> None:
|
||||
with self._lock:
|
||||
record = self._records.get(proxy_call_id)
|
||||
if record:
|
||||
record.task_state = task_state
|
||||
record.billing_state = "FINALIZE_FAILED"
|
||||
record.finalized_at = time.time()
|
||||
logger.info(
|
||||
"[ThirdPartyProxy][Ledger] finalize failed: proxy_call_id=%s task_state=%s",
|
||||
proxy_call_id,
|
||||
task_state,
|
||||
)
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] finalize failure state: provider=%s call_id=%s frozen_id=%s finalized_at=%s",
|
||||
record.provider,
|
||||
record.call_id,
|
||||
record.frozen_id,
|
||||
record.finalized_at,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] set_finalize_failed ignored for missing record: proxy_call_id=%s task_state=%s",
|
||||
proxy_call_id,
|
||||
task_state,
|
||||
)
|
||||
|
||||
def update_response(self, proxy_call_id: str, response: dict[str, Any]) -> None:
|
||||
with self._lock:
|
||||
record = self._records.get(proxy_call_id)
|
||||
if record:
|
||||
record.last_response = response
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] cached response: proxy_call_id=%s keys=%s",
|
||||
proxy_call_id,
|
||||
sorted(response.keys()),
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy][Ledger] update_response ignored for missing record: proxy_call_id=%s",
|
||||
proxy_call_id,
|
||||
)
|
||||
|
||||
def is_finalized(self, proxy_call_id: str) -> bool:
|
||||
record = self._records.get(proxy_call_id)
|
||||
return record is not None and record.billing_state in ("FINALIZED", "FINALIZE_FAILED")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_by_idem_key_locked(self, provider: str, idempotency_key: str) -> CallRecord | None:
|
||||
key = f"{provider}:{idempotency_key}"
|
||||
proxy_call_id = self._idem_index.get(key)
|
||||
return self._records.get(proxy_call_id) if proxy_call_id else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ledger: CallLedger | None = None
|
||||
_ledger_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_ledger() -> CallLedger:
|
||||
global _ledger
|
||||
if _ledger is None:
|
||||
with _ledger_lock:
|
||||
if _ledger is None:
|
||||
_ledger = CallLedger()
|
||||
logger.info("[ThirdPartyProxy][Ledger] singleton initialized")
|
||||
return _ledger
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
"""HTTP forwarding, route classification, and JSONPath extraction for the third-party proxy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.third_party_proxy_config import (
|
||||
QueryRouteConfig,
|
||||
SubmitRouteConfig,
|
||||
ThirdPartyProviderConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SENSITIVE_HEADERS = frozenset(
|
||||
[
|
||||
"authorization",
|
||||
"proxy-authorization",
|
||||
"x-api-key",
|
||||
"api-key",
|
||||
"cookie",
|
||||
"set-cookie",
|
||||
]
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider config lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_provider_config(provider: str) -> ThirdPartyProviderConfig | None:
|
||||
"""Return the provider config for *provider*, or None if not configured/disabled."""
|
||||
cfg = get_app_config().third_party_proxy
|
||||
if not cfg.enabled:
|
||||
return None
|
||||
return cfg.providers.get(provider)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def match_submit_route(
|
||||
config: ThirdPartyProviderConfig,
|
||||
method: str,
|
||||
path: str,
|
||||
) -> SubmitRouteConfig | None:
|
||||
"""Return the first submit route that matches (method, path), or None."""
|
||||
for route in config.submit_routes:
|
||||
if route.method.upper() != method.upper():
|
||||
continue
|
||||
if not _path_matches(path, route.path_pattern):
|
||||
continue
|
||||
if route.exclude_path_pattern and _path_matches(path, route.exclude_path_pattern):
|
||||
continue
|
||||
return route
|
||||
return None
|
||||
|
||||
|
||||
def match_query_route(
|
||||
config: ThirdPartyProviderConfig,
|
||||
method: str,
|
||||
path: str,
|
||||
) -> QueryRouteConfig | None:
|
||||
"""Return the first query route that matches (method, path), or None."""
|
||||
for route in config.query_routes:
|
||||
if route.method.upper() != method.upper():
|
||||
continue
|
||||
if _path_matches(path, route.path_pattern):
|
||||
return route
|
||||
return None
|
||||
|
||||
|
||||
def _path_matches(path: str, pattern: str) -> bool:
|
||||
"""Match *path* against a glob-ish *pattern*.
|
||||
|
||||
Rules:
|
||||
- Pattern ending in /** matches the prefix and any sub-path.
|
||||
- Otherwise exact match.
|
||||
"""
|
||||
# Normalise trailing slashes
|
||||
path = path.rstrip("/") or "/"
|
||||
pattern = pattern.rstrip("/") or "/"
|
||||
|
||||
if pattern.endswith("/**"):
|
||||
prefix = pattern[:-3]
|
||||
return path == prefix or path.startswith(prefix + "/")
|
||||
|
||||
return path == pattern
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal path evaluator (dot-notation shorthand only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def jsonpath_get(data: Any, path: str) -> Any:
|
||||
"""Extract a value from *data* using a simple dot-notation shorthand path.
|
||||
|
||||
Supports paths like: taskId usage.thirdPartyConsumeMoney
|
||||
Paths with a leading '$' are intentionally not supported.
|
||||
Returns None if any segment is missing or the input is not a dict.
|
||||
"""
|
||||
if not isinstance(path, str):
|
||||
return None
|
||||
|
||||
remainder = path.strip()
|
||||
if not remainder or remainder.startswith("$"):
|
||||
return None
|
||||
|
||||
current: Any = data
|
||||
for part in remainder.split("."):
|
||||
if not part:
|
||||
return None
|
||||
if not isinstance(current, dict):
|
||||
return None
|
||||
current = current.get(part)
|
||||
if current is None:
|
||||
return None
|
||||
return current
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP forwarding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Request headers we never forward (hop-by-hop, sensitive, or proxy-internal)
|
||||
_STRIP_REQUEST_HEADERS = frozenset(
|
||||
[
|
||||
"host",
|
||||
"content-length",
|
||||
"transfer-encoding",
|
||||
"connection",
|
||||
"x-thread-id",
|
||||
"x-idempotency-key",
|
||||
]
|
||||
)
|
||||
|
||||
# Response headers we strip before returning to the caller
|
||||
_STRIP_RESPONSE_HEADERS = frozenset(
|
||||
[
|
||||
"transfer-encoding",
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"content-encoding",
|
||||
"content-length",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_headers(headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Return a copy of headers with sensitive values redacted."""
|
||||
sanitized: dict[str, str] = {}
|
||||
for key, value in headers.items():
|
||||
if key.lower() in _SENSITIVE_HEADERS:
|
||||
sanitized[key] = "***"
|
||||
else:
|
||||
sanitized[key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def _preview_body(data: bytes, limit: int = 2048) -> str:
|
||||
"""Return a safe textual preview of body bytes for debugging logs."""
|
||||
if not data:
|
||||
return ""
|
||||
chunk = data[:limit]
|
||||
text = chunk.decode("utf-8", errors="replace")
|
||||
if len(data) > limit:
|
||||
text += f" ...<truncated {len(data) - limit} bytes>"
|
||||
return text
|
||||
|
||||
|
||||
async def forward_request(
|
||||
*,
|
||||
provider_config: ThirdPartyProviderConfig,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: dict[str, str],
|
||||
body: bytes,
|
||||
query_params: str,
|
||||
) -> tuple[int, dict[str, str], bytes]:
|
||||
"""Forward *method* *path* to the provider and return (status_code, headers, body).
|
||||
|
||||
The provider's API key (read from the environment variable named in
|
||||
``provider_config.api_key_env``) is injected automatically, replacing
|
||||
any Authorization header the caller might have sent.
|
||||
"""
|
||||
target_url = provider_config.base_url.rstrip("/") + "/" + path.lstrip("/")
|
||||
if query_params:
|
||||
target_url += "?" + query_params
|
||||
|
||||
# Build forwarded headers: drop internal/hop-by-hop, then inject API key
|
||||
forward_headers = {
|
||||
k: v for k, v in headers.items() if k.lower() not in _STRIP_REQUEST_HEADERS
|
||||
}
|
||||
if provider_config.api_key_env:
|
||||
api_key = os.getenv(provider_config.api_key_env)
|
||||
if api_key:
|
||||
forward_headers[provider_config.api_key_header] = provider_config.api_key_prefix + api_key
|
||||
else:
|
||||
logger.warning(
|
||||
"[ThirdPartyProxy] api_key_env '%s' is not set for provider",
|
||||
provider_config.api_key_env,
|
||||
)
|
||||
|
||||
logger.info("[ThirdPartyProxy] → %s %s", method, target_url)
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy] request headers=%s",
|
||||
_sanitize_headers(forward_headers)
|
||||
)
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy] request body(%dB)=%s",
|
||||
len(body),
|
||||
_preview_body(body),
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=provider_config.timeout_seconds) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=target_url,
|
||||
headers=forward_headers,
|
||||
content=body,
|
||||
)
|
||||
|
||||
response_headers = {
|
||||
k: v
|
||||
for k, v in response.headers.items()
|
||||
if k.lower() not in _STRIP_RESPONSE_HEADERS
|
||||
}
|
||||
logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code)
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy] response headers=%s",
|
||||
_sanitize_headers(response_headers)
|
||||
)
|
||||
logger.debug(
|
||||
"[ThirdPartyProxy] response body(%dB)=%s",
|
||||
len(response.content),
|
||||
_preview_body(response.content),
|
||||
)
|
||||
return response.status_code, response_headers, response.content
|
||||
|
|
@ -261,6 +261,12 @@ class LocalContainerBackend(SandboxBackend):
|
|||
]
|
||||
)
|
||||
|
||||
# On Linux, containers started via DooD (Docker-out-of-Docker) do not
|
||||
# automatically resolve host.docker.internal. Add the mapping explicitly
|
||||
# so sandbox containers can call back into the host-exposed gateway.
|
||||
if self._runtime == "docker":
|
||||
cmd.extend(["--add-host", "host.docker.internal:host-gateway"])
|
||||
|
||||
# Environment variables (static config first, runtime overrides last)
|
||||
for key, value in self._environment.items():
|
||||
cmd.extend(["-e", f"{key}={value}"])
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from deerflow.config.skills_config import SkillsConfig
|
|||
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
|
||||
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
|
||||
from deerflow.config.third_party_proxy_config import ThirdPartyProxyConfig
|
||||
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
|
||||
from deerflow.config.token_usage_config import TokenUsageConfig
|
||||
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
|
||||
|
|
@ -42,6 +43,7 @@ class AppConfig(BaseModel):
|
|||
|
||||
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
|
||||
billing: BillingConfig = Field(default_factory=BillingConfig, description="External billing reservation/finalization configuration")
|
||||
third_party_proxy: ThirdPartyProxyConfig = Field(default_factory=ThirdPartyProxyConfig, description="Third-party API proxy with billing integration")
|
||||
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
|
||||
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
|
||||
sandbox: SandboxConfig = Field(description="Sandbox configuration")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,108 @@
|
|||
"""Configuration for the third-party API proxy with billing integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SubmitRouteConfig(BaseModel):
|
||||
"""Identifies a submit request — triggers billing reserve + task state tracking."""
|
||||
|
||||
method: str = Field(default="POST", description="HTTP method to match (case-insensitive)")
|
||||
path_pattern: str = Field(
|
||||
description="Glob-style path pattern. Use ** to match any sub-path, e.g. /openapi/v2/**"
|
||||
)
|
||||
exclude_path_pattern: str | None = Field(
|
||||
default=None,
|
||||
description="If set, paths matching this pattern are excluded from submit handling",
|
||||
)
|
||||
task_id_jsonpath: str = Field(
|
||||
description="Dot-path into the *response* body to extract the provider task ID, e.g. taskId"
|
||||
)
|
||||
frozen_amount: float | None = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Optional route-level override for billing reserve payload frozenAmount",
|
||||
)
|
||||
frozen_type: int | None = Field(
|
||||
default=None,
|
||||
description="Optional route-level override for billing reserve payload frozenType",
|
||||
)
|
||||
|
||||
|
||||
class QueryRouteConfig(BaseModel):
|
||||
"""Identifies a query/poll request — checks for terminal status + triggers billing finalize."""
|
||||
|
||||
method: str = Field(default="POST", description="HTTP method to match (case-insensitive)")
|
||||
path_pattern: str = Field(description="Glob-style path pattern for the query endpoint")
|
||||
request_task_id_jsonpath: str = Field(
|
||||
description="Dot-path into the *request* body to extract the task ID being queried"
|
||||
)
|
||||
status_jsonpath: str = Field(
|
||||
description="Dot-path into the response body to read the task status value"
|
||||
)
|
||||
success_values: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Status string values that indicate successful terminal state, e.g. [\"SUCCESS\"]",
|
||||
)
|
||||
failure_values: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Status string values that indicate failed terminal state, e.g. [\"FAILED\", \"CANCELLED\"]",
|
||||
)
|
||||
usage_jsonpath: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Dot-path into the response body for the actual monetary cost to pass to billing finalize. "
|
||||
"E.g. usage.thirdPartyConsumeMoney"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ThirdPartyProviderConfig(BaseModel):
|
||||
"""Configuration for a single third-party API platform."""
|
||||
|
||||
base_url: str = Field(description="Base URL of the provider, e.g. https://www.runninghub.cn")
|
||||
api_key_env: str | None = Field(
|
||||
default=None,
|
||||
description="Name of the environment variable holding the API key",
|
||||
)
|
||||
api_key_header: str = Field(
|
||||
default="Authorization",
|
||||
description="Request header name for the API key",
|
||||
)
|
||||
api_key_prefix: str = Field(
|
||||
default="Bearer ",
|
||||
description="String prepended to the API key value in the header",
|
||||
)
|
||||
timeout_seconds: float = Field(
|
||||
default=30.0,
|
||||
gt=0,
|
||||
description="HTTP request timeout when forwarding to the provider",
|
||||
)
|
||||
frozen_amount: float = Field(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
description="Amount to reserve in billing reserve payload (frozenAmount)",
|
||||
)
|
||||
frozen_type: int | None = Field(
|
||||
default=None,
|
||||
description="Billing frozen type for this provider (frozenType). If omitted, falls back to billing.frozen_type",
|
||||
)
|
||||
submit_routes: list[SubmitRouteConfig] = Field(
|
||||
default_factory=list,
|
||||
description="Route patterns that identify submit (task-create) requests",
|
||||
)
|
||||
query_routes: list[QueryRouteConfig] = Field(
|
||||
default_factory=list,
|
||||
description="Route patterns that identify query/poll requests",
|
||||
)
|
||||
|
||||
|
||||
class ThirdPartyProxyConfig(BaseModel):
|
||||
"""Top-level configuration for the third-party API proxy."""
|
||||
|
||||
enabled: bool = Field(default=False, description="Enable the proxy endpoint")
|
||||
providers: dict[str, ThirdPartyProviderConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Keyed by provider name (used in the URL path /api/proxy/{provider}/...)",
|
||||
)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.community.aio_sandbox.local_backend import LocalContainerBackend, _format_container_mount
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,192 @@
|
|||
"""Unit tests for the third-party proxy module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.gateway.third_party_proxy.ledger import CallLedger
|
||||
from app.gateway.third_party_proxy.proxy import (
|
||||
_path_matches,
|
||||
jsonpath_get,
|
||||
match_query_route,
|
||||
match_submit_route,
|
||||
)
|
||||
from deerflow.config.third_party_proxy_config import (
|
||||
QueryRouteConfig,
|
||||
SubmitRouteConfig,
|
||||
ThirdPartyProviderConfig,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _path_matches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPathMatches:
|
||||
def test_exact_match(self):
|
||||
assert _path_matches("/openapi/v2/query", "/openapi/v2/query")
|
||||
|
||||
def test_exact_no_match(self):
|
||||
assert not _path_matches("/openapi/v2/query", "/openapi/v2/submit")
|
||||
|
||||
def test_glob_matches_prefix(self):
|
||||
assert _path_matches("/openapi/v2/vidu/submit", "/openapi/v2/**")
|
||||
|
||||
def test_glob_matches_prefix_itself(self):
|
||||
assert _path_matches("/openapi/v2", "/openapi/v2/**")
|
||||
|
||||
def test_glob_no_match_different_prefix(self):
|
||||
assert not _path_matches("/other/v2/submit", "/openapi/v2/**")
|
||||
|
||||
def test_trailing_slashes_normalised(self):
|
||||
assert _path_matches("/openapi/v2/query/", "/openapi/v2/query")
|
||||
|
||||
def test_glob_excludes_sibling_prefix(self):
|
||||
# /openapi/v2/** should not match /openapi/v2extra/foo
|
||||
assert not _path_matches("/openapi/v2extra/foo", "/openapi/v2/**")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# jsonpath_get
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJsonpathGet:
|
||||
def test_single_key(self):
|
||||
assert jsonpath_get({"taskId": "abc"}, "taskId") == "abc"
|
||||
|
||||
def test_nested_key(self):
|
||||
data = {"usage": {"thirdPartyConsumeMoney": 1.23}}
|
||||
assert jsonpath_get(data, "usage.thirdPartyConsumeMoney") == 1.23
|
||||
|
||||
def test_missing_key_returns_none(self):
|
||||
assert jsonpath_get({"foo": "bar"}, "taskId") is None
|
||||
|
||||
def test_rejects_dollar_prefixed_path(self):
|
||||
assert jsonpath_get({"taskId": "abc"}, "$.taskId") is None
|
||||
|
||||
def test_short_path_supported(self):
|
||||
assert jsonpath_get({"x": 1}, "x") == 1
|
||||
|
||||
def test_non_dict_intermediate(self):
|
||||
data = {"usage": "not-a-dict"}
|
||||
assert jsonpath_get(data, "usage.something") is None
|
||||
|
||||
def test_none_input(self):
|
||||
assert jsonpath_get(None, "x") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# match_submit_route / match_query_route
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PROVIDER_CFG = ThirdPartyProviderConfig(
|
||||
base_url="https://example.com",
|
||||
api_key_env="TEST_API_KEY",
|
||||
submit_routes=[
|
||||
SubmitRouteConfig(
|
||||
method="POST",
|
||||
path_pattern="/openapi/v2/**",
|
||||
exclude_path_pattern="/openapi/v2/query",
|
||||
task_id_jsonpath="taskId",
|
||||
)
|
||||
],
|
||||
query_routes=[
|
||||
QueryRouteConfig(
|
||||
method="POST",
|
||||
path_pattern="/openapi/v2/query",
|
||||
request_task_id_jsonpath="taskId",
|
||||
status_jsonpath="status",
|
||||
success_values=["SUCCESS"],
|
||||
failure_values=["FAILED", "CANCELLED"],
|
||||
usage_jsonpath="usage.thirdPartyConsumeMoney",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestMatchRoutes:
|
||||
def test_submit_matches_non_query_path(self):
|
||||
result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/vidu/submit")
|
||||
assert result is not None
|
||||
assert result.task_id_jsonpath == "taskId"
|
||||
|
||||
def test_submit_excluded_by_exclude_pattern(self):
|
||||
result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/query")
|
||||
assert result is None
|
||||
|
||||
def test_submit_wrong_method(self):
|
||||
result = match_submit_route(_PROVIDER_CFG, "GET", "/openapi/v2/vidu/submit")
|
||||
assert result is None
|
||||
|
||||
def test_query_matches(self):
|
||||
result = match_query_route(_PROVIDER_CFG, "POST", "/openapi/v2/query")
|
||||
assert result is not None
|
||||
assert result.status_jsonpath == "status"
|
||||
|
||||
def test_query_wrong_method(self):
|
||||
result = match_query_route(_PROVIDER_CFG, "GET", "/openapi/v2/query")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CallLedger
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCallLedger:
|
||||
def _make_ledger(self) -> CallLedger:
|
||||
return CallLedger()
|
||||
|
||||
def test_create_and_get(self):
|
||||
ledger = self._make_ledger()
|
||||
rec = ledger.create("prov", "tid", None)
|
||||
assert rec.provider == "prov"
|
||||
found = ledger.get(rec.proxy_call_id)
|
||||
assert found is not None
|
||||
assert found.proxy_call_id == rec.proxy_call_id
|
||||
|
||||
def test_set_reserved(self):
|
||||
ledger = self._make_ledger()
|
||||
rec = ledger.create("prov", "tid", None)
|
||||
ledger.set_reserved(rec.proxy_call_id, "frozen-123")
|
||||
found = ledger.get(rec.proxy_call_id)
|
||||
assert found.frozen_id == "frozen-123"
|
||||
assert found.billing_state == "RESERVED"
|
||||
|
||||
def test_set_running(self):
|
||||
ledger = self._make_ledger()
|
||||
rec = ledger.create("prov", "tid", None)
|
||||
ledger.set_running(rec.proxy_call_id, "task-abc")
|
||||
found = ledger.get_by_task_id("prov", "task-abc")
|
||||
assert found is not None
|
||||
assert found.proxy_call_id == rec.proxy_call_id
|
||||
|
||||
def test_try_claim_finalize_once(self):
|
||||
ledger = self._make_ledger()
|
||||
rec = ledger.create("prov", "tid", None)
|
||||
# First claim should succeed
|
||||
assert ledger.try_claim_finalize(rec.proxy_call_id) is True
|
||||
# Second claim should fail — already in progress/done
|
||||
assert ledger.try_claim_finalize(rec.proxy_call_id) is False
|
||||
|
||||
def test_is_finalized(self):
|
||||
ledger = self._make_ledger()
|
||||
rec = ledger.create("prov", "tid", None)
|
||||
assert ledger.is_finalized(rec.proxy_call_id) is False
|
||||
ledger.try_claim_finalize(rec.proxy_call_id)
|
||||
ledger.set_finalized(rec.proxy_call_id, "SUCCESS")
|
||||
assert ledger.is_finalized(rec.proxy_call_id) is True
|
||||
|
||||
def test_idempotency_key_dedup(self):
|
||||
ledger = self._make_ledger()
|
||||
rec1 = ledger.create("prov", "tid", "idem-key-1")
|
||||
rec2 = ledger.get_by_idempotency_key("prov", "idem-key-1")
|
||||
assert rec2 is not None
|
||||
assert rec2.proxy_call_id == rec1.proxy_call_id
|
||||
|
||||
def test_update_response(self):
|
||||
ledger = self._make_ledger()
|
||||
rec = ledger.create("prov", "tid", None)
|
||||
ledger.update_response(rec.proxy_call_id, {"result": "ok"})
|
||||
found = ledger.get(rec.proxy_call_id)
|
||||
assert found.last_response == {"result": "ok"}
|
||||
|
|
@ -49,6 +49,51 @@ billing:
|
|||
# Authorization: "Bearer your-secret-token"
|
||||
# X-App-Id: "deer-flow"
|
||||
|
||||
# ============================================================================
|
||||
# Third-Party Transparent Proxy
|
||||
# ============================================================================
|
||||
# Exposes /api/proxy/{provider}/... and handles reserve/finalize around
|
||||
# third-party async task APIs such as RunningHub.
|
||||
|
||||
third_party_proxy:
|
||||
enabled: false
|
||||
providers:
|
||||
runninghub:
|
||||
base_url: https://www.runninghub.cn
|
||||
api_key_env: RUNNINGHUB_API_KEY
|
||||
api_key_header: Authorization
|
||||
api_key_prefix: "Bearer "
|
||||
timeout_seconds: 30.0
|
||||
frozen_type: 2
|
||||
submit_routes:
|
||||
- path_pattern: "/openapi/v2/**"
|
||||
exclude_path_pattern: "/openapi/v2/query"
|
||||
task_id_jsonpath: "taskId"
|
||||
# Optional per-model billing override examples:
|
||||
# frozen_amount: 10.0
|
||||
# frozen_type: 2
|
||||
|
||||
# Example: model-specific reserve policy
|
||||
# - path_pattern: "/openapi/v2/rhart-image/z-image/turbo-lora"
|
||||
# task_id_jsonpath: "taskId"
|
||||
# frozen_amount: 10.0
|
||||
# frozen_type: 2
|
||||
# - path_pattern: "/openapi/v2/vidu/text-to-video-q3-turbo"
|
||||
# task_id_jsonpath: "taskId"
|
||||
# frozen_amount: 50.0
|
||||
# frozen_type: 2
|
||||
# - path_pattern: "/openapi/v2/wan-2.7/image-edit"
|
||||
# task_id_jsonpath: "taskId"
|
||||
# frozen_amount: 20.0
|
||||
# frozen_type: 2
|
||||
query_routes:
|
||||
- path_pattern: "/openapi/v2/query"
|
||||
request_task_id_jsonpath: "taskId"
|
||||
status_jsonpath: "status"
|
||||
success_values: ["SUCCESS"]
|
||||
failure_values: ["FAILED", "CANCELLED"]
|
||||
usage_jsonpath: "usage.thirdPartyConsumeMoney"
|
||||
|
||||
# ============================================================================
|
||||
# Token Usage Tracking
|
||||
# ============================================================================
|
||||
|
|
|
|||
|
|
@ -121,6 +121,10 @@ services:
|
|||
UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple}
|
||||
container_name: deer-flow-gateway
|
||||
command: sh -c "cd backend && uv sync && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --reload --reload-include='*.yaml .env' > /app/logs/gateway.log 2>&1"
|
||||
ports:
|
||||
# Expose to host so DooD-started sandbox containers can reach the gateway
|
||||
# via host.docker.internal:8001
|
||||
- "8001:8001"
|
||||
volumes:
|
||||
- ../backend/:/app/backend/
|
||||
# Preserve the .venv built during Docker image build — mounting the full backend/
|
||||
|
|
|
|||
|
|
@ -69,6 +69,10 @@ services:
|
|||
UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple}
|
||||
container_name: deer-flow-gateway
|
||||
command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers 2"
|
||||
ports:
|
||||
# Expose gateway port for direct access (e.g. for API clients or testing tools like Postman).
|
||||
# via host.docker.internal:8001
|
||||
- "8001:8001"
|
||||
volumes:
|
||||
- ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro
|
||||
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro
|
||||
|
|
|
|||
Loading…
Reference in New Issue