From dabe529cc7c869865b7983ea3cefdb0c4e42c4b6 Mon Sep 17 00:00:00 2001 From: Titan Date: Thu, 23 Apr 2026 15:27:49 +0800 Subject: [PATCH] feat(proxy): add third-party proxy module with billing integration - Introduced a new third-party proxy package for handling async task APIs. - Implemented billing client with reserve and finalize functionalities. - Created an in-memory ledger to track call states and ensure idempotency. - Added route classification for submit and query requests. - Configured third-party provider settings and routes in the application config. - Updated local backend to support Docker networking for sandbox containers. --- backend/app/gateway/app.py | 34 +- backend/app/gateway/routers/__init__.py | 4 +- backend/app/gateway/routers/third_party.py | 403 ++++++++++++++++++ .../app/gateway/third_party_proxy/__init__.py | 1 + .../app/gateway/third_party_proxy/billing.py | 190 +++++++++ .../app/gateway/third_party_proxy/ledger.py | 289 +++++++++++++ .../app/gateway/third_party_proxy/proxy.py | 246 +++++++++++ .../community/aio_sandbox/local_backend.py | 6 + .../harness/deerflow/config/app_config.py | 2 + .../config/third_party_proxy_config.py | 108 +++++ .../tests/test_aio_sandbox_local_backend.py | 2 +- backend/tests/test_third_party_proxy.py | 192 +++++++++ config.example.yaml | 45 ++ docker/docker-compose-dev.yaml | 4 + docker/docker-compose.yaml | 4 + 15 files changed, 1523 insertions(+), 7 deletions(-) create mode 100644 backend/app/gateway/routers/third_party.py create mode 100644 backend/app/gateway/third_party_proxy/__init__.py create mode 100644 backend/app/gateway/third_party_proxy/billing.py create mode 100644 backend/app/gateway/third_party_proxy/ledger.py create mode 100644 backend/app/gateway/third_party_proxy/proxy.py create mode 100644 backend/packages/harness/deerflow/config/third_party_proxy_config.py create mode 100644 backend/tests/test_third_party_proxy.py diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 64d2d093..1df9230e 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -1,4 +1,5 @@ import logging +import os from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -17,21 +18,39 @@ from app.gateway.routers import ( runs, skills, suggestions, + third_party, thread_runs, threads, uploads, ) from deerflow.config.app_config import get_app_config -# Configure logging with env override -import os -log_level = os.environ.get("LOG_LEVEL", "INFO").upper() +# Configure logging (prefer config.yaml log_level, fallback to LOG_LEVEL env) +env_log_level = os.environ.get("LOG_LEVEL", "INFO").upper() +log_level = env_log_level +try: + configured_log_level = get_app_config().log_level.upper() + if configured_log_level: + log_level = configured_log_level +except Exception: + # Keep startup resilient even if config is temporarily invalid/unavailable. + log_level = env_log_level + +resolved_log_level = getattr(logging, log_level, logging.INFO) logging.basicConfig( - level=getattr(logging, log_level, logging.INFO), + level=resolved_log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", + # Uvicorn installs logging handlers before app import; force reconfigure so + # config.yaml log_level reliably takes effect. + force=True, ) +# Ensure package loggers inherit the intended level even under custom handlers. +logging.getLogger().setLevel(resolved_log_level) +logging.getLogger("app").setLevel(resolved_log_level) +logging.getLogger("deerflow").setLevel(resolved_log_level) + logger = logging.getLogger(__name__) @@ -162,6 +181,10 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an "name": "health", "description": "Health check and system status endpoints", }, + { + "name": "third-party-proxy", + "description": "Universal third-party API proxy with billing integration (/api/proxy/{provider}/...)", + }, ], ) @@ -207,6 +230,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an # Stateless Runs API (stream/wait without a pre-existing thread) app.include_router(runs.router) + # Third-party API proxy with billing integration + app.include_router(third_party.router) + @app.get("/health", tags=["health"]) async def health_check() -> dict: """Health check endpoint. diff --git a/backend/app/gateway/routers/__init__.py b/backend/app/gateway/routers/__init__.py index c5f67a39..b36258df 100644 --- a/backend/app/gateway/routers/__init__.py +++ b/backend/app/gateway/routers/__init__.py @@ -1,3 +1,3 @@ -from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads +from . import artifacts, assistants_compat, mcp, models, skills, suggestions, third_party, thread_runs, threads, uploads -__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"] +__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "third_party", "threads", "thread_runs", "uploads"] diff --git a/backend/app/gateway/routers/third_party.py b/backend/app/gateway/routers/third_party.py new file mode 100644 index 00000000..3e38f2b1 --- /dev/null +++ b/backend/app/gateway/routers/third_party.py @@ -0,0 +1,403 @@ +"""Universal third-party API proxy router with integrated billing. + +Endpoint: ANY /api/proxy/{provider}/{path...} + +The caller (a sandbox skill script) should set: + X-Thread-Id: — used for billing reservation (injected via THREAD_ID env var) + X-Idempotency-Key: — optional; deduplicates submit calls + +The gateway automatically: + 1. Injects the provider's API key from the configured env var. + 2. For *submit* routes: reserves billing, forwards, records task state. + 3. For *query* routes: forwards, detects terminal status, finalizes billing once. + 4. For all other routes: transparent passthrough, no billing side-effects. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse, Response + +from app.gateway.third_party_proxy import billing, proxy +from app.gateway.third_party_proxy.ledger import CallRecord, get_ledger + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/proxy", tags=["third-party-proxy"]) + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + + +@router.api_route("/{provider}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) +async def proxy_request(provider: str, path: str, request: Request) -> Response: + """Universal proxy endpoint for third-party API calls with billing integration.""" + provider_config = proxy.get_provider_config(provider) + if provider_config is None: + raise HTTPException( + status_code=404, + detail=f"Provider '{provider}' is not configured or the proxy is disabled.", + ) + + method = request.method + # Normalise: ensure leading slash so patterns like /openapi/v2/** match correctly + path = "/" + path.lstrip("/") + + thread_id = request.headers.get("x-thread-id") + idempotency_key = request.headers.get("x-idempotency-key") + + body = await request.body() + request_json: dict[str, Any] | None = _try_parse_json(body) + + submit_route = proxy.match_submit_route(provider_config, method, path) + query_route = proxy.match_query_route(provider_config, method, path) + logger.info("[ThirdPartyProxy] route=%s provider=%s method=%s path=%s", "submit" if submit_route else "query" if query_route else "passthrough", provider, method, path) + + if submit_route: + return await _handle_submit( + provider=provider, + provider_config=provider_config, + method=method, + path=path, + request=request, + body=body, + thread_id=thread_id, + idempotency_key=idempotency_key, + task_id_jsonpath=submit_route.task_id_jsonpath, + route_frozen_amount=submit_route.frozen_amount, + route_frozen_type=submit_route.frozen_type, + ) + + if query_route: + return await _handle_query( + provider=provider, + provider_config=provider_config, + method=method, + path=path, + request=request, + body=body, + request_json=request_json, + query_route=query_route, + ) + + # Pure passthrough — no billing, no state + return await _passthrough( + provider_config=provider_config, + method=method, + path=path, + request=request, + body=body, + ) + + +# --------------------------------------------------------------------------- +# Submit handler +# --------------------------------------------------------------------------- + + +async def _handle_submit( + *, + provider: str, + provider_config, + method: str, + path: str, + request: Request, + body: bytes, + thread_id: str | None, + idempotency_key: str | None, + task_id_jsonpath: str, + route_frozen_amount: float | None, + route_frozen_type: int | None, +) -> Response: + ledger = get_ledger() + + # Idempotency: if we've already handled this exact submit, return the cached response + if idempotency_key: + existing = ledger.get_by_idempotency_key(provider, idempotency_key) + if existing is not None and existing.last_response is not None: + logger.info("[ThirdPartyProxy] idempotent submit: proxy_call_id=%s", existing.proxy_call_id) + return _proxy_response(existing.last_response, existing.proxy_call_id) + + record = ledger.create(provider, thread_id, idempotency_key) + + # Reserve billing before touching the provider + reserve_frozen_amount = route_frozen_amount if route_frozen_amount is not None else provider_config.frozen_amount + reserve_frozen_type = route_frozen_type if route_frozen_type is not None else provider_config.frozen_type + frozen_id = await billing.reserve( + thread_id=thread_id, + call_id=record.call_id, + provider=provider, + operation=path, + frozen_amount=reserve_frozen_amount, + frozen_type=reserve_frozen_type, + ) + if frozen_id: + ledger.set_reserved(record.proxy_call_id, frozen_id) + + # Forward to provider + try: + status_code, resp_headers, resp_body = await proxy.forward_request( + provider_config=provider_config, + method=method, + path=path, + headers=dict(request.headers), + body=body, + query_params=str(request.query_params), + ) + except Exception as exc: + await _finalize_zero(frozen_id, record.proxy_call_id, "error exception") + raise HTTPException(status_code=502, detail=f"Provider unreachable: {exc}") from exc + + resp_json = _try_parse_json(resp_body) + + # HTTP-level failure + if status_code >= 400: + reason = f"error_http_{status_code}" + await _finalize_zero(frozen_id, record.proxy_call_id, reason) + if resp_json is not None: + ledger.update_response(record.proxy_call_id, resp_json) + return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json") + + # Extract task_id from response; no task_id means provider rejected at business level + provider_task_id: str | None = None + if resp_json is not None: + raw = proxy.jsonpath_get(resp_json, task_id_jsonpath) + if raw is not None: + provider_task_id = str(raw) + + if provider_task_id: + ledger.set_running(record.proxy_call_id, provider_task_id) + else: + # No async task ID usually means provider-side business rejection. + # Propagate errorCode (if present) into finalize_reason. + error_code = None + if resp_json is not None: + raw_error_code = resp_json.get("errorCode") + if raw_error_code is None: + raw_error_code = resp_json.get("code") + if raw_error_code is not None: + error_code = str(raw_error_code) + + finalize_reason = error_code or "no_task_id" + await _finalize_zero(frozen_id, record.proxy_call_id, finalize_reason) + + if resp_json is not None: + ledger.update_response(record.proxy_call_id, resp_json) + + return _proxy_response(resp_json or {}, record.proxy_call_id, status_code, resp_headers) + + +# --------------------------------------------------------------------------- +# Query handler +# --------------------------------------------------------------------------- + + +async def _handle_query( + *, + provider: str, + provider_config, + method: str, + path: str, + request: Request, + body: bytes, + request_json: dict[str, Any] | None, + query_route, +) -> Response: + ledger = get_ledger() + + # Locate the call record by provider_task_id embedded in the request body + provider_task_id: str | None = None + if request_json: + raw = proxy.jsonpath_get(request_json, query_route.request_task_id_jsonpath) + if raw is not None: + provider_task_id = str(raw) + + record: CallRecord | None = None + if provider_task_id: + record = ledger.get_by_task_id(provider, provider_task_id) + + # Already at terminal state — return cached result without calling the provider again + if record is not None and ledger.is_finalized(record.proxy_call_id) and record.last_response is not None: + logger.info("[ThirdPartyProxy] query already finalized, returning cache: proxy_call_id=%s", record.proxy_call_id) + return _proxy_response(record.last_response, record.proxy_call_id) + + # Forward query to provider + try: + status_code, resp_headers, resp_body = await proxy.forward_request( + provider_config=provider_config, + method=method, + path=path, + headers=dict(request.headers), + body=body, + query_params=str(request.query_params), + ) + except Exception as exc: + raise HTTPException(status_code=502, detail=f"Provider query failed: {exc}") from exc + + resp_json = _try_parse_json(resp_body) + if status_code >= 400 or resp_json is None: + return Response(content=resp_body, status_code=status_code, headers=resp_headers, media_type="application/json") + + # Detect terminal status in the response + status_value = proxy.jsonpath_get(resp_json, query_route.status_jsonpath) + status_str = str(status_value) if status_value is not None else None + is_success = status_str in query_route.success_values + is_failure = status_str in query_route.failure_values + + logger.debug( + "[ThirdPartyProxy] query terminal check: provider=%s task_id=%s status=%s is_success=%s is_failure=%s", + provider, + provider_task_id, + status_str, + is_success, + is_failure, + ) + + if record is not None and (is_success or is_failure): + logger.info( + "[ThirdPartyProxy] finalize candidate: proxy_call_id=%s provider_task_id=%s terminal_status=%s", + record.proxy_call_id, + provider_task_id, + status_str, + ) + # Atomically claim finalize rights — only one concurrent query wins + if ledger.try_claim_finalize(record.proxy_call_id): + logger.info( + "[ThirdPartyProxy] finalize claimed: proxy_call_id=%s", + record.proxy_call_id, + ) + final_amount: float = 0.0 + if is_success and query_route.usage_jsonpath: + raw_amount = proxy.jsonpath_get(resp_json, query_route.usage_jsonpath) + try: + final_amount = float(raw_amount) if raw_amount is not None else 0.0 + except (TypeError, ValueError): + final_amount = 0.0 + + logger.debug( + "[ThirdPartyProxy] finalize amount resolved: proxy_call_id=%s final_amount=%s usage_path=%s", + record.proxy_call_id, + final_amount, + query_route.usage_jsonpath, + ) + + task_state = "SUCCESS" if is_success else "FAILED" + finalize_reason = "success" if is_success else "error" + + logger.info( + "[ThirdPartyProxy] finalize start: proxy_call_id=%s reason=%s task_state=%s has_frozen_id=%s", + record.proxy_call_id, + finalize_reason, + task_state, + bool(record.frozen_id), + ) + + if record.frozen_id: + ok = await billing.finalize( + frozen_id=record.frozen_id, + final_amount=final_amount, + finalize_reason=finalize_reason, + ) + logger.info( + "[ThirdPartyProxy] finalize result: proxy_call_id=%s ok=%s", + record.proxy_call_id, + ok, + ) + if ok: + ledger.set_finalized(record.proxy_call_id, task_state) + else: + ledger.set_finalize_failed(record.proxy_call_id, task_state) + else: + logger.info( + "[ThirdPartyProxy] finalize skipped billing call (no frozen_id): proxy_call_id=%s", + record.proxy_call_id, + ) + ledger.set_finalized(record.proxy_call_id, task_state) + + ledger.update_response(record.proxy_call_id, resp_json) + else: + logger.info( + "[ThirdPartyProxy] finalize claim denied (already processed): proxy_call_id=%s", + record.proxy_call_id, + ) + + proxy_call_id = record.proxy_call_id if record else None + return _proxy_response(resp_json, proxy_call_id, status_code, resp_headers) + + +# --------------------------------------------------------------------------- +# Passthrough handler +# --------------------------------------------------------------------------- + + +async def _passthrough(*, provider_config, method: str, path: str, request: Request, body: bytes) -> Response: + try: + status_code, resp_headers, resp_body = await proxy.forward_request( + provider_config=provider_config, + method=method, + path=path, + headers=dict(request.headers), + body=body, + query_params=str(request.query_params), + ) + except Exception as exc: + raise HTTPException(status_code=502, detail=f"Provider request failed: {exc}") from exc + + return Response(content=resp_body, status_code=status_code, headers=resp_headers) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _finalize_zero(frozen_id: str | None, proxy_call_id: str, reason: str) -> None: + """Finalize with amount=0 when billing was reserved but the call failed.""" + ledger = get_ledger() + logger.info( + "[ThirdPartyProxy] finalize_zero requested: proxy_call_id=%s reason=%s has_frozen_id=%s", + proxy_call_id, + reason, + bool(frozen_id), + ) + if frozen_id and ledger.try_claim_finalize(proxy_call_id): + logger.info("[ThirdPartyProxy] finalize_zero claimed: proxy_call_id=%s", proxy_call_id) + ok = await billing.finalize(frozen_id=frozen_id, final_amount=0, finalize_reason=reason) + logger.info("[ThirdPartyProxy] finalize_zero result: proxy_call_id=%s ok=%s", proxy_call_id, ok) + task_state = "SUCCESS" if reason == "success" else "FAILED" + if ok: + ledger.set_finalized(proxy_call_id, task_state) + else: + ledger.set_finalize_failed(proxy_call_id, task_state) + elif not frozen_id: + logger.debug("[ThirdPartyProxy] finalize_zero skipped: no frozen_id proxy_call_id=%s", proxy_call_id) + else: + logger.info("[ThirdPartyProxy] finalize_zero claim denied: proxy_call_id=%s", proxy_call_id) + + +def _try_parse_json(data: bytes) -> dict[str, Any] | None: + if not data: + return None + try: + parsed = json.loads(data) + return parsed if isinstance(parsed, dict) else None + except (json.JSONDecodeError, ValueError): + return None + + +def _proxy_response( + data: dict[str, Any], + proxy_call_id: str | None, + status_code: int = 200, + extra_headers: dict[str, str] | None = None, +) -> JSONResponse: + headers: dict[str, str] = dict(extra_headers or {}) + if proxy_call_id: + headers["X-Proxy-Call-Id"] = proxy_call_id + return JSONResponse(content=data, status_code=status_code, headers=headers) diff --git a/backend/app/gateway/third_party_proxy/__init__.py b/backend/app/gateway/third_party_proxy/__init__.py new file mode 100644 index 00000000..52d45938 --- /dev/null +++ b/backend/app/gateway/third_party_proxy/__init__.py @@ -0,0 +1 @@ +"""Third-party proxy package.""" diff --git a/backend/app/gateway/third_party_proxy/billing.py b/backend/app/gateway/third_party_proxy/billing.py new file mode 100644 index 00000000..0c863670 --- /dev/null +++ b/backend/app/gateway/third_party_proxy/billing.py @@ -0,0 +1,190 @@ +"""Thin async billing client for the third-party proxy. + +Calls the same reserve/finalize HTTP endpoints as BillingMiddleware, +but with semantics appropriate for third-party task calls: +- estimatedTokens = 0 (not applicable) +- finalAmount = actual provider monetary charge (thirdPartyConsumeMoney) +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta + +import httpx + +from deerflow.config.app_config import get_app_config + +logger = logging.getLogger(__name__) + +_SUCCESS_STATUS_CODES = {200, 1000} + + +async def reserve( + *, + thread_id: str | None, + call_id: str, + provider: str, + operation: str, + frozen_amount: float, + frozen_type: int | None, +) -> str | None: + """Reserve billing before forwarding a submit call. + + Returns the frozen_id string on success, or None if billing is disabled + or the reserve call fails (non-blocking — proxy continues in that case). + """ + cfg = get_app_config().billing + if not cfg.enabled or not cfg.reserve_url: + logger.info( + "[ThirdPartyProxy][Billing] reserve skipped: enabled=%s reserve_url=%s call_id=%s", + cfg.enabled, + cfg.reserve_url, + call_id, + ) + return None + + expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds) + payload = { + "sessionId": thread_id, + "callId": call_id, + "modelName": provider, + "question": f"skill invokes {operation.split('/')[-1]}", + "frozenAmount": frozen_amount, + "frozenType": frozen_type if frozen_type is not None else cfg.frozen_type, + "estimatedInputTokens": 0, + "estimatedOutputTokens": 0, + "expireAt": expire_at.strftime("%Y-%m-%d %H:%M:%S"), + } + + logger.info( + "[ThirdPartyProxy][Billing] reserve request: url=%s call_id=%s provider=%s thread_id=%s", + cfg.reserve_url, + call_id, + provider, + thread_id, + ) + logger.debug("[ThirdPartyProxy][Billing] reserve payload: %s", payload) + try: + async with httpx.AsyncClient(timeout=cfg.timeout_seconds) as client: + resp = await client.post(cfg.reserve_url, headers=cfg.headers, json=payload) + resp.raise_for_status() + data: dict = resp.json() + except Exception as exc: + logger.warning("[ThirdPartyProxy][Billing] reserve HTTP error: %s", exc) + return None + + logger.info( + "[ThirdPartyProxy][Billing] reserve response: call_id=%s status_code=%s", + call_id, + resp.status_code, + ) + logger.debug("[ThirdPartyProxy][Billing] reserve response body: %s", data) + + if not _is_success(data): + logger.warning( + "[ThirdPartyProxy][Billing] reserve rejected: call_id=%s status=%s payload=%s", + call_id, + data.get("status") or data.get("code"), + data, + ) + return None + + frozen_id = (data.get("data") or {}).get("frozenId") + if not isinstance(frozen_id, str) or not frozen_id: + logger.warning( + "[ThirdPartyProxy][Billing] reserve response missing frozenId: call_id=%s payload=%s", + call_id, + data, + ) + return None + + logger.info("[ThirdPartyProxy][Billing] reserve ok: call_id=%s frozen_id=%s", call_id, frozen_id) + logger.debug( + "[ThirdPartyProxy][Billing] reserve success details: provider=%s operation=%s expire_at=%s", + provider, + operation, + payload["expireAt"], + ) + return frozen_id + + +async def finalize( + *, + frozen_id: str, + final_amount: float, + finalize_reason: str, +) -> bool: + """Finalize billing after a third-party call reaches a terminal state. + + final_amount is the actual provider charge (e.g. thirdPartyConsumeMoney from RunningHub). + Pass 0 for failed/cancelled calls. + Returns True on success. + """ + cfg = get_app_config().billing + if not cfg.enabled or not cfg.finalize_url: + # Billing not configured — treat as success so the caller marks the record finalized + logger.info( + "[ThirdPartyProxy][Billing] finalize skipped: enabled=%s finalize_url=%s frozen_id=%s", + cfg.enabled, + cfg.finalize_url, + frozen_id, + ) + return True + + payload = { + "frozenId": frozen_id, + "finalAmount": final_amount, + "usageInputTokens": 0, + "usageOutputTokens": 0, + "usageTotalTokens": 0, + "finalizeReason": finalize_reason, + } + + logger.info( + "[ThirdPartyProxy][Billing] finalize request: frozen_id=%s amount=%s reason=%s url=%s", + frozen_id, + final_amount, + finalize_reason, + cfg.finalize_url, + ) + logger.debug("[ThirdPartyProxy][Billing] finalize payload: %s", payload) + try: + async with httpx.AsyncClient(timeout=cfg.timeout_seconds) as client: + resp = await client.post(cfg.finalize_url, headers=cfg.headers, json=payload) + resp.raise_for_status() + data: dict = resp.json() + except Exception as exc: + logger.warning("[ThirdPartyProxy][Billing] finalize HTTP error: frozen_id=%s err=%s", frozen_id, exc) + return False + + logger.info( + "[ThirdPartyProxy][Billing] finalize response: frozen_id=%s status_code=%s", + frozen_id, + resp.status_code, + ) + logger.debug("[ThirdPartyProxy][Billing] finalize response body: %s", data) + + if not _is_success(data): + logger.warning( + "[ThirdPartyProxy][Billing] finalize rejected: frozen_id=%s status=%s payload=%s", + frozen_id, + data.get("status") or data.get("code"), + data, + ) + return False + + logger.info("[ThirdPartyProxy][Billing] finalize ok: frozen_id=%s", frozen_id) + logger.debug( + "[ThirdPartyProxy][Billing] finalize success details: amount=%s reason=%s", + final_amount, + finalize_reason, + ) + return True + + +def _is_success(data: dict) -> bool: + status = data.get("status") or data.get("code") + if isinstance(status, int) and status in _SUCCESS_STATUS_CODES: + return True + return data.get("success") is True diff --git a/backend/app/gateway/third_party_proxy/ledger.py b/backend/app/gateway/third_party_proxy/ledger.py new file mode 100644 index 00000000..42a02861 --- /dev/null +++ b/backend/app/gateway/third_party_proxy/ledger.py @@ -0,0 +1,289 @@ +"""In-memory call state ledger for the third-party proxy. + +Tracks each proxied call from reserve → submit → query → finalize, +enforcing idempotency and ensuring billing finalize runs exactly once. +""" + +from __future__ import annotations + +import logging +import threading +import time +from dataclasses import dataclass, field +from typing import Any, Literal +from uuid import uuid4 + +logger = logging.getLogger(__name__) + +BillingState = Literal["UNRESERVED", "RESERVED", "FINALIZED", "FINALIZE_FAILED"] +TaskState = Literal["PENDING", "RUNNING", "SUCCESS", "FAILED", "UNKNOWN"] + + +@dataclass +class CallRecord: + proxy_call_id: str + provider: str + thread_id: str | None + # call_id is sent to the billing platform (callId in reserve payload) + call_id: str + frozen_id: str | None = None + provider_task_id: str | None = None + billing_state: BillingState = "UNRESERVED" + task_state: TaskState = "PENDING" + created_at: float = field(default_factory=time.time) + finalized_at: float | None = None + error: str | None = None + idempotency_key: str | None = None + # Cached last provider response — returned for repeat queries after finalization + last_response: dict[str, Any] | None = None + + +class CallLedger: + """Thread-safe in-memory ledger for third-party proxy call records.""" + + def __init__(self) -> None: + self._records: dict[str, CallRecord] = {} # proxy_call_id → record + self._task_index: dict[str, str] = {} # "{provider}:{provider_task_id}" → proxy_call_id + self._idem_index: dict[str, str] = {} # "{provider}:{idem_key}" → proxy_call_id + self._lock = threading.Lock() + + def create( + self, + provider: str, + thread_id: str | None, + idempotency_key: str | None = None, + ) -> CallRecord: + """Create a new call record, or return the existing one if idempotency key matches.""" + with self._lock: + if idempotency_key: + existing = self._get_by_idem_key_locked(provider, idempotency_key) + if existing is not None: + logger.info( + "[ThirdPartyProxy][Ledger] idempotent hit: provider=%s proxy_call_id=%s idem_key=%s", + provider, + existing.proxy_call_id, + idempotency_key, + ) + # logger.debug( + # "[ThirdPartyProxy][Ledger] existing record reused: call_id=%s task_id=%s billing_state=%s task_state=%s", + # existing.call_id, + # existing.provider_task_id, + # existing.billing_state, + # existing.task_state, + # ) + return existing + + record = CallRecord( + proxy_call_id=str(uuid4()), + provider=provider, + thread_id=thread_id, + call_id=str(uuid4()), + idempotency_key=idempotency_key, + ) + self._records[record.proxy_call_id] = record + if idempotency_key: + self._idem_index[f"{provider}:{idempotency_key}"] = record.proxy_call_id + logger.info( + "[ThirdPartyProxy][Ledger] created record: provider=%s proxy_call_id=%s call_id=%s thread_id=%s", + provider, + record.proxy_call_id, + record.call_id, + thread_id, + ) + # logger.debug( + # "[ThirdPartyProxy][Ledger] create details: idem_key=%s billing_state=%s task_state=%s", + # idempotency_key, + # record.billing_state, + # record.task_state, + # ) + return record + + def get(self, proxy_call_id: str) -> CallRecord | None: + return self._records.get(proxy_call_id) + + def get_by_task_id(self, provider: str, provider_task_id: str) -> CallRecord | None: + key = f"{provider}:{provider_task_id}" + proxy_call_id = self._task_index.get(key) + return self._records.get(proxy_call_id) if proxy_call_id else None + + def get_by_idempotency_key(self, provider: str, idempotency_key: str) -> CallRecord | None: + return self._get_by_idem_key_locked(provider, idempotency_key) + + def set_reserved(self, proxy_call_id: str, frozen_id: str) -> None: + with self._lock: + record = self._records.get(proxy_call_id) + if record: + record.frozen_id = frozen_id + record.billing_state = "RESERVED" + logger.info( + "[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s", + proxy_call_id, + frozen_id, + ) + # logger.debug( + # "[ThirdPartyProxy][Ledger] reserve state: call_id=%s provider=%s task_state=%s", + # record.call_id, + # record.provider, + # record.task_state, + # ) + else: + logger.debug( + "[ThirdPartyProxy][Ledger] set_reserved ignored for missing record: proxy_call_id=%s", + proxy_call_id, + ) + + def set_running(self, proxy_call_id: str, provider_task_id: str) -> None: + with self._lock: + record = self._records.get(proxy_call_id) + if record: + record.provider_task_id = provider_task_id + record.task_state = "RUNNING" + self._task_index[f"{record.provider}:{provider_task_id}"] = proxy_call_id + logger.info( + "[ThirdPartyProxy][Ledger] running: proxy_call_id=%s provider_task_id=%s", + proxy_call_id, + provider_task_id, + ) + # logger.debug( + # "[ThirdPartyProxy][Ledger] running state: provider=%s call_id=%s billing_state=%s", + # record.provider, + # record.call_id, + # record.billing_state, + # ) + else: + logger.debug( + "[ThirdPartyProxy][Ledger] set_running ignored for missing record: proxy_call_id=%s provider_task_id=%s", + proxy_call_id, + provider_task_id, + ) + + def try_claim_finalize(self, proxy_call_id: str) -> bool: + """Atomically claim finalization rights. Returns True only once per record.""" + with self._lock: + record = self._records.get(proxy_call_id) + if record is None: + logger.debug( + "[ThirdPartyProxy][Ledger] finalize claim denied: missing record proxy_call_id=%s", + proxy_call_id, + ) + return False + if record.billing_state in ("FINALIZED", "FINALIZE_FAILED"): + logger.debug( + "[ThirdPartyProxy][Ledger] finalize claim denied: proxy_call_id=%s billing_state=%s", + proxy_call_id, + record.billing_state, + ) + return False + # Mark as finalized immediately to prevent concurrent finalize + record.billing_state = "FINALIZED" + logger.info( + "[ThirdPartyProxy][Ledger] finalize claimed: proxy_call_id=%s", + proxy_call_id, + ) + logger.debug( + "[ThirdPartyProxy][Ledger] finalize claim state: call_id=%s provider=%s task_state=%s frozen_id=%s", + record.call_id, + record.provider, + record.task_state, + record.frozen_id, + ) + return True + + def set_finalized(self, proxy_call_id: str, task_state: TaskState) -> None: + with self._lock: + record = self._records.get(proxy_call_id) + if record: + record.task_state = task_state + record.billing_state = "FINALIZED" + record.finalized_at = time.time() + logger.info( + "[ThirdPartyProxy][Ledger] finalized: proxy_call_id=%s task_state=%s", + proxy_call_id, + task_state, + ) + logger.debug( + "[ThirdPartyProxy][Ledger] finalized state: provider=%s call_id=%s frozen_id=%s finalized_at=%s", + record.provider, + record.call_id, + record.frozen_id, + record.finalized_at, + ) + else: + logger.debug( + "[ThirdPartyProxy][Ledger] set_finalized ignored for missing record: proxy_call_id=%s task_state=%s", + proxy_call_id, + task_state, + ) + + def set_finalize_failed(self, proxy_call_id: str, task_state: TaskState) -> None: + with self._lock: + record = self._records.get(proxy_call_id) + if record: + record.task_state = task_state + record.billing_state = "FINALIZE_FAILED" + record.finalized_at = time.time() + logger.info( + "[ThirdPartyProxy][Ledger] finalize failed: proxy_call_id=%s task_state=%s", + proxy_call_id, + task_state, + ) + logger.debug( + "[ThirdPartyProxy][Ledger] finalize failure state: provider=%s call_id=%s frozen_id=%s finalized_at=%s", + record.provider, + record.call_id, + record.frozen_id, + record.finalized_at, + ) + else: + logger.debug( + "[ThirdPartyProxy][Ledger] set_finalize_failed ignored for missing record: proxy_call_id=%s task_state=%s", + proxy_call_id, + task_state, + ) + + def update_response(self, proxy_call_id: str, response: dict[str, Any]) -> None: + with self._lock: + record = self._records.get(proxy_call_id) + if record: + record.last_response = response + logger.debug( + "[ThirdPartyProxy][Ledger] cached response: proxy_call_id=%s keys=%s", + proxy_call_id, + sorted(response.keys()), + ) + else: + logger.debug( + "[ThirdPartyProxy][Ledger] update_response ignored for missing record: proxy_call_id=%s", + proxy_call_id, + ) + + def is_finalized(self, proxy_call_id: str) -> bool: + record = self._records.get(proxy_call_id) + return record is not None and record.billing_state in ("FINALIZED", "FINALIZE_FAILED") + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _get_by_idem_key_locked(self, provider: str, idempotency_key: str) -> CallRecord | None: + key = f"{provider}:{idempotency_key}" + proxy_call_id = self._idem_index.get(key) + return self._records.get(proxy_call_id) if proxy_call_id else None + + +# --------------------------------------------------------------------------- +# Module-level singleton +# --------------------------------------------------------------------------- + +_ledger: CallLedger | None = None +_ledger_lock = threading.Lock() + + +def get_ledger() -> CallLedger: + global _ledger + if _ledger is None: + with _ledger_lock: + if _ledger is None: + _ledger = CallLedger() + logger.info("[ThirdPartyProxy][Ledger] singleton initialized") + return _ledger diff --git a/backend/app/gateway/third_party_proxy/proxy.py b/backend/app/gateway/third_party_proxy/proxy.py new file mode 100644 index 00000000..912814e6 --- /dev/null +++ b/backend/app/gateway/third_party_proxy/proxy.py @@ -0,0 +1,246 @@ +"""HTTP forwarding, route classification, and JSONPath extraction for the third-party proxy.""" + +from __future__ import annotations + +import logging +import os +from typing import Any + +import httpx + +from deerflow.config.app_config import get_app_config +from deerflow.config.third_party_proxy_config import ( + QueryRouteConfig, + SubmitRouteConfig, + ThirdPartyProviderConfig, +) + +logger = logging.getLogger(__name__) + +_SENSITIVE_HEADERS = frozenset( + [ + "authorization", + "proxy-authorization", + "x-api-key", + "api-key", + "cookie", + "set-cookie", + ] +) + +# --------------------------------------------------------------------------- +# Provider config lookup +# --------------------------------------------------------------------------- + + +def get_provider_config(provider: str) -> ThirdPartyProviderConfig | None: + """Return the provider config for *provider*, or None if not configured/disabled.""" + cfg = get_app_config().third_party_proxy + if not cfg.enabled: + return None + return cfg.providers.get(provider) + + +# --------------------------------------------------------------------------- +# Route classification +# --------------------------------------------------------------------------- + + +def match_submit_route( + config: ThirdPartyProviderConfig, + method: str, + path: str, +) -> SubmitRouteConfig | None: + """Return the first submit route that matches (method, path), or None.""" + for route in config.submit_routes: + if route.method.upper() != method.upper(): + continue + if not _path_matches(path, route.path_pattern): + continue + if route.exclude_path_pattern and _path_matches(path, route.exclude_path_pattern): + continue + return route + return None + + +def match_query_route( + config: ThirdPartyProviderConfig, + method: str, + path: str, +) -> QueryRouteConfig | None: + """Return the first query route that matches (method, path), or None.""" + for route in config.query_routes: + if route.method.upper() != method.upper(): + continue + if _path_matches(path, route.path_pattern): + return route + return None + + +def _path_matches(path: str, pattern: str) -> bool: + """Match *path* against a glob-ish *pattern*. + + Rules: + - Pattern ending in /** matches the prefix and any sub-path. + - Otherwise exact match. + """ + # Normalise trailing slashes + path = path.rstrip("/") or "/" + pattern = pattern.rstrip("/") or "/" + + if pattern.endswith("/**"): + prefix = pattern[:-3] + return path == prefix or path.startswith(prefix + "/") + + return path == pattern + + +# --------------------------------------------------------------------------- +# Minimal path evaluator (dot-notation shorthand only) +# --------------------------------------------------------------------------- + + +def jsonpath_get(data: Any, path: str) -> Any: + """Extract a value from *data* using a simple dot-notation shorthand path. + + Supports paths like: taskId usage.thirdPartyConsumeMoney + Paths with a leading '$' are intentionally not supported. + Returns None if any segment is missing or the input is not a dict. + """ + if not isinstance(path, str): + return None + + remainder = path.strip() + if not remainder or remainder.startswith("$"): + return None + + current: Any = data + for part in remainder.split("."): + if not part: + return None + if not isinstance(current, dict): + return None + current = current.get(part) + if current is None: + return None + return current + + +# --------------------------------------------------------------------------- +# HTTP forwarding +# --------------------------------------------------------------------------- + +# Request headers we never forward (hop-by-hop, sensitive, or proxy-internal) +_STRIP_REQUEST_HEADERS = frozenset( + [ + "host", + "content-length", + "transfer-encoding", + "connection", + "x-thread-id", + "x-idempotency-key", + ] +) + +# Response headers we strip before returning to the caller +_STRIP_RESPONSE_HEADERS = frozenset( + [ + "transfer-encoding", + "connection", + "keep-alive", + "content-encoding", + "content-length", + ] +) + + +def _sanitize_headers(headers: dict[str, str]) -> dict[str, str]: + """Return a copy of headers with sensitive values redacted.""" + sanitized: dict[str, str] = {} + for key, value in headers.items(): + if key.lower() in _SENSITIVE_HEADERS: + sanitized[key] = "***" + else: + sanitized[key] = value + return sanitized + + +def _preview_body(data: bytes, limit: int = 2048) -> str: + """Return a safe textual preview of body bytes for debugging logs.""" + if not data: + return "" + chunk = data[:limit] + text = chunk.decode("utf-8", errors="replace") + if len(data) > limit: + text += f" ..." + return text + + +async def forward_request( + *, + provider_config: ThirdPartyProviderConfig, + method: str, + path: str, + headers: dict[str, str], + body: bytes, + query_params: str, +) -> tuple[int, dict[str, str], bytes]: + """Forward *method* *path* to the provider and return (status_code, headers, body). + + The provider's API key (read from the environment variable named in + ``provider_config.api_key_env``) is injected automatically, replacing + any Authorization header the caller might have sent. + """ + target_url = provider_config.base_url.rstrip("/") + "/" + path.lstrip("/") + if query_params: + target_url += "?" + query_params + + # Build forwarded headers: drop internal/hop-by-hop, then inject API key + forward_headers = { + k: v for k, v in headers.items() if k.lower() not in _STRIP_REQUEST_HEADERS + } + if provider_config.api_key_env: + api_key = os.getenv(provider_config.api_key_env) + if api_key: + forward_headers[provider_config.api_key_header] = provider_config.api_key_prefix + api_key + else: + logger.warning( + "[ThirdPartyProxy] api_key_env '%s' is not set for provider", + provider_config.api_key_env, + ) + + logger.info("[ThirdPartyProxy] → %s %s", method, target_url) + logger.debug( + "[ThirdPartyProxy] request headers=%s", + _sanitize_headers(forward_headers) + ) + logger.debug( + "[ThirdPartyProxy] request body(%dB)=%s", + len(body), + _preview_body(body), + ) + + async with httpx.AsyncClient(timeout=provider_config.timeout_seconds) as client: + response = await client.request( + method=method, + url=target_url, + headers=forward_headers, + content=body, + ) + + response_headers = { + k: v + for k, v in response.headers.items() + if k.lower() not in _STRIP_RESPONSE_HEADERS + } + logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code) + logger.debug( + "[ThirdPartyProxy] response headers=%s", + _sanitize_headers(response_headers) + ) + logger.debug( + "[ThirdPartyProxy] response body(%dB)=%s", + len(response.content), + _preview_body(response.content), + ) + return response.status_code, response_headers, response.content diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py index e93ea5e9..3030d325 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py @@ -261,6 +261,12 @@ class LocalContainerBackend(SandboxBackend): ] ) + # On Linux, containers started via DooD (Docker-out-of-Docker) do not + # automatically resolve host.docker.internal. Add the mapping explicitly + # so sandbox containers can call back into the host-exposed gateway. + if self._runtime == "docker": + cmd.extend(["--add-host", "host.docker.internal:host-gateway"]) + # Environment variables (static config first, runtime overrides last) for key, value in self._environment.items(): cmd.extend(["-e", f"{key}={value}"]) diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index f15e0304..228c9b14 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -20,6 +20,7 @@ from deerflow.config.skills_config import SkillsConfig from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict +from deerflow.config.third_party_proxy_config import ThirdPartyProxyConfig from deerflow.config.title_config import TitleConfig, load_title_config_from_dict from deerflow.config.token_usage_config import TokenUsageConfig from deerflow.config.tool_config import ToolConfig, ToolGroupConfig @@ -42,6 +43,7 @@ class AppConfig(BaseModel): log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)") billing: BillingConfig = Field(default_factory=BillingConfig, description="External billing reservation/finalization configuration") + third_party_proxy: ThirdPartyProxyConfig = Field(default_factory=ThirdPartyProxyConfig, description="Third-party API proxy with billing integration") token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration") models: list[ModelConfig] = Field(default_factory=list, description="Available models") sandbox: SandboxConfig = Field(description="Sandbox configuration") diff --git a/backend/packages/harness/deerflow/config/third_party_proxy_config.py b/backend/packages/harness/deerflow/config/third_party_proxy_config.py new file mode 100644 index 00000000..55c349f9 --- /dev/null +++ b/backend/packages/harness/deerflow/config/third_party_proxy_config.py @@ -0,0 +1,108 @@ +"""Configuration for the third-party API proxy with billing integration.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class SubmitRouteConfig(BaseModel): + """Identifies a submit request — triggers billing reserve + task state tracking.""" + + method: str = Field(default="POST", description="HTTP method to match (case-insensitive)") + path_pattern: str = Field( + description="Glob-style path pattern. Use ** to match any sub-path, e.g. /openapi/v2/**" + ) + exclude_path_pattern: str | None = Field( + default=None, + description="If set, paths matching this pattern are excluded from submit handling", + ) + task_id_jsonpath: str = Field( + description="Dot-path into the *response* body to extract the provider task ID, e.g. taskId" + ) + frozen_amount: float | None = Field( + default=None, + ge=0, + description="Optional route-level override for billing reserve payload frozenAmount", + ) + frozen_type: int | None = Field( + default=None, + description="Optional route-level override for billing reserve payload frozenType", + ) + + +class QueryRouteConfig(BaseModel): + """Identifies a query/poll request — checks for terminal status + triggers billing finalize.""" + + method: str = Field(default="POST", description="HTTP method to match (case-insensitive)") + path_pattern: str = Field(description="Glob-style path pattern for the query endpoint") + request_task_id_jsonpath: str = Field( + description="Dot-path into the *request* body to extract the task ID being queried" + ) + status_jsonpath: str = Field( + description="Dot-path into the response body to read the task status value" + ) + success_values: list[str] = Field( + default_factory=list, + description="Status string values that indicate successful terminal state, e.g. [\"SUCCESS\"]", + ) + failure_values: list[str] = Field( + default_factory=list, + description="Status string values that indicate failed terminal state, e.g. [\"FAILED\", \"CANCELLED\"]", + ) + usage_jsonpath: str | None = Field( + default=None, + description=( + "Dot-path into the response body for the actual monetary cost to pass to billing finalize. " + "E.g. usage.thirdPartyConsumeMoney" + ), + ) + + +class ThirdPartyProviderConfig(BaseModel): + """Configuration for a single third-party API platform.""" + + base_url: str = Field(description="Base URL of the provider, e.g. https://www.runninghub.cn") + api_key_env: str | None = Field( + default=None, + description="Name of the environment variable holding the API key", + ) + api_key_header: str = Field( + default="Authorization", + description="Request header name for the API key", + ) + api_key_prefix: str = Field( + default="Bearer ", + description="String prepended to the API key value in the header", + ) + timeout_seconds: float = Field( + default=30.0, + gt=0, + description="HTTP request timeout when forwarding to the provider", + ) + frozen_amount: float = Field( + default=0.0, + ge=0, + description="Amount to reserve in billing reserve payload (frozenAmount)", + ) + frozen_type: int | None = Field( + default=None, + description="Billing frozen type for this provider (frozenType). If omitted, falls back to billing.frozen_type", + ) + submit_routes: list[SubmitRouteConfig] = Field( + default_factory=list, + description="Route patterns that identify submit (task-create) requests", + ) + query_routes: list[QueryRouteConfig] = Field( + default_factory=list, + description="Route patterns that identify query/poll requests", + ) + + +class ThirdPartyProxyConfig(BaseModel): + """Top-level configuration for the third-party API proxy.""" + + enabled: bool = Field(default=False, description="Enable the proxy endpoint") + providers: dict[str, ThirdPartyProviderConfig] = Field( + default_factory=dict, + description="Keyed by provider name (used in the URL path /api/proxy/{provider}/...)", + ) diff --git a/backend/tests/test_aio_sandbox_local_backend.py b/backend/tests/test_aio_sandbox_local_backend.py index 529bf4e9..1cc7ae1a 100644 --- a/backend/tests/test_aio_sandbox_local_backend.py +++ b/backend/tests/test_aio_sandbox_local_backend.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from deerflow.community.aio_sandbox.local_backend import LocalContainerBackend, _format_container_mount diff --git a/backend/tests/test_third_party_proxy.py b/backend/tests/test_third_party_proxy.py new file mode 100644 index 00000000..35d86bc5 --- /dev/null +++ b/backend/tests/test_third_party_proxy.py @@ -0,0 +1,192 @@ +"""Unit tests for the third-party proxy module.""" + +from __future__ import annotations + +from app.gateway.third_party_proxy.ledger import CallLedger +from app.gateway.third_party_proxy.proxy import ( + _path_matches, + jsonpath_get, + match_query_route, + match_submit_route, +) +from deerflow.config.third_party_proxy_config import ( + QueryRouteConfig, + SubmitRouteConfig, + ThirdPartyProviderConfig, +) + + +# --------------------------------------------------------------------------- +# _path_matches +# --------------------------------------------------------------------------- + + +class TestPathMatches: + def test_exact_match(self): + assert _path_matches("/openapi/v2/query", "/openapi/v2/query") + + def test_exact_no_match(self): + assert not _path_matches("/openapi/v2/query", "/openapi/v2/submit") + + def test_glob_matches_prefix(self): + assert _path_matches("/openapi/v2/vidu/submit", "/openapi/v2/**") + + def test_glob_matches_prefix_itself(self): + assert _path_matches("/openapi/v2", "/openapi/v2/**") + + def test_glob_no_match_different_prefix(self): + assert not _path_matches("/other/v2/submit", "/openapi/v2/**") + + def test_trailing_slashes_normalised(self): + assert _path_matches("/openapi/v2/query/", "/openapi/v2/query") + + def test_glob_excludes_sibling_prefix(self): + # /openapi/v2/** should not match /openapi/v2extra/foo + assert not _path_matches("/openapi/v2extra/foo", "/openapi/v2/**") + + +# --------------------------------------------------------------------------- +# jsonpath_get +# --------------------------------------------------------------------------- + + +class TestJsonpathGet: + def test_single_key(self): + assert jsonpath_get({"taskId": "abc"}, "taskId") == "abc" + + def test_nested_key(self): + data = {"usage": {"thirdPartyConsumeMoney": 1.23}} + assert jsonpath_get(data, "usage.thirdPartyConsumeMoney") == 1.23 + + def test_missing_key_returns_none(self): + assert jsonpath_get({"foo": "bar"}, "taskId") is None + + def test_rejects_dollar_prefixed_path(self): + assert jsonpath_get({"taskId": "abc"}, "$.taskId") is None + + def test_short_path_supported(self): + assert jsonpath_get({"x": 1}, "x") == 1 + + def test_non_dict_intermediate(self): + data = {"usage": "not-a-dict"} + assert jsonpath_get(data, "usage.something") is None + + def test_none_input(self): + assert jsonpath_get(None, "x") is None + + +# --------------------------------------------------------------------------- +# match_submit_route / match_query_route +# --------------------------------------------------------------------------- + +_PROVIDER_CFG = ThirdPartyProviderConfig( + base_url="https://example.com", + api_key_env="TEST_API_KEY", + submit_routes=[ + SubmitRouteConfig( + method="POST", + path_pattern="/openapi/v2/**", + exclude_path_pattern="/openapi/v2/query", + task_id_jsonpath="taskId", + ) + ], + query_routes=[ + QueryRouteConfig( + method="POST", + path_pattern="/openapi/v2/query", + request_task_id_jsonpath="taskId", + status_jsonpath="status", + success_values=["SUCCESS"], + failure_values=["FAILED", "CANCELLED"], + usage_jsonpath="usage.thirdPartyConsumeMoney", + ) + ], +) + + +class TestMatchRoutes: + def test_submit_matches_non_query_path(self): + result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/vidu/submit") + assert result is not None + assert result.task_id_jsonpath == "taskId" + + def test_submit_excluded_by_exclude_pattern(self): + result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/query") + assert result is None + + def test_submit_wrong_method(self): + result = match_submit_route(_PROVIDER_CFG, "GET", "/openapi/v2/vidu/submit") + assert result is None + + def test_query_matches(self): + result = match_query_route(_PROVIDER_CFG, "POST", "/openapi/v2/query") + assert result is not None + assert result.status_jsonpath == "status" + + def test_query_wrong_method(self): + result = match_query_route(_PROVIDER_CFG, "GET", "/openapi/v2/query") + assert result is None + + +# --------------------------------------------------------------------------- +# CallLedger +# --------------------------------------------------------------------------- + + +class TestCallLedger: + def _make_ledger(self) -> CallLedger: + return CallLedger() + + def test_create_and_get(self): + ledger = self._make_ledger() + rec = ledger.create("prov", "tid", None) + assert rec.provider == "prov" + found = ledger.get(rec.proxy_call_id) + assert found is not None + assert found.proxy_call_id == rec.proxy_call_id + + def test_set_reserved(self): + ledger = self._make_ledger() + rec = ledger.create("prov", "tid", None) + ledger.set_reserved(rec.proxy_call_id, "frozen-123") + found = ledger.get(rec.proxy_call_id) + assert found.frozen_id == "frozen-123" + assert found.billing_state == "RESERVED" + + def test_set_running(self): + ledger = self._make_ledger() + rec = ledger.create("prov", "tid", None) + ledger.set_running(rec.proxy_call_id, "task-abc") + found = ledger.get_by_task_id("prov", "task-abc") + assert found is not None + assert found.proxy_call_id == rec.proxy_call_id + + def test_try_claim_finalize_once(self): + ledger = self._make_ledger() + rec = ledger.create("prov", "tid", None) + # First claim should succeed + assert ledger.try_claim_finalize(rec.proxy_call_id) is True + # Second claim should fail — already in progress/done + assert ledger.try_claim_finalize(rec.proxy_call_id) is False + + def test_is_finalized(self): + ledger = self._make_ledger() + rec = ledger.create("prov", "tid", None) + assert ledger.is_finalized(rec.proxy_call_id) is False + ledger.try_claim_finalize(rec.proxy_call_id) + ledger.set_finalized(rec.proxy_call_id, "SUCCESS") + assert ledger.is_finalized(rec.proxy_call_id) is True + + def test_idempotency_key_dedup(self): + ledger = self._make_ledger() + rec1 = ledger.create("prov", "tid", "idem-key-1") + rec2 = ledger.get_by_idempotency_key("prov", "idem-key-1") + assert rec2 is not None + assert rec2.proxy_call_id == rec1.proxy_call_id + + def test_update_response(self): + ledger = self._make_ledger() + rec = ledger.create("prov", "tid", None) + ledger.update_response(rec.proxy_call_id, {"result": "ok"}) + found = ledger.get(rec.proxy_call_id) + assert found.last_response == {"result": "ok"} diff --git a/config.example.yaml b/config.example.yaml index 31e90a0c..308236f9 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -49,6 +49,51 @@ billing: # Authorization: "Bearer your-secret-token" # X-App-Id: "deer-flow" +# ============================================================================ +# Third-Party Transparent Proxy +# ============================================================================ +# Exposes /api/proxy/{provider}/... and handles reserve/finalize around +# third-party async task APIs such as RunningHub. + +third_party_proxy: + enabled: false + providers: + runninghub: + base_url: https://www.runninghub.cn + api_key_env: RUNNINGHUB_API_KEY + api_key_header: Authorization + api_key_prefix: "Bearer " + timeout_seconds: 30.0 + frozen_type: 2 + submit_routes: + - path_pattern: "/openapi/v2/**" + exclude_path_pattern: "/openapi/v2/query" + task_id_jsonpath: "taskId" + # Optional per-model billing override examples: + # frozen_amount: 10.0 + # frozen_type: 2 + + # Example: model-specific reserve policy + # - path_pattern: "/openapi/v2/rhart-image/z-image/turbo-lora" + # task_id_jsonpath: "taskId" + # frozen_amount: 10.0 + # frozen_type: 2 + # - path_pattern: "/openapi/v2/vidu/text-to-video-q3-turbo" + # task_id_jsonpath: "taskId" + # frozen_amount: 50.0 + # frozen_type: 2 + # - path_pattern: "/openapi/v2/wan-2.7/image-edit" + # task_id_jsonpath: "taskId" + # frozen_amount: 20.0 + # frozen_type: 2 + query_routes: + - path_pattern: "/openapi/v2/query" + request_task_id_jsonpath: "taskId" + status_jsonpath: "status" + success_values: ["SUCCESS"] + failure_values: ["FAILED", "CANCELLED"] + usage_jsonpath: "usage.thirdPartyConsumeMoney" + # ============================================================================ # Token Usage Tracking # ============================================================================ diff --git a/docker/docker-compose-dev.yaml b/docker/docker-compose-dev.yaml index 3691c130..df5e6628 100644 --- a/docker/docker-compose-dev.yaml +++ b/docker/docker-compose-dev.yaml @@ -121,6 +121,10 @@ services: UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple} container_name: deer-flow-gateway command: sh -c "cd backend && uv sync && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --reload --reload-include='*.yaml .env' > /app/logs/gateway.log 2>&1" + ports: + # Expose to host so DooD-started sandbox containers can reach the gateway + # via host.docker.internal:8001 + - "8001:8001" volumes: - ../backend/:/app/backend/ # Preserve the .venv built during Docker image build — mounting the full backend/ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 98d54987..a3c2a1b4 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -69,6 +69,10 @@ services: UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple} container_name: deer-flow-gateway command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers 2" + ports: + # Expose gateway port for direct access (e.g. for API clients or testing tools like Postman). + # via host.docker.internal:8001 + - "8001:8001" volumes: - ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro - ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro