ai-chat-ui/server/adapters/openai_adapter.py

485 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
OpenAI 适配器
支持 OpenAI 及兼容 API如 Deepseek
"""
import json
import os
from typing import Dict, List, Optional
from fastapi.responses import JSONResponse, StreamingResponse
from .base import BaseAdapter, ChatCompletionRequest, ModelInfo
from .plugins import get_web_search_mode, build_openai_search_tool, execute_tavily_search, get_current_time_info
from core import get_logger
logger = get_logger()
# OpenAI 模型配置
OPENAI_MODELS = [
ModelInfo(
id="gpt-4o",
name="GPT-4o",
description="最新旗舰多模态模型",
max_tokens=128000,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=True,
supports_files=True,
),
ModelInfo(
id="gpt-4o-mini",
name="GPT-4o Mini",
description="高性价比多模态模型",
max_tokens=128000,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=True,
supports_files=True,
),
ModelInfo(
id="gpt-4-turbo",
name="GPT-4 Turbo",
description="GPT-4 增强版",
max_tokens=128000,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=True,
supports_files=False,
),
ModelInfo(
id="gpt-3.5-turbo",
name="GPT-3.5 Turbo",
description="快速经济的选择",
max_tokens=16385,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=False,
supports_files=False,
),
]
# Deepseek 模型配置
DEEPSEEK_MODELS = [
ModelInfo(
id="deepseek-chat",
name="Deepseek Chat",
description="Deepseek 对话模型",
max_tokens=64000,
provider="Deepseek",
supports_thinking=False,
supports_web_search=False,
supports_vision=False,
supports_files=False,
),
ModelInfo(
id="deepseek-reasoner",
name="Deepseek Reasoner",
description="Deepseek 推理模型(支持深度思考)",
max_tokens=64000,
provider="Deepseek",
supports_thinking=True,
supports_web_search=True, # 注:通过内置检索增强实现
supports_vision=False,
supports_files=False,
),
]
# DeepSeek 支持深度思考的模型
DEEPSEEK_THINKING_MODELS = {"deepseek-reasoner"}
class OpenAIAdapter(BaseAdapter):
"""OpenAI 平台适配器"""
_client = None
_provider_type: str = "openai" # openai 或 deepseek
def __init__(self, provider_type: str = "openai"):
self._provider_type = provider_type
@property
def provider_name(self) -> str:
return self._provider_type
def is_available(self) -> bool:
"""检查 API Key 是否配置"""
if self._provider_type == "deepseek":
return bool(os.getenv("DEEPSEEK_API_KEY"))
return bool(os.getenv("OPENAI_API_KEY"))
def _get_client(self):
"""获取 OpenAI 客户端(懒加载)"""
if self._client is None:
from openai import OpenAI
if self._provider_type == "deepseek":
api_key = os.getenv("DEEPSEEK_API_KEY", "")
base_url = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
else:
api_key = os.getenv("OPENAI_API_KEY", "")
base_url = os.getenv("OPENAI_BASE_URL") # 可选自定义端点
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
self._client = OpenAI(**kwargs)
return self._client
def list_models(self) -> List[ModelInfo]:
if self._provider_type == "deepseek":
return DEEPSEEK_MODELS
return OPENAI_MODELS
async def chat(self, request: ChatCompletionRequest):
"""
处理 OpenAI 聊天请求
直接使用 OpenAI SDK支持流式/非流式
"""
client = self._get_client()
# 打印请求参数
provider_name = self._provider_type.upper()
logger.info(f"[{provider_name}] 请求参数:")
logger.info(f" - model: {request.model}")
logger.info(f" - stream: {request.stream}")
logger.info(f" - temperature: {request.temperature}")
logger.info(f" - max_tokens: {request.max_tokens}")
logger.info(f" - provider_type: {self._provider_type}")
if self._provider_type == "deepseek":
logger.info(f" - deep_thinking: {request.deep_thinking}")
# 构建消息
messages = self._build_messages(request)
# 统一添加联网搜索插件参数
web_search_mode = get_web_search_mode(request)
if web_search_mode:
# 注入当前时间信息到 System Prompt 中,以便模型拥有时间感知能力
time_info = get_current_time_info()
has_system = False
for msg in messages:
if msg.get("role") == "system":
msg["content"] = f"当前系统时间:{time_info}\n" + str(msg.get("content", ""))
has_system = True
break
if not has_system:
messages.insert(0, {"role": "system", "content": f"当前系统时间:{time_info}"})
logger.info(
f" - messages: {json.dumps(messages, ensure_ascii=False, indent=2)}"
)
# 构建请求参数
kwargs = {
"model": request.model,
"messages": messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"stream": request.stream,
}
if web_search_mode:
search_tool = build_openai_search_tool(web_search_mode)
kwargs["tools"] = [search_tool]
# DeepSeek 深度思考支持
extra_body = None
if self._provider_type == "deepseek" and request.deep_thinking:
if self._supports_thinking(request.model):
extra_body = {"thinking": {"type": "enabled"}}
kwargs["extra_body"] = extra_body
logger.info(
f"[{provider_name}] 深度思考已启用: extra_body = {extra_body}"
)
if request.stream:
return self._stream_chat(client, kwargs, extra_body)
else:
return self._sync_chat(client, kwargs, extra_body)
def _supports_thinking(self, model: str) -> bool:
"""检查模型是否支持深度思考"""
return model.lower() in DEEPSEEK_THINKING_MODELS
def _build_messages(self, request: ChatCompletionRequest) -> List[Dict]:
"""构建 OpenAI 格式消息"""
messages = []
for msg in request.messages:
role = msg.get("role", "user")
content = msg.get("content", "")
# OpenAI 直接支持标准格式
if isinstance(content, str):
if content.strip():
messages.append({"role": role, "content": content})
elif isinstance(content, list):
# 多模态内容
openai_content = []
for item in content:
if isinstance(item, dict):
openai_content.append(item)
if openai_content:
messages.append({"role": role, "content": openai_content})
return messages
def _stream_chat(
self, client, kwargs: Dict, extra_body: Optional[Dict] = None
) -> StreamingResponse:
"""流式聊天"""
provider_name = self._provider_type.upper()
logger.info(f"[{provider_name}] 开始流式响应...")
def generator():
from utils.helpers import generate_unique_id, get_current_timestamp
nonlocal kwargs
# 可能需要执行多轮对话(当发生工具调用时)
while True:
resp = client.chat.completions.create(**kwargs)
full_content = ""
full_reasoning = ""
chunk_count = 0
tool_calls = []
current_tool_call = None
for chunk in resp:
if not chunk.choices:
continue
chunk_count += 1
delta = chunk.choices[0].delta
# 1. 收集可能有内容/推理
delta_content = {}
if hasattr(delta, "content") and delta.content:
delta_content["content"] = delta.content
full_content += delta.content
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
delta_content["reasoning_content"] = delta.reasoning_content
full_reasoning += delta.reasoning_content
# 2. 收集可能产生的 tool_calls (流式)
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tool_call_chunk in delta.tool_calls:
idx = tool_call_chunk.index
# 确保 tool_calls 列表足够长
while len(tool_calls) <= idx:
tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
if tool_call_chunk.id:
tool_calls[idx]["id"] += tool_call_chunk.id
if tool_call_chunk.type:
# 对于 type, 因为 OpenAI 可能会传 chunks, 但通常只在第一块或者每块传, 为了避免 functionfunction, 使用赋值而非累加
tool_calls[idx]["type"] = tool_call_chunk.type
if tool_call_chunk.function:
if tool_call_chunk.function.name:
tool_calls[idx]["function"]["name"] += tool_call_chunk.function.name
if tool_call_chunk.function.arguments:
tool_calls[idx]["function"]["arguments"] += tool_call_chunk.function.arguments
# 3. 输出给前端普通文本
if delta_content and not tool_calls:
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [
{
"index": 0,
"delta": delta_content,
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
# 检查此轮请求是否收到了完整工具调用,若是则执行搜索逻辑并追加继续请求,不再让外部函数退出
if tool_calls:
logger.info(f"[{provider_name}] 检测到流式中包含了工具调用进行拦截并处理: {json.dumps(tool_calls, ensure_ascii=False)}")
# 把大模型的工具调用请求也追加进去
assistant_msg = {
"role": "assistant",
"content": full_content or None, # 如果工具和普通内容同时存在也保留
"tool_calls": tool_calls
}
if full_reasoning:
assistant_msg["reasoning_content"] = full_reasoning
elif self._provider_type == "deepseek" and self._supports_thinking(kwargs["model"]):
# DeepSeek 推理模型在有工具调用时必须有 reasoning_content 字段
assistant_msg["reasoning_content"] = ""
kwargs["messages"].append(assistant_msg)
for tc in tool_calls:
if tc["function"]["name"] == "web_search":
try:
args = json.loads(tc["function"]["arguments"])
query = args.get("query", "")
mode = "deep" if "advanced" in str(kwargs.get("tools", [])) else "simple"
logger.info(f"[{provider_name}] 执行搜索插件: {query}")
search_result = execute_tavily_search(query, mode=mode)
except Exception as e:
search_result = f"获取搜索参数或执行搜索失败: {str(e)}"
logger.error(search_result)
# 把执行结果告诉大模型
kwargs["messages"].append({
"role": "tool",
"tool_call_id": tc["id"],
"name": "web_search",
"content": search_result
})
# 工具执行完毕,继续发起下一轮请求大模型归纳总结输出
continue
# 如果没有工具调用或者全部分发完毕,正常结束给前端
finish = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
# 打印流式响应结果
logger.info(f"[{provider_name}] 流式响应完成:")
logger.info(f" - chunks: {chunk_count}")
logger.info(f" - content_length: {len(full_content)} 字符")
if full_reasoning:
logger.info(f" - reasoning_length: {len(full_reasoning)} 字符")
logger.info(
f" - content_preview: {full_content[:200]}..."
if len(full_content) > 200
else f" - content: {full_content}"
)
# 结束外层循环退出生成器
break
return StreamingResponse(generator(), media_type="text/event-stream")
def _sync_chat(
self, client, kwargs: Dict, extra_body: Optional[Dict] = None
) -> JSONResponse:
"""非流式聊天"""
from utils.helpers import generate_unique_id, get_current_timestamp
while True:
resp = client.chat.completions.create(**kwargs)
message = resp.choices[0].message
# 判断是否涉及工具调用
if hasattr(message, "tool_calls") and message.tool_calls:
# 记录这轮的助手回复
assistant_msg = {"role": "assistant", "content": message.content or None}
# openai sdk 对象转 dict 存储 tool_calls
tool_calls_dict = []
for tc in message.tool_calls:
tc_dict = {
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
}
tool_calls_dict.append(tc_dict)
assistant_msg["tool_calls"] = tool_calls_dict
if hasattr(message, "reasoning_content") and message.reasoning_content:
assistant_msg["reasoning_content"] = message.reasoning_content
elif self._provider_type == "deepseek" and self._supports_thinking(kwargs["model"]):
# DeepSeek 推理模型在有工具调用时必须有 reasoning_content 字段
assistant_msg["reasoning_content"] = ""
kwargs["messages"].append(assistant_msg)
# 执行所有的工具调用
for tc in tool_calls_dict:
if tc["function"]["name"] == "web_search":
try:
args = json.loads(tc["function"]["arguments"])
query = args.get("query", "")
mode = "deep" if "advanced" in str(kwargs.get("tools", [])) else "simple"
search_result = execute_tavily_search(query, mode=mode)
except Exception as e:
search_result = f"执行搜索失败: {str(e)}"
# 把执行结果追加到消息中
kwargs["messages"].append({
"role": "tool",
"tool_call_id": tc["id"],
"name": "web_search",
"content": search_result
})
# 工具调用完成,发起下一轮请求获取归纳答案
continue
# 处理普通的文本回复
content = message.content or ""
response = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [
{
"index": 0,
"message": {
"role": message.role,
"content": content,
},
"finish_reason": resp.choices[0].finish_reason,
}
],
}
# 添加推理内容(如有)
if hasattr(message, "reasoning_content") and message.reasoning_content:
response["choices"][0]["message"][
"reasoning_content"
] = message.reasoning_content
if resp.usage:
response["usage"] = {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens,
}
# 打印响应结果
provider_name = self._provider_type.upper()
logger.info(f"[{provider_name}] 响应结果:")
logger.info(f" - content_length: {len(content)} 字符")
if hasattr(message, "reasoning_content") and message.reasoning_content:
logger.info(f" - reasoning_length: {len(message.reasoning_content)} 字符")
logger.info(
f" - content_preview: {content[:200]}..."
if len(content) > 200
else f" - content: {content}"
)
if resp.usage:
logger.info(f" - usage: {response['usage']}")
return JSONResponse(content=response)
class DeepseekAdapter(OpenAIAdapter):
"""Deepseek 平台适配器(继承 OpenAI 适配器)"""
def __init__(self):
super().__init__(provider_type="deepseek")