"""HTTP forwarding, route classification, and JSONPath extraction for the third-party proxy.""" from __future__ import annotations import json import logging import os from typing import Any import httpx from deerflow.config.app_config import get_app_config from deerflow.config.third_party_proxy_config import ( QueryRouteConfig, SubmitRouteConfig, ThirdPartyProviderConfig, ) logger = logging.getLogger(__name__) API_KEY_MARKER = "__API_KEY_MARKER__" # --------------------------------------------------------------------------- # Provider config lookup # --------------------------------------------------------------------------- def get_provider_config(provider: str) -> ThirdPartyProviderConfig | None: """Return the provider config for *provider*, or None if not configured/disabled.""" cfg = get_app_config().third_party_proxy if not cfg.enabled: return None return cfg.providers.get(provider) # --------------------------------------------------------------------------- # Route classification # --------------------------------------------------------------------------- def match_submit_route( config: ThirdPartyProviderConfig, method: str, path: str, ) -> SubmitRouteConfig | None: """Return the first submit route that matches (method, path), or None.""" for route in config.submit_routes: if route.method.upper() != method.upper(): continue if not _path_matches(path, route.path_pattern): continue if route.exclude_path_pattern and _path_matches(path, route.exclude_path_pattern): continue return route return None def match_query_route( config: ThirdPartyProviderConfig, method: str, path: str, ) -> QueryRouteConfig | None: """Return the first query route that matches (method, path), or None.""" for route in config.query_routes: if route.method.upper() != method.upper(): continue if _path_matches(path, route.path_pattern): return route return None def _path_matches(path: str, pattern: str) -> bool: """Match *path* against a glob-ish *pattern*. Rules: - Pattern ending in /** matches the prefix and any sub-path. - Otherwise exact match. """ # Normalise trailing slashes path = path.rstrip("/") or "/" pattern = pattern.rstrip("/") or "/" if pattern.endswith("/**"): prefix = pattern[:-3] return path == prefix or path.startswith(prefix + "/") return path == pattern # --------------------------------------------------------------------------- # Minimal path evaluator (dot-notation shorthand only) # --------------------------------------------------------------------------- def jsonpath_get(data: Any, path: str) -> Any: """Extract a value from *data* using a simple dot-notation shorthand path. Supports paths like: taskId usage.thirdPartyConsumeMoney Paths with a leading '$' are intentionally not supported. Returns None if any segment is missing or the input is not a dict. """ if not isinstance(path, str): return None remainder = path.strip() if not remainder or remainder.startswith("$"): return None current: Any = data for part in remainder.split("."): if not part: return None if not isinstance(current, dict): return None current = current.get(part) if current is None: return None return current # --------------------------------------------------------------------------- # HTTP forwarding # --------------------------------------------------------------------------- # Request headers we never forward (hop-by-hop, sensitive, or proxy-internal) _STRIP_REQUEST_HEADERS = frozenset( [ "host", "content-length", "transfer-encoding", "connection", "x-thread-id", "x-idempotency-key", ] ) # Response headers we strip before returning to the caller _STRIP_RESPONSE_HEADERS = frozenset( [ "transfer-encoding", "connection", "keep-alive", "content-encoding", "content-length", ] ) def _preview_body(data: bytes, limit: int = 2048) -> str: """Return a safe textual preview of body bytes for debugging logs.""" if not data: return "" chunk = data[:limit] text = chunk.decode("utf-8", errors="replace") if len(data) > limit: text += f" ..." return text def _replace_api_key_marker_in_headers(headers: dict[str, str], api_key: str) -> dict[str, str]: """Replace API key marker placeholders in header values.""" replaced: dict[str, str] = {} for key, value in headers.items(): if isinstance(value, str) and API_KEY_MARKER in value: replaced[key] = value.replace(API_KEY_MARKER, api_key) else: replaced[key] = value return replaced def _header_value(headers: dict[str, str], key: str) -> str | None: target = key.lower() for h_key, h_val in headers.items(): if h_key.lower() == target: return h_val return None def _replace_api_key_marker_in_json(data: Any, api_key: str) -> Any: if isinstance(data, str): return data.replace(API_KEY_MARKER, api_key) if isinstance(data, list): return [_replace_api_key_marker_in_json(item, api_key) for item in data] if isinstance(data, dict): return {k: _replace_api_key_marker_in_json(v, api_key) for k, v in data.items()} return data def _replace_api_key_marker_in_body(headers: dict[str, str], body: bytes, api_key: str) -> bytes: """Replace API key marker in JSON body payloads only.""" if not body: return body content_type = _header_value(headers, "content-type") or "" if "application/json" not in content_type.lower(): return body try: parsed = json.loads(body) except (json.JSONDecodeError, ValueError): return body replaced = _replace_api_key_marker_in_json(parsed, api_key) return json.dumps(replaced, ensure_ascii=False, separators=(",", ":")).encode("utf-8") async def forward_request( *, provider_config: ThirdPartyProviderConfig, method: str, path: str, headers: dict[str, str], body: bytes, query_params: str, ) -> tuple[int, dict[str, str], bytes]: """Forward *method* *path* to the provider and return (status_code, headers, body). If configured, the provider API key from ``provider_config.api_key_env`` is used to replace API key marker placeholders in forwarded headers/body. """ target_url = provider_config.base_url.rstrip("/") + "/" + path.lstrip("/") if query_params: target_url += "?" + query_params # Build forwarded headers: drop internal/hop-by-hop, then replace API key markers. forward_headers = { k: v for k, v in headers.items() if k.lower() not in _STRIP_REQUEST_HEADERS } if provider_config.api_key_env: api_key = os.getenv(provider_config.api_key_env) if api_key: # Dependency-injection style: replace marker placeholders first. forward_headers = _replace_api_key_marker_in_headers(forward_headers, api_key) body = _replace_api_key_marker_in_body(forward_headers, body, api_key) else: logger.warning( "[ThirdPartyProxy] api_key_env '%s' is not set for provider", provider_config.api_key_env, ) logger.info("[ThirdPartyProxy] → %s %s", method, target_url) logger.debug( "[ThirdPartyProxy] request headers=%s", forward_headers, ) logger.debug( "[ThirdPartyProxy] request body(%dB)=%s", len(body), _preview_body(body), ) async with httpx.AsyncClient(timeout=provider_config.timeout_seconds) as client: response = await client.request( method=method, url=target_url, headers=forward_headers, content=body, ) response_headers = { k: v for k, v in response.headers.items() if k.lower() not in _STRIP_RESPONSE_HEADERS } logger.info("[ThirdPartyProxy] ← %s %s %d", method, target_url, response.status_code) logger.debug( "[ThirdPartyProxy] response headers=%s", response_headers, ) logger.debug( "[ThirdPartyProxy] response body(%dB)=%s", len(response.content), _preview_body(response.content), ) return response.status_code, response_headers, response.content