""" 统一 OpenAI SDK 适配器基类 所有平台适配器继承此类,通过配置区分不同平台。 MCP (Model Context Protocol) 支持: - 子类可覆盖 _get_mcp_tools() 返回 MCP 工具定义 - 子类可覆盖 _handle_mcp_tool_call() 处理 MCP 工具调用 """ import json import os from abc import abstractmethod from typing import Any, Dict, List, Optional from fastapi.responses import JSONResponse, StreamingResponse from openai import OpenAI from .base import BaseAdapter, ChatCompletionRequest, ModelInfo from core import get_logger logger = get_logger() # 平台配置 PROVIDER_CONFIGS = { "zhipu": { "base_url": "https://open.bigmodel.cn/api/paas/v4/", "api_key_env": "ZHIPU_API_KEY", "alias_env": ["GLM_API_KEY"], # 备选环境变量 }, "dashscope": { "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", "api_key_env": "DASHSCOPE_API_KEY", "alias_env": ["ALIYUN_API_KEY"], }, "deepseek": { "base_url": "https://api.deepseek.com/v1", "api_key_env": "DEEPSEEK_API_KEY", "alias_env": [], }, "openai": { "base_url": None, # 使用 OpenAI 默认值 "api_key_env": "OPENAI_API_KEY", "alias_env": [], }, } class UnifiedOpenAIAdapter(BaseAdapter): """ 基于 OpenAI SDK 的统一适配器基类 子类只需提供: - provider_name: 平台名称 - list_models(): 支持的模型列表 - _get_extra_params(): 特殊参数(可选) MCP 扩展点: - _get_mcp_tools(): 返回 MCP 工具定义 - _handle_mcp_tool_call(): 处理 MCP 工具调用 """ _client: Optional[OpenAI] = None _provider_type: str = "openai" def _get_api_key(self) -> Optional[str]: """获取 API Key""" config = PROVIDER_CONFIGS.get(self._provider_type, {}) api_key_env = config.get("api_key_env", "") alias_env = config.get("alias_env", []) # 优先使用主环境变量 api_key = os.getenv(api_key_env) if api_key: return api_key # 尝试备选环境变量 for env_name in alias_env: api_key = os.getenv(env_name) if api_key: return api_key return None def _get_base_url(self) -> Optional[str]: """获取 Base URL""" config = PROVIDER_CONFIGS.get(self._provider_type, {}) return config.get("base_url") def _get_client(self) -> OpenAI: """获取 OpenAI 客户端(懒加载)""" if self._client is None: api_key = self._get_api_key() base_url = self._get_base_url() kwargs = {"api_key": api_key or ""} if base_url: kwargs["base_url"] = base_url self._client = OpenAI(**kwargs) logger.info(f"[{self.provider_name}] 创建 OpenAI 客户端: base_url={base_url or 'default'}") return self._client def is_available(self) -> bool: """检查适配器是否可用""" return bool(self._get_api_key()) def _get_extra_params(self, request: ChatCompletionRequest) -> Dict[str, Any]: """ 获取额外参数(子类可覆盖) Returns: 传递给 OpenAI API 的额外参数,如 extra_body """ return {} # ============================================================ # MCP 扩展点(子类可覆盖) # ============================================================ def _get_mcp_tools(self, request: ChatCompletionRequest) -> List[Dict]: """ 获取 MCP 工具定义(子类可覆盖) Returns: MCP 工具列表,格式与 OpenAI tools 相同 例如: [{"type": "function", "function": {...}}] 示例: return [{ "type": "function", "function": { "name": "mcp_search", "description": "通过 MCP 协议搜索", "parameters": {...} } }] """ return [] def _handle_mcp_tool_call( self, tool_name: str, tool_args: Dict, request: ChatCompletionRequest ) -> Optional[str]: """ 处理 MCP 工具调用(子类可覆盖) Args: tool_name: 工具名称 tool_args: 工具参数 request: 原始请求 Returns: 工具执行结果(字符串),返回 None 表示不是 MCP 工具 示例: if tool_name == "mcp_search": # 调用 MCP 客户端 result = await mcp_client.call(tool_name, tool_args) return result return None """ return None # ============================================================ # 聊天处理 # ============================================================ async def chat(self, request: ChatCompletionRequest): """ 处理聊天请求(统一流程) """ client = self._get_client() # 打印请求参数 logger.info(f"[{self.provider_name}] 请求参数:") logger.info(f" - model: {request.model}") logger.info(f" - stream: {request.stream}") logger.info(f" - temperature: {request.temperature}") logger.info(f" - max_tokens: {request.max_tokens}") logger.info(f" - deep_thinking: {request.deep_thinking}") logger.info(f" - web_search: {request.web_search}") logger.info(f" - deep_search: {request.deep_search}") # 构建消息 messages = self._build_messages(request) # 构建请求参数 kwargs: Dict[str, Any] = { "model": request.model, "messages": messages, "temperature": request.temperature, "max_tokens": request.max_tokens, "stream": request.stream, } # 添加特殊参数(由子类实现) extra_params = self._get_extra_params(request) # 分离 extra_body 和其他参数 # extra_body 需要作为 OpenAI SDK 的单独参数传递 extra_body = None if extra_params: if "extra_body" in extra_params: extra_body = extra_params.pop("extra_body") kwargs.update(extra_params) logger.info(f" - extra_params: {json.dumps(extra_params, ensure_ascii=False)}") if extra_body: logger.info(f" - extra_body: {json.dumps(extra_body, ensure_ascii=False)}") # 添加 MCP 工具(由子类实现) mcp_tools = self._get_mcp_tools(request) if mcp_tools: if "tools" not in kwargs: kwargs["tools"] = [] kwargs["tools"].extend(mcp_tools) logger.info(f" - mcp_tools: {len(mcp_tools)} 个工具") # 单独传递 extra_body if extra_body: kwargs["extra_body"] = extra_body logger.info(f" - messages: {json.dumps(messages, ensure_ascii=False, indent=2)}") if request.stream: return self._stream_chat(client, kwargs) else: return self._sync_chat(client, kwargs) def _build_messages(self, request: ChatCompletionRequest) -> List[Dict]: """ 构建 OpenAI 格式消息 子类可覆盖以处理特殊格式(如多模态) """ messages = [] for msg in request.messages: role = msg.get("role", "user") content = msg.get("content", "") if isinstance(content, str): if content.strip(): messages.append({"role": role, "content": content}) elif isinstance(content, list): # 多模态内容 openai_content = [] for item in content: if isinstance(item, dict): openai_content.append(item) if openai_content: messages.append({"role": role, "content": openai_content}) return messages def _stream_chat(self, client: OpenAI, kwargs: Dict) -> StreamingResponse: """流式聊天""" logger.info(f"[{self.provider_name}] 开始流式响应...") # 调试:打印最终传给 API 的参数 logger.info(f"[{self.provider_name}] API 调用参数:") for key, value in kwargs.items(): if key == "messages": logger.info(f" - {key}: [{len(value)} 条消息]") elif key == "extra_body": logger.info(f" - {key}: {json.dumps(value, ensure_ascii=False)}") elif key == "tools": logger.info(f" - {key}: {json.dumps(value, ensure_ascii=False)}") else: logger.info(f" - {key}: {value}") def generator(): from utils.helpers import generate_unique_id, get_current_timestamp full_content = "" full_reasoning = "" chunk_count = 0 resp = client.chat.completions.create(**kwargs) for chunk in resp: if not chunk.choices: continue chunk_count += 1 delta = chunk.choices[0].delta # 处理深度思考内容 reasoning_content = getattr(delta, "reasoning_content", None) if reasoning_content: full_reasoning += reasoning_content data = { "id": f"chatcmpl-{generate_unique_id()}", "object": "chat.completion.chunk", "created": get_current_timestamp(), "model": kwargs["model"], "choices": [{ "index": 0, "delta": {"reasoning_content": reasoning_content}, "finish_reason": None, }], } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" continue # 处理普通内容 content = getattr(delta, "content", None) if content: full_content += content data = { "id": f"chatcmpl-{generate_unique_id()}", "object": "chat.completion.chunk", "created": get_current_timestamp(), "model": kwargs["model"], "choices": [{ "index": 0, "delta": {"content": content}, "finish_reason": None, }], } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" # 结束标记 finish = { "id": f"chatcmpl-{generate_unique_id()}", "object": "chat.completion.chunk", "created": get_current_timestamp(), "model": kwargs["model"], "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" logger.info(f"[{self.provider_name}] 流式响应完成: chunks={chunk_count}, content_len={len(full_content)}") return StreamingResponse(generator(), media_type="text/event-stream") def _sync_chat(self, client: OpenAI, kwargs: Dict) -> JSONResponse: """非流式聊天""" from utils.helpers import generate_unique_id, get_current_timestamp resp = client.chat.completions.create(**kwargs) message = resp.choices[0].message content = message.content or "" # 构建响应消息 response_message = {"role": message.role, "content": content} # 处理深度思考内容 reasoning_content = getattr(message, "reasoning_content", None) if reasoning_content: response_message["reasoning_content"] = reasoning_content response = { "id": f"chatcmpl-{generate_unique_id()}", "object": "chat.completion", "created": get_current_timestamp(), "model": kwargs["model"], "choices": [{ "index": 0, "message": response_message, "finish_reason": resp.choices[0].finish_reason, }], } if resp.usage: response["usage"] = { "prompt_tokens": resp.usage.prompt_tokens, "completion_tokens": resp.usage.completion_tokens, "total_tokens": resp.usage.total_tokens, } logger.info(f"[{self.provider_name}] 响应完成: content_len={len(content)}") if reasoning_content: logger.info(f"[{self.provider_name}] reasoning_len={len(reasoning_content)}") return JSONResponse(content=response)