"""Unified LLM client for multiple providers. Supports OpenAI-compatible APIs, Anthropic native API, and streaming/non-streaming modes. Provides a consistent interface for all LLM operations across the application. """ from __future__ import annotations import asyncio import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Callable, Coroutine, Literal import httpx from loguru import logger # ============================================================================ # Data Models # ============================================================================ @dataclass class LLMMessage: """Unified message format.""" role: Literal["system", "user", "assistant", "tool"] content: str | list | None = None tool_calls: list[dict] | None = None tool_call_id: str | None = None reasoning_content: str | None = None reasoning_signature: str | None = None dynamic_content: str | None = None def to_openai_format(self) -> dict: """Convert to OpenAI format.""" msg: dict[str, Any] = {"role": self.role} content = self.content if self.role == "system" and self.dynamic_content: content = f"{content}\n\n{self.dynamic_content}" if content is not None: msg["content"] = content if self.tool_calls: msg["tool_calls"] = self.tool_calls if self.tool_call_id: msg["tool_call_id"] = self.tool_call_id if self.reasoning_content: msg["reasoning_content"] = self.reasoning_content return msg def to_anthropic_format(self) -> dict | None: """Convert to Anthropic format (returns None for system messages).""" if self.role == "system": return None role = self.role # Tool response (from user to assistant) if role == "tool": # Build tool_result content: support both string and vision array formats if isinstance(self.content, list): # Vision content array: extract text parts and image parts # Anthropic tool_result content supports [{type: "text", text: ...}, {type: "image", source: ...}] tool_content_blocks = [] for part in self.content: if part.get("type") == "text": tool_content_blocks.append({"type": "text", "text": part.get("text", "")}) elif part.get("type") == "image_url": # Convert OpenAI image_url format to Anthropic image source format img_url = part.get("image_url", {}).get("url", "") if img_url.startswith("data:image/"): # Parse data URL: data:image/jpeg;base64,xxxxx header, b64_data = img_url.split(",", 1) media_type = header.split(":")[1].split(";")[0] # e.g. image/jpeg tool_content_blocks.append({ "type": "image", "source": { "type": "base64", "media_type": media_type, "data": b64_data, } }) result_content = tool_content_blocks if tool_content_blocks else (self.content or "") else: result_content = self.content or "" return { "role": "user", "content": [ { "type": "tool_result", "tool_use_id": self.tool_call_id, "content": result_content, } ] } content_blocks = [] # Add reasoning/thinking content if present if self.role == "assistant" and self.reasoning_content: content_blocks.append({ "type": "thinking", "thinking": self.reasoning_content, "signature": self.reasoning_signature or "synthetic_signature" }) if self.content: content_blocks.append({"type": "text", "text": self.content}) # Tool requests (from assistant to user) if self.tool_calls: for tc in self.tool_calls: function_call = tc.get("function", {}) args = function_call.get("arguments", "{}") if isinstance(args, str): try: args = json.loads(args) except json.JSONDecodeError: args = {} content_blocks.append({ "type": "tool_use", "id": tc.get("id", ""), "name": function_call.get("name", ""), "input": args }) # Handle the structure if len(content_blocks) == 1 and content_blocks[0]["type"] == "text": content = content_blocks[0]["text"] else: content = content_blocks return {"role": role, "content": content} @dataclass class LLMResponse: """Unified response format.""" content: str tool_calls: list[dict] = field(default_factory=list) reasoning_content: str | None = None reasoning_signature: str | None = None finish_reason: str | None = None usage: dict[str, int] | None = None model: str | None = None @dataclass class LLMStreamChunk: """Stream chunk format.""" content: str = "" reasoning_content: str = "" tool_call: dict | None = None finish_reason: str | None = None is_finished: bool = False usage: dict | None = None # ============================================================================ # Type Definitions # ============================================================================ ChunkCallback = Callable[[str], Coroutine[Any, Any, None]] ToolCallback = Callable[[dict], Coroutine[Any, Any, None]] ThinkingCallback = Callable[[str], Coroutine[Any, Any, None]] # ============================================================================ # Base Client Interface # ============================================================================ class LLMClient(ABC): """Abstract base class for LLM clients.""" def __init__( self, api_key: str, base_url: str | None = None, model: str | None = None, timeout: float = 120.0, ): self.api_key = api_key self.base_url = base_url self.model = model self.timeout = timeout @abstractmethod async def complete( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, **kwargs: Any, ) -> LLMResponse: """Send a completion request and return the full response.""" pass @abstractmethod async def stream( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, on_chunk: ChunkCallback | None = None, on_thinking: ThinkingCallback | None = None, **kwargs: Any, ) -> LLMResponse: """Send a streaming request and return the aggregated response.""" pass @abstractmethod def _get_headers(self) -> dict[str, str]: """Get request headers.""" pass # ============================================================================ # OpenAI-Compatible Client # ============================================================================ class OpenAICompatibleClient(LLMClient): """Client for OpenAI-compatible APIs (OpenAI, DeepSeek, Qwen, etc.).""" DEFAULT_BASE_URL = "https://api.openai.com/v1" def __init__( self, api_key: str, base_url: str | None = None, model: str | None = None, timeout: float = 120.0, supports_tool_choice: bool = True, ): super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, model, timeout) self.supports_tool_choice = supports_tool_choice self._client: httpx.AsyncClient | None = None async def _get_client(self) -> httpx.AsyncClient: """Get or create HTTP client.""" if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient(timeout=self.timeout, follow_redirects=True, proxy=None) return self._client def _get_headers(self) -> dict[str, str]: return { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } def _normalize_base_url(self) -> str: """Normalize base URL by stripping trailing /chat/completions.""" url = self.base_url.rstrip("/") if url.endswith("/chat/completions"): url = url[: -len("/chat/completions")] return url def _build_payload( self, messages: list[LLMMessage], tools: list[dict] | None, temperature: float | None, max_tokens: int | None, stream: bool = False, **kwargs: Any, ) -> dict[str, Any]: """Build request payload.""" payload: dict[str, Any] = { "model": self.model, "messages": [m.to_openai_format() for m in messages], "stream": stream, } if temperature is not None: payload["temperature"] = temperature # Request usage stats in streaming responses (OpenAI extension) if stream: payload["stream_options"] = {"include_usage": True} if max_tokens: payload["max_tokens"] = max_tokens if tools: payload["tools"] = tools if self.supports_tool_choice: payload["tool_choice"] = "auto" payload["parallel_tool_calls"] = True # Add any additional kwargs payload.update(kwargs) return payload def _parse_stream_line( self, line: str, in_think: bool, tag_buffer: str, json_buffer: str = "", ) -> tuple[LLMStreamChunk, bool, str, str]: """Parse a single SSE line from stream. Returns (chunk, new_in_think, new_tag_buffer, new_json_buffer). The json_buffer accumulates partial JSON from non-standard APIs that split a single JSON object across multiple data: lines. """ chunk = LLMStreamChunk() # SSE spec: "data:" may or may not have a space after the colon if line.startswith("data: "): data_str = line[6:] elif line.startswith("data:"): data_str = line[5:] else: # Non-data lines (comments, event types, empty) — never buffer return chunk, in_think, tag_buffer, json_buffer data_str = data_str.strip() if not data_str: return chunk, in_think, tag_buffer, json_buffer if data_str == "[DONE]": chunk.is_finished = True return chunk, in_think, tag_buffer, "" # Accumulate into json_buffer for split JSON handling if json_buffer: json_buffer += data_str else: json_buffer = data_str try: data = json.loads(json_buffer) json_buffer = "" # Reset on successful parse except json.JSONDecodeError: # Cap buffer at 64KB to prevent memory leaks if len(json_buffer) > 65536: logger.warning("[LLM] JSON buffer exceeded 64KB, discarding") json_buffer = "" return chunk, in_think, tag_buffer, json_buffer if "error" in data: raise LLMError(f"Stream error: {data['error']}") # Parse usage from stream (returned in the final chunk with include_usage) if data.get("usage"): chunk.usage = data["usage"] choices = data.get("choices", []) if not choices: return chunk, in_think, tag_buffer, json_buffer choice = choices[0] delta = choice.get("delta", {}) if choice.get("finish_reason"): chunk.finish_reason = choice["finish_reason"] # Reasoning content (DeepSeek R1) if delta.get("reasoning_content"): chunk.reasoning_content = delta["reasoning_content"] # Regular content with think tag filtering if delta.get("content"): text = delta["content"] chunk.content, in_think, tag_buffer = self._filter_think_tags( text, in_think, tag_buffer ) # Tool calls if delta.get("tool_calls"): for tc_delta in delta["tool_calls"]: chunk.tool_call = tc_delta break # Return one at a time return chunk, in_think, tag_buffer, json_buffer def _filter_think_tags( self, text: str, in_think: bool, tag_buffer: str ) -> tuple[str, bool, str]: """Filter out ... tags from content. Returns (filtered_content, new_in_think, new_tag_buffer). """ tag_buffer += text emit = "" i = 0 buf = tag_buffer while i < len(buf): if not in_think: # Look for "): in_think = True i += len("") continue elif "".startswith(tag_candidate): # Partial match - keep in buffer break else: emit += buf[i] i += 1 else: emit += buf[i] i += 1 else: # Inside think - look for close tag if buf[i] == "<": tag_candidate = buf[i:] if tag_candidate.startswith(""): in_think = False i += len("") continue elif "".startswith(tag_candidate): break i += 1 tag_buffer = buf[i:] return emit, in_think, tag_buffer async def complete( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, **kwargs: Any, ) -> LLMResponse: """Non-streaming completion.""" url = f"{self._normalize_base_url()}/chat/completions" payload = self._build_payload(messages, tools, temperature, max_tokens, stream=False, **kwargs) client = await self._get_client() response = await client.post(url, json=payload, headers=self._get_headers()) if response.status_code >= 400: error_text = response.text[:500] raise LLMError(f"HTTP {response.status_code}: {error_text}") data = response.json() if "error" in data: raise LLMError(f"API error: {data['error']}") choice = data.get("choices", [{}])[0] msg = choice.get("message", {}) return LLMResponse( content=msg.get("content", ""), tool_calls=msg.get("tool_calls", []), finish_reason=choice.get("finish_reason"), usage=data.get("usage"), model=data.get("model"), ) async def stream( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, on_chunk: ChunkCallback | None = None, on_thinking: ThinkingCallback | None = None, **kwargs: Any, ) -> LLMResponse: """Streaming completion.""" url = f"{self._normalize_base_url()}/chat/completions" payload = self._build_payload(messages, tools, temperature, max_tokens, stream=True, **kwargs) full_content = "" full_reasoning = "" tool_calls_data: list[dict] = [] last_finish_reason: str | None = None final_usage: dict | None = None in_think = False tag_buffer = "" json_buffer = "" # Buffer for non-standard APIs with split JSON (inspired by PR #120) max_retries = 3 client = await self._get_client() for attempt in range(max_retries): try: async with client.stream("POST", url, json=payload, headers=self._get_headers()) as resp: if resp.status_code >= 400: error_body = "" async for chunk in resp.aiter_bytes(): error_body += chunk.decode(errors="replace") raise LLMError(f"HTTP {resp.status_code}: {error_body[:500]}") async for line in resp.aiter_lines(): chunk, in_think, tag_buffer, json_buffer = self._parse_stream_line( line, in_think, tag_buffer, json_buffer ) if chunk.is_finished: break if chunk.content: full_content += chunk.content if on_chunk: await on_chunk(chunk.content) if chunk.reasoning_content: full_reasoning += chunk.reasoning_content if on_thinking: await on_thinking(chunk.reasoning_content) if chunk.tool_call: idx = chunk.tool_call.get("index", 0) while len(tool_calls_data) <= idx: tool_calls_data.append({"id": "", "function": {"name": "", "arguments": ""}}) tc = tool_calls_data[idx] if chunk.tool_call.get("id"): tc["id"] = chunk.tool_call["id"] fn_delta = chunk.tool_call.get("function", {}) if fn_delta.get("name"): tc["function"]["name"] += fn_delta["name"] if fn_delta.get("arguments") is not None: arg_chunk = fn_delta["arguments"] if isinstance(arg_chunk, dict): tc["function"]["arguments"] = json.dumps(arg_chunk, ensure_ascii=False) else: tc["function"]["arguments"] += str(arg_chunk) if chunk.usage: final_usage = chunk.usage if chunk.finish_reason: last_finish_reason = chunk.finish_reason break break except (httpx.TransportError, httpx.ConnectTimeout) as e: # TransportError covers all network-layer issues: # - ConnectError, ReadError, WriteError (NetworkError subclasses) # - RemoteProtocolError, LocalProtocolError (ProtocolError subclasses) # The last case is common with local vLLM when the server closes # the connection mid-stream (e.g. OOM, context limit exceeded). if attempt < max_retries - 1: wait = (attempt + 1) * 1 logger.warning(f"Stream attempt {attempt + 1} failed ({type(e).__name__}: {e}), retrying in {wait}s...") await asyncio.sleep(wait) full_content = "" full_reasoning = "" tool_calls_data = [] in_think = False tag_buffer = "" json_buffer = "" else: raise LLMError(f"Connection failed after {max_retries} attempts: {type(e).__name__}: {e}") # Clean up any remaining think tags full_content = re.sub(r"[\s\S]*?\s*", "", full_content).strip() return LLMResponse( content=full_content, tool_calls=tool_calls_data, reasoning_content=full_reasoning or None, finish_reason=last_finish_reason, usage=final_usage, model=self.model, ) async def close(self) -> None: """Close the HTTP client.""" if self._client and not self._client.is_closed: try: await asyncio.wait_for(self._client.aclose(), timeout=5.0) except asyncio.TimeoutError: logger.warning("[LLM] Client close timed out, forcing") self._client = None # ============================================================================ # OpenAI Responses API Client # ============================================================================ class OpenAIResponsesClient(LLMClient): """Client for OpenAI Responses API (`/v1/responses`).""" DEFAULT_BASE_URL = "https://api.openai.com/v1" def __init__( self, api_key: str, base_url: str | None = None, model: str | None = None, timeout: float = 120.0, supports_tool_choice: bool = True, ): super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, model, timeout) self.supports_tool_choice = supports_tool_choice self._client: httpx.AsyncClient | None = None async def _get_client(self) -> httpx.AsyncClient: """Get or create HTTP client.""" if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient(timeout=self.timeout, follow_redirects=True, proxy=None) return self._client def _get_headers(self) -> dict[str, str]: return { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } def _normalize_base_url(self) -> str: """Normalize base URL by stripping trailing /responses endpoint.""" url = self.base_url.rstrip("/") if url.endswith("/responses"): url = url[: -len("/responses")] return url def _format_content_for_input(self, content: Any) -> Any: """Convert OpenAI chat-style content into Responses API input content.""" if not isinstance(content, list): return content formatted: list[dict[str, Any]] = [] for part in content: if not isinstance(part, dict): continue ptype = part.get("type") if ptype == "text": formatted.append({"type": "input_text", "text": part.get("text", "")}) elif ptype == "image_url": img = part.get("image_url", {}) if isinstance(img, dict): formatted.append({"type": "input_image", "image_url": img.get("url", "")}) else: formatted.append(part) return formatted if formatted else content def _messages_to_input(self, messages: list[LLMMessage]) -> list[dict[str, Any]]: """Convert canonical message format to Responses API input format.""" input_items: list[dict[str, Any]] = [] for msg in messages: if msg.role in {"system", "user", "assistant"} and msg.content is not None: item: dict[str, Any] = {"role": msg.role} item["content"] = self._format_content_for_input(msg.content) input_items.append(item) if msg.role == "assistant" and msg.tool_calls: for tc in msg.tool_calls: fn = tc.get("function", {}) args = fn.get("arguments", "{}") if isinstance(args, dict): args = json.dumps(args, ensure_ascii=False) input_items.append({ "type": "function_call", "call_id": tc.get("id", ""), "name": fn.get("name", ""), "arguments": str(args or "{}"), }) if msg.role == "tool": input_items.append({ "type": "function_call_output", "call_id": msg.tool_call_id or "", "output": msg.content or "", }) return input_items def _convert_tools(self, tools: list[dict] | None) -> list[dict] | None: """Convert OpenAI tool schema to Responses API function tool schema.""" if not tools: return None converted: list[dict[str, Any]] = [] for tool in tools: if tool.get("type") != "function": continue fn = tool.get("function", {}) converted.append({ "type": "function", "name": fn.get("name", ""), "description": fn.get("description", ""), "parameters": fn.get("parameters", {"type": "object"}), }) return converted or None def _build_payload( self, messages: list[LLMMessage], tools: list[dict] | None, temperature: float | None, max_tokens: int | None, stream: bool = False, **kwargs: Any, ) -> dict[str, Any]: """Build request payload.""" payload: dict[str, Any] = { "model": self.model, "input": self._messages_to_input(messages), "stream": stream, } if temperature is not None: payload["temperature"] = temperature if max_tokens: payload["max_output_tokens"] = max_tokens converted_tools = self._convert_tools(tools) if converted_tools: payload["tools"] = converted_tools if self.supports_tool_choice: payload["tool_choice"] = "auto" payload.update(kwargs) return payload def _parse_response_data(self, data: dict[str, Any]) -> LLMResponse: """Convert Responses API payload into canonical LLMResponse.""" content_parts: list[str] = [] reasoning_parts: list[str] = [] tool_calls: list[dict[str, Any]] = [] for item in data.get("output", []) or []: item_type = item.get("type") if item_type == "message": for c in item.get("content", []) or []: c_type = c.get("type") if c_type in {"output_text", "text"}: content_parts.append(c.get("text", "")) elif c_type == "reasoning": reasoning_parts.append(c.get("summary", "") or c.get("text", "")) elif item_type == "function_call": args = item.get("arguments", "{}") if isinstance(args, dict): args = json.dumps(args, ensure_ascii=False) tool_calls.append({ "id": item.get("call_id") or item.get("id", ""), "type": "function", "function": { "name": item.get("name", ""), "arguments": str(args or "{}"), }, }) # Some Responses payloads include a pre-aggregated output_text field. # Use it as a fallback when output blocks are empty. if not content_parts and data.get("output_text"): content_parts.append(str(data.get("output_text", ""))) usage = data.get("usage") finish_reason = "tool_calls" if tool_calls else "stop" return LLMResponse( content="".join(content_parts), tool_calls=tool_calls, reasoning_content="".join(reasoning_parts) or None, finish_reason=finish_reason, usage=usage if isinstance(usage, dict) else None, model=data.get("model"), ) def _extract_api_error(self, data: dict[str, Any]) -> str | None: """Extract meaningful error message from Responses API payload.""" # OpenAI Responses often returns `"error": null` on success, # so we must only treat it as error when it's truthy. err = data.get("error") if err: if isinstance(err, dict): msg = err.get("message") or str(err) err_type = err.get("type") err_code = err.get("code") extra = [] if err_type: extra.append(f"type={err_type}") if err_code: extra.append(f"code={err_code}") suffix = f" ({', '.join(extra)})" if extra else "" return f"{msg}{suffix}" return str(err) status = str(data.get("status") or "").lower() if status in {"failed", "incomplete", "cancelled"}: last_error = data.get("last_error") incomplete = data.get("incomplete_details") rid = data.get("id") details: list[str] = [f"status={status}"] if rid: details.append(f"id={rid}") if last_error: details.append(f"last_error={last_error}") if incomplete: details.append(f"incomplete_details={incomplete}") return "Responses API returned non-success status: " + "; ".join(details) return None def _build_error_log_context(self, data: dict[str, Any]) -> dict[str, Any]: """Build compact context for error logs.""" return { "provider": "openai-response", "model": self.model, "response_id": data.get("id"), "status": data.get("status"), "incomplete_details": data.get("incomplete_details"), "last_error": data.get("last_error"), "has_output": bool(data.get("output")), } async def complete( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, **kwargs: Any, ) -> LLMResponse: """Non-streaming completion.""" url = f"{self._normalize_base_url()}/responses" payload = self._build_payload(messages, tools, temperature, max_tokens, stream=False, **kwargs) client = await self._get_client() response = await client.post(url, json=payload, headers=self._get_headers()) if response.status_code >= 400: error_text = response.text[:500] raise LLMError(f"HTTP {response.status_code}: {error_text}") data = response.json() api_error = self._extract_api_error(data) if api_error: ctx = self._build_error_log_context(data) logger.error( "OpenAIResponses API error: %s | context=%s", api_error, ctx, ) raise LLMError(api_error) return self._parse_response_data(data) async def stream( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, on_chunk: ChunkCallback | None = None, on_thinking: ThinkingCallback | None = None, **kwargs: Any, ) -> LLMResponse: """Streaming completion. Minimal implementation: fallback to non-streaming and forward final text. """ response = await self.complete( messages=messages, tools=tools, temperature=temperature, max_tokens=max_tokens, **kwargs, ) if on_chunk and response.content: await on_chunk(response.content) if on_thinking and response.reasoning_content: await on_thinking(response.reasoning_content) return response async def close(self) -> None: """Close the HTTP client.""" if self._client and not self._client.is_closed: try: await asyncio.wait_for(self._client.aclose(), timeout=5.0) except asyncio.TimeoutError: logger.warning("[LLM] Client close timed out, forcing") self._client = None # ============================================================================ # Gemini Native Client # ============================================================================ class GeminiClient(LLMClient): """Client for Gemini native API (`generateContent` / `streamGenerateContent`).""" DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" def __init__( self, api_key: str, base_url: str | None = None, model: str | None = None, timeout: float = 120.0, supports_tool_choice: bool = True, ): super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, model, timeout) self.supports_tool_choice = supports_tool_choice self._client: httpx.AsyncClient | None = None self._openai_fallback_client: OpenAICompatibleClient | None = None async def _get_client(self) -> httpx.AsyncClient: """Get or create HTTP client.""" if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient(timeout=self.timeout, follow_redirects=True, proxy=None) return self._client async def _get_openai_fallback_client(self) -> OpenAICompatibleClient: """Fallback for legacy `/openai` base URL deployments.""" if self._openai_fallback_client is None: self._openai_fallback_client = OpenAICompatibleClient( api_key=self.api_key, base_url=self.base_url, model=self.model, timeout=self.timeout, supports_tool_choice=self.supports_tool_choice, ) return self._openai_fallback_client def _is_openai_compatible_base(self) -> bool: """Detect legacy OpenAI-compatible Gemini gateway endpoint.""" url = self.base_url.rstrip("/").lower() return url.endswith("/openai") or "/openai/" in url def _get_headers(self) -> dict[str, str]: return { "Content-Type": "application/json", "x-goog-api-key": self.api_key, } def _normalize_base_url(self) -> str: """Normalize base URL for Gemini native endpoints.""" url = self.base_url.rstrip("/") if "/models/" in url and (url.endswith(":generateContent") or url.endswith(":streamGenerateContent")): url = url.split("/models/")[0] return url def _normalize_model_name(self) -> str: """Normalize model id for native Gemini endpoint path.""" model = (self.model or "").strip() if model.startswith("models/"): model = model[len("models/"):] return model def _parse_data_url_image(self, data_url: str) -> tuple[str, str] | None: """Parse data URL into (mime_type, base64_data).""" m = re.match(r"^data:([^;]+);base64,([A-Za-z0-9+/=]+)$", data_url or "") if not m: return None return m.group(1), m.group(2) def _content_to_gemini_parts(self, content: Any) -> list[dict[str, Any]]: """Convert canonical content into Gemini `parts`.""" if content is None: return [] if isinstance(content, str): return [{"text": content}] if isinstance(content, list): parts: list[dict[str, Any]] = [] for part in content: if not isinstance(part, dict): continue ptype = part.get("type") if ptype == "text": text = part.get("text", "") if text: parts.append({"text": text}) elif ptype == "image_url": image_obj = part.get("image_url", {}) image_url = image_obj.get("url", "") if isinstance(image_obj, dict) else "" parsed = self._parse_data_url_image(image_url) if parsed: mime_type, b64_data = parsed parts.append({ "inlineData": { "mimeType": mime_type, "data": b64_data, } }) elif image_url: # Gemini native API requires uploaded files or inline data; # preserve reference in text when URL cannot be inlined. parts.append({"text": f"[image_url:{image_url}]"}) return parts return [{"text": str(content)}] def _extract_tool_name_map(self, messages: list[LLMMessage]) -> dict[str, str]: """Build tool_call_id -> function_name map from assistant messages.""" out: dict[str, str] = {} for msg in messages: if msg.role != "assistant" or not msg.tool_calls: continue for tc in msg.tool_calls: tc_id = tc.get("id") tc_name = tc.get("function", {}).get("name") if tc_id and tc_name: out[tc_id] = tc_name return out def _convert_tools(self, tools: list[dict] | None) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | None]: """Convert OpenAI-style tools to Gemini function declarations.""" if not tools: return None, None declarations: list[dict[str, Any]] = [] for tool in tools: if tool.get("type") != "function": continue fn = tool.get("function", {}) decl: dict[str, Any] = { "name": fn.get("name", ""), "description": fn.get("description", ""), } params = fn.get("parameters") if isinstance(params, dict): decl["parameters"] = params declarations.append(decl) if not declarations: return None, None tools_payload = [{"functionDeclarations": declarations}] tool_config = None if self.supports_tool_choice: tool_config = {"functionCallingConfig": {"mode": "AUTO"}} return tools_payload, tool_config def _build_payload( self, messages: list[LLMMessage], tools: list[dict] | None, temperature: float, max_tokens: int | None, **kwargs: Any, ) -> dict[str, Any]: """Build Gemini request payload.""" system_blocks: list[str] = [] contents: list[dict[str, Any]] = [] tool_name_map = self._extract_tool_name_map(messages) for msg in messages: if msg.role == "system": parts = self._content_to_gemini_parts(msg.content) text_chunks = [p.get("text", "") for p in parts if p.get("text")] if text_chunks: system_blocks.append("\n".join(text_chunks)) continue if msg.role == "user": parts = self._content_to_gemini_parts(msg.content) if parts: contents.append({"role": "user", "parts": parts}) continue if msg.role == "assistant": parts = self._content_to_gemini_parts(msg.content) if msg.tool_calls: for tc in msg.tool_calls: fn = tc.get("function", {}) args = fn.get("arguments", "{}") if isinstance(args, str): try: parsed_args = json.loads(args) except json.JSONDecodeError: parsed_args = {} elif isinstance(args, dict): parsed_args = args else: parsed_args = {} parts.append({ "functionCall": { "name": fn.get("name", ""), "args": parsed_args, } }) if parts: contents.append({"role": "model", "parts": parts}) continue if msg.role == "tool": name = tool_name_map.get(msg.tool_call_id or "", msg.tool_call_id or "tool_result") response_content = msg.content or "" if isinstance(response_content, str): try: parsed = json.loads(response_content) if isinstance(parsed, dict): response_obj: dict[str, Any] = parsed else: response_obj = {"result": parsed} except json.JSONDecodeError: response_obj = {"result": response_content} elif isinstance(response_content, dict): response_obj = response_content else: response_obj = {"result": str(response_content)} contents.append({ "role": "user", "parts": [{ "functionResponse": { "name": name, "response": response_obj, } }], }) generation_config: dict[str, Any] = {} if temperature is not None: generation_config["temperature"] = temperature payload: dict[str, Any] = { "contents": contents or [{"role": "user", "parts": [{"text": ""}]}], "generationConfig": generation_config, } if max_tokens: payload["generationConfig"]["maxOutputTokens"] = max_tokens if system_blocks: payload["systemInstruction"] = { "parts": [{"text": "\n\n".join(system_blocks)}] } tools_payload, tool_config = self._convert_tools(tools) if tools_payload: payload["tools"] = tools_payload if tool_config: payload["toolConfig"] = tool_config payload.update(kwargs) return payload def _normalize_usage(self, usage: dict[str, Any] | None) -> dict[str, int] | None: """Normalize Gemini usage metadata to unified usage dict.""" if not isinstance(usage, dict): return None input_tokens = int(usage.get("promptTokenCount", 0) or 0) output_tokens = int(usage.get("candidatesTokenCount", 0) or 0) total_tokens = int(usage.get("totalTokenCount", input_tokens + output_tokens) or 0) return { "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens, } def _normalize_finish_reason(self, finish_reason: str | None, tool_calls: list[dict]) -> str | None: """Normalize Gemini finish reason to OpenAI-style labels.""" if tool_calls: return "tool_calls" if not finish_reason: return None mapping = { "STOP": "stop", "MAX_TOKENS": "length", "SAFETY": "content_filter", "RECITATION": "content_filter", } return mapping.get(finish_reason, "stop") def _parse_response_data(self, data: dict[str, Any]) -> LLMResponse: """Convert Gemini native response into canonical LLMResponse.""" content_chunks: list[str] = [] tool_calls: list[dict[str, Any]] = [] seen_tool_calls: set[str] = set() finish_reason = None candidates = data.get("candidates") or [] if candidates: candidate = candidates[0] finish_reason = candidate.get("finishReason") content_obj = candidate.get("content", {}) or {} for part in content_obj.get("parts", []) or []: text = part.get("text") if text: content_chunks.append(text) function_call = part.get("functionCall") if function_call: name = function_call.get("name", "") args = function_call.get("args", {}) args_str = json.dumps(args if isinstance(args, dict) else {}, ensure_ascii=False) dedup_key = f"{name}:{args_str}" if dedup_key in seen_tool_calls: continue seen_tool_calls.add(dedup_key) tool_calls.append({ "id": f"call_{len(tool_calls) + 1}", "type": "function", "function": { "name": name, "arguments": args_str, }, }) usage = self._normalize_usage(data.get("usageMetadata")) return LLMResponse( content="".join(content_chunks), tool_calls=tool_calls, finish_reason=self._normalize_finish_reason(finish_reason, tool_calls), usage=usage, model=data.get("modelVersion") or self.model, ) async def complete( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, **kwargs: Any, ) -> LLMResponse: """Non-streaming completion.""" if self._is_openai_compatible_base(): fallback = await self._get_openai_fallback_client() return await fallback.complete( messages=messages, tools=tools, temperature=temperature, max_tokens=max_tokens, **kwargs, ) model_name = self._normalize_model_name() url = f"{self._normalize_base_url()}/models/{model_name}:generateContent" payload = self._build_payload(messages, tools, temperature, max_tokens, **kwargs) client = await self._get_client() response = await client.post(url, json=payload, headers=self._get_headers()) if response.status_code >= 400: error_text = response.text[:500] raise LLMError(f"HTTP {response.status_code}: {error_text}") data = response.json() if isinstance(data, dict) and data.get("error"): raise LLMError(f"API error: {data['error']}") return self._parse_response_data(data) async def stream( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, on_chunk: ChunkCallback | None = None, on_thinking: ThinkingCallback | None = None, **kwargs: Any, ) -> LLMResponse: """Streaming completion using Gemini SSE endpoint.""" if self._is_openai_compatible_base(): fallback = await self._get_openai_fallback_client() return await fallback.stream( messages=messages, tools=tools, temperature=temperature, max_tokens=max_tokens, on_chunk=on_chunk, on_thinking=on_thinking, **kwargs, ) model_name = self._normalize_model_name() url = f"{self._normalize_base_url()}/models/{model_name}:streamGenerateContent" payload = self._build_payload(messages, tools, temperature, max_tokens, **kwargs) full_text = "" tool_calls: list[dict[str, Any]] = [] seen_tool_calls: set[str] = set() final_usage: dict[str, int] | None = None final_finish_reason: str | None = None client = await self._get_client() try: async with client.stream( "POST", url, params={"alt": "sse"}, json=payload, headers=self._get_headers(), ) as resp: if resp.status_code >= 400: error_body = "" async for chunk in resp.aiter_bytes(): error_body += chunk.decode(errors="replace") raise LLMError(f"HTTP {resp.status_code}: {error_body[:500]}") async for line in resp.aiter_lines(): if not line.startswith("data:"): continue data_str = line[len("data:"):].strip() if not data_str: continue if data_str == "[DONE]": break try: data = json.loads(data_str) except json.JSONDecodeError: continue if isinstance(data, dict) and data.get("error"): raise LLMError(f"API error: {data['error']}") usage = self._normalize_usage(data.get("usageMetadata")) if usage: final_usage = usage candidates = data.get("candidates") or [] if not candidates: continue candidate = candidates[0] final_finish_reason = candidate.get("finishReason") or final_finish_reason if final_finish_reason: break content_obj = candidate.get("content", {}) or {} for part in content_obj.get("parts", []) or []: text = part.get("text") if text: full_text += text if on_chunk: await on_chunk(text) function_call = part.get("functionCall") if function_call: name = function_call.get("name", "") args = function_call.get("args", {}) args_str = json.dumps(args if isinstance(args, dict) else {}, ensure_ascii=False) dedup_key = f"{name}:{args_str}" if dedup_key in seen_tool_calls: continue seen_tool_calls.add(dedup_key) tool_calls.append({ "id": f"call_{len(tool_calls) + 1}", "type": "function", "function": { "name": name, "arguments": args_str, }, }) except (httpx.TransportError, httpx.ConnectTimeout) as e: # TransportError covers NetworkError (ConnectError, ReadError) and # ProtocolError (RemoteProtocolError) — all common with local vLLM. raise LLMError(f"Connection failed: {type(e).__name__}: {e}") return LLMResponse( content=full_text, tool_calls=tool_calls, finish_reason=self._normalize_finish_reason(final_finish_reason, tool_calls), usage=final_usage, model=self.model, ) async def close(self) -> None: """Close the HTTP client.""" if self._openai_fallback_client: await self._openai_fallback_client.close() if self._client and not self._client.is_closed: try: await asyncio.wait_for(self._client.aclose(), timeout=5.0) except asyncio.TimeoutError: logger.warning("[LLM] Client close timed out, forcing") self._client = None # ============================================================================ # Anthropic Native Client # ============================================================================ class AnthropicClient(LLMClient): """Client for Anthropic's native Messages API. Supports Claude 3.x and Claude 3.7+ with extended thinking. """ DEFAULT_BASE_URL = "https://api.anthropic.com" API_VERSION = "2023-06-01" def __init__( self, api_key: str, base_url: str | None = None, model: str | None = None, timeout: float = 120.0, ): super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, model, timeout) self._client: httpx.AsyncClient | None = None async def _get_client(self) -> httpx.AsyncClient: """Get or create HTTP client.""" if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient(timeout=self.timeout, follow_redirects=True, proxy=None) return self._client def _get_headers(self) -> dict[str, str]: return { "Content-Type": "application/json", "x-api-key": self.api_key, "anthropic-version": self.API_VERSION, "anthropic-beta": "prompt-caching-2024-07-31", } def _normalize_base_url(self) -> str: """Normalize base URL by stripping trailing API paths.""" url = self.base_url.rstrip("/") if url.endswith("/v1/messages"): url = url[: -len("/v1/messages")] elif url.endswith("/v1/chat/completions"): url = url[: -len("/v1/chat/completions")] elif url.endswith("/v1"): url = url[: -len("/v1")] return url def _build_payload( self, messages: list[LLMMessage], tools: list[dict] | None, temperature: float | None, max_tokens: int | None, stream: bool = False, **kwargs: Any, ) -> dict[str, Any]: """Build Anthropic request payload.""" system_blocks = [] anthropic_messages = [] for msg in messages: if msg.role == "system": if msg.content: system_blocks.append({ "type": "text", "text": msg.content, "cache_control": {"type": "ephemeral"} }) if msg.dynamic_content: system_blocks.append({ "type": "text", "text": f"\n{msg.dynamic_content}" }) else: formatted = msg.to_anthropic_format() if formatted: anthropic_messages.append(formatted) # In Anthropic prompt caching, we also want to cache_control the last user message # So we add cache_control to the very last message in the history if it's a user message if anthropic_messages and anthropic_messages[-1]["role"] == "user": user_msg = anthropic_messages[-1] if isinstance(user_msg["content"], list) and user_msg["content"]: # Ensure the last block of the user message has cache_control user_msg["content"][-1]["cache_control"] = {"type": "ephemeral"} elif isinstance(user_msg["content"], str): user_msg["content"] = [ { "type": "text", "text": user_msg["content"], "cache_control": {"type": "ephemeral"} } ] payload: dict[str, Any] = { "model": self.model, "messages": anthropic_messages, "max_tokens": max_tokens or 4096, "stream": stream, } if temperature is not None: payload["temperature"] = temperature if system_blocks: payload["system"] = system_blocks # Handle Extended Thinking thinking = kwargs.pop("thinking", None) if thinking: payload["thinking"] = thinking # For thinking models, temperature must be 1.0 or omitted in some cases # But usually it's best to let user specify or default to 1.0 if not set if "temperature" not in kwargs: payload["temperature"] = 1.0 if tools: anthropic_tools = [] for tool in tools: if tool.get("type") == "function": func = tool["function"] anthropic_tools.append({ "name": func["name"], "description": func.get("description", ""), "input_schema": func.get("parameters", {"type": "object"}), }) if anthropic_tools: anthropic_tools[-1]["cache_control"] = {"type": "ephemeral"} payload["tools"] = anthropic_tools payload.update(kwargs) return payload async def complete( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, **kwargs: Any, ) -> LLMResponse: """Non-streaming completion.""" url = f"{self._normalize_base_url()}/v1/messages" payload = self._build_payload(messages, tools, temperature, max_tokens, stream=False, **kwargs) client = await self._get_client() response = await client.post(url, json=payload, headers=self._get_headers()) if response.status_code >= 400: error_text = response.text[:500] raise LLMError(f"HTTP {response.status_code}: {error_text}") data = response.json() if data.get("type") == "error": raise LLMError(f"API error: {data.get('error', {})}") full_content = "" full_reasoning = "" full_signature = None tool_calls = [] for block in data.get("content", []): if block.get("type") == "text": full_content += block.get("text", "") elif block.get("type") == "thinking": full_reasoning += block.get("thinking", "") full_signature = block.get("signature") elif block.get("type") == "tool_use": tool_calls.append({ "id": block.get("id"), "type": "function", "function": { "name": block.get("name"), "arguments": json.dumps(block.get("input", {}), ensure_ascii=False) } }) usage = None if "usage" in data: usage = { "input_tokens": data["usage"].get("input_tokens", 0), "output_tokens": data["usage"].get("output_tokens", 0), } return LLMResponse( content=full_content, tool_calls=tool_calls, reasoning_content=full_reasoning or None, reasoning_signature=full_signature, finish_reason=data.get("stop_reason"), usage=usage, model=data.get("model"), ) async def stream( self, messages: list[LLMMessage], tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, on_chunk: ChunkCallback | None = None, on_thinking: ThinkingCallback | None = None, **kwargs: Any, ) -> LLMResponse: """Streaming completion.""" url = f"{self._normalize_base_url()}/v1/messages" payload = self._build_payload(messages, tools, temperature, max_tokens, stream=True, **kwargs) full_content = "" full_reasoning = "" full_signature = None tool_calls_data: list[dict] = [] tool_call_index_map: dict[int, int] = {} last_finish_reason: str | None = None final_usage = None final_model = self.model client = await self._get_client() try: async with client.stream("POST", url, json=payload, headers=self._get_headers()) as resp: if resp.status_code >= 400: error_body = "" async for chunk in resp.aiter_bytes(): error_body += chunk.decode(errors="replace") raise LLMError(f"HTTP {resp.status_code}: {error_body[:500]}") current_event = None async for line in resp.aiter_lines(): if not line.strip(): continue if line.startswith("event:"): current_event = line[len("event:"):].strip() logger.debug(f"[Anthropic SSE] event: {current_event}") continue if not line.startswith("data:"): continue data_str = line[len("data:"):].strip() if data_str == "[DONE]": logger.debug("[Anthropic SSE] received [DONE]") break try: data = json.loads(data_str) except json.JSONDecodeError: continue # Handle events if current_event == "message_start": msg = data.get("message", {}) if msg.get("model"): final_model = msg["model"] if msg.get("usage"): final_usage = msg["usage"] elif current_event == "content_block_start": block = data.get("content_block", {}) idx = data.get("index", 0) if block.get("type") == "tool_use": tool_call_index_map[idx] = len(tool_calls_data) tool_calls_data.append({ "id": block.get("id"), "type": "function", "function": {"name": block.get("name"), "arguments": ""} }) elif current_event == "content_block_delta": idx = data.get("index", 0) delta = data.get("delta", {}) delta_type = delta.get("type") if delta_type == "text_delta": text = delta.get("text", "") full_content += text if on_chunk: await on_chunk(text) elif delta_type == "thinking_delta": thought = delta.get("thinking", "") full_reasoning += thought if on_thinking: await on_thinking(thought) elif delta_type == "signature_delta": full_signature = delta.get("signature") elif delta_type == "input_json_delta": if idx in tool_call_index_map: tc_idx = tool_call_index_map[idx] tool_calls_data[tc_idx]["function"]["arguments"] += delta.get("partial_json", "") elif current_event == "message_delta": delta = data.get("delta", {}) logger.debug(f"[Anthropic SSE] message_delta: stop_reason={delta.get('stop_reason')}, usage={bool(data.get('usage'))}") if data.get("usage"): final_usage = data["usage"] if delta.get("stop_reason"): last_finish_reason = delta["stop_reason"] logger.debug("[Anthropic SSE] breaking on stop_reason") break elif current_event == "error": error_info = data.get("error", {}) raise LLMError(f"Anthropic stream error ({error_info.get('type')}): {error_info.get('message')}") elif current_event == "message_stop": break logger.debug(f"[Anthropic SSE] stream loop ended, content={len(full_content)} chars, finish={last_finish_reason}, tools={len(tool_calls_data)}") except (httpx.TransportError, httpx.ConnectTimeout) as e: # TransportError covers NetworkError (ConnectError, ReadError) and # ProtocolError (RemoteProtocolError) — all common with local vLLM. raise LLMError(f"Connection failed: {type(e).__name__}: {e}") # Normalize stop reason to OpenAI style (optional but helpful for consistency) if last_finish_reason == "end_turn": last_finish_reason = "stop" elif last_finish_reason == "tool_use": last_finish_reason = "tool_calls" logger.debug(f"[Anthropic SSE] returning LLMResponse: {len(full_content)} chars, finish={last_finish_reason}") return LLMResponse( content=full_content, tool_calls=tool_calls_data, reasoning_content=full_reasoning or None, reasoning_signature=full_signature, finish_reason=last_finish_reason, usage=final_usage, model=final_model, ) async def close(self) -> None: """Close the HTTP client.""" if self._client and not self._client.is_closed: try: await asyncio.wait_for(self._client.aclose(), timeout=5.0) except asyncio.TimeoutError: logger.warning("[LLM] Client close timed out, forcing") self._client = None # ============================================================================ # Factory and Utilities # ============================================================================ @dataclass(frozen=True) class ProviderSpec: """Provider registry entry.""" provider: str display_name: str protocol: Literal["openai_compatible", "anthropic", "openai_responses", "gemini"] default_base_url: str | None supports_tool_choice: bool = True default_max_tokens: int = 4096 model_max_tokens: dict[str, int] = field(default_factory=dict) # Provider aliases accepted for compatibility PROVIDER_ALIASES: dict[str, str] = { "openai_response": "openai-response", "openairesponses": "openai-response", } # Canonical provider registry (single source of truth) PROVIDER_REGISTRY: dict[str, ProviderSpec] = { "anthropic": ProviderSpec( provider="anthropic", display_name="Anthropic", protocol="anthropic", default_base_url="https://api.anthropic.com", supports_tool_choice=False, default_max_tokens=8192, ), "openai": ProviderSpec( provider="openai", display_name="OpenAI", protocol="openai_compatible", default_base_url="https://api.openai.com/v1", default_max_tokens=16384, ), "openai-response": ProviderSpec( provider="openai-response", display_name="OpenAI Responses", protocol="openai_responses", default_base_url="https://api.openai.com/v1", default_max_tokens=16384, ), "azure": ProviderSpec( provider="azure", display_name="Azure OpenAI", protocol="openai_compatible", default_base_url=None, default_max_tokens=16384, ), "deepseek": ProviderSpec( provider="deepseek", display_name="DeepSeek", protocol="openai_compatible", default_base_url="https://api.deepseek.com/v1", default_max_tokens=8192, ), "qwen": ProviderSpec( provider="qwen", display_name="Qwen (DashScope)", protocol="openai_compatible", default_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", default_max_tokens=8192, model_max_tokens={ "qwen-plus": 16384, "qwen-long": 16384, "qwen-turbo": 8192, "qwen-max": 8192, }, ), "minimax": ProviderSpec( provider="minimax", display_name="MiniMax", protocol="openai_compatible", default_base_url="https://api.minimaxi.com/v1", default_max_tokens=16384, ), "openrouter": ProviderSpec( provider="openrouter", display_name="OpenRouter", protocol="openai_compatible", default_base_url="https://openrouter.ai/api/v1", default_max_tokens=4096, ), "zhipu": ProviderSpec( provider="zhipu", display_name="Zhipu", protocol="openai_compatible", default_base_url="https://open.bigmodel.cn/api/paas/v4", default_max_tokens=8192, ), "baidu": ProviderSpec( provider="baidu", display_name="Baidu (Qianfan)", protocol="openai_compatible", default_base_url="https://qianfan.baidubce.com/v2", supports_tool_choice=False, default_max_tokens=4096, ), "gemini": ProviderSpec( provider="gemini", display_name="Gemini", protocol="gemini", default_base_url="https://generativelanguage.googleapis.com/v1beta", default_max_tokens=8192, ), "kimi": ProviderSpec( provider="kimi", display_name="Kimi (Moonshot)", protocol="openai_compatible", default_base_url="https://api.moonshot.cn/v1", default_max_tokens=8192, ), "vllm": ProviderSpec( provider="vllm", display_name="vLLM", protocol="openai_compatible", default_base_url="http://localhost:8000/v1", default_max_tokens=4096, ), "ollama": ProviderSpec( provider="ollama", display_name="Ollama", protocol="openai_compatible", default_base_url="http://localhost:11434/v1", default_max_tokens=4096, ), "sglang": ProviderSpec( provider="sglang", display_name="SGLang", protocol="openai_compatible", default_base_url="http://localhost:30000/v1", default_max_tokens=4096, ), "custom": ProviderSpec( provider="custom", display_name="Custom", protocol="openai_compatible", default_base_url=None, default_max_tokens=4096, ), } def normalize_provider(provider: str) -> str: """Normalize provider id with aliases and lowercase.""" p = (provider or "").strip().lower() return PROVIDER_ALIASES.get(p, p) def get_provider_spec(provider: str) -> ProviderSpec | None: """Get provider spec from registry.""" return PROVIDER_REGISTRY.get(normalize_provider(provider)) def get_provider_manifest() -> list[dict[str, Any]]: """List supported providers and capabilities for UI/config discovery.""" out: list[dict[str, Any]] = [] for spec in PROVIDER_REGISTRY.values(): out.append({ "provider": spec.provider, "display_name": spec.display_name, "protocol": spec.protocol, "default_base_url": spec.default_base_url, "supports_tool_choice": spec.supports_tool_choice, "default_max_tokens": spec.default_max_tokens, "model_max_tokens": spec.model_max_tokens, "aliases": [k for k, v in PROVIDER_ALIASES.items() if v == spec.provider], }) return out # Backward-compatible constants derived from registry PROVIDER_CLIENTS: dict[str, type[LLMClient]] = { spec.provider: ( AnthropicClient if spec.protocol == "anthropic" else OpenAIResponsesClient if spec.protocol == "openai_responses" else GeminiClient if spec.protocol == "gemini" else OpenAICompatibleClient ) for spec in PROVIDER_REGISTRY.values() } PROVIDER_URLS: dict[str, str | None] = { spec.provider: spec.default_base_url for spec in PROVIDER_REGISTRY.values() } TOOL_CHOICE_PROVIDERS = { spec.provider for spec in PROVIDER_REGISTRY.values() if spec.supports_tool_choice } MAX_TOKENS_BY_PROVIDER: dict[str, int] = { spec.provider: spec.default_max_tokens for spec in PROVIDER_REGISTRY.values() } MAX_TOKENS_BY_MODEL: dict[str, int] = { prefix: limit for spec in PROVIDER_REGISTRY.values() for prefix, limit in spec.model_max_tokens.items() } class LLMError(Exception): """Base exception for LLM client errors.""" pass def get_provider_base_url(provider: str, custom_base_url: str | None = None) -> str | None: """Return the API base URL for a provider. If a custom base_url is provided, it takes precedence. Otherwise falls back to the default URL for the provider. """ if custom_base_url: return custom_base_url spec = get_provider_spec(provider) if spec: return spec.default_base_url return PROVIDER_URLS.get(normalize_provider(provider)) def get_max_tokens(provider: str, model: str | None = None, max_output_tokens: int | None = None) -> int: """Return a safe max_tokens value for the given provider/model pair. Priority: max_output_tokens (DB override) > model prefix > provider default > 4096 """ spec = get_provider_spec(provider) model_limits = spec.model_max_tokens if spec else MAX_TOKENS_BY_MODEL # Highest priority: per-model DB override if max_output_tokens and max_output_tokens > 0: return max_output_tokens # Check model-specific limits if model: for prefix, limit in model_limits.items(): if model.lower().startswith(prefix): return limit if spec: return spec.default_max_tokens # Provider default, falling back to safe 4096 return MAX_TOKENS_BY_PROVIDER.get(normalize_provider(provider), 4096) def create_llm_client( provider: str, api_key: str, model: str, base_url: str | None = None, timeout: float = 120.0, ) -> LLMClient: """Create an LLM client for the given provider. Args: provider: Provider name (openai, anthropic, deepseek, etc.) api_key: API key for authentication model: Model name base_url: Optional custom base URL timeout: Request timeout in seconds Returns: An instance of the appropriate LLMClient subclass Raises: ValueError: If provider is not supported """ normalized_provider = normalize_provider(provider) spec = get_provider_spec(normalized_provider) # Get base URL final_base_url = get_provider_base_url(normalized_provider, base_url) # Create appropriate client if spec and spec.protocol == "anthropic": return AnthropicClient( api_key=api_key, base_url=final_base_url, model=model, timeout=timeout, ) elif spec and spec.protocol == "openai_responses": return OpenAIResponsesClient( api_key=api_key, base_url=final_base_url, model=model, timeout=timeout, supports_tool_choice=spec.supports_tool_choice, ) elif spec and spec.protocol == "gemini": return GeminiClient( api_key=api_key, base_url=final_base_url, model=model, timeout=timeout, supports_tool_choice=spec.supports_tool_choice, ) elif normalized_provider in PROVIDER_CLIENTS: supports_tool_choice = normalized_provider in TOOL_CHOICE_PROVIDERS return OpenAICompatibleClient( api_key=api_key, base_url=final_base_url, model=model, timeout=timeout, supports_tool_choice=supports_tool_choice, ) else: # Default to OpenAI-compatible for unknown providers return OpenAICompatibleClient( api_key=api_key, base_url=final_base_url or PROVIDER_URLS["openai"], model=model, timeout=timeout, supports_tool_choice=True, ) # ============================================================================ # High-level Convenience Functions # ============================================================================ async def chat_complete( provider: str, api_key: str, model: str, messages: list[dict], base_url: str | None = None, tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, timeout: float = 120.0, ) -> dict: """High-level function for non-streaming chat completion. Returns response in OpenAI-compatible format for backward compatibility. """ client = create_llm_client(provider, api_key, model, base_url, timeout) try: llm_messages = [LLMMessage(**m) for m in messages] response = await client.complete( messages=llm_messages, tools=tools, temperature=temperature, max_tokens=max_tokens or get_max_tokens(provider, model), ) return { "choices": [{ "message": { "role": "assistant", "content": response.content, "tool_calls": response.tool_calls or None, }, "finish_reason": response.finish_reason or "stop", }], "model": response.model or model, "usage": response.usage or {}, } finally: await client.close() async def chat_stream( provider: str, api_key: str, model: str, messages: list[dict], base_url: str | None = None, tools: list[dict] | None = None, temperature: float | None = None, max_tokens: int | None = None, timeout: float = 120.0, on_chunk: ChunkCallback | None = None, on_thinking: ThinkingCallback | None = None, ) -> dict: """High-level function for streaming chat completion. Returns aggregated response in OpenAI-compatible format. """ client = create_llm_client(provider, api_key, model, base_url, timeout) try: llm_messages = [LLMMessage(**m) for m in messages] response = await client.stream( messages=llm_messages, tools=tools, temperature=temperature, max_tokens=max_tokens or get_max_tokens(provider, model), on_chunk=on_chunk, on_thinking=on_thinking, ) return { "choices": [{ "message": { "role": "assistant", "content": response.content, "tool_calls": response.tool_calls or None, }, "finish_reason": response.finish_reason or "stop", }], "model": response.model or model, "usage": response.usage or {}, } finally: await client.close()