293 lines
11 KiB
Python
293 lines
11 KiB
Python
"""
|
||
OpenAI 适配器
|
||
支持 OpenAI 及兼容 API(如 Deepseek)
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from .base import ChatCompletionRequest, ModelInfo
|
||
from .unified_adapter import UnifiedOpenAIAdapter
|
||
from .plugins import (
|
||
get_web_search_mode,
|
||
build_openai_search_tool,
|
||
execute_tavily_search,
|
||
get_current_time_info,
|
||
)
|
||
from core import get_logger
|
||
|
||
logger = get_logger()
|
||
|
||
# OpenAI 模型配置
|
||
OPENAI_MODELS = [
|
||
ModelInfo(
|
||
id="gpt-4o",
|
||
name="GPT-4o",
|
||
description="最新旗舰多模态模型",
|
||
max_tokens=128000,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=True,
|
||
supports_files=True,
|
||
),
|
||
ModelInfo(
|
||
id="gpt-4o-mini",
|
||
name="GPT-4o Mini",
|
||
description="高性价比多模态模型",
|
||
max_tokens=128000,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=True,
|
||
supports_files=True,
|
||
),
|
||
ModelInfo(
|
||
id="gpt-4-turbo",
|
||
name="GPT-4 Turbo",
|
||
description="GPT-4 增强版",
|
||
max_tokens=128000,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=True,
|
||
supports_files=False,
|
||
),
|
||
ModelInfo(
|
||
id="gpt-3.5-turbo",
|
||
name="GPT-3.5 Turbo",
|
||
description="快速经济的选择",
|
||
max_tokens=16385,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=False,
|
||
supports_files=False,
|
||
),
|
||
]
|
||
|
||
# Deepseek 模型配置
|
||
DEEPSEEK_MODELS = [
|
||
ModelInfo(
|
||
id="deepseek-chat",
|
||
name="Deepseek Chat",
|
||
description="Deepseek 对话模型",
|
||
max_tokens=64000,
|
||
provider="Deepseek",
|
||
supports_thinking=False,
|
||
supports_web_search=False,
|
||
supports_vision=False,
|
||
supports_files=False,
|
||
),
|
||
ModelInfo(
|
||
id="deepseek-reasoner",
|
||
name="Deepseek Reasoner",
|
||
description="Deepseek 推理模型(支持深度思考)",
|
||
max_tokens=64000,
|
||
provider="Deepseek",
|
||
supports_thinking=True,
|
||
supports_web_search=True,
|
||
supports_vision=False,
|
||
supports_files=False,
|
||
),
|
||
]
|
||
|
||
# 从 DEEPSEEK_MODELS 自动计算
|
||
DEEPSEEK_THINKING_MODELS = {m.id.lower() for m in DEEPSEEK_MODELS if m.supports_thinking}
|
||
|
||
|
||
class OpenAIAdapter(UnifiedOpenAIAdapter):
|
||
"""OpenAI 平台适配器"""
|
||
|
||
_provider_type = "openai"
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
return "openai"
|
||
|
||
def list_models(self) -> List[ModelInfo]:
|
||
return OPENAI_MODELS
|
||
|
||
def _get_extra_params(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||
"""获取 OpenAI 特殊参数"""
|
||
extra_params = {}
|
||
|
||
# 联网搜索 - 使用 Function Calling
|
||
web_search_mode = get_web_search_mode(request)
|
||
if web_search_mode:
|
||
extra_params["tools"] = [build_openai_search_tool(web_search_mode)]
|
||
logger.info(f"[OpenAI] 联网搜索已启用: mode={web_search_mode}")
|
||
|
||
return extra_params
|
||
|
||
def _stream_chat(self, client, kwargs: Dict) -> StreamingResponse:
|
||
"""
|
||
流式聊天 - 处理联网搜索的 Function Calling
|
||
"""
|
||
logger.info(f"[OpenAI] 开始流式响应...")
|
||
|
||
def generator():
|
||
from utils.helpers import generate_unique_id, get_current_timestamp
|
||
|
||
# 可能需要多轮对话(当发生工具调用时)
|
||
while True:
|
||
resp = client.chat.completions.create(**kwargs)
|
||
full_content = ""
|
||
full_reasoning = ""
|
||
chunk_count = 0
|
||
|
||
tool_calls = []
|
||
current_tool_call = None
|
||
|
||
for chunk in resp:
|
||
if not chunk.choices:
|
||
continue
|
||
|
||
chunk_count += 1
|
||
delta = chunk.choices[0].delta
|
||
|
||
# 收集内容
|
||
delta_content = {}
|
||
if hasattr(delta, "content") and delta.content:
|
||
delta_content["content"] = delta.content
|
||
full_content += delta.content
|
||
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||
delta_content["reasoning_content"] = delta.reasoning_content
|
||
full_reasoning += delta.reasoning_content
|
||
|
||
# 收集 tool_calls(流式)
|
||
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||
for tool_call_chunk in delta.tool_calls:
|
||
idx = tool_call_chunk.index
|
||
while len(tool_calls) <= idx:
|
||
tool_calls.append({
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {"name": "", "arguments": ""}
|
||
})
|
||
|
||
if tool_call_chunk.id:
|
||
tool_calls[idx]["id"] += tool_call_chunk.id
|
||
if tool_call_chunk.type:
|
||
tool_calls[idx]["type"] = tool_call_chunk.type
|
||
if tool_call_chunk.function:
|
||
if tool_call_chunk.function.name:
|
||
tool_calls[idx]["function"]["name"] += tool_call_chunk.function.name
|
||
if tool_call_chunk.function.arguments:
|
||
tool_calls[idx]["function"]["arguments"] += tool_call_chunk.function.arguments
|
||
|
||
# 输出普通内容
|
||
if delta_content and not tool_calls:
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": kwargs["model"],
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": delta_content,
|
||
"finish_reason": None,
|
||
}],
|
||
}
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
|
||
# 检查是否有完整的工具调用
|
||
if tool_calls:
|
||
logger.info(f"[OpenAI] 检测到工具调用: {json.dumps(tool_calls, ensure_ascii=False)}")
|
||
|
||
# 添加助手消息
|
||
assistant_msg = {
|
||
"role": "assistant",
|
||
"content": full_content or None,
|
||
"tool_calls": tool_calls
|
||
}
|
||
if full_reasoning:
|
||
assistant_msg["reasoning_content"] = full_reasoning
|
||
kwargs["messages"].append(assistant_msg)
|
||
|
||
# 执行搜索工具
|
||
for tc in tool_calls:
|
||
if tc["function"]["name"] == "web_search":
|
||
try:
|
||
args = json.loads(tc["function"]["arguments"])
|
||
query = args.get("query", "")
|
||
mode = "deep" if "advanced" in str(kwargs.get("tools", [])) else "simple"
|
||
logger.info(f"[OpenAI] 执行搜索: {query}")
|
||
search_result = execute_tavily_search(query, mode=mode)
|
||
except Exception as e:
|
||
search_result = f"搜索失败: {str(e)}"
|
||
logger.error(search_result)
|
||
|
||
kwargs["messages"].append({
|
||
"role": "tool",
|
||
"tool_call_id": tc["id"],
|
||
"name": "web_search",
|
||
"content": search_result
|
||
})
|
||
|
||
# 继续请求归纳答案
|
||
continue
|
||
|
||
# 没有工具调用,结束
|
||
finish = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": kwargs["model"],
|
||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||
}
|
||
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
|
||
logger.info(f"[OpenAI] 流式响应完成: chunks={chunk_count}, content_len={len(full_content)}")
|
||
break
|
||
|
||
return StreamingResponse(generator(), media_type="text/event-stream")
|
||
|
||
|
||
class DeepseekAdapter(UnifiedOpenAIAdapter):
|
||
"""Deepseek 平台适配器"""
|
||
|
||
_provider_type = "deepseek"
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
return "deepseek"
|
||
|
||
def list_models(self) -> List[ModelInfo]:
|
||
return DEEPSEEK_MODELS
|
||
|
||
def _supports_thinking(self, model: str) -> bool:
|
||
"""检查模型是否支持深度思考"""
|
||
return model.lower() in DEEPSEEK_THINKING_MODELS
|
||
|
||
def _get_extra_params(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||
"""获取 Deepseek 特殊参数"""
|
||
extra_params = {}
|
||
|
||
# 深度思考 - 始终传递,明确启用或禁用
|
||
logger.info(f"[Deepseek] 深度思考请求: deep_thinking={request.deep_thinking}, model={request.model}")
|
||
|
||
supports_thinking = self._supports_thinking(request.model)
|
||
logger.info(f"[Deepseek] 模型 {request.model} 支持深度思考: {supports_thinking}")
|
||
|
||
thinking_enabled = request.deep_thinking and supports_thinking
|
||
thinking_type = "enabled" if thinking_enabled else "disabled"
|
||
extra_params["extra_body"] = {"thinking": {"type": thinking_type}}
|
||
logger.info(f"[Deepseek] 深度思考最终状态: {thinking_type}")
|
||
|
||
# 联网搜索 - 使用 Function Calling
|
||
web_search_mode = get_web_search_mode(request)
|
||
if web_search_mode:
|
||
extra_params["tools"] = [build_openai_search_tool(web_search_mode)]
|
||
logger.info(f"[Deepseek] 联网搜索已启用: mode={web_search_mode}")
|
||
|
||
return extra_params
|
||
|
||
def _stream_chat(self, client, kwargs: Dict) -> StreamingResponse:
|
||
"""流式聊天 - 复用 OpenAI 的工具调用逻辑"""
|
||
# DeepSeek 使用相同的工具调用处理逻辑
|
||
return OpenAIAdapter._stream_chat(self, client, kwargs) |