343 lines
12 KiB
Python
343 lines
12 KiB
Python
"""
|
||
OpenAI 适配器
|
||
支持 OpenAI 及兼容 API(如 Deepseek)
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
from typing import Dict, List, Optional
|
||
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
||
from .base import BaseAdapter, ChatCompletionRequest, ModelInfo
|
||
from core import get_logger
|
||
|
||
logger = get_logger()
|
||
|
||
# OpenAI 模型配置
|
||
OPENAI_MODELS = [
|
||
ModelInfo(
|
||
id="gpt-4o",
|
||
name="GPT-4o",
|
||
description="最新旗舰多模态模型",
|
||
max_tokens=128000,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=True,
|
||
supports_files=True,
|
||
),
|
||
ModelInfo(
|
||
id="gpt-4o-mini",
|
||
name="GPT-4o Mini",
|
||
description="高性价比多模态模型",
|
||
max_tokens=128000,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=True,
|
||
supports_files=True,
|
||
),
|
||
ModelInfo(
|
||
id="gpt-4-turbo",
|
||
name="GPT-4 Turbo",
|
||
description="GPT-4 增强版",
|
||
max_tokens=128000,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=True,
|
||
supports_files=False,
|
||
),
|
||
ModelInfo(
|
||
id="gpt-3.5-turbo",
|
||
name="GPT-3.5 Turbo",
|
||
description="快速经济的选择",
|
||
max_tokens=16385,
|
||
provider="OpenAI",
|
||
supports_thinking=False,
|
||
supports_web_search=True,
|
||
supports_vision=False,
|
||
supports_files=False,
|
||
),
|
||
]
|
||
|
||
# Deepseek 模型配置
|
||
DEEPSEEK_MODELS = [
|
||
ModelInfo(
|
||
id="deepseek-chat",
|
||
name="Deepseek Chat",
|
||
description="Deepseek 对话模型",
|
||
max_tokens=64000,
|
||
provider="Deepseek",
|
||
supports_thinking=False,
|
||
supports_web_search=False,
|
||
supports_vision=False,
|
||
supports_files=False,
|
||
),
|
||
ModelInfo(
|
||
id="deepseek-reasoner",
|
||
name="Deepseek Reasoner",
|
||
description="Deepseek 推理模型(支持深度思考)",
|
||
max_tokens=64000,
|
||
provider="Deepseek",
|
||
supports_thinking=True,
|
||
supports_web_search=True, # 注:通过内置检索增强实现
|
||
supports_vision=False,
|
||
supports_files=False,
|
||
),
|
||
]
|
||
|
||
# DeepSeek 支持深度思考的模型
|
||
DEEPSEEK_THINKING_MODELS = {"deepseek-reasoner"}
|
||
|
||
|
||
class OpenAIAdapter(BaseAdapter):
|
||
"""OpenAI 平台适配器"""
|
||
|
||
_client = None
|
||
_provider_type: str = "openai" # openai 或 deepseek
|
||
|
||
def __init__(self, provider_type: str = "openai"):
|
||
self._provider_type = provider_type
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
return self._provider_type
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查 API Key 是否配置"""
|
||
if self._provider_type == "deepseek":
|
||
return bool(os.getenv("DEEPSEEK_API_KEY"))
|
||
return bool(os.getenv("OPENAI_API_KEY"))
|
||
|
||
def _get_client(self):
|
||
"""获取 OpenAI 客户端(懒加载)"""
|
||
if self._client is None:
|
||
from openai import OpenAI
|
||
|
||
if self._provider_type == "deepseek":
|
||
api_key = os.getenv("DEEPSEEK_API_KEY", "")
|
||
base_url = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
|
||
else:
|
||
api_key = os.getenv("OPENAI_API_KEY", "")
|
||
base_url = os.getenv("OPENAI_BASE_URL") # 可选自定义端点
|
||
|
||
kwargs = {"api_key": api_key}
|
||
if base_url:
|
||
kwargs["base_url"] = base_url
|
||
|
||
self._client = OpenAI(**kwargs)
|
||
return self._client
|
||
|
||
def list_models(self) -> List[ModelInfo]:
|
||
if self._provider_type == "deepseek":
|
||
return DEEPSEEK_MODELS
|
||
return OPENAI_MODELS
|
||
|
||
async def chat(self, request: ChatCompletionRequest):
|
||
"""
|
||
处理 OpenAI 聊天请求
|
||
直接使用 OpenAI SDK,支持流式/非流式
|
||
"""
|
||
client = self._get_client()
|
||
|
||
# 打印请求参数
|
||
provider_name = self._provider_type.upper()
|
||
logger.info(f"[{provider_name}] 请求参数:")
|
||
logger.info(f" - model: {request.model}")
|
||
logger.info(f" - stream: {request.stream}")
|
||
logger.info(f" - temperature: {request.temperature}")
|
||
logger.info(f" - max_tokens: {request.max_tokens}")
|
||
logger.info(f" - provider_type: {self._provider_type}")
|
||
if self._provider_type == "deepseek":
|
||
logger.info(f" - deep_thinking: {request.deep_thinking}")
|
||
|
||
# 构建消息
|
||
messages = self._build_messages(request)
|
||
logger.info(
|
||
f" - messages: {json.dumps(messages, ensure_ascii=False, indent=2)}"
|
||
)
|
||
|
||
# 构建请求参数
|
||
kwargs = {
|
||
"model": request.model,
|
||
"messages": messages,
|
||
"temperature": request.temperature,
|
||
"max_tokens": request.max_tokens,
|
||
"stream": request.stream,
|
||
}
|
||
|
||
# DeepSeek 深度思考支持
|
||
extra_body = None
|
||
if self._provider_type == "deepseek" and request.deep_thinking:
|
||
if self._supports_thinking(request.model):
|
||
extra_body = {"thinking": {"type": "enabled"}}
|
||
kwargs["extra_body"] = extra_body
|
||
logger.info(
|
||
f"[{provider_name}] 深度思考已启用: extra_body = {extra_body}"
|
||
)
|
||
|
||
if request.stream:
|
||
return self._stream_chat(client, kwargs, extra_body)
|
||
else:
|
||
return self._sync_chat(client, kwargs, extra_body)
|
||
|
||
def _supports_thinking(self, model: str) -> bool:
|
||
"""检查模型是否支持深度思考"""
|
||
return model.lower() in DEEPSEEK_THINKING_MODELS
|
||
|
||
def _build_messages(self, request: ChatCompletionRequest) -> List[Dict]:
|
||
"""构建 OpenAI 格式消息"""
|
||
messages = []
|
||
|
||
for msg in request.messages:
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
|
||
# OpenAI 直接支持标准格式
|
||
if isinstance(content, str):
|
||
if content.strip():
|
||
messages.append({"role": role, "content": content})
|
||
elif isinstance(content, list):
|
||
# 多模态内容
|
||
openai_content = []
|
||
for item in content:
|
||
if isinstance(item, dict):
|
||
openai_content.append(item)
|
||
if openai_content:
|
||
messages.append({"role": role, "content": openai_content})
|
||
|
||
return messages
|
||
|
||
def _stream_chat(
|
||
self, client, kwargs: Dict, extra_body: Optional[Dict] = None
|
||
) -> StreamingResponse:
|
||
"""流式聊天"""
|
||
provider_name = self._provider_type.upper()
|
||
logger.info(f"[{provider_name}] 开始流式响应...")
|
||
|
||
def generator():
|
||
from utils.helpers import generate_unique_id, get_current_timestamp
|
||
|
||
resp = client.chat.completions.create(**kwargs)
|
||
|
||
full_content = ""
|
||
full_reasoning = ""
|
||
chunk_count = 0
|
||
for chunk in resp:
|
||
if chunk.choices:
|
||
chunk_count += 1
|
||
delta = chunk.choices[0].delta
|
||
|
||
delta_content = {}
|
||
if hasattr(delta, "content") and delta.content:
|
||
delta_content["content"] = delta.content
|
||
full_content += delta.content
|
||
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||
delta_content["reasoning_content"] = delta.reasoning_content
|
||
full_reasoning += delta.reasoning_content
|
||
|
||
if delta_content:
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": kwargs["model"],
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": delta_content,
|
||
"finish_reason": None,
|
||
}
|
||
],
|
||
}
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
|
||
finish = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": kwargs["model"],
|
||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||
}
|
||
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
|
||
# 打印流式响应结果
|
||
logger.info(f"[{provider_name}] 流式响应完成:")
|
||
logger.info(f" - chunks: {chunk_count}")
|
||
logger.info(f" - content_length: {len(full_content)} 字符")
|
||
if full_reasoning:
|
||
logger.info(f" - reasoning_length: {len(full_reasoning)} 字符")
|
||
logger.info(
|
||
f" - content_preview: {full_content[:200]}..."
|
||
if len(full_content) > 200
|
||
else f" - content: {full_content}"
|
||
)
|
||
|
||
return StreamingResponse(generator(), media_type="text/event-stream")
|
||
|
||
def _sync_chat(
|
||
self, client, kwargs: Dict, extra_body: Optional[Dict] = None
|
||
) -> JSONResponse:
|
||
"""非流式聊天"""
|
||
from utils.helpers import generate_unique_id, get_current_timestamp
|
||
|
||
resp = client.chat.completions.create(**kwargs)
|
||
|
||
message = resp.choices[0].message
|
||
content = message.content or ""
|
||
response = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion",
|
||
"created": get_current_timestamp(),
|
||
"model": kwargs["model"],
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {
|
||
"role": message.role,
|
||
"content": content,
|
||
},
|
||
"finish_reason": resp.choices[0].finish_reason,
|
||
}
|
||
],
|
||
}
|
||
|
||
# 添加推理内容(如有)
|
||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||
response["choices"][0]["message"][
|
||
"reasoning_content"
|
||
] = message.reasoning_content
|
||
|
||
if resp.usage:
|
||
response["usage"] = {
|
||
"prompt_tokens": resp.usage.prompt_tokens,
|
||
"completion_tokens": resp.usage.completion_tokens,
|
||
"total_tokens": resp.usage.total_tokens,
|
||
}
|
||
|
||
# 打印响应结果
|
||
provider_name = self._provider_type.upper()
|
||
logger.info(f"[{provider_name}] 响应结果:")
|
||
logger.info(f" - content_length: {len(content)} 字符")
|
||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||
logger.info(f" - reasoning_length: {len(message.reasoning_content)} 字符")
|
||
logger.info(
|
||
f" - content_preview: {content[:200]}..."
|
||
if len(content) > 200
|
||
else f" - content: {content}"
|
||
)
|
||
if resp.usage:
|
||
logger.info(f" - usage: {response['usage']}")
|
||
|
||
return JSONResponse(content=response)
|
||
|
||
|
||
class DeepseekAdapter(OpenAIAdapter):
|
||
"""Deepseek 平台适配器(继承 OpenAI 适配器)"""
|
||
|
||
def __init__(self):
|
||
super().__init__(provider_type="deepseek")
|