247 lines
7.3 KiB
Python
247 lines
7.3 KiB
Python
"""HTTP forwarding, route classification, and JSONPath extraction for the third-party proxy."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from deerflow.config.app_config import get_app_config
|
|
from deerflow.config.third_party_proxy_config import (
|
|
QueryRouteConfig,
|
|
SubmitRouteConfig,
|
|
ThirdPartyProviderConfig,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_SENSITIVE_HEADERS = frozenset(
|
|
[
|
|
"authorization",
|
|
"proxy-authorization",
|
|
"x-api-key",
|
|
"api-key",
|
|
"cookie",
|
|
"set-cookie",
|
|
]
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Provider config lookup
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def get_provider_config(provider: str) -> ThirdPartyProviderConfig | None:
|
|
"""Return the provider config for *provider*, or None if not configured/disabled."""
|
|
cfg = get_app_config().third_party_proxy
|
|
if not cfg.enabled:
|
|
return None
|
|
return cfg.providers.get(provider)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Route classification
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def match_submit_route(
|
|
config: ThirdPartyProviderConfig,
|
|
method: str,
|
|
path: str,
|
|
) -> SubmitRouteConfig | None:
|
|
"""Return the first submit route that matches (method, path), or None."""
|
|
for route in config.submit_routes:
|
|
if route.method.upper() != method.upper():
|
|
continue
|
|
if not _path_matches(path, route.path_pattern):
|
|
continue
|
|
if route.exclude_path_pattern and _path_matches(path, route.exclude_path_pattern):
|
|
continue
|
|
return route
|
|
return None
|
|
|
|
|
|
def match_query_route(
|
|
config: ThirdPartyProviderConfig,
|
|
method: str,
|
|
path: str,
|
|
) -> QueryRouteConfig | None:
|
|
"""Return the first query route that matches (method, path), or None."""
|
|
for route in config.query_routes:
|
|
if route.method.upper() != method.upper():
|
|
continue
|
|
if _path_matches(path, route.path_pattern):
|
|
return route
|
|
return None
|
|
|
|
|
|
def _path_matches(path: str, pattern: str) -> bool:
|
|
"""Match *path* against a glob-ish *pattern*.
|
|
|
|
Rules:
|
|
- Pattern ending in /** matches the prefix and any sub-path.
|
|
- Otherwise exact match.
|
|
"""
|
|
# Normalise trailing slashes
|
|
path = path.rstrip("/") or "/"
|
|
pattern = pattern.rstrip("/") or "/"
|
|
|
|
if pattern.endswith("/**"):
|
|
prefix = pattern[:-3]
|
|
return path == prefix or path.startswith(prefix + "/")
|
|
|
|
return path == pattern
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Minimal path evaluator (dot-notation shorthand only)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def jsonpath_get(data: Any, path: str) -> Any:
|
|
"""Extract a value from *data* using a simple dot-notation shorthand path.
|
|
|
|
Supports paths like: taskId usage.thirdPartyConsumeMoney
|
|
Paths with a leading '$' are intentionally not supported.
|
|
Returns None if any segment is missing or the input is not a dict.
|
|
"""
|
|
if not isinstance(path, str):
|
|
return None
|
|
|
|
remainder = path.strip()
|
|
if not remainder or remainder.startswith("$"):
|
|
return None
|
|
|
|
current: Any = data
|
|
for part in remainder.split("."):
|
|
if not part:
|
|
return None
|
|
if not isinstance(current, dict):
|
|
return None
|
|
current = current.get(part)
|
|
if current is None:
|
|
return None
|
|
return current
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HTTP forwarding
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Request headers we never forward (hop-by-hop, sensitive, or proxy-internal)
|
|
_STRIP_REQUEST_HEADERS = frozenset(
|
|
[
|
|
"host",
|
|
"content-length",
|
|
"transfer-encoding",
|
|
"connection",
|
|
"x-thread-id",
|
|
"x-idempotency-key",
|
|
]
|
|
)
|
|
|
|
# Response headers we strip before returning to the caller
|
|
_STRIP_RESPONSE_HEADERS = frozenset(
|
|
[
|
|
"transfer-encoding",
|
|
"connection",
|
|
"keep-alive",
|
|
"content-encoding",
|
|
"content-length",
|
|
]
|
|
)
|
|
|
|
|
|
def _sanitize_headers(headers: dict[str, str]) -> dict[str, str]:
|
|
"""Return a copy of headers with sensitive values redacted."""
|
|
sanitized: dict[str, str] = {}
|
|
for key, value in headers.items():
|
|
if key.lower() in _SENSITIVE_HEADERS:
|
|
sanitized[key] = "***"
|
|
else:
|
|
sanitized[key] = value
|
|
return sanitized
|
|
|
|
|
|
def _preview_body(data: bytes, limit: int = 2048) -> str:
|
|
"""Return a safe textual preview of body bytes for debugging logs."""
|
|
if not data:
|
|
return ""
|
|
chunk = data[:limit]
|
|
text = chunk.decode("utf-8", errors="replace")
|
|
if len(data) > limit:
|
|
text += f" ...<truncated {len(data) - limit} bytes>"
|
|
return text
|
|
|
|
|
|
async def forward_request(
|
|
*,
|
|
provider_config: ThirdPartyProviderConfig,
|
|
method: str,
|
|
path: str,
|
|
headers: dict[str, str],
|
|
body: bytes,
|
|
query_params: str,
|
|
) -> tuple[int, dict[str, str], bytes]:
|
|
"""Forward *method* *path* to the provider and return (status_code, headers, body).
|
|
|
|
The provider's API key (read from the environment variable named in
|
|
``provider_config.api_key_env``) is injected automatically, replacing
|
|
any Authorization header the caller might have sent.
|
|
"""
|
|
target_url = provider_config.base_url.rstrip("/") + "/" + path.lstrip("/")
|
|
if query_params:
|
|
target_url += "?" + query_params
|
|
|
|
# Build forwarded headers: drop internal/hop-by-hop, then inject API key
|
|
forward_headers = {
|
|
k: v for k, v in headers.items() if k.lower() not in _STRIP_REQUEST_HEADERS
|
|
}
|
|
if provider_config.api_key_env:
|
|
api_key = os.getenv(provider_config.api_key_env)
|
|
if api_key:
|
|
forward_headers[provider_config.api_key_header] = provider_config.api_key_prefix + api_key
|
|
else:
|
|
logger.warning(
|
|
"[ThirdPartyProxy] api_key_env '%s' is not set for provider",
|
|
provider_config.api_key_env,
|
|
)
|
|
|
|
logger.info("[ThirdPartyProxy] → %s %s", method, target_url)
|
|
logger.debug(
|
|
"[ThirdPartyProxy] request headers=%s",
|
|
_sanitize_headers(forward_headers)
|
|
)
|
|
logger.debug(
|
|
"[ThirdPartyProxy] request body(%dB)=%s",
|
|
len(body),
|
|
_preview_body(body),
|
|
)
|
|
|
|
async with httpx.AsyncClient(timeout=provider_config.timeout_seconds) as client:
|
|
response = await client.request(
|
|
method=method,
|
|
url=target_url,
|
|
headers=forward_headers,
|
|
content=body,
|
|
)
|
|
|
|
response_headers = {
|
|
k: v
|
|
for k, v in response.headers.items()
|
|
if k.lower() not in _STRIP_RESPONSE_HEADERS
|
|
}
|
|
logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code)
|
|
logger.debug(
|
|
"[ThirdPartyProxy] response headers=%s",
|
|
_sanitize_headers(response_headers)
|
|
)
|
|
logger.debug(
|
|
"[ThirdPartyProxy] response body(%dB)=%s",
|
|
len(response.content),
|
|
_preview_body(response.content),
|
|
)
|
|
return response.status_code, response_headers, response.content
|