deerflow2/backend/app/gateway/third_party_proxy/proxy.py

276 lines
8.5 KiB
Python

"""HTTP forwarding, route classification, and JSONPath extraction for the third-party proxy."""
from __future__ import annotations
import json
import logging
import os
from typing import Any
import httpx
from deerflow.config.app_config import get_app_config
from deerflow.config.third_party_proxy_config import (
QueryRouteConfig,
SubmitRouteConfig,
ThirdPartyProviderConfig,
)
logger = logging.getLogger(__name__)
API_KEY_MARKER = "__API_KEY_MARKER__"
# ---------------------------------------------------------------------------
# Provider config lookup
# ---------------------------------------------------------------------------
def get_provider_config(provider: str) -> ThirdPartyProviderConfig | None:
"""Return the provider config for *provider*, or None if not configured/disabled."""
cfg = get_app_config().third_party_proxy
if not cfg.enabled:
return None
return cfg.providers.get(provider)
# ---------------------------------------------------------------------------
# Route classification
# ---------------------------------------------------------------------------
def match_submit_route(
config: ThirdPartyProviderConfig,
method: str,
path: str,
) -> SubmitRouteConfig | None:
"""Return the first submit route that matches (method, path), or None."""
for route in config.submit_routes:
if route.method.upper() != method.upper():
continue
if not _path_matches(path, route.path_pattern):
continue
if route.exclude_path_pattern and _path_matches(path, route.exclude_path_pattern):
continue
return route
return None
def match_query_route(
config: ThirdPartyProviderConfig,
method: str,
path: str,
) -> QueryRouteConfig | None:
"""Return the first query route that matches (method, path), or None."""
for route in config.query_routes:
if route.method.upper() != method.upper():
continue
if _path_matches(path, route.path_pattern):
return route
return None
def _path_matches(path: str, pattern: str) -> bool:
"""Match *path* against a glob-ish *pattern*.
Rules:
- Pattern ending in /** matches the prefix and any sub-path.
- Otherwise exact match.
"""
# Normalise trailing slashes
path = path.rstrip("/") or "/"
pattern = pattern.rstrip("/") or "/"
if pattern.endswith("/**"):
prefix = pattern[:-3]
return path == prefix or path.startswith(prefix + "/")
return path == pattern
# ---------------------------------------------------------------------------
# Minimal path evaluator (dot-notation shorthand only)
# ---------------------------------------------------------------------------
def jsonpath_get(data: Any, path: str) -> Any:
"""Extract a value from *data* using a simple dot-notation shorthand path.
Supports paths like: taskId usage.thirdPartyConsumeMoney
Paths with a leading '$' are intentionally not supported.
Returns None if any segment is missing or the input is not a dict.
"""
if not isinstance(path, str):
return None
remainder = path.strip()
if not remainder or remainder.startswith("$"):
return None
current: Any = data
for part in remainder.split("."):
if not part:
return None
if not isinstance(current, dict):
return None
current = current.get(part)
if current is None:
return None
return current
# ---------------------------------------------------------------------------
# HTTP forwarding
# ---------------------------------------------------------------------------
# Request headers we never forward (hop-by-hop, sensitive, or proxy-internal)
_STRIP_REQUEST_HEADERS = frozenset(
[
"host",
"content-length",
"transfer-encoding",
"connection",
"x-thread-id",
"x-idempotency-key",
]
)
# Response headers we strip before returning to the caller
_STRIP_RESPONSE_HEADERS = frozenset(
[
"transfer-encoding",
"connection",
"keep-alive",
"content-encoding",
"content-length",
]
)
def _preview_body(data: bytes, limit: int = 2048) -> str:
"""Return a safe textual preview of body bytes for debugging logs."""
if not data:
return ""
chunk = data[:limit]
text = chunk.decode("utf-8", errors="replace")
if len(data) > limit:
text += f" ...<truncated {len(data) - limit} bytes>"
return text
def _replace_api_key_marker_in_headers(headers: dict[str, str], api_key: str) -> dict[str, str]:
"""Replace API key marker placeholders in header values."""
replaced: dict[str, str] = {}
for key, value in headers.items():
if isinstance(value, str) and API_KEY_MARKER in value:
replaced[key] = value.replace(API_KEY_MARKER, api_key)
else:
replaced[key] = value
return replaced
def _header_value(headers: dict[str, str], key: str) -> str | None:
target = key.lower()
for h_key, h_val in headers.items():
if h_key.lower() == target:
return h_val
return None
def _replace_api_key_marker_in_json(data: Any, api_key: str) -> Any:
if isinstance(data, str):
return data.replace(API_KEY_MARKER, api_key)
if isinstance(data, list):
return [_replace_api_key_marker_in_json(item, api_key) for item in data]
if isinstance(data, dict):
return {k: _replace_api_key_marker_in_json(v, api_key) for k, v in data.items()}
return data
def _replace_api_key_marker_in_body(headers: dict[str, str], body: bytes, api_key: str) -> bytes:
"""Replace API key marker in JSON body payloads only."""
if not body:
return body
content_type = _header_value(headers, "content-type") or ""
if "application/json" not in content_type.lower():
return body
try:
parsed = json.loads(body)
except (json.JSONDecodeError, ValueError):
return body
replaced = _replace_api_key_marker_in_json(parsed, api_key)
return json.dumps(replaced, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
async def forward_request(
*,
provider_config: ThirdPartyProviderConfig,
method: str,
path: str,
headers: dict[str, str],
body: bytes,
query_params: str,
) -> tuple[int, dict[str, str], bytes]:
"""Forward *method* *path* to the provider and return (status_code, headers, body).
If configured, the provider API key from ``provider_config.api_key_env``
is used to replace API key marker placeholders in forwarded headers/body.
"""
target_url = provider_config.base_url.rstrip("/") + "/" + path.lstrip("/")
if query_params:
target_url += "?" + query_params
# Build forwarded headers: drop internal/hop-by-hop, then replace API key markers.
forward_headers = {
k: v for k, v in headers.items() if k.lower() not in _STRIP_REQUEST_HEADERS
}
if provider_config.api_key_env:
api_key = os.getenv(provider_config.api_key_env)
if api_key:
# Dependency-injection style: replace marker placeholders first.
forward_headers = _replace_api_key_marker_in_headers(forward_headers, api_key)
body = _replace_api_key_marker_in_body(forward_headers, body, api_key)
else:
logger.warning(
"[ThirdPartyProxy] api_key_env '%s' is not set for provider",
provider_config.api_key_env,
)
logger.info("[ThirdPartyProxy] → %s %s", method, target_url)
logger.debug(
"[ThirdPartyProxy] request headers=%s",
forward_headers,
)
logger.debug(
"[ThirdPartyProxy] request body(%dB)=%s",
len(body),
_preview_body(body),
)
async with httpx.AsyncClient(timeout=provider_config.timeout_seconds) as client:
response = await client.request(
method=method,
url=target_url,
headers=forward_headers,
content=body,
)
response_headers = {
k: v
for k, v in response.headers.items()
if k.lower() not in _STRIP_RESPONSE_HEADERS
}
logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code)
logger.debug(
"[ThirdPartyProxy] response headers=%s",
response_headers,
)
logger.debug(
"[ThirdPartyProxy] response body(%dB)=%s",
len(response.content),
_preview_body(response.content),
)
return response.status_code, response_headers, response.content