deerflow2/backend/app/gateway/third_party_proxy/proxy.py

247 lines
7.3 KiB
Python

"""HTTP forwarding, route classification, and JSONPath extraction for the third-party proxy."""
from __future__ import annotations
import logging
import os
from typing import Any
import httpx
from deerflow.config.app_config import get_app_config
from deerflow.config.third_party_proxy_config import (
QueryRouteConfig,
SubmitRouteConfig,
ThirdPartyProviderConfig,
)
logger = logging.getLogger(__name__)
_SENSITIVE_HEADERS = frozenset(
[
"authorization",
"proxy-authorization",
"x-api-key",
"api-key",
"cookie",
"set-cookie",
]
)
# ---------------------------------------------------------------------------
# Provider config lookup
# ---------------------------------------------------------------------------
def get_provider_config(provider: str) -> ThirdPartyProviderConfig | None:
"""Return the provider config for *provider*, or None if not configured/disabled."""
cfg = get_app_config().third_party_proxy
if not cfg.enabled:
return None
return cfg.providers.get(provider)
# ---------------------------------------------------------------------------
# Route classification
# ---------------------------------------------------------------------------
def match_submit_route(
config: ThirdPartyProviderConfig,
method: str,
path: str,
) -> SubmitRouteConfig | None:
"""Return the first submit route that matches (method, path), or None."""
for route in config.submit_routes:
if route.method.upper() != method.upper():
continue
if not _path_matches(path, route.path_pattern):
continue
if route.exclude_path_pattern and _path_matches(path, route.exclude_path_pattern):
continue
return route
return None
def match_query_route(
config: ThirdPartyProviderConfig,
method: str,
path: str,
) -> QueryRouteConfig | None:
"""Return the first query route that matches (method, path), or None."""
for route in config.query_routes:
if route.method.upper() != method.upper():
continue
if _path_matches(path, route.path_pattern):
return route
return None
def _path_matches(path: str, pattern: str) -> bool:
"""Match *path* against a glob-ish *pattern*.
Rules:
- Pattern ending in /** matches the prefix and any sub-path.
- Otherwise exact match.
"""
# Normalise trailing slashes
path = path.rstrip("/") or "/"
pattern = pattern.rstrip("/") or "/"
if pattern.endswith("/**"):
prefix = pattern[:-3]
return path == prefix or path.startswith(prefix + "/")
return path == pattern
# ---------------------------------------------------------------------------
# Minimal path evaluator (dot-notation shorthand only)
# ---------------------------------------------------------------------------
def jsonpath_get(data: Any, path: str) -> Any:
"""Extract a value from *data* using a simple dot-notation shorthand path.
Supports paths like: taskId usage.thirdPartyConsumeMoney
Paths with a leading '$' are intentionally not supported.
Returns None if any segment is missing or the input is not a dict.
"""
if not isinstance(path, str):
return None
remainder = path.strip()
if not remainder or remainder.startswith("$"):
return None
current: Any = data
for part in remainder.split("."):
if not part:
return None
if not isinstance(current, dict):
return None
current = current.get(part)
if current is None:
return None
return current
# ---------------------------------------------------------------------------
# HTTP forwarding
# ---------------------------------------------------------------------------
# Request headers we never forward (hop-by-hop, sensitive, or proxy-internal)
_STRIP_REQUEST_HEADERS = frozenset(
[
"host",
"content-length",
"transfer-encoding",
"connection",
"x-thread-id",
"x-idempotency-key",
]
)
# Response headers we strip before returning to the caller
_STRIP_RESPONSE_HEADERS = frozenset(
[
"transfer-encoding",
"connection",
"keep-alive",
"content-encoding",
"content-length",
]
)
def _sanitize_headers(headers: dict[str, str]) -> dict[str, str]:
"""Return a copy of headers with sensitive values redacted."""
sanitized: dict[str, str] = {}
for key, value in headers.items():
if key.lower() in _SENSITIVE_HEADERS:
sanitized[key] = "***"
else:
sanitized[key] = value
return sanitized
def _preview_body(data: bytes, limit: int = 2048) -> str:
"""Return a safe textual preview of body bytes for debugging logs."""
if not data:
return ""
chunk = data[:limit]
text = chunk.decode("utf-8", errors="replace")
if len(data) > limit:
text += f" ...<truncated {len(data) - limit} bytes>"
return text
async def forward_request(
*,
provider_config: ThirdPartyProviderConfig,
method: str,
path: str,
headers: dict[str, str],
body: bytes,
query_params: str,
) -> tuple[int, dict[str, str], bytes]:
"""Forward *method* *path* to the provider and return (status_code, headers, body).
The provider's API key (read from the environment variable named in
``provider_config.api_key_env``) is injected automatically, replacing
any Authorization header the caller might have sent.
"""
target_url = provider_config.base_url.rstrip("/") + "/" + path.lstrip("/")
if query_params:
target_url += "?" + query_params
# Build forwarded headers: drop internal/hop-by-hop, then inject API key
forward_headers = {
k: v for k, v in headers.items() if k.lower() not in _STRIP_REQUEST_HEADERS
}
if provider_config.api_key_env:
api_key = os.getenv(provider_config.api_key_env)
if api_key:
forward_headers[provider_config.api_key_header] = provider_config.api_key_prefix + api_key
else:
logger.warning(
"[ThirdPartyProxy] api_key_env '%s' is not set for provider",
provider_config.api_key_env,
)
logger.info("[ThirdPartyProxy] → %s %s", method, target_url)
logger.debug(
"[ThirdPartyProxy] request headers=%s",
_sanitize_headers(forward_headers)
)
logger.debug(
"[ThirdPartyProxy] request body(%dB)=%s",
len(body),
_preview_body(body),
)
async with httpx.AsyncClient(timeout=provider_config.timeout_seconds) as client:
response = await client.request(
method=method,
url=target_url,
headers=forward_headers,
content=body,
)
response_headers = {
k: v
for k, v in response.headers.items()
if k.lower() not in _STRIP_RESPONSE_HEADERS
}
logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code)
logger.debug(
"[ThirdPartyProxy] response headers=%s",
_sanitize_headers(response_headers)
)
logger.debug(
"[ThirdPartyProxy] response body(%dB)=%s",
len(response.content),
_preview_body(response.content),
)
return response.status_code, response_headers, response.content