382 lines
13 KiB
Python
382 lines
13 KiB
Python
"""
|
|
统一 OpenAI SDK 适配器基类
|
|
|
|
所有平台适配器继承此类,通过配置区分不同平台。
|
|
|
|
MCP (Model Context Protocol) 支持:
|
|
- 子类可覆盖 _get_mcp_tools() 返回 MCP 工具定义
|
|
- 子类可覆盖 _handle_mcp_tool_call() 处理 MCP 工具调用
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
from abc import abstractmethod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from openai import OpenAI
|
|
|
|
from .base import BaseAdapter, ChatCompletionRequest, ModelInfo
|
|
from core import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
# 平台配置
|
|
PROVIDER_CONFIGS = {
|
|
"zhipu": {
|
|
"base_url": "https://open.bigmodel.cn/api/paas/v4/",
|
|
"api_key_env": "ZHIPU_API_KEY",
|
|
"alias_env": ["GLM_API_KEY"], # 备选环境变量
|
|
},
|
|
"dashscope": {
|
|
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
"api_key_env": "DASHSCOPE_API_KEY",
|
|
"alias_env": ["ALIYUN_API_KEY"],
|
|
},
|
|
"deepseek": {
|
|
"base_url": "https://api.deepseek.com/v1",
|
|
"api_key_env": "DEEPSEEK_API_KEY",
|
|
"alias_env": [],
|
|
},
|
|
"openai": {
|
|
"base_url": None, # 使用 OpenAI 默认值
|
|
"api_key_env": "OPENAI_API_KEY",
|
|
"alias_env": [],
|
|
},
|
|
}
|
|
|
|
|
|
class UnifiedOpenAIAdapter(BaseAdapter):
|
|
"""
|
|
基于 OpenAI SDK 的统一适配器基类
|
|
|
|
子类只需提供:
|
|
- provider_name: 平台名称
|
|
- list_models(): 支持的模型列表
|
|
- _get_extra_params(): 特殊参数(可选)
|
|
|
|
MCP 扩展点:
|
|
- _get_mcp_tools(): 返回 MCP 工具定义
|
|
- _handle_mcp_tool_call(): 处理 MCP 工具调用
|
|
"""
|
|
|
|
_client: Optional[OpenAI] = None
|
|
_provider_type: str = "openai"
|
|
|
|
def _get_api_key(self) -> Optional[str]:
|
|
"""获取 API Key"""
|
|
config = PROVIDER_CONFIGS.get(self._provider_type, {})
|
|
api_key_env = config.get("api_key_env", "")
|
|
alias_env = config.get("alias_env", [])
|
|
|
|
# 优先使用主环境变量
|
|
api_key = os.getenv(api_key_env)
|
|
if api_key:
|
|
return api_key
|
|
|
|
# 尝试备选环境变量
|
|
for env_name in alias_env:
|
|
api_key = os.getenv(env_name)
|
|
if api_key:
|
|
return api_key
|
|
|
|
return None
|
|
|
|
def _get_base_url(self) -> Optional[str]:
|
|
"""获取 Base URL"""
|
|
config = PROVIDER_CONFIGS.get(self._provider_type, {})
|
|
return config.get("base_url")
|
|
|
|
def _get_client(self) -> OpenAI:
|
|
"""获取 OpenAI 客户端(懒加载)"""
|
|
if self._client is None:
|
|
api_key = self._get_api_key()
|
|
base_url = self._get_base_url()
|
|
|
|
kwargs = {"api_key": api_key or ""}
|
|
if base_url:
|
|
kwargs["base_url"] = base_url
|
|
|
|
self._client = OpenAI(**kwargs)
|
|
logger.info(f"[{self.provider_name}] 创建 OpenAI 客户端: base_url={base_url or 'default'}")
|
|
|
|
return self._client
|
|
|
|
def is_available(self) -> bool:
|
|
"""检查适配器是否可用"""
|
|
return bool(self._get_api_key())
|
|
|
|
def _get_extra_params(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
|
"""
|
|
获取额外参数(子类可覆盖)
|
|
|
|
Returns:
|
|
传递给 OpenAI API 的额外参数,如 extra_body
|
|
"""
|
|
return {}
|
|
|
|
# ============================================================
|
|
# MCP 扩展点(子类可覆盖)
|
|
# ============================================================
|
|
|
|
def _get_mcp_tools(self, request: ChatCompletionRequest) -> List[Dict]:
|
|
"""
|
|
获取 MCP 工具定义(子类可覆盖)
|
|
|
|
Returns:
|
|
MCP 工具列表,格式与 OpenAI tools 相同
|
|
例如: [{"type": "function", "function": {...}}]
|
|
|
|
示例:
|
|
return [{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "mcp_search",
|
|
"description": "通过 MCP 协议搜索",
|
|
"parameters": {...}
|
|
}
|
|
}]
|
|
"""
|
|
return []
|
|
|
|
def _handle_mcp_tool_call(
|
|
self,
|
|
tool_name: str,
|
|
tool_args: Dict,
|
|
request: ChatCompletionRequest
|
|
) -> Optional[str]:
|
|
"""
|
|
处理 MCP 工具调用(子类可覆盖)
|
|
|
|
Args:
|
|
tool_name: 工具名称
|
|
tool_args: 工具参数
|
|
request: 原始请求
|
|
|
|
Returns:
|
|
工具执行结果(字符串),返回 None 表示不是 MCP 工具
|
|
|
|
示例:
|
|
if tool_name == "mcp_search":
|
|
# 调用 MCP 客户端
|
|
result = await mcp_client.call(tool_name, tool_args)
|
|
return result
|
|
return None
|
|
"""
|
|
return None
|
|
|
|
# ============================================================
|
|
# 聊天处理
|
|
# ============================================================
|
|
|
|
async def chat(self, request: ChatCompletionRequest):
|
|
"""
|
|
处理聊天请求(统一流程)
|
|
"""
|
|
client = self._get_client()
|
|
|
|
# 打印请求参数
|
|
logger.info(f"[{self.provider_name}] 请求参数:")
|
|
logger.info(f" - model: {request.model}")
|
|
logger.info(f" - stream: {request.stream}")
|
|
logger.info(f" - temperature: {request.temperature}")
|
|
logger.info(f" - max_tokens: {request.max_tokens}")
|
|
logger.info(f" - deep_thinking: {request.deep_thinking}")
|
|
logger.info(f" - web_search: {request.web_search}")
|
|
logger.info(f" - deep_search: {request.deep_search}")
|
|
|
|
# 构建消息
|
|
messages = self._build_messages(request)
|
|
|
|
# 构建请求参数
|
|
kwargs: Dict[str, Any] = {
|
|
"model": request.model,
|
|
"messages": messages,
|
|
"temperature": request.temperature,
|
|
"max_tokens": request.max_tokens,
|
|
"stream": request.stream,
|
|
}
|
|
|
|
# 添加特殊参数(由子类实现)
|
|
extra_params = self._get_extra_params(request)
|
|
|
|
# 分离 extra_body 和其他参数
|
|
# extra_body 需要作为 OpenAI SDK 的单独参数传递
|
|
extra_body = None
|
|
if extra_params:
|
|
if "extra_body" in extra_params:
|
|
extra_body = extra_params.pop("extra_body")
|
|
kwargs.update(extra_params)
|
|
logger.info(f" - extra_params: {json.dumps(extra_params, ensure_ascii=False)}")
|
|
if extra_body:
|
|
logger.info(f" - extra_body: {json.dumps(extra_body, ensure_ascii=False)}")
|
|
|
|
# 添加 MCP 工具(由子类实现)
|
|
mcp_tools = self._get_mcp_tools(request)
|
|
if mcp_tools:
|
|
if "tools" not in kwargs:
|
|
kwargs["tools"] = []
|
|
kwargs["tools"].extend(mcp_tools)
|
|
logger.info(f" - mcp_tools: {len(mcp_tools)} 个工具")
|
|
|
|
# 单独传递 extra_body
|
|
if extra_body:
|
|
kwargs["extra_body"] = extra_body
|
|
|
|
logger.info(f" - messages: {json.dumps(messages, ensure_ascii=False, indent=2)}")
|
|
|
|
if request.stream:
|
|
return self._stream_chat(client, kwargs)
|
|
else:
|
|
return self._sync_chat(client, kwargs)
|
|
|
|
def _build_messages(self, request: ChatCompletionRequest) -> List[Dict]:
|
|
"""
|
|
构建 OpenAI 格式消息
|
|
|
|
子类可覆盖以处理特殊格式(如多模态)
|
|
"""
|
|
messages = []
|
|
|
|
for msg in request.messages:
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
|
|
if isinstance(content, str):
|
|
if content.strip():
|
|
messages.append({"role": role, "content": content})
|
|
elif isinstance(content, list):
|
|
# 多模态内容
|
|
openai_content = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
openai_content.append(item)
|
|
if openai_content:
|
|
messages.append({"role": role, "content": openai_content})
|
|
|
|
return messages
|
|
|
|
def _stream_chat(self, client: OpenAI, kwargs: Dict) -> StreamingResponse:
|
|
"""流式聊天"""
|
|
logger.info(f"[{self.provider_name}] 开始流式响应...")
|
|
|
|
# 调试:打印最终传给 API 的参数
|
|
logger.info(f"[{self.provider_name}] API 调用参数:")
|
|
for key, value in kwargs.items():
|
|
if key == "messages":
|
|
logger.info(f" - {key}: [{len(value)} 条消息]")
|
|
elif key == "extra_body":
|
|
logger.info(f" - {key}: {json.dumps(value, ensure_ascii=False)}")
|
|
elif key == "tools":
|
|
logger.info(f" - {key}: {json.dumps(value, ensure_ascii=False)}")
|
|
else:
|
|
logger.info(f" - {key}: {value}")
|
|
|
|
def generator():
|
|
from utils.helpers import generate_unique_id, get_current_timestamp
|
|
|
|
full_content = ""
|
|
full_reasoning = ""
|
|
chunk_count = 0
|
|
|
|
resp = client.chat.completions.create(**kwargs)
|
|
|
|
for chunk in resp:
|
|
if not chunk.choices:
|
|
continue
|
|
|
|
chunk_count += 1
|
|
delta = chunk.choices[0].delta
|
|
|
|
# 处理深度思考内容
|
|
reasoning_content = getattr(delta, "reasoning_content", None)
|
|
if reasoning_content:
|
|
full_reasoning += reasoning_content
|
|
data = {
|
|
"id": f"chatcmpl-{generate_unique_id()}",
|
|
"object": "chat.completion.chunk",
|
|
"created": get_current_timestamp(),
|
|
"model": kwargs["model"],
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {"reasoning_content": reasoning_content},
|
|
"finish_reason": None,
|
|
}],
|
|
}
|
|
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
continue
|
|
|
|
# 处理普通内容
|
|
content = getattr(delta, "content", None)
|
|
if content:
|
|
full_content += content
|
|
data = {
|
|
"id": f"chatcmpl-{generate_unique_id()}",
|
|
"object": "chat.completion.chunk",
|
|
"created": get_current_timestamp(),
|
|
"model": kwargs["model"],
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {"content": content},
|
|
"finish_reason": None,
|
|
}],
|
|
}
|
|
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
# 结束标记
|
|
finish = {
|
|
"id": f"chatcmpl-{generate_unique_id()}",
|
|
"object": "chat.completion.chunk",
|
|
"created": get_current_timestamp(),
|
|
"model": kwargs["model"],
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
}
|
|
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
logger.info(f"[{self.provider_name}] 流式响应完成: chunks={chunk_count}, content_len={len(full_content)}")
|
|
|
|
return StreamingResponse(generator(), media_type="text/event-stream")
|
|
|
|
def _sync_chat(self, client: OpenAI, kwargs: Dict) -> JSONResponse:
|
|
"""非流式聊天"""
|
|
from utils.helpers import generate_unique_id, get_current_timestamp
|
|
|
|
resp = client.chat.completions.create(**kwargs)
|
|
|
|
message = resp.choices[0].message
|
|
content = message.content or ""
|
|
|
|
# 构建响应消息
|
|
response_message = {"role": message.role, "content": content}
|
|
|
|
# 处理深度思考内容
|
|
reasoning_content = getattr(message, "reasoning_content", None)
|
|
if reasoning_content:
|
|
response_message["reasoning_content"] = reasoning_content
|
|
|
|
response = {
|
|
"id": f"chatcmpl-{generate_unique_id()}",
|
|
"object": "chat.completion",
|
|
"created": get_current_timestamp(),
|
|
"model": kwargs["model"],
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": response_message,
|
|
"finish_reason": resp.choices[0].finish_reason,
|
|
}],
|
|
}
|
|
|
|
if resp.usage:
|
|
response["usage"] = {
|
|
"prompt_tokens": resp.usage.prompt_tokens,
|
|
"completion_tokens": resp.usage.completion_tokens,
|
|
"total_tokens": resp.usage.total_tokens,
|
|
}
|
|
|
|
logger.info(f"[{self.provider_name}] 响应完成: content_len={len(content)}")
|
|
if reasoning_content:
|
|
logger.info(f"[{self.provider_name}] reasoning_len={len(reasoning_content)}")
|
|
|
|
return JSONResponse(content=response) |