485 lines
19 KiB
Python
485 lines
19 KiB
Python
"""
|
||
OpenAI 适配器
|
||
支持 OpenAI 及兼容 API(如 Deepseek)
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
from typing import Dict, List, Optional
|
||
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
||
from .base import BaseAdapter, ChatCompletionRequest, ModelInfo
|
||
from .plugins import get_web_search_mode, build_openai_search_tool, execute_tavily_search, get_current_time_info
|
||
from core import get_logger
|
||
|
||
logger = get_logger()
|
||
|
||
# OpenAI 模型配置
|
||
OPENAI_MODELS = [
|
||
ModelInfo(
|
||
id="gpt-4o",
|
||
name="GPT-4o",
|
||
description="最新旗舰多模态模型",
|
||
max_tokens=128000,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=True,
|
||
supports_files=True,
|
||
),
|
||
ModelInfo(
|
||
id="gpt-4o-mini",
|
||
name="GPT-4o Mini",
|
||
description="高性价比多模态模型",
|
||
max_tokens=128000,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=True,
|
||
supports_files=True,
|
||
),
|
||
ModelInfo(
|
||
id="gpt-4-turbo",
|
||
name="GPT-4 Turbo",
|
||
description="GPT-4 增强版",
|
||
max_tokens=128000,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=True,
|
||
supports_files=False,
|
||
),
|
||
ModelInfo(
|
||
id="gpt-3.5-turbo",
|
||
name="GPT-3.5 Turbo",
|
||
description="快速经济的选择",
|
||
max_tokens=16385,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=False,
|
||
supports_files=False,
|
||
),
|
||
]
|
||
|
||
# Deepseek 模型配置
|
||
DEEPSEEK_MODELS = [
|
||
ModelInfo(
|
||
id="deepseek-chat",
|
||
name="Deepseek Chat",
|
||
description="Deepseek 对话模型",
|
||
max_tokens=64000,
|
||
provider="Deepseek",
|
||
supports_thinking=False,
|
||
supports_web_search=False,
|
||
supports_vision=False,
|
||
supports_files=False,
|
||
),
|
||
ModelInfo(
|
||
id="deepseek-reasoner",
|
||
name="Deepseek Reasoner",
|
||
description="Deepseek 推理模型(支持深度思考)",
|
||
max_tokens=64000,
|
||
provider="Deepseek",
|
||
supports_thinking=True,
|
||
supports_web_search=True, # 注:通过内置检索增强实现
|
||
supports_vision=False,
|
||
supports_files=False,
|
||
),
|
||
]
|
||
|
||
# DeepSeek 支持深度思考的模型
|
||
DEEPSEEK_THINKING_MODELS = {"deepseek-reasoner"}
|
||
|
||
|
||
class OpenAIAdapter(BaseAdapter):
|
||
"""OpenAI 平台适配器"""
|
||
|
||
_client = None
|
||
_provider_type: str = "openai" # openai 或 deepseek
|
||
|
||
def __init__(self, provider_type: str = "openai"):
|
||
self._provider_type = provider_type
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
return self._provider_type
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查 API Key 是否配置"""
|
||
if self._provider_type == "deepseek":
|
||
return bool(os.getenv("DEEPSEEK_API_KEY"))
|
||
return bool(os.getenv("OPENAI_API_KEY"))
|
||
|
||
def _get_client(self):
|
||
"""获取 OpenAI 客户端(懒加载)"""
|
||
if self._client is None:
|
||
from openai import OpenAI
|
||
|
||
if self._provider_type == "deepseek":
|
||
api_key = os.getenv("DEEPSEEK_API_KEY", "")
|
||
base_url = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
|
||
else:
|
||
api_key = os.getenv("OPENAI_API_KEY", "")
|
||
base_url = os.getenv("OPENAI_BASE_URL") # 可选自定义端点
|
||
|
||
kwargs = {"api_key": api_key}
|
||
if base_url:
|
||
kwargs["base_url"] = base_url
|
||
|
||
self._client = OpenAI(**kwargs)
|
||
return self._client
|
||
|
||
def list_models(self) -> List[ModelInfo]:
|
||
if self._provider_type == "deepseek":
|
||
return DEEPSEEK_MODELS
|
||
return OPENAI_MODELS
|
||
|
||
async def chat(self, request: ChatCompletionRequest):
|
||
"""
|
||
处理 OpenAI 聊天请求
|
||
直接使用 OpenAI SDK,支持流式/非流式
|
||
"""
|
||
client = self._get_client()
|
||
|
||
# 打印请求参数
|
||
provider_name = self._provider_type.upper()
|
||
logger.info(f"[{provider_name}] 请求参数:")
|
||
logger.info(f" - model: {request.model}")
|
||
logger.info(f" - stream: {request.stream}")
|
||
logger.info(f" - temperature: {request.temperature}")
|
||
logger.info(f" - max_tokens: {request.max_tokens}")
|
||
logger.info(f" - provider_type: {self._provider_type}")
|
||
if self._provider_type == "deepseek":
|
||
logger.info(f" - deep_thinking: {request.deep_thinking}")
|
||
|
||
# 构建消息
|
||
messages = self._build_messages(request)
|
||
|
||
# 统一添加联网搜索插件参数
|
||
web_search_mode = get_web_search_mode(request)
|
||
if web_search_mode:
|
||
# 注入当前时间信息到 System Prompt 中,以便模型拥有时间感知能力
|
||
time_info = get_current_time_info()
|
||
has_system = False
|
||
for msg in messages:
|
||
if msg.get("role") == "system":
|
||
msg["content"] = f"当前系统时间:{time_info}\n" + str(msg.get("content", ""))
|
||
has_system = True
|
||
break
|
||
if not has_system:
|
||
messages.insert(0, {"role": "system", "content": f"当前系统时间:{time_info}"})
|
||
|
||
logger.info(
|
||
f" - messages: {json.dumps(messages, ensure_ascii=False, indent=2)}"
|
||
)
|
||
|
||
# 构建请求参数
|
||
kwargs = {
|
||
"model": request.model,
|
||
"messages": messages,
|
||
"temperature": request.temperature,
|
||
"max_tokens": request.max_tokens,
|
||
"stream": request.stream,
|
||
}
|
||
|
||
if web_search_mode:
|
||
search_tool = build_openai_search_tool(web_search_mode)
|
||
kwargs["tools"] = [search_tool]
|
||
|
||
# DeepSeek 深度思考支持
|
||
extra_body = None
|
||
if self._provider_type == "deepseek" and request.deep_thinking:
|
||
if self._supports_thinking(request.model):
|
||
extra_body = {"thinking": {"type": "enabled"}}
|
||
kwargs["extra_body"] = extra_body
|
||
logger.info(
|
||
f"[{provider_name}] 深度思考已启用: extra_body = {extra_body}"
|
||
)
|
||
|
||
if request.stream:
|
||
return self._stream_chat(client, kwargs, extra_body)
|
||
else:
|
||
return self._sync_chat(client, kwargs, extra_body)
|
||
|
||
def _supports_thinking(self, model: str) -> bool:
|
||
"""检查模型是否支持深度思考"""
|
||
return model.lower() in DEEPSEEK_THINKING_MODELS
|
||
|
||
def _build_messages(self, request: ChatCompletionRequest) -> List[Dict]:
|
||
"""构建 OpenAI 格式消息"""
|
||
messages = []
|
||
|
||
for msg in request.messages:
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
|
||
# OpenAI 直接支持标准格式
|
||
if isinstance(content, str):
|
||
if content.strip():
|
||
messages.append({"role": role, "content": content})
|
||
elif isinstance(content, list):
|
||
# 多模态内容
|
||
openai_content = []
|
||
for item in content:
|
||
if isinstance(item, dict):
|
||
openai_content.append(item)
|
||
if openai_content:
|
||
messages.append({"role": role, "content": openai_content})
|
||
|
||
return messages
|
||
|
||
def _stream_chat(
|
||
self, client, kwargs: Dict, extra_body: Optional[Dict] = None
|
||
) -> StreamingResponse:
|
||
"""流式聊天"""
|
||
provider_name = self._provider_type.upper()
|
||
logger.info(f"[{provider_name}] 开始流式响应...")
|
||
|
||
def generator():
|
||
from utils.helpers import generate_unique_id, get_current_timestamp
|
||
|
||
nonlocal kwargs
|
||
|
||
# 可能需要执行多轮对话(当发生工具调用时)
|
||
while True:
|
||
resp = client.chat.completions.create(**kwargs)
|
||
full_content = ""
|
||
full_reasoning = ""
|
||
chunk_count = 0
|
||
|
||
tool_calls = []
|
||
current_tool_call = None
|
||
|
||
for chunk in resp:
|
||
if not chunk.choices:
|
||
continue
|
||
|
||
chunk_count += 1
|
||
delta = chunk.choices[0].delta
|
||
|
||
# 1. 收集可能有内容/推理
|
||
delta_content = {}
|
||
if hasattr(delta, "content") and delta.content:
|
||
delta_content["content"] = delta.content
|
||
full_content += delta.content
|
||
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||
delta_content["reasoning_content"] = delta.reasoning_content
|
||
full_reasoning += delta.reasoning_content
|
||
|
||
# 2. 收集可能产生的 tool_calls (流式)
|
||
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||
for tool_call_chunk in delta.tool_calls:
|
||
idx = tool_call_chunk.index
|
||
# 确保 tool_calls 列表足够长
|
||
while len(tool_calls) <= idx:
|
||
tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
|
||
|
||
if tool_call_chunk.id:
|
||
tool_calls[idx]["id"] += tool_call_chunk.id
|
||
if tool_call_chunk.type:
|
||
# 对于 type, 因为 OpenAI 可能会传 chunks, 但通常只在第一块或者每块传, 为了避免 functionfunction, 使用赋值而非累加
|
||
tool_calls[idx]["type"] = tool_call_chunk.type
|
||
if tool_call_chunk.function:
|
||
if tool_call_chunk.function.name:
|
||
tool_calls[idx]["function"]["name"] += tool_call_chunk.function.name
|
||
if tool_call_chunk.function.arguments:
|
||
tool_calls[idx]["function"]["arguments"] += tool_call_chunk.function.arguments
|
||
|
||
# 3. 输出给前端普通文本
|
||
if delta_content and not tool_calls:
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": kwargs["model"],
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": delta_content,
|
||
"finish_reason": None,
|
||
}
|
||
],
|
||
}
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
|
||
# 检查此轮请求是否收到了完整工具调用,若是则执行搜索逻辑并追加继续请求,不再让外部函数退出
|
||
if tool_calls:
|
||
logger.info(f"[{provider_name}] 检测到流式中包含了工具调用进行拦截并处理: {json.dumps(tool_calls, ensure_ascii=False)}")
|
||
|
||
# 把大模型的工具调用请求也追加进去
|
||
assistant_msg = {
|
||
"role": "assistant",
|
||
"content": full_content or None, # 如果工具和普通内容同时存在也保留
|
||
"tool_calls": tool_calls
|
||
}
|
||
if full_reasoning:
|
||
assistant_msg["reasoning_content"] = full_reasoning
|
||
elif self._provider_type == "deepseek" and self._supports_thinking(kwargs["model"]):
|
||
# DeepSeek 推理模型在有工具调用时必须有 reasoning_content 字段
|
||
assistant_msg["reasoning_content"] = ""
|
||
kwargs["messages"].append(assistant_msg)
|
||
|
||
for tc in tool_calls:
|
||
if tc["function"]["name"] == "web_search":
|
||
try:
|
||
args = json.loads(tc["function"]["arguments"])
|
||
query = args.get("query", "")
|
||
mode = "deep" if "advanced" in str(kwargs.get("tools", [])) else "simple"
|
||
logger.info(f"[{provider_name}] 执行搜索插件: {query}")
|
||
search_result = execute_tavily_search(query, mode=mode)
|
||
except Exception as e:
|
||
search_result = f"获取搜索参数或执行搜索失败: {str(e)}"
|
||
logger.error(search_result)
|
||
|
||
# 把执行结果告诉大模型
|
||
kwargs["messages"].append({
|
||
"role": "tool",
|
||
"tool_call_id": tc["id"],
|
||
"name": "web_search",
|
||
"content": search_result
|
||
})
|
||
|
||
# 工具执行完毕,继续发起下一轮请求大模型归纳总结输出
|
||
continue
|
||
|
||
# 如果没有工具调用或者全部分发完毕,正常结束给前端
|
||
finish = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": kwargs["model"],
|
||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||
}
|
||
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
|
||
# 打印流式响应结果
|
||
logger.info(f"[{provider_name}] 流式响应完成:")
|
||
logger.info(f" - chunks: {chunk_count}")
|
||
logger.info(f" - content_length: {len(full_content)} 字符")
|
||
if full_reasoning:
|
||
logger.info(f" - reasoning_length: {len(full_reasoning)} 字符")
|
||
logger.info(
|
||
f" - content_preview: {full_content[:200]}..."
|
||
if len(full_content) > 200
|
||
else f" - content: {full_content}"
|
||
)
|
||
|
||
# 结束外层循环退出生成器
|
||
break
|
||
|
||
return StreamingResponse(generator(), media_type="text/event-stream")
|
||
|
||
def _sync_chat(
|
||
self, client, kwargs: Dict, extra_body: Optional[Dict] = None
|
||
) -> JSONResponse:
|
||
"""非流式聊天"""
|
||
from utils.helpers import generate_unique_id, get_current_timestamp
|
||
|
||
while True:
|
||
resp = client.chat.completions.create(**kwargs)
|
||
|
||
message = resp.choices[0].message
|
||
|
||
# 判断是否涉及工具调用
|
||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||
# 记录这轮的助手回复
|
||
assistant_msg = {"role": "assistant", "content": message.content or None}
|
||
# openai sdk 对象转 dict 存储 tool_calls
|
||
tool_calls_dict = []
|
||
for tc in message.tool_calls:
|
||
tc_dict = {
|
||
"id": tc.id,
|
||
"type": tc.type,
|
||
"function": {
|
||
"name": tc.function.name,
|
||
"arguments": tc.function.arguments
|
||
}
|
||
}
|
||
tool_calls_dict.append(tc_dict)
|
||
assistant_msg["tool_calls"] = tool_calls_dict
|
||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||
assistant_msg["reasoning_content"] = message.reasoning_content
|
||
elif self._provider_type == "deepseek" and self._supports_thinking(kwargs["model"]):
|
||
# DeepSeek 推理模型在有工具调用时必须有 reasoning_content 字段
|
||
assistant_msg["reasoning_content"] = ""
|
||
kwargs["messages"].append(assistant_msg)
|
||
|
||
# 执行所有的工具调用
|
||
for tc in tool_calls_dict:
|
||
if tc["function"]["name"] == "web_search":
|
||
try:
|
||
args = json.loads(tc["function"]["arguments"])
|
||
query = args.get("query", "")
|
||
mode = "deep" if "advanced" in str(kwargs.get("tools", [])) else "simple"
|
||
search_result = execute_tavily_search(query, mode=mode)
|
||
except Exception as e:
|
||
search_result = f"执行搜索失败: {str(e)}"
|
||
|
||
# 把执行结果追加到消息中
|
||
kwargs["messages"].append({
|
||
"role": "tool",
|
||
"tool_call_id": tc["id"],
|
||
"name": "web_search",
|
||
"content": search_result
|
||
})
|
||
# 工具调用完成,发起下一轮请求获取归纳答案
|
||
continue
|
||
|
||
# 处理普通的文本回复
|
||
content = message.content or ""
|
||
response = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion",
|
||
"created": get_current_timestamp(),
|
||
"model": kwargs["model"],
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {
|
||
"role": message.role,
|
||
"content": content,
|
||
},
|
||
"finish_reason": resp.choices[0].finish_reason,
|
||
}
|
||
],
|
||
}
|
||
|
||
# 添加推理内容(如有)
|
||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||
response["choices"][0]["message"][
|
||
"reasoning_content"
|
||
] = message.reasoning_content
|
||
|
||
if resp.usage:
|
||
response["usage"] = {
|
||
"prompt_tokens": resp.usage.prompt_tokens,
|
||
"completion_tokens": resp.usage.completion_tokens,
|
||
"total_tokens": resp.usage.total_tokens,
|
||
}
|
||
|
||
# 打印响应结果
|
||
provider_name = self._provider_type.upper()
|
||
logger.info(f"[{provider_name}] 响应结果:")
|
||
logger.info(f" - content_length: {len(content)} 字符")
|
||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||
logger.info(f" - reasoning_length: {len(message.reasoning_content)} 字符")
|
||
logger.info(
|
||
f" - content_preview: {content[:200]}..."
|
||
if len(content) > 200
|
||
else f" - content: {content}"
|
||
)
|
||
if resp.usage:
|
||
logger.info(f" - usage: {response['usage']}")
|
||
|
||
return JSONResponse(content=response)
|
||
|
||
|
||
class DeepseekAdapter(OpenAIAdapter):
|
||
"""Deepseek 平台适配器(继承 OpenAI 适配器)"""
|
||
|
||
def __init__(self):
|
||
super().__init__(provider_type="deepseek")
|