ai-chat-ui/server/adapters/openai_adapter.py

293 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
OpenAI 适配器
支持 OpenAI 及兼容 API如 Deepseek
"""
import json
import os
from typing import Any, Dict, List, Optional
from fastapi.responses import StreamingResponse
from .base import ChatCompletionRequest, ModelInfo
from .unified_adapter import UnifiedOpenAIAdapter
from .plugins import (
get_web_search_mode,
build_openai_search_tool,
execute_tavily_search,
get_current_time_info,
)
from core import get_logger
logger = get_logger()
# OpenAI 模型配置
OPENAI_MODELS = [
ModelInfo(
id="gpt-4o",
name="GPT-4o",
description="最新旗舰多模态模型",
max_tokens=128000,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=True,
supports_files=True,
),
ModelInfo(
id="gpt-4o-mini",
name="GPT-4o Mini",
description="高性价比多模态模型",
max_tokens=128000,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=True,
supports_files=True,
),
ModelInfo(
id="gpt-4-turbo",
name="GPT-4 Turbo",
description="GPT-4 增强版",
max_tokens=128000,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=True,
supports_files=False,
),
ModelInfo(
id="gpt-3.5-turbo",
name="GPT-3.5 Turbo",
description="快速经济的选择",
max_tokens=16385,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=False,
supports_files=False,
),
]
# Deepseek 模型配置
DEEPSEEK_MODELS = [
ModelInfo(
id="deepseek-chat",
name="Deepseek Chat",
description="Deepseek 对话模型",
max_tokens=64000,
provider="Deepseek",
supports_thinking=False,
supports_web_search=False,
supports_vision=False,
supports_files=False,
),
ModelInfo(
id="deepseek-reasoner",
name="Deepseek Reasoner",
description="Deepseek 推理模型(支持深度思考)",
max_tokens=64000,
provider="Deepseek",
supports_thinking=True,
supports_web_search=True,
supports_vision=False,
supports_files=False,
),
]
# 从 DEEPSEEK_MODELS 自动计算
DEEPSEEK_THINKING_MODELS = {m.id.lower() for m in DEEPSEEK_MODELS if m.supports_thinking}
class OpenAIAdapter(UnifiedOpenAIAdapter):
"""OpenAI 平台适配器"""
_provider_type = "openai"
@property
def provider_name(self) -> str:
return "openai"
def list_models(self) -> List[ModelInfo]:
return OPENAI_MODELS
def _get_extra_params(self, request: ChatCompletionRequest) -> Dict[str, Any]:
"""获取 OpenAI 特殊参数"""
extra_params = {}
# 联网搜索 - 使用 Function Calling
web_search_mode = get_web_search_mode(request)
if web_search_mode:
extra_params["tools"] = [build_openai_search_tool(web_search_mode)]
logger.info(f"[OpenAI] 联网搜索已启用: mode={web_search_mode}")
return extra_params
def _stream_chat(self, client, kwargs: Dict) -> StreamingResponse:
"""
流式聊天 - 处理联网搜索的 Function Calling
"""
logger.info(f"[OpenAI] 开始流式响应...")
def generator():
from utils.helpers import generate_unique_id, get_current_timestamp
# 可能需要多轮对话(当发生工具调用时)
while True:
resp = client.chat.completions.create(**kwargs)
full_content = ""
full_reasoning = ""
chunk_count = 0
tool_calls = []
current_tool_call = None
for chunk in resp:
if not chunk.choices:
continue
chunk_count += 1
delta = chunk.choices[0].delta
# 收集内容
delta_content = {}
if hasattr(delta, "content") and delta.content:
delta_content["content"] = delta.content
full_content += delta.content
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
delta_content["reasoning_content"] = delta.reasoning_content
full_reasoning += delta.reasoning_content
# 收集 tool_calls流式
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tool_call_chunk in delta.tool_calls:
idx = tool_call_chunk.index
while len(tool_calls) <= idx:
tool_calls.append({
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""}
})
if tool_call_chunk.id:
tool_calls[idx]["id"] += tool_call_chunk.id
if tool_call_chunk.type:
tool_calls[idx]["type"] = tool_call_chunk.type
if tool_call_chunk.function:
if tool_call_chunk.function.name:
tool_calls[idx]["function"]["name"] += tool_call_chunk.function.name
if tool_call_chunk.function.arguments:
tool_calls[idx]["function"]["arguments"] += tool_call_chunk.function.arguments
# 输出普通内容
if delta_content and not tool_calls:
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [{
"index": 0,
"delta": delta_content,
"finish_reason": None,
}],
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
# 检查是否有完整的工具调用
if tool_calls:
logger.info(f"[OpenAI] 检测到工具调用: {json.dumps(tool_calls, ensure_ascii=False)}")
# 添加助手消息
assistant_msg = {
"role": "assistant",
"content": full_content or None,
"tool_calls": tool_calls
}
if full_reasoning:
assistant_msg["reasoning_content"] = full_reasoning
kwargs["messages"].append(assistant_msg)
# 执行搜索工具
for tc in tool_calls:
if tc["function"]["name"] == "web_search":
try:
args = json.loads(tc["function"]["arguments"])
query = args.get("query", "")
mode = "deep" if "advanced" in str(kwargs.get("tools", [])) else "simple"
logger.info(f"[OpenAI] 执行搜索: {query}")
search_result = execute_tavily_search(query, mode=mode)
except Exception as e:
search_result = f"搜索失败: {str(e)}"
logger.error(search_result)
kwargs["messages"].append({
"role": "tool",
"tool_call_id": tc["id"],
"name": "web_search",
"content": search_result
})
# 继续请求归纳答案
continue
# 没有工具调用,结束
finish = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
logger.info(f"[OpenAI] 流式响应完成: chunks={chunk_count}, content_len={len(full_content)}")
break
return StreamingResponse(generator(), media_type="text/event-stream")
class DeepseekAdapter(UnifiedOpenAIAdapter):
"""Deepseek 平台适配器"""
_provider_type = "deepseek"
@property
def provider_name(self) -> str:
return "deepseek"
def list_models(self) -> List[ModelInfo]:
return DEEPSEEK_MODELS
def _supports_thinking(self, model: str) -> bool:
"""检查模型是否支持深度思考"""
return model.lower() in DEEPSEEK_THINKING_MODELS
def _get_extra_params(self, request: ChatCompletionRequest) -> Dict[str, Any]:
"""获取 Deepseek 特殊参数"""
extra_params = {}
# 深度思考 - 始终传递,明确启用或禁用
logger.info(f"[Deepseek] 深度思考请求: deep_thinking={request.deep_thinking}, model={request.model}")
supports_thinking = self._supports_thinking(request.model)
logger.info(f"[Deepseek] 模型 {request.model} 支持深度思考: {supports_thinking}")
thinking_enabled = request.deep_thinking and supports_thinking
thinking_type = "enabled" if thinking_enabled else "disabled"
extra_params["extra_body"] = {"thinking": {"type": thinking_type}}
logger.info(f"[Deepseek] 深度思考最终状态: {thinking_type}")
# 联网搜索 - 使用 Function Calling
web_search_mode = get_web_search_mode(request)
if web_search_mode:
extra_params["tools"] = [build_openai_search_tool(web_search_mode)]
logger.info(f"[Deepseek] 联网搜索已启用: mode={web_search_mode}")
return extra_params
def _stream_chat(self, client, kwargs: Dict) -> StreamingResponse:
"""流式聊天 - 复用 OpenAI 的工具调用逻辑"""
# DeepSeek 使用相同的工具调用处理逻辑
return OpenAIAdapter._stream_chat(self, client, kwargs)