"""HTTP forwarding, route classification, and JSONPath extraction for the third-party proxy.""" from __future__ import annotations import logging import os from typing import Any import httpx from deerflow.config.app_config import get_app_config from deerflow.config.third_party_proxy_config import ( QueryRouteConfig, SubmitRouteConfig, ThirdPartyProviderConfig, ) logger = logging.getLogger(__name__) _SENSITIVE_HEADERS = frozenset( [ "authorization", "proxy-authorization", "x-api-key", "api-key", "cookie", "set-cookie", ] ) # --------------------------------------------------------------------------- # Provider config lookup # --------------------------------------------------------------------------- def get_provider_config(provider: str) -> ThirdPartyProviderConfig | None: """Return the provider config for *provider*, or None if not configured/disabled.""" cfg = get_app_config().third_party_proxy if not cfg.enabled: return None return cfg.providers.get(provider) # --------------------------------------------------------------------------- # Route classification # --------------------------------------------------------------------------- def match_submit_route( config: ThirdPartyProviderConfig, method: str, path: str, ) -> SubmitRouteConfig | None: """Return the first submit route that matches (method, path), or None.""" for route in config.submit_routes: if route.method.upper() != method.upper(): continue if not _path_matches(path, route.path_pattern): continue if route.exclude_path_pattern and _path_matches(path, route.exclude_path_pattern): continue return route return None def match_query_route( config: ThirdPartyProviderConfig, method: str, path: str, ) -> QueryRouteConfig | None: """Return the first query route that matches (method, path), or None.""" for route in config.query_routes: if route.method.upper() != method.upper(): continue if _path_matches(path, route.path_pattern): return route return None def _path_matches(path: str, pattern: str) -> bool: """Match *path* against a glob-ish *pattern*. Rules: - Pattern ending in /** matches the prefix and any sub-path. - Otherwise exact match. """ # Normalise trailing slashes path = path.rstrip("/") or "/" pattern = pattern.rstrip("/") or "/" if pattern.endswith("/**"): prefix = pattern[:-3] return path == prefix or path.startswith(prefix + "/") return path == pattern # --------------------------------------------------------------------------- # Minimal path evaluator (dot-notation shorthand only) # --------------------------------------------------------------------------- def jsonpath_get(data: Any, path: str) -> Any: """Extract a value from *data* using a simple dot-notation shorthand path. Supports paths like: taskId usage.thirdPartyConsumeMoney Paths with a leading '$' are intentionally not supported. Returns None if any segment is missing or the input is not a dict. """ if not isinstance(path, str): return None remainder = path.strip() if not remainder or remainder.startswith("$"): return None current: Any = data for part in remainder.split("."): if not part: return None if not isinstance(current, dict): return None current = current.get(part) if current is None: return None return current # --------------------------------------------------------------------------- # HTTP forwarding # --------------------------------------------------------------------------- # Request headers we never forward (hop-by-hop, sensitive, or proxy-internal) _STRIP_REQUEST_HEADERS = frozenset( [ "host", "content-length", "transfer-encoding", "connection", "x-thread-id", "x-idempotency-key", ] ) # Response headers we strip before returning to the caller _STRIP_RESPONSE_HEADERS = frozenset( [ "transfer-encoding", "connection", "keep-alive", "content-encoding", "content-length", ] ) def _sanitize_headers(headers: dict[str, str]) -> dict[str, str]: """Return a copy of headers with sensitive values redacted.""" sanitized: dict[str, str] = {} for key, value in headers.items(): if key.lower() in _SENSITIVE_HEADERS: sanitized[key] = "***" else: sanitized[key] = value return sanitized def _preview_body(data: bytes, limit: int = 2048) -> str: """Return a safe textual preview of body bytes for debugging logs.""" if not data: return "" chunk = data[:limit] text = chunk.decode("utf-8", errors="replace") if len(data) > limit: text += f" ..." return text async def forward_request( *, provider_config: ThirdPartyProviderConfig, method: str, path: str, headers: dict[str, str], body: bytes, query_params: str, ) -> tuple[int, dict[str, str], bytes]: """Forward *method* *path* to the provider and return (status_code, headers, body). The provider's API key (read from the environment variable named in ``provider_config.api_key_env``) is injected automatically, replacing any Authorization header the caller might have sent. """ target_url = provider_config.base_url.rstrip("/") + "/" + path.lstrip("/") if query_params: target_url += "?" + query_params # Build forwarded headers: drop internal/hop-by-hop, then inject API key forward_headers = { k: v for k, v in headers.items() if k.lower() not in _STRIP_REQUEST_HEADERS } if provider_config.api_key_env: api_key = os.getenv(provider_config.api_key_env) if api_key: forward_headers[provider_config.api_key_header] = provider_config.api_key_prefix + api_key else: logger.warning( "[ThirdPartyProxy] api_key_env '%s' is not set for provider", provider_config.api_key_env, ) logger.info("[ThirdPartyProxy] → %s %s", method, target_url) logger.debug( "[ThirdPartyProxy] request headers=%s", _sanitize_headers(forward_headers) ) logger.debug( "[ThirdPartyProxy] request body(%dB)=%s", len(body), _preview_body(body), ) async with httpx.AsyncClient(timeout=provider_config.timeout_seconds) as client: response = await client.request( method=method, url=target_url, headers=forward_headers, content=body, ) response_headers = { k: v for k, v in response.headers.items() if k.lower() not in _STRIP_RESPONSE_HEADERS } logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code) logger.debug( "[ThirdPartyProxy] response headers=%s", _sanitize_headers(response_headers) ) logger.debug( "[ThirdPartyProxy] response body(%dB)=%s", len(response.content), _preview_body(response.content), ) return response.status_code, response_headers, response.content