169 lines
4.9 KiB
Python
169 lines
4.9 KiB
Python
import logging
|
|
from collections.abc import AsyncGenerator
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Request, Response
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from src.gateway.config import get_gateway_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(tags=["proxy"])
|
|
|
|
# Shared httpx client for all proxy requests
|
|
# This avoids creating/closing clients during streaming responses
|
|
_http_client: httpx.AsyncClient | None = None
|
|
|
|
|
|
def get_http_client() -> httpx.AsyncClient:
|
|
"""Get or create the shared HTTP client.
|
|
|
|
Returns:
|
|
The shared httpx AsyncClient instance.
|
|
"""
|
|
global _http_client
|
|
if _http_client is None:
|
|
_http_client = httpx.AsyncClient()
|
|
return _http_client
|
|
|
|
|
|
async def close_http_client() -> None:
|
|
"""Close the shared HTTP client if it exists."""
|
|
global _http_client
|
|
if _http_client is not None:
|
|
await _http_client.aclose()
|
|
_http_client = None
|
|
|
|
|
|
# Hop-by-hop headers that should not be forwarded
|
|
EXCLUDED_HEADERS = {
|
|
"host",
|
|
"connection",
|
|
"keep-alive",
|
|
"proxy-authenticate",
|
|
"proxy-authorization",
|
|
"te",
|
|
"trailers",
|
|
"transfer-encoding",
|
|
"upgrade",
|
|
"content-length",
|
|
}
|
|
|
|
|
|
async def stream_sse_response(stream_ctx, response: httpx.Response) -> AsyncGenerator[bytes, None]:
|
|
"""Stream SSE response from the upstream server.
|
|
|
|
Args:
|
|
stream_ctx: The httpx stream context manager.
|
|
response: The httpx streaming response.
|
|
|
|
Yields:
|
|
Response chunks.
|
|
"""
|
|
try:
|
|
async for chunk in response.aiter_bytes():
|
|
yield chunk
|
|
finally:
|
|
# Ensure stream is properly closed when done
|
|
await stream_ctx.__aexit__(None, None, None)
|
|
|
|
|
|
async def proxy_request(request: Request, path: str) -> Response | StreamingResponse:
|
|
"""Proxy a request to the LangGraph server.
|
|
|
|
Args:
|
|
request: The incoming FastAPI request.
|
|
path: The path to proxy to.
|
|
|
|
Returns:
|
|
Response or StreamingResponse depending on content type.
|
|
"""
|
|
config = get_gateway_config()
|
|
target_url = f"{config.langgraph_url}/{path}"
|
|
|
|
# Preserve query parameters
|
|
if request.url.query:
|
|
target_url = f"{target_url}?{request.url.query}"
|
|
|
|
# Prepare headers (exclude hop-by-hop headers)
|
|
headers = {key: value for key, value in request.headers.items() if key.lower() not in EXCLUDED_HEADERS}
|
|
|
|
# Read request body for non-GET requests
|
|
body = None
|
|
if request.method not in ("GET", "HEAD"):
|
|
body = await request.body()
|
|
|
|
client = get_http_client()
|
|
|
|
try:
|
|
# Use streaming request to avoid waiting for full response
|
|
# This allows us to check headers immediately and stream SSE without delay
|
|
stream_ctx = client.stream(
|
|
method=request.method,
|
|
url=target_url,
|
|
headers=headers,
|
|
content=body,
|
|
timeout=config.stream_timeout,
|
|
)
|
|
|
|
response = await stream_ctx.__aenter__()
|
|
|
|
content_type = response.headers.get("content-type", "")
|
|
|
|
# Check if response is SSE (Server-Sent Events)
|
|
if "text/event-stream" in content_type:
|
|
# For SSE, stream the response immediately
|
|
return StreamingResponse(
|
|
stream_sse_response(stream_ctx, response),
|
|
status_code=response.status_code,
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
# For non-SSE responses, read full content and close the stream
|
|
content = await response.aread()
|
|
await stream_ctx.__aexit__(None, None, None)
|
|
|
|
# Prepare response headers
|
|
response_headers = dict(response.headers)
|
|
for header in ["transfer-encoding", "connection", "keep-alive"]:
|
|
response_headers.pop(header, None)
|
|
|
|
return Response(
|
|
content=content,
|
|
status_code=response.status_code,
|
|
headers=response_headers,
|
|
)
|
|
|
|
except httpx.TimeoutException:
|
|
logger.error(f"Proxy request to {target_url} timed out")
|
|
return Response(
|
|
content='{"error": "Proxy request timed out"}',
|
|
status_code=504,
|
|
media_type="application/json",
|
|
)
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Proxy request to {target_url} failed: {e}")
|
|
return Response(
|
|
content='{"error": "Proxy request failed"}',
|
|
status_code=502,
|
|
media_type="application/json",
|
|
)
|
|
|
|
|
|
@router.api_route(
|
|
"/{path:path}",
|
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
|
)
|
|
async def proxy_langgraph(request: Request, path: str) -> Response:
|
|
"""Proxy all requests to LangGraph server.
|
|
|
|
This catch-all route forwards requests to the LangGraph server.
|
|
"""
|
|
return await proxy_request(request, path)
|