Clawith/backend/app/services/llm_client.py

2141 lines
78 KiB
Python

"""Unified LLM client for multiple providers.
Supports OpenAI-compatible APIs, Anthropic native API, and streaming/non-streaming modes.
Provides a consistent interface for all LLM operations across the application.
"""
from __future__ import annotations
import asyncio
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Literal
import httpx
from loguru import logger
# ============================================================================
# Data Models
# ============================================================================
@dataclass
class LLMMessage:
"""Unified message format."""
role: Literal["system", "user", "assistant", "tool"]
content: str | list | None = None
tool_calls: list[dict] | None = None
tool_call_id: str | None = None
reasoning_content: str | None = None
reasoning_signature: str | None = None
dynamic_content: str | None = None
def to_openai_format(self) -> dict:
"""Convert to OpenAI format."""
msg: dict[str, Any] = {"role": self.role}
content = self.content
if self.role == "system" and self.dynamic_content:
content = f"{content}\n\n{self.dynamic_content}"
if content is not None:
msg["content"] = content
if self.tool_calls:
msg["tool_calls"] = self.tool_calls
if self.tool_call_id:
msg["tool_call_id"] = self.tool_call_id
if self.reasoning_content:
msg["reasoning_content"] = self.reasoning_content
return msg
def to_anthropic_format(self) -> dict | None:
"""Convert to Anthropic format (returns None for system messages)."""
if self.role == "system":
return None
role = self.role
# Tool response (from user to assistant)
if role == "tool":
# Build tool_result content: support both string and vision array formats
if isinstance(self.content, list):
# Vision content array: extract text parts and image parts
# Anthropic tool_result content supports [{type: "text", text: ...}, {type: "image", source: ...}]
tool_content_blocks = []
for part in self.content:
if part.get("type") == "text":
tool_content_blocks.append({"type": "text", "text": part.get("text", "")})
elif part.get("type") == "image_url":
# Convert OpenAI image_url format to Anthropic image source format
img_url = part.get("image_url", {}).get("url", "")
if img_url.startswith("data:image/"):
# Parse data URL: data:image/jpeg;base64,xxxxx
header, b64_data = img_url.split(",", 1)
media_type = header.split(":")[1].split(";")[0] # e.g. image/jpeg
tool_content_blocks.append({
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": b64_data,
}
})
result_content = tool_content_blocks if tool_content_blocks else (self.content or "")
else:
result_content = self.content or ""
return {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": self.tool_call_id,
"content": result_content,
}
]
}
content_blocks = []
# Add reasoning/thinking content if present
if self.role == "assistant" and self.reasoning_content:
content_blocks.append({
"type": "thinking",
"thinking": self.reasoning_content,
"signature": self.reasoning_signature or "synthetic_signature"
})
if self.content:
content_blocks.append({"type": "text", "text": self.content})
# Tool requests (from assistant to user)
if self.tool_calls:
for tc in self.tool_calls:
function_call = tc.get("function", {})
args = function_call.get("arguments", "{}")
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
content_blocks.append({
"type": "tool_use",
"id": tc.get("id", ""),
"name": function_call.get("name", ""),
"input": args
})
# Handle the structure
if len(content_blocks) == 1 and content_blocks[0]["type"] == "text":
content = content_blocks[0]["text"]
else:
content = content_blocks
return {"role": role, "content": content}
@dataclass
class LLMResponse:
"""Unified response format."""
content: str
tool_calls: list[dict] = field(default_factory=list)
reasoning_content: str | None = None
reasoning_signature: str | None = None
finish_reason: str | None = None
usage: dict[str, int] | None = None
model: str | None = None
@dataclass
class LLMStreamChunk:
"""Stream chunk format."""
content: str = ""
reasoning_content: str = ""
tool_call: dict | None = None
finish_reason: str | None = None
is_finished: bool = False
usage: dict | None = None
# ============================================================================
# Type Definitions
# ============================================================================
ChunkCallback = Callable[[str], Coroutine[Any, Any, None]]
ToolCallback = Callable[[dict], Coroutine[Any, Any, None]]
ThinkingCallback = Callable[[str], Coroutine[Any, Any, None]]
# ============================================================================
# Base Client Interface
# ============================================================================
class LLMClient(ABC):
"""Abstract base class for LLM clients."""
def __init__(
self,
api_key: str,
base_url: str | None = None,
model: str | None = None,
timeout: float = 120.0,
):
self.api_key = api_key
self.base_url = base_url
self.model = model
self.timeout = timeout
@abstractmethod
async def complete(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Send a completion request and return the full response."""
pass
@abstractmethod
async def stream(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
on_chunk: ChunkCallback | None = None,
on_thinking: ThinkingCallback | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Send a streaming request and return the aggregated response."""
pass
@abstractmethod
def _get_headers(self) -> dict[str, str]:
"""Get request headers."""
pass
# ============================================================================
# OpenAI-Compatible Client
# ============================================================================
class OpenAICompatibleClient(LLMClient):
"""Client for OpenAI-compatible APIs (OpenAI, DeepSeek, Qwen, etc.)."""
DEFAULT_BASE_URL = "https://api.openai.com/v1"
def __init__(
self,
api_key: str,
base_url: str | None = None,
model: str | None = None,
timeout: float = 120.0,
supports_tool_choice: bool = True,
):
super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, model, timeout)
self.supports_tool_choice = supports_tool_choice
self._client: httpx.AsyncClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=self.timeout, follow_redirects=True, proxy=None)
return self._client
def _get_headers(self) -> dict[str, str]:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
def _normalize_base_url(self) -> str:
"""Normalize base URL by stripping trailing /chat/completions."""
url = self.base_url.rstrip("/")
if url.endswith("/chat/completions"):
url = url[: -len("/chat/completions")]
return url
def _build_payload(
self,
messages: list[LLMMessage],
tools: list[dict] | None,
temperature: float | None,
max_tokens: int | None,
stream: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
"""Build request payload."""
payload: dict[str, Any] = {
"model": self.model,
"messages": [m.to_openai_format() for m in messages],
"stream": stream,
}
if temperature is not None:
payload["temperature"] = temperature
# Request usage stats in streaming responses (OpenAI extension)
if stream:
payload["stream_options"] = {"include_usage": True}
if max_tokens:
payload["max_tokens"] = max_tokens
if tools:
payload["tools"] = tools
if self.supports_tool_choice:
payload["tool_choice"] = "auto"
payload["parallel_tool_calls"] = True
# Add any additional kwargs
payload.update(kwargs)
return payload
def _parse_stream_line(
self,
line: str,
in_think: bool,
tag_buffer: str,
json_buffer: str = "",
) -> tuple[LLMStreamChunk, bool, str, str]:
"""Parse a single SSE line from stream.
Returns (chunk, new_in_think, new_tag_buffer, new_json_buffer).
The json_buffer accumulates partial JSON from non-standard APIs that
split a single JSON object across multiple data: lines.
"""
chunk = LLMStreamChunk()
# SSE spec: "data:" may or may not have a space after the colon
if line.startswith("data: "):
data_str = line[6:]
elif line.startswith("data:"):
data_str = line[5:]
else:
# Non-data lines (comments, event types, empty) — never buffer
return chunk, in_think, tag_buffer, json_buffer
data_str = data_str.strip()
if not data_str:
return chunk, in_think, tag_buffer, json_buffer
if data_str == "[DONE]":
chunk.is_finished = True
return chunk, in_think, tag_buffer, ""
# Accumulate into json_buffer for split JSON handling
if json_buffer:
json_buffer += data_str
else:
json_buffer = data_str
try:
data = json.loads(json_buffer)
json_buffer = "" # Reset on successful parse
except json.JSONDecodeError:
# Cap buffer at 64KB to prevent memory leaks
if len(json_buffer) > 65536:
logger.warning("[LLM] JSON buffer exceeded 64KB, discarding")
json_buffer = ""
return chunk, in_think, tag_buffer, json_buffer
if "error" in data:
raise LLMError(f"Stream error: {data['error']}")
# Parse usage from stream (returned in the final chunk with include_usage)
if data.get("usage"):
chunk.usage = data["usage"]
choices = data.get("choices", [])
if not choices:
return chunk, in_think, tag_buffer, json_buffer
choice = choices[0]
delta = choice.get("delta", {})
if choice.get("finish_reason"):
chunk.finish_reason = choice["finish_reason"]
# Reasoning content (DeepSeek R1)
if delta.get("reasoning_content"):
chunk.reasoning_content = delta["reasoning_content"]
# Regular content with think tag filtering
if delta.get("content"):
text = delta["content"]
chunk.content, in_think, tag_buffer = self._filter_think_tags(
text, in_think, tag_buffer
)
# Tool calls
if delta.get("tool_calls"):
for tc_delta in delta["tool_calls"]:
chunk.tool_call = tc_delta
break # Return one at a time
return chunk, in_think, tag_buffer, json_buffer
def _filter_think_tags(
self, text: str, in_think: bool, tag_buffer: str
) -> tuple[str, bool, str]:
"""Filter out <think>...</think> tags from content.
Returns (filtered_content, new_in_think, new_tag_buffer).
"""
tag_buffer += text
emit = ""
i = 0
buf = tag_buffer
while i < len(buf):
if not in_think:
# Look for <think open tag
if buf[i] == "<":
tag_candidate = buf[i:]
if tag_candidate.startswith("<think>"):
in_think = True
i += len("<think>")
continue
elif "<think>".startswith(tag_candidate):
# Partial match - keep in buffer
break
else:
emit += buf[i]
i += 1
else:
emit += buf[i]
i += 1
else:
# Inside think - look for </think> close tag
if buf[i] == "<":
tag_candidate = buf[i:]
if tag_candidate.startswith("</think>"):
in_think = False
i += len("</think>")
continue
elif "</think>".startswith(tag_candidate):
break
i += 1
tag_buffer = buf[i:]
return emit, in_think, tag_buffer
async def complete(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Non-streaming completion."""
url = f"{self._normalize_base_url()}/chat/completions"
payload = self._build_payload(messages, tools, temperature, max_tokens, stream=False, **kwargs)
client = await self._get_client()
response = await client.post(url, json=payload, headers=self._get_headers())
if response.status_code >= 400:
error_text = response.text[:500]
raise LLMError(f"HTTP {response.status_code}: {error_text}")
data = response.json()
if "error" in data:
raise LLMError(f"API error: {data['error']}")
choice = data.get("choices", [{}])[0]
msg = choice.get("message", {})
return LLMResponse(
content=msg.get("content", ""),
tool_calls=msg.get("tool_calls", []),
finish_reason=choice.get("finish_reason"),
usage=data.get("usage"),
model=data.get("model"),
)
async def stream(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
on_chunk: ChunkCallback | None = None,
on_thinking: ThinkingCallback | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Streaming completion."""
url = f"{self._normalize_base_url()}/chat/completions"
payload = self._build_payload(messages, tools, temperature, max_tokens, stream=True, **kwargs)
full_content = ""
full_reasoning = ""
tool_calls_data: list[dict] = []
last_finish_reason: str | None = None
final_usage: dict | None = None
in_think = False
tag_buffer = ""
json_buffer = "" # Buffer for non-standard APIs with split JSON (inspired by PR #120)
max_retries = 3
client = await self._get_client()
for attempt in range(max_retries):
try:
async with client.stream("POST", url, json=payload, headers=self._get_headers()) as resp:
if resp.status_code >= 400:
error_body = ""
async for chunk in resp.aiter_bytes():
error_body += chunk.decode(errors="replace")
raise LLMError(f"HTTP {resp.status_code}: {error_body[:500]}")
async for line in resp.aiter_lines():
chunk, in_think, tag_buffer, json_buffer = self._parse_stream_line(
line, in_think, tag_buffer, json_buffer
)
if chunk.is_finished:
break
if chunk.content:
full_content += chunk.content
if on_chunk:
await on_chunk(chunk.content)
if chunk.reasoning_content:
full_reasoning += chunk.reasoning_content
if on_thinking:
await on_thinking(chunk.reasoning_content)
if chunk.tool_call:
idx = chunk.tool_call.get("index", 0)
while len(tool_calls_data) <= idx:
tool_calls_data.append({"id": "", "function": {"name": "", "arguments": ""}})
tc = tool_calls_data[idx]
if chunk.tool_call.get("id"):
tc["id"] = chunk.tool_call["id"]
fn_delta = chunk.tool_call.get("function", {})
if fn_delta.get("name"):
tc["function"]["name"] += fn_delta["name"]
if fn_delta.get("arguments") is not None:
arg_chunk = fn_delta["arguments"]
if isinstance(arg_chunk, dict):
tc["function"]["arguments"] = json.dumps(arg_chunk, ensure_ascii=False)
else:
tc["function"]["arguments"] += str(arg_chunk)
if chunk.usage:
final_usage = chunk.usage
if chunk.finish_reason:
last_finish_reason = chunk.finish_reason
break
break
except (httpx.TransportError, httpx.ConnectTimeout) as e:
# TransportError covers all network-layer issues:
# - ConnectError, ReadError, WriteError (NetworkError subclasses)
# - RemoteProtocolError, LocalProtocolError (ProtocolError subclasses)
# The last case is common with local vLLM when the server closes
# the connection mid-stream (e.g. OOM, context limit exceeded).
if attempt < max_retries - 1:
wait = (attempt + 1) * 1
logger.warning(f"Stream attempt {attempt + 1} failed ({type(e).__name__}: {e}), retrying in {wait}s...")
await asyncio.sleep(wait)
full_content = ""
full_reasoning = ""
tool_calls_data = []
in_think = False
tag_buffer = ""
json_buffer = ""
else:
raise LLMError(f"Connection failed after {max_retries} attempts: {type(e).__name__}: {e}")
# Clean up any remaining think tags
full_content = re.sub(r"<think>[\s\S]*?</think>\s*", "", full_content).strip()
return LLMResponse(
content=full_content,
tool_calls=tool_calls_data,
reasoning_content=full_reasoning or None,
finish_reason=last_finish_reason,
usage=final_usage,
model=self.model,
)
async def close(self) -> None:
"""Close the HTTP client."""
if self._client and not self._client.is_closed:
try:
await asyncio.wait_for(self._client.aclose(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning("[LLM] Client close timed out, forcing")
self._client = None
# ============================================================================
# OpenAI Responses API Client
# ============================================================================
class OpenAIResponsesClient(LLMClient):
"""Client for OpenAI Responses API (`/v1/responses`)."""
DEFAULT_BASE_URL = "https://api.openai.com/v1"
def __init__(
self,
api_key: str,
base_url: str | None = None,
model: str | None = None,
timeout: float = 120.0,
supports_tool_choice: bool = True,
):
super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, model, timeout)
self.supports_tool_choice = supports_tool_choice
self._client: httpx.AsyncClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=self.timeout, follow_redirects=True, proxy=None)
return self._client
def _get_headers(self) -> dict[str, str]:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
def _normalize_base_url(self) -> str:
"""Normalize base URL by stripping trailing /responses endpoint."""
url = self.base_url.rstrip("/")
if url.endswith("/responses"):
url = url[: -len("/responses")]
return url
def _format_content_for_input(self, content: Any) -> Any:
"""Convert OpenAI chat-style content into Responses API input content."""
if not isinstance(content, list):
return content
formatted: list[dict[str, Any]] = []
for part in content:
if not isinstance(part, dict):
continue
ptype = part.get("type")
if ptype == "text":
formatted.append({"type": "input_text", "text": part.get("text", "")})
elif ptype == "image_url":
img = part.get("image_url", {})
if isinstance(img, dict):
formatted.append({"type": "input_image", "image_url": img.get("url", "")})
else:
formatted.append(part)
return formatted if formatted else content
def _messages_to_input(self, messages: list[LLMMessage]) -> list[dict[str, Any]]:
"""Convert canonical message format to Responses API input format."""
input_items: list[dict[str, Any]] = []
for msg in messages:
if msg.role in {"system", "user", "assistant"} and msg.content is not None:
item: dict[str, Any] = {"role": msg.role}
item["content"] = self._format_content_for_input(msg.content)
input_items.append(item)
if msg.role == "assistant" and msg.tool_calls:
for tc in msg.tool_calls:
fn = tc.get("function", {})
args = fn.get("arguments", "{}")
if isinstance(args, dict):
args = json.dumps(args, ensure_ascii=False)
input_items.append({
"type": "function_call",
"call_id": tc.get("id", ""),
"name": fn.get("name", ""),
"arguments": str(args or "{}"),
})
if msg.role == "tool":
input_items.append({
"type": "function_call_output",
"call_id": msg.tool_call_id or "",
"output": msg.content or "",
})
return input_items
def _convert_tools(self, tools: list[dict] | None) -> list[dict] | None:
"""Convert OpenAI tool schema to Responses API function tool schema."""
if not tools:
return None
converted: list[dict[str, Any]] = []
for tool in tools:
if tool.get("type") != "function":
continue
fn = tool.get("function", {})
converted.append({
"type": "function",
"name": fn.get("name", ""),
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {"type": "object"}),
})
return converted or None
def _build_payload(
self,
messages: list[LLMMessage],
tools: list[dict] | None,
temperature: float | None,
max_tokens: int | None,
stream: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
"""Build request payload."""
payload: dict[str, Any] = {
"model": self.model,
"input": self._messages_to_input(messages),
"stream": stream,
}
if temperature is not None:
payload["temperature"] = temperature
if max_tokens:
payload["max_output_tokens"] = max_tokens
converted_tools = self._convert_tools(tools)
if converted_tools:
payload["tools"] = converted_tools
if self.supports_tool_choice:
payload["tool_choice"] = "auto"
payload.update(kwargs)
return payload
def _parse_response_data(self, data: dict[str, Any]) -> LLMResponse:
"""Convert Responses API payload into canonical LLMResponse."""
content_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
for item in data.get("output", []) or []:
item_type = item.get("type")
if item_type == "message":
for c in item.get("content", []) or []:
c_type = c.get("type")
if c_type in {"output_text", "text"}:
content_parts.append(c.get("text", ""))
elif c_type == "reasoning":
reasoning_parts.append(c.get("summary", "") or c.get("text", ""))
elif item_type == "function_call":
args = item.get("arguments", "{}")
if isinstance(args, dict):
args = json.dumps(args, ensure_ascii=False)
tool_calls.append({
"id": item.get("call_id") or item.get("id", ""),
"type": "function",
"function": {
"name": item.get("name", ""),
"arguments": str(args or "{}"),
},
})
# Some Responses payloads include a pre-aggregated output_text field.
# Use it as a fallback when output blocks are empty.
if not content_parts and data.get("output_text"):
content_parts.append(str(data.get("output_text", "")))
usage = data.get("usage")
finish_reason = "tool_calls" if tool_calls else "stop"
return LLMResponse(
content="".join(content_parts),
tool_calls=tool_calls,
reasoning_content="".join(reasoning_parts) or None,
finish_reason=finish_reason,
usage=usage if isinstance(usage, dict) else None,
model=data.get("model"),
)
def _extract_api_error(self, data: dict[str, Any]) -> str | None:
"""Extract meaningful error message from Responses API payload."""
# OpenAI Responses often returns `"error": null` on success,
# so we must only treat it as error when it's truthy.
err = data.get("error")
if err:
if isinstance(err, dict):
msg = err.get("message") or str(err)
err_type = err.get("type")
err_code = err.get("code")
extra = []
if err_type:
extra.append(f"type={err_type}")
if err_code:
extra.append(f"code={err_code}")
suffix = f" ({', '.join(extra)})" if extra else ""
return f"{msg}{suffix}"
return str(err)
status = str(data.get("status") or "").lower()
if status in {"failed", "incomplete", "cancelled"}:
last_error = data.get("last_error")
incomplete = data.get("incomplete_details")
rid = data.get("id")
details: list[str] = [f"status={status}"]
if rid:
details.append(f"id={rid}")
if last_error:
details.append(f"last_error={last_error}")
if incomplete:
details.append(f"incomplete_details={incomplete}")
return "Responses API returned non-success status: " + "; ".join(details)
return None
def _build_error_log_context(self, data: dict[str, Any]) -> dict[str, Any]:
"""Build compact context for error logs."""
return {
"provider": "openai-response",
"model": self.model,
"response_id": data.get("id"),
"status": data.get("status"),
"incomplete_details": data.get("incomplete_details"),
"last_error": data.get("last_error"),
"has_output": bool(data.get("output")),
}
async def complete(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Non-streaming completion."""
url = f"{self._normalize_base_url()}/responses"
payload = self._build_payload(messages, tools, temperature, max_tokens, stream=False, **kwargs)
client = await self._get_client()
response = await client.post(url, json=payload, headers=self._get_headers())
if response.status_code >= 400:
error_text = response.text[:500]
raise LLMError(f"HTTP {response.status_code}: {error_text}")
data = response.json()
api_error = self._extract_api_error(data)
if api_error:
ctx = self._build_error_log_context(data)
logger.error(
"OpenAIResponses API error: %s | context=%s",
api_error,
ctx,
)
raise LLMError(api_error)
return self._parse_response_data(data)
async def stream(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
on_chunk: ChunkCallback | None = None,
on_thinking: ThinkingCallback | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Streaming completion.
Minimal implementation: fallback to non-streaming and forward final text.
"""
response = await self.complete(
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
if on_chunk and response.content:
await on_chunk(response.content)
if on_thinking and response.reasoning_content:
await on_thinking(response.reasoning_content)
return response
async def close(self) -> None:
"""Close the HTTP client."""
if self._client and not self._client.is_closed:
try:
await asyncio.wait_for(self._client.aclose(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning("[LLM] Client close timed out, forcing")
self._client = None
# ============================================================================
# Gemini Native Client
# ============================================================================
class GeminiClient(LLMClient):
"""Client for Gemini native API (`generateContent` / `streamGenerateContent`)."""
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
def __init__(
self,
api_key: str,
base_url: str | None = None,
model: str | None = None,
timeout: float = 120.0,
supports_tool_choice: bool = True,
):
super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, model, timeout)
self.supports_tool_choice = supports_tool_choice
self._client: httpx.AsyncClient | None = None
self._openai_fallback_client: OpenAICompatibleClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=self.timeout, follow_redirects=True, proxy=None)
return self._client
async def _get_openai_fallback_client(self) -> OpenAICompatibleClient:
"""Fallback for legacy `/openai` base URL deployments."""
if self._openai_fallback_client is None:
self._openai_fallback_client = OpenAICompatibleClient(
api_key=self.api_key,
base_url=self.base_url,
model=self.model,
timeout=self.timeout,
supports_tool_choice=self.supports_tool_choice,
)
return self._openai_fallback_client
def _is_openai_compatible_base(self) -> bool:
"""Detect legacy OpenAI-compatible Gemini gateway endpoint."""
url = self.base_url.rstrip("/").lower()
return url.endswith("/openai") or "/openai/" in url
def _get_headers(self) -> dict[str, str]:
return {
"Content-Type": "application/json",
"x-goog-api-key": self.api_key,
}
def _normalize_base_url(self) -> str:
"""Normalize base URL for Gemini native endpoints."""
url = self.base_url.rstrip("/")
if "/models/" in url and (url.endswith(":generateContent") or url.endswith(":streamGenerateContent")):
url = url.split("/models/")[0]
return url
def _normalize_model_name(self) -> str:
"""Normalize model id for native Gemini endpoint path."""
model = (self.model or "").strip()
if model.startswith("models/"):
model = model[len("models/"):]
return model
def _parse_data_url_image(self, data_url: str) -> tuple[str, str] | None:
"""Parse data URL into (mime_type, base64_data)."""
m = re.match(r"^data:([^;]+);base64,([A-Za-z0-9+/=]+)$", data_url or "")
if not m:
return None
return m.group(1), m.group(2)
def _content_to_gemini_parts(self, content: Any) -> list[dict[str, Any]]:
"""Convert canonical content into Gemini `parts`."""
if content is None:
return []
if isinstance(content, str):
return [{"text": content}]
if isinstance(content, list):
parts: list[dict[str, Any]] = []
for part in content:
if not isinstance(part, dict):
continue
ptype = part.get("type")
if ptype == "text":
text = part.get("text", "")
if text:
parts.append({"text": text})
elif ptype == "image_url":
image_obj = part.get("image_url", {})
image_url = image_obj.get("url", "") if isinstance(image_obj, dict) else ""
parsed = self._parse_data_url_image(image_url)
if parsed:
mime_type, b64_data = parsed
parts.append({
"inlineData": {
"mimeType": mime_type,
"data": b64_data,
}
})
elif image_url:
# Gemini native API requires uploaded files or inline data;
# preserve reference in text when URL cannot be inlined.
parts.append({"text": f"[image_url:{image_url}]"})
return parts
return [{"text": str(content)}]
def _extract_tool_name_map(self, messages: list[LLMMessage]) -> dict[str, str]:
"""Build tool_call_id -> function_name map from assistant messages."""
out: dict[str, str] = {}
for msg in messages:
if msg.role != "assistant" or not msg.tool_calls:
continue
for tc in msg.tool_calls:
tc_id = tc.get("id")
tc_name = tc.get("function", {}).get("name")
if tc_id and tc_name:
out[tc_id] = tc_name
return out
def _convert_tools(self, tools: list[dict] | None) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | None]:
"""Convert OpenAI-style tools to Gemini function declarations."""
if not tools:
return None, None
declarations: list[dict[str, Any]] = []
for tool in tools:
if tool.get("type") != "function":
continue
fn = tool.get("function", {})
decl: dict[str, Any] = {
"name": fn.get("name", ""),
"description": fn.get("description", ""),
}
params = fn.get("parameters")
if isinstance(params, dict):
decl["parameters"] = params
declarations.append(decl)
if not declarations:
return None, None
tools_payload = [{"functionDeclarations": declarations}]
tool_config = None
if self.supports_tool_choice:
tool_config = {"functionCallingConfig": {"mode": "AUTO"}}
return tools_payload, tool_config
def _build_payload(
self,
messages: list[LLMMessage],
tools: list[dict] | None,
temperature: float,
max_tokens: int | None,
**kwargs: Any,
) -> dict[str, Any]:
"""Build Gemini request payload."""
system_blocks: list[str] = []
contents: list[dict[str, Any]] = []
tool_name_map = self._extract_tool_name_map(messages)
for msg in messages:
if msg.role == "system":
parts = self._content_to_gemini_parts(msg.content)
text_chunks = [p.get("text", "") for p in parts if p.get("text")]
if text_chunks:
system_blocks.append("\n".join(text_chunks))
continue
if msg.role == "user":
parts = self._content_to_gemini_parts(msg.content)
if parts:
contents.append({"role": "user", "parts": parts})
continue
if msg.role == "assistant":
parts = self._content_to_gemini_parts(msg.content)
if msg.tool_calls:
for tc in msg.tool_calls:
fn = tc.get("function", {})
args = fn.get("arguments", "{}")
if isinstance(args, str):
try:
parsed_args = json.loads(args)
except json.JSONDecodeError:
parsed_args = {}
elif isinstance(args, dict):
parsed_args = args
else:
parsed_args = {}
parts.append({
"functionCall": {
"name": fn.get("name", ""),
"args": parsed_args,
}
})
if parts:
contents.append({"role": "model", "parts": parts})
continue
if msg.role == "tool":
name = tool_name_map.get(msg.tool_call_id or "", msg.tool_call_id or "tool_result")
response_content = msg.content or ""
if isinstance(response_content, str):
try:
parsed = json.loads(response_content)
if isinstance(parsed, dict):
response_obj: dict[str, Any] = parsed
else:
response_obj = {"result": parsed}
except json.JSONDecodeError:
response_obj = {"result": response_content}
elif isinstance(response_content, dict):
response_obj = response_content
else:
response_obj = {"result": str(response_content)}
contents.append({
"role": "user",
"parts": [{
"functionResponse": {
"name": name,
"response": response_obj,
}
}],
})
generation_config: dict[str, Any] = {}
if temperature is not None:
generation_config["temperature"] = temperature
payload: dict[str, Any] = {
"contents": contents or [{"role": "user", "parts": [{"text": ""}]}],
"generationConfig": generation_config,
}
if max_tokens:
payload["generationConfig"]["maxOutputTokens"] = max_tokens
if system_blocks:
payload["systemInstruction"] = {
"parts": [{"text": "\n\n".join(system_blocks)}]
}
tools_payload, tool_config = self._convert_tools(tools)
if tools_payload:
payload["tools"] = tools_payload
if tool_config:
payload["toolConfig"] = tool_config
payload.update(kwargs)
return payload
def _normalize_usage(self, usage: dict[str, Any] | None) -> dict[str, int] | None:
"""Normalize Gemini usage metadata to unified usage dict."""
if not isinstance(usage, dict):
return None
input_tokens = int(usage.get("promptTokenCount", 0) or 0)
output_tokens = int(usage.get("candidatesTokenCount", 0) or 0)
total_tokens = int(usage.get("totalTokenCount", input_tokens + output_tokens) or 0)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}
def _normalize_finish_reason(self, finish_reason: str | None, tool_calls: list[dict]) -> str | None:
"""Normalize Gemini finish reason to OpenAI-style labels."""
if tool_calls:
return "tool_calls"
if not finish_reason:
return None
mapping = {
"STOP": "stop",
"MAX_TOKENS": "length",
"SAFETY": "content_filter",
"RECITATION": "content_filter",
}
return mapping.get(finish_reason, "stop")
def _parse_response_data(self, data: dict[str, Any]) -> LLMResponse:
"""Convert Gemini native response into canonical LLMResponse."""
content_chunks: list[str] = []
tool_calls: list[dict[str, Any]] = []
seen_tool_calls: set[str] = set()
finish_reason = None
candidates = data.get("candidates") or []
if candidates:
candidate = candidates[0]
finish_reason = candidate.get("finishReason")
content_obj = candidate.get("content", {}) or {}
for part in content_obj.get("parts", []) or []:
text = part.get("text")
if text:
content_chunks.append(text)
function_call = part.get("functionCall")
if function_call:
name = function_call.get("name", "")
args = function_call.get("args", {})
args_str = json.dumps(args if isinstance(args, dict) else {}, ensure_ascii=False)
dedup_key = f"{name}:{args_str}"
if dedup_key in seen_tool_calls:
continue
seen_tool_calls.add(dedup_key)
tool_calls.append({
"id": f"call_{len(tool_calls) + 1}",
"type": "function",
"function": {
"name": name,
"arguments": args_str,
},
})
usage = self._normalize_usage(data.get("usageMetadata"))
return LLMResponse(
content="".join(content_chunks),
tool_calls=tool_calls,
finish_reason=self._normalize_finish_reason(finish_reason, tool_calls),
usage=usage,
model=data.get("modelVersion") or self.model,
)
async def complete(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Non-streaming completion."""
if self._is_openai_compatible_base():
fallback = await self._get_openai_fallback_client()
return await fallback.complete(
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
model_name = self._normalize_model_name()
url = f"{self._normalize_base_url()}/models/{model_name}:generateContent"
payload = self._build_payload(messages, tools, temperature, max_tokens, **kwargs)
client = await self._get_client()
response = await client.post(url, json=payload, headers=self._get_headers())
if response.status_code >= 400:
error_text = response.text[:500]
raise LLMError(f"HTTP {response.status_code}: {error_text}")
data = response.json()
if isinstance(data, dict) and data.get("error"):
raise LLMError(f"API error: {data['error']}")
return self._parse_response_data(data)
async def stream(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
on_chunk: ChunkCallback | None = None,
on_thinking: ThinkingCallback | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Streaming completion using Gemini SSE endpoint."""
if self._is_openai_compatible_base():
fallback = await self._get_openai_fallback_client()
return await fallback.stream(
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
on_chunk=on_chunk,
on_thinking=on_thinking,
**kwargs,
)
model_name = self._normalize_model_name()
url = f"{self._normalize_base_url()}/models/{model_name}:streamGenerateContent"
payload = self._build_payload(messages, tools, temperature, max_tokens, **kwargs)
full_text = ""
tool_calls: list[dict[str, Any]] = []
seen_tool_calls: set[str] = set()
final_usage: dict[str, int] | None = None
final_finish_reason: str | None = None
client = await self._get_client()
try:
async with client.stream(
"POST",
url,
params={"alt": "sse"},
json=payload,
headers=self._get_headers(),
) as resp:
if resp.status_code >= 400:
error_body = ""
async for chunk in resp.aiter_bytes():
error_body += chunk.decode(errors="replace")
raise LLMError(f"HTTP {resp.status_code}: {error_body[:500]}")
async for line in resp.aiter_lines():
if not line.startswith("data:"):
continue
data_str = line[len("data:"):].strip()
if not data_str:
continue
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
if isinstance(data, dict) and data.get("error"):
raise LLMError(f"API error: {data['error']}")
usage = self._normalize_usage(data.get("usageMetadata"))
if usage:
final_usage = usage
candidates = data.get("candidates") or []
if not candidates:
continue
candidate = candidates[0]
final_finish_reason = candidate.get("finishReason") or final_finish_reason
if final_finish_reason:
break
content_obj = candidate.get("content", {}) or {}
for part in content_obj.get("parts", []) or []:
text = part.get("text")
if text:
full_text += text
if on_chunk:
await on_chunk(text)
function_call = part.get("functionCall")
if function_call:
name = function_call.get("name", "")
args = function_call.get("args", {})
args_str = json.dumps(args if isinstance(args, dict) else {}, ensure_ascii=False)
dedup_key = f"{name}:{args_str}"
if dedup_key in seen_tool_calls:
continue
seen_tool_calls.add(dedup_key)
tool_calls.append({
"id": f"call_{len(tool_calls) + 1}",
"type": "function",
"function": {
"name": name,
"arguments": args_str,
},
})
except (httpx.TransportError, httpx.ConnectTimeout) as e:
# TransportError covers NetworkError (ConnectError, ReadError) and
# ProtocolError (RemoteProtocolError) — all common with local vLLM.
raise LLMError(f"Connection failed: {type(e).__name__}: {e}")
return LLMResponse(
content=full_text,
tool_calls=tool_calls,
finish_reason=self._normalize_finish_reason(final_finish_reason, tool_calls),
usage=final_usage,
model=self.model,
)
async def close(self) -> None:
"""Close the HTTP client."""
if self._openai_fallback_client:
await self._openai_fallback_client.close()
if self._client and not self._client.is_closed:
try:
await asyncio.wait_for(self._client.aclose(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning("[LLM] Client close timed out, forcing")
self._client = None
# ============================================================================
# Anthropic Native Client
# ============================================================================
class AnthropicClient(LLMClient):
"""Client for Anthropic's native Messages API.
Supports Claude 3.x and Claude 3.7+ with extended thinking.
"""
DEFAULT_BASE_URL = "https://api.anthropic.com"
API_VERSION = "2023-06-01"
def __init__(
self,
api_key: str,
base_url: str | None = None,
model: str | None = None,
timeout: float = 120.0,
):
super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, model, timeout)
self._client: httpx.AsyncClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=self.timeout, follow_redirects=True, proxy=None)
return self._client
def _get_headers(self) -> dict[str, str]:
return {
"Content-Type": "application/json",
"x-api-key": self.api_key,
"anthropic-version": self.API_VERSION,
"anthropic-beta": "prompt-caching-2024-07-31",
}
def _normalize_base_url(self) -> str:
"""Normalize base URL by stripping trailing API paths."""
url = self.base_url.rstrip("/")
if url.endswith("/v1/messages"):
url = url[: -len("/v1/messages")]
elif url.endswith("/v1/chat/completions"):
url = url[: -len("/v1/chat/completions")]
elif url.endswith("/v1"):
url = url[: -len("/v1")]
return url
def _build_payload(
self,
messages: list[LLMMessage],
tools: list[dict] | None,
temperature: float | None,
max_tokens: int | None,
stream: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
"""Build Anthropic request payload."""
system_blocks = []
anthropic_messages = []
for msg in messages:
if msg.role == "system":
if msg.content:
system_blocks.append({
"type": "text",
"text": msg.content,
"cache_control": {"type": "ephemeral"}
})
if msg.dynamic_content:
system_blocks.append({
"type": "text",
"text": f"\n{msg.dynamic_content}"
})
else:
formatted = msg.to_anthropic_format()
if formatted:
anthropic_messages.append(formatted)
# In Anthropic prompt caching, we also want to cache_control the last user message
# So we add cache_control to the very last message in the history if it's a user message
if anthropic_messages and anthropic_messages[-1]["role"] == "user":
user_msg = anthropic_messages[-1]
if isinstance(user_msg["content"], list) and user_msg["content"]:
# Ensure the last block of the user message has cache_control
user_msg["content"][-1]["cache_control"] = {"type": "ephemeral"}
elif isinstance(user_msg["content"], str):
user_msg["content"] = [
{
"type": "text",
"text": user_msg["content"],
"cache_control": {"type": "ephemeral"}
}
]
payload: dict[str, Any] = {
"model": self.model,
"messages": anthropic_messages,
"max_tokens": max_tokens or 4096,
"stream": stream,
}
if temperature is not None:
payload["temperature"] = temperature
if system_blocks:
payload["system"] = system_blocks
# Handle Extended Thinking
thinking = kwargs.pop("thinking", None)
if thinking:
payload["thinking"] = thinking
# For thinking models, temperature must be 1.0 or omitted in some cases
# But usually it's best to let user specify or default to 1.0 if not set
if "temperature" not in kwargs:
payload["temperature"] = 1.0
if tools:
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
anthropic_tools.append({
"name": func["name"],
"description": func.get("description", ""),
"input_schema": func.get("parameters", {"type": "object"}),
})
if anthropic_tools:
anthropic_tools[-1]["cache_control"] = {"type": "ephemeral"}
payload["tools"] = anthropic_tools
payload.update(kwargs)
return payload
async def complete(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Non-streaming completion."""
url = f"{self._normalize_base_url()}/v1/messages"
payload = self._build_payload(messages, tools, temperature, max_tokens, stream=False, **kwargs)
client = await self._get_client()
response = await client.post(url, json=payload, headers=self._get_headers())
if response.status_code >= 400:
error_text = response.text[:500]
raise LLMError(f"HTTP {response.status_code}: {error_text}")
data = response.json()
if data.get("type") == "error":
raise LLMError(f"API error: {data.get('error', {})}")
full_content = ""
full_reasoning = ""
full_signature = None
tool_calls = []
for block in data.get("content", []):
if block.get("type") == "text":
full_content += block.get("text", "")
elif block.get("type") == "thinking":
full_reasoning += block.get("thinking", "")
full_signature = block.get("signature")
elif block.get("type") == "tool_use":
tool_calls.append({
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(block.get("input", {}), ensure_ascii=False)
}
})
usage = None
if "usage" in data:
usage = {
"input_tokens": data["usage"].get("input_tokens", 0),
"output_tokens": data["usage"].get("output_tokens", 0),
}
return LLMResponse(
content=full_content,
tool_calls=tool_calls,
reasoning_content=full_reasoning or None,
reasoning_signature=full_signature,
finish_reason=data.get("stop_reason"),
usage=usage,
model=data.get("model"),
)
async def stream(
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
on_chunk: ChunkCallback | None = None,
on_thinking: ThinkingCallback | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Streaming completion."""
url = f"{self._normalize_base_url()}/v1/messages"
payload = self._build_payload(messages, tools, temperature, max_tokens, stream=True, **kwargs)
full_content = ""
full_reasoning = ""
full_signature = None
tool_calls_data: list[dict] = []
tool_call_index_map: dict[int, int] = {}
last_finish_reason: str | None = None
final_usage = None
final_model = self.model
client = await self._get_client()
try:
async with client.stream("POST", url, json=payload, headers=self._get_headers()) as resp:
if resp.status_code >= 400:
error_body = ""
async for chunk in resp.aiter_bytes():
error_body += chunk.decode(errors="replace")
raise LLMError(f"HTTP {resp.status_code}: {error_body[:500]}")
current_event = None
async for line in resp.aiter_lines():
if not line.strip():
continue
if line.startswith("event:"):
current_event = line[len("event:"):].strip()
logger.debug(f"[Anthropic SSE] event: {current_event}")
continue
if not line.startswith("data:"):
continue
data_str = line[len("data:"):].strip()
if data_str == "[DONE]":
logger.debug("[Anthropic SSE] received [DONE]")
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
# Handle events
if current_event == "message_start":
msg = data.get("message", {})
if msg.get("model"):
final_model = msg["model"]
if msg.get("usage"):
final_usage = msg["usage"]
elif current_event == "content_block_start":
block = data.get("content_block", {})
idx = data.get("index", 0)
if block.get("type") == "tool_use":
tool_call_index_map[idx] = len(tool_calls_data)
tool_calls_data.append({
"id": block.get("id"),
"type": "function",
"function": {"name": block.get("name"), "arguments": ""}
})
elif current_event == "content_block_delta":
idx = data.get("index", 0)
delta = data.get("delta", {})
delta_type = delta.get("type")
if delta_type == "text_delta":
text = delta.get("text", "")
full_content += text
if on_chunk:
await on_chunk(text)
elif delta_type == "thinking_delta":
thought = delta.get("thinking", "")
full_reasoning += thought
if on_thinking:
await on_thinking(thought)
elif delta_type == "signature_delta":
full_signature = delta.get("signature")
elif delta_type == "input_json_delta":
if idx in tool_call_index_map:
tc_idx = tool_call_index_map[idx]
tool_calls_data[tc_idx]["function"]["arguments"] += delta.get("partial_json", "")
elif current_event == "message_delta":
delta = data.get("delta", {})
logger.debug(f"[Anthropic SSE] message_delta: stop_reason={delta.get('stop_reason')}, usage={bool(data.get('usage'))}")
if data.get("usage"):
final_usage = data["usage"]
if delta.get("stop_reason"):
last_finish_reason = delta["stop_reason"]
logger.debug("[Anthropic SSE] breaking on stop_reason")
break
elif current_event == "error":
error_info = data.get("error", {})
raise LLMError(f"Anthropic stream error ({error_info.get('type')}): {error_info.get('message')}")
elif current_event == "message_stop":
break
logger.debug(f"[Anthropic SSE] stream loop ended, content={len(full_content)} chars, finish={last_finish_reason}, tools={len(tool_calls_data)}")
except (httpx.TransportError, httpx.ConnectTimeout) as e:
# TransportError covers NetworkError (ConnectError, ReadError) and
# ProtocolError (RemoteProtocolError) — all common with local vLLM.
raise LLMError(f"Connection failed: {type(e).__name__}: {e}")
# Normalize stop reason to OpenAI style (optional but helpful for consistency)
if last_finish_reason == "end_turn":
last_finish_reason = "stop"
elif last_finish_reason == "tool_use":
last_finish_reason = "tool_calls"
logger.debug(f"[Anthropic SSE] returning LLMResponse: {len(full_content)} chars, finish={last_finish_reason}")
return LLMResponse(
content=full_content,
tool_calls=tool_calls_data,
reasoning_content=full_reasoning or None,
reasoning_signature=full_signature,
finish_reason=last_finish_reason,
usage=final_usage,
model=final_model,
)
async def close(self) -> None:
"""Close the HTTP client."""
if self._client and not self._client.is_closed:
try:
await asyncio.wait_for(self._client.aclose(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning("[LLM] Client close timed out, forcing")
self._client = None
# ============================================================================
# Factory and Utilities
# ============================================================================
@dataclass(frozen=True)
class ProviderSpec:
"""Provider registry entry."""
provider: str
display_name: str
protocol: Literal["openai_compatible", "anthropic", "openai_responses", "gemini"]
default_base_url: str | None
supports_tool_choice: bool = True
default_max_tokens: int = 4096
model_max_tokens: dict[str, int] = field(default_factory=dict)
# Provider aliases accepted for compatibility
PROVIDER_ALIASES: dict[str, str] = {
"openai_response": "openai-response",
"openairesponses": "openai-response",
}
# Canonical provider registry (single source of truth)
PROVIDER_REGISTRY: dict[str, ProviderSpec] = {
"anthropic": ProviderSpec(
provider="anthropic",
display_name="Anthropic",
protocol="anthropic",
default_base_url="https://api.anthropic.com",
supports_tool_choice=False,
default_max_tokens=8192,
),
"openai": ProviderSpec(
provider="openai",
display_name="OpenAI",
protocol="openai_compatible",
default_base_url="https://api.openai.com/v1",
default_max_tokens=16384,
),
"openai-response": ProviderSpec(
provider="openai-response",
display_name="OpenAI Responses",
protocol="openai_responses",
default_base_url="https://api.openai.com/v1",
default_max_tokens=16384,
),
"azure": ProviderSpec(
provider="azure",
display_name="Azure OpenAI",
protocol="openai_compatible",
default_base_url=None,
default_max_tokens=16384,
),
"deepseek": ProviderSpec(
provider="deepseek",
display_name="DeepSeek",
protocol="openai_compatible",
default_base_url="https://api.deepseek.com/v1",
default_max_tokens=8192,
),
"qwen": ProviderSpec(
provider="qwen",
display_name="Qwen (DashScope)",
protocol="openai_compatible",
default_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
default_max_tokens=8192,
model_max_tokens={
"qwen-plus": 16384,
"qwen-long": 16384,
"qwen-turbo": 8192,
"qwen-max": 8192,
},
),
"minimax": ProviderSpec(
provider="minimax",
display_name="MiniMax",
protocol="openai_compatible",
default_base_url="https://api.minimaxi.com/v1",
default_max_tokens=16384,
),
"openrouter": ProviderSpec(
provider="openrouter",
display_name="OpenRouter",
protocol="openai_compatible",
default_base_url="https://openrouter.ai/api/v1",
default_max_tokens=4096,
),
"zhipu": ProviderSpec(
provider="zhipu",
display_name="Zhipu",
protocol="openai_compatible",
default_base_url="https://open.bigmodel.cn/api/paas/v4",
default_max_tokens=8192,
),
"baidu": ProviderSpec(
provider="baidu",
display_name="Baidu (Qianfan)",
protocol="openai_compatible",
default_base_url="https://qianfan.baidubce.com/v2",
supports_tool_choice=False,
default_max_tokens=4096,
),
"gemini": ProviderSpec(
provider="gemini",
display_name="Gemini",
protocol="gemini",
default_base_url="https://generativelanguage.googleapis.com/v1beta",
default_max_tokens=8192,
),
"kimi": ProviderSpec(
provider="kimi",
display_name="Kimi (Moonshot)",
protocol="openai_compatible",
default_base_url="https://api.moonshot.cn/v1",
default_max_tokens=8192,
),
"vllm": ProviderSpec(
provider="vllm",
display_name="vLLM",
protocol="openai_compatible",
default_base_url="http://localhost:8000/v1",
default_max_tokens=4096,
),
"ollama": ProviderSpec(
provider="ollama",
display_name="Ollama",
protocol="openai_compatible",
default_base_url="http://localhost:11434/v1",
default_max_tokens=4096,
),
"sglang": ProviderSpec(
provider="sglang",
display_name="SGLang",
protocol="openai_compatible",
default_base_url="http://localhost:30000/v1",
default_max_tokens=4096,
),
"custom": ProviderSpec(
provider="custom",
display_name="Custom",
protocol="openai_compatible",
default_base_url=None,
default_max_tokens=4096,
),
}
def normalize_provider(provider: str) -> str:
"""Normalize provider id with aliases and lowercase."""
p = (provider or "").strip().lower()
return PROVIDER_ALIASES.get(p, p)
def get_provider_spec(provider: str) -> ProviderSpec | None:
"""Get provider spec from registry."""
return PROVIDER_REGISTRY.get(normalize_provider(provider))
def get_provider_manifest() -> list[dict[str, Any]]:
"""List supported providers and capabilities for UI/config discovery."""
out: list[dict[str, Any]] = []
for spec in PROVIDER_REGISTRY.values():
out.append({
"provider": spec.provider,
"display_name": spec.display_name,
"protocol": spec.protocol,
"default_base_url": spec.default_base_url,
"supports_tool_choice": spec.supports_tool_choice,
"default_max_tokens": spec.default_max_tokens,
"model_max_tokens": spec.model_max_tokens,
"aliases": [k for k, v in PROVIDER_ALIASES.items() if v == spec.provider],
})
return out
# Backward-compatible constants derived from registry
PROVIDER_CLIENTS: dict[str, type[LLMClient]] = {
spec.provider: (
AnthropicClient
if spec.protocol == "anthropic"
else OpenAIResponsesClient
if spec.protocol == "openai_responses"
else GeminiClient
if spec.protocol == "gemini"
else OpenAICompatibleClient
)
for spec in PROVIDER_REGISTRY.values()
}
PROVIDER_URLS: dict[str, str | None] = {
spec.provider: spec.default_base_url for spec in PROVIDER_REGISTRY.values()
}
TOOL_CHOICE_PROVIDERS = {
spec.provider for spec in PROVIDER_REGISTRY.values() if spec.supports_tool_choice
}
MAX_TOKENS_BY_PROVIDER: dict[str, int] = {
spec.provider: spec.default_max_tokens for spec in PROVIDER_REGISTRY.values()
}
MAX_TOKENS_BY_MODEL: dict[str, int] = {
prefix: limit
for spec in PROVIDER_REGISTRY.values()
for prefix, limit in spec.model_max_tokens.items()
}
class LLMError(Exception):
"""Base exception for LLM client errors."""
pass
def get_provider_base_url(provider: str, custom_base_url: str | None = None) -> str | None:
"""Return the API base URL for a provider.
If a custom base_url is provided, it takes precedence.
Otherwise falls back to the default URL for the provider.
"""
if custom_base_url:
return custom_base_url
spec = get_provider_spec(provider)
if spec:
return spec.default_base_url
return PROVIDER_URLS.get(normalize_provider(provider))
def get_max_tokens(provider: str, model: str | None = None, max_output_tokens: int | None = None) -> int:
"""Return a safe max_tokens value for the given provider/model pair.
Priority: max_output_tokens (DB override) > model prefix > provider default > 4096
"""
spec = get_provider_spec(provider)
model_limits = spec.model_max_tokens if spec else MAX_TOKENS_BY_MODEL
# Highest priority: per-model DB override
if max_output_tokens and max_output_tokens > 0:
return max_output_tokens
# Check model-specific limits
if model:
for prefix, limit in model_limits.items():
if model.lower().startswith(prefix):
return limit
if spec:
return spec.default_max_tokens
# Provider default, falling back to safe 4096
return MAX_TOKENS_BY_PROVIDER.get(normalize_provider(provider), 4096)
def create_llm_client(
provider: str,
api_key: str,
model: str,
base_url: str | None = None,
timeout: float = 120.0,
) -> LLMClient:
"""Create an LLM client for the given provider.
Args:
provider: Provider name (openai, anthropic, deepseek, etc.)
api_key: API key for authentication
model: Model name
base_url: Optional custom base URL
timeout: Request timeout in seconds
Returns:
An instance of the appropriate LLMClient subclass
Raises:
ValueError: If provider is not supported
"""
normalized_provider = normalize_provider(provider)
spec = get_provider_spec(normalized_provider)
# Get base URL
final_base_url = get_provider_base_url(normalized_provider, base_url)
# Create appropriate client
if spec and spec.protocol == "anthropic":
return AnthropicClient(
api_key=api_key,
base_url=final_base_url,
model=model,
timeout=timeout,
)
elif spec and spec.protocol == "openai_responses":
return OpenAIResponsesClient(
api_key=api_key,
base_url=final_base_url,
model=model,
timeout=timeout,
supports_tool_choice=spec.supports_tool_choice,
)
elif spec and spec.protocol == "gemini":
return GeminiClient(
api_key=api_key,
base_url=final_base_url,
model=model,
timeout=timeout,
supports_tool_choice=spec.supports_tool_choice,
)
elif normalized_provider in PROVIDER_CLIENTS:
supports_tool_choice = normalized_provider in TOOL_CHOICE_PROVIDERS
return OpenAICompatibleClient(
api_key=api_key,
base_url=final_base_url,
model=model,
timeout=timeout,
supports_tool_choice=supports_tool_choice,
)
else:
# Default to OpenAI-compatible for unknown providers
return OpenAICompatibleClient(
api_key=api_key,
base_url=final_base_url or PROVIDER_URLS["openai"],
model=model,
timeout=timeout,
supports_tool_choice=True,
)
# ============================================================================
# High-level Convenience Functions
# ============================================================================
async def chat_complete(
provider: str,
api_key: str,
model: str,
messages: list[dict],
base_url: str | None = None,
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float = 120.0,
) -> dict:
"""High-level function for non-streaming chat completion.
Returns response in OpenAI-compatible format for backward compatibility.
"""
client = create_llm_client(provider, api_key, model, base_url, timeout)
try:
llm_messages = [LLMMessage(**m) for m in messages]
response = await client.complete(
messages=llm_messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens or get_max_tokens(provider, model),
)
return {
"choices": [{
"message": {
"role": "assistant",
"content": response.content,
"tool_calls": response.tool_calls or None,
},
"finish_reason": response.finish_reason or "stop",
}],
"model": response.model or model,
"usage": response.usage or {},
}
finally:
await client.close()
async def chat_stream(
provider: str,
api_key: str,
model: str,
messages: list[dict],
base_url: str | None = None,
tools: list[dict] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float = 120.0,
on_chunk: ChunkCallback | None = None,
on_thinking: ThinkingCallback | None = None,
) -> dict:
"""High-level function for streaming chat completion.
Returns aggregated response in OpenAI-compatible format.
"""
client = create_llm_client(provider, api_key, model, base_url, timeout)
try:
llm_messages = [LLMMessage(**m) for m in messages]
response = await client.stream(
messages=llm_messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens or get_max_tokens(provider, model),
on_chunk=on_chunk,
on_thinking=on_thinking,
)
return {
"choices": [{
"message": {
"role": "assistant",
"content": response.content,
"tool_calls": response.tool_calls or None,
},
"finish_reason": response.finish_reason or "stop",
}],
"model": response.model or model,
"usage": response.usage or {},
}
finally:
await client.close()