ai-chat-ui/server/adapters/unified_adapter.py

382 lines
13 KiB
Python

"""
统一 OpenAI SDK 适配器基类
所有平台适配器继承此类,通过配置区分不同平台。
MCP (Model Context Protocol) 支持:
- 子类可覆盖 _get_mcp_tools() 返回 MCP 工具定义
- 子类可覆盖 _handle_mcp_tool_call() 处理 MCP 工具调用
"""
import json
import os
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from fastapi.responses import JSONResponse, StreamingResponse
from openai import OpenAI
from .base import BaseAdapter, ChatCompletionRequest, ModelInfo
from core import get_logger
logger = get_logger()
# 平台配置
PROVIDER_CONFIGS = {
"zhipu": {
"base_url": "https://open.bigmodel.cn/api/paas/v4/",
"api_key_env": "ZHIPU_API_KEY",
"alias_env": ["GLM_API_KEY"], # 备选环境变量
},
"dashscope": {
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"api_key_env": "DASHSCOPE_API_KEY",
"alias_env": ["ALIYUN_API_KEY"],
},
"deepseek": {
"base_url": "https://api.deepseek.com/v1",
"api_key_env": "DEEPSEEK_API_KEY",
"alias_env": [],
},
"openai": {
"base_url": None, # 使用 OpenAI 默认值
"api_key_env": "OPENAI_API_KEY",
"alias_env": [],
},
}
class UnifiedOpenAIAdapter(BaseAdapter):
"""
基于 OpenAI SDK 的统一适配器基类
子类只需提供:
- provider_name: 平台名称
- list_models(): 支持的模型列表
- _get_extra_params(): 特殊参数(可选)
MCP 扩展点:
- _get_mcp_tools(): 返回 MCP 工具定义
- _handle_mcp_tool_call(): 处理 MCP 工具调用
"""
_client: Optional[OpenAI] = None
_provider_type: str = "openai"
def _get_api_key(self) -> Optional[str]:
"""获取 API Key"""
config = PROVIDER_CONFIGS.get(self._provider_type, {})
api_key_env = config.get("api_key_env", "")
alias_env = config.get("alias_env", [])
# 优先使用主环境变量
api_key = os.getenv(api_key_env)
if api_key:
return api_key
# 尝试备选环境变量
for env_name in alias_env:
api_key = os.getenv(env_name)
if api_key:
return api_key
return None
def _get_base_url(self) -> Optional[str]:
"""获取 Base URL"""
config = PROVIDER_CONFIGS.get(self._provider_type, {})
return config.get("base_url")
def _get_client(self) -> OpenAI:
"""获取 OpenAI 客户端(懒加载)"""
if self._client is None:
api_key = self._get_api_key()
base_url = self._get_base_url()
kwargs = {"api_key": api_key or ""}
if base_url:
kwargs["base_url"] = base_url
self._client = OpenAI(**kwargs)
logger.info(f"[{self.provider_name}] 创建 OpenAI 客户端: base_url={base_url or 'default'}")
return self._client
def is_available(self) -> bool:
"""检查适配器是否可用"""
return bool(self._get_api_key())
def _get_extra_params(self, request: ChatCompletionRequest) -> Dict[str, Any]:
"""
获取额外参数(子类可覆盖)
Returns:
传递给 OpenAI API 的额外参数,如 extra_body
"""
return {}
# ============================================================
# MCP 扩展点(子类可覆盖)
# ============================================================
def _get_mcp_tools(self, request: ChatCompletionRequest) -> List[Dict]:
"""
获取 MCP 工具定义(子类可覆盖)
Returns:
MCP 工具列表,格式与 OpenAI tools 相同
例如: [{"type": "function", "function": {...}}]
示例:
return [{
"type": "function",
"function": {
"name": "mcp_search",
"description": "通过 MCP 协议搜索",
"parameters": {...}
}
}]
"""
return []
def _handle_mcp_tool_call(
self,
tool_name: str,
tool_args: Dict,
request: ChatCompletionRequest
) -> Optional[str]:
"""
处理 MCP 工具调用(子类可覆盖)
Args:
tool_name: 工具名称
tool_args: 工具参数
request: 原始请求
Returns:
工具执行结果(字符串),返回 None 表示不是 MCP 工具
示例:
if tool_name == "mcp_search":
# 调用 MCP 客户端
result = await mcp_client.call(tool_name, tool_args)
return result
return None
"""
return None
# ============================================================
# 聊天处理
# ============================================================
async def chat(self, request: ChatCompletionRequest):
"""
处理聊天请求(统一流程)
"""
client = self._get_client()
# 打印请求参数
logger.info(f"[{self.provider_name}] 请求参数:")
logger.info(f" - model: {request.model}")
logger.info(f" - stream: {request.stream}")
logger.info(f" - temperature: {request.temperature}")
logger.info(f" - max_tokens: {request.max_tokens}")
logger.info(f" - deep_thinking: {request.deep_thinking}")
logger.info(f" - web_search: {request.web_search}")
logger.info(f" - deep_search: {request.deep_search}")
# 构建消息
messages = self._build_messages(request)
# 构建请求参数
kwargs: Dict[str, Any] = {
"model": request.model,
"messages": messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"stream": request.stream,
}
# 添加特殊参数(由子类实现)
extra_params = self._get_extra_params(request)
# 分离 extra_body 和其他参数
# extra_body 需要作为 OpenAI SDK 的单独参数传递
extra_body = None
if extra_params:
if "extra_body" in extra_params:
extra_body = extra_params.pop("extra_body")
kwargs.update(extra_params)
logger.info(f" - extra_params: {json.dumps(extra_params, ensure_ascii=False)}")
if extra_body:
logger.info(f" - extra_body: {json.dumps(extra_body, ensure_ascii=False)}")
# 添加 MCP 工具(由子类实现)
mcp_tools = self._get_mcp_tools(request)
if mcp_tools:
if "tools" not in kwargs:
kwargs["tools"] = []
kwargs["tools"].extend(mcp_tools)
logger.info(f" - mcp_tools: {len(mcp_tools)} 个工具")
# 单独传递 extra_body
if extra_body:
kwargs["extra_body"] = extra_body
logger.info(f" - messages: {json.dumps(messages, ensure_ascii=False, indent=2)}")
if request.stream:
return self._stream_chat(client, kwargs)
else:
return self._sync_chat(client, kwargs)
def _build_messages(self, request: ChatCompletionRequest) -> List[Dict]:
"""
构建 OpenAI 格式消息
子类可覆盖以处理特殊格式(如多模态)
"""
messages = []
for msg in request.messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if isinstance(content, str):
if content.strip():
messages.append({"role": role, "content": content})
elif isinstance(content, list):
# 多模态内容
openai_content = []
for item in content:
if isinstance(item, dict):
openai_content.append(item)
if openai_content:
messages.append({"role": role, "content": openai_content})
return messages
def _stream_chat(self, client: OpenAI, kwargs: Dict) -> StreamingResponse:
"""流式聊天"""
logger.info(f"[{self.provider_name}] 开始流式响应...")
# 调试:打印最终传给 API 的参数
logger.info(f"[{self.provider_name}] API 调用参数:")
for key, value in kwargs.items():
if key == "messages":
logger.info(f" - {key}: [{len(value)} 条消息]")
elif key == "extra_body":
logger.info(f" - {key}: {json.dumps(value, ensure_ascii=False)}")
elif key == "tools":
logger.info(f" - {key}: {json.dumps(value, ensure_ascii=False)}")
else:
logger.info(f" - {key}: {value}")
def generator():
from utils.helpers import generate_unique_id, get_current_timestamp
full_content = ""
full_reasoning = ""
chunk_count = 0
resp = client.chat.completions.create(**kwargs)
for chunk in resp:
if not chunk.choices:
continue
chunk_count += 1
delta = chunk.choices[0].delta
# 处理深度思考内容
reasoning_content = getattr(delta, "reasoning_content", None)
if reasoning_content:
full_reasoning += reasoning_content
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [{
"index": 0,
"delta": {"reasoning_content": reasoning_content},
"finish_reason": None,
}],
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
continue
# 处理普通内容
content = getattr(delta, "content", None)
if content:
full_content += content
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [{
"index": 0,
"delta": {"content": content},
"finish_reason": None,
}],
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
# 结束标记
finish = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
logger.info(f"[{self.provider_name}] 流式响应完成: chunks={chunk_count}, content_len={len(full_content)}")
return StreamingResponse(generator(), media_type="text/event-stream")
def _sync_chat(self, client: OpenAI, kwargs: Dict) -> JSONResponse:
"""非流式聊天"""
from utils.helpers import generate_unique_id, get_current_timestamp
resp = client.chat.completions.create(**kwargs)
message = resp.choices[0].message
content = message.content or ""
# 构建响应消息
response_message = {"role": message.role, "content": content}
# 处理深度思考内容
reasoning_content = getattr(message, "reasoning_content", None)
if reasoning_content:
response_message["reasoning_content"] = reasoning_content
response = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [{
"index": 0,
"message": response_message,
"finish_reason": resp.choices[0].finish_reason,
}],
}
if resp.usage:
response["usage"] = {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens,
}
logger.info(f"[{self.provider_name}] 响应完成: content_len={len(content)}")
if reasoning_content:
logger.info(f"[{self.provider_name}] reasoning_len={len(reasoning_content)}")
return JSONResponse(content=response)