ai-chat-ui/server/adapters/openai_adapter.py

343 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
OpenAI 适配器
支持 OpenAI 及兼容 API如 Deepseek
"""
import json
import os
from typing import Dict, List, Optional
from fastapi.responses import JSONResponse, StreamingResponse
from .base import BaseAdapter, ChatCompletionRequest, ModelInfo
from utils.logger import get_logger
logger = get_logger()
# OpenAI 模型配置
OPENAI_MODELS = [
ModelInfo(
id="gpt-4o",
name="GPT-4o",
description="最新旗舰多模态模型",
max_tokens=128000,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=True,
supports_files=True,
),
ModelInfo(
id="gpt-4o-mini",
name="GPT-4o Mini",
description="高性价比多模态模型",
max_tokens=128000,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=True,
supports_files=True,
),
ModelInfo(
id="gpt-4-turbo",
name="GPT-4 Turbo",
description="GPT-4 增强版",
max_tokens=128000,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=True,
supports_files=False,
),
ModelInfo(
id="gpt-3.5-turbo",
name="GPT-3.5 Turbo",
description="快速经济的选择",
max_tokens=16385,
provider="OpenAI",
supports_thinking=False,
supports_web_search=True,
supports_vision=False,
supports_files=False,
),
]
# Deepseek 模型配置
DEEPSEEK_MODELS = [
ModelInfo(
id="deepseek-chat",
name="Deepseek Chat",
description="Deepseek 对话模型",
max_tokens=64000,
provider="Deepseek",
supports_thinking=False,
supports_web_search=False,
supports_vision=False,
supports_files=False,
),
ModelInfo(
id="deepseek-reasoner",
name="Deepseek Reasoner",
description="Deepseek 推理模型(支持深度思考)",
max_tokens=64000,
provider="Deepseek",
supports_thinking=True,
supports_web_search=True, # 注:通过内置检索增强实现
supports_vision=False,
supports_files=False,
),
]
# DeepSeek 支持深度思考的模型
DEEPSEEK_THINKING_MODELS = {"deepseek-reasoner"}
class OpenAIAdapter(BaseAdapter):
"""OpenAI 平台适配器"""
_client = None
_provider_type: str = "openai" # openai 或 deepseek
def __init__(self, provider_type: str = "openai"):
self._provider_type = provider_type
@property
def provider_name(self) -> str:
return self._provider_type
def is_available(self) -> bool:
"""检查 API Key 是否配置"""
if self._provider_type == "deepseek":
return bool(os.getenv("DEEPSEEK_API_KEY"))
return bool(os.getenv("OPENAI_API_KEY"))
def _get_client(self):
"""获取 OpenAI 客户端(懒加载)"""
if self._client is None:
from openai import OpenAI
if self._provider_type == "deepseek":
api_key = os.getenv("DEEPSEEK_API_KEY", "")
base_url = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
else:
api_key = os.getenv("OPENAI_API_KEY", "")
base_url = os.getenv("OPENAI_BASE_URL") # 可选自定义端点
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
self._client = OpenAI(**kwargs)
return self._client
def list_models(self) -> List[ModelInfo]:
if self._provider_type == "deepseek":
return DEEPSEEK_MODELS
return OPENAI_MODELS
async def chat(self, request: ChatCompletionRequest):
"""
处理 OpenAI 聊天请求
直接使用 OpenAI SDK支持流式/非流式
"""
client = self._get_client()
# 打印请求参数
provider_name = self._provider_type.upper()
logger.info(f"[{provider_name}] 请求参数:")
logger.info(f" - model: {request.model}")
logger.info(f" - stream: {request.stream}")
logger.info(f" - temperature: {request.temperature}")
logger.info(f" - max_tokens: {request.max_tokens}")
logger.info(f" - provider_type: {self._provider_type}")
if self._provider_type == "deepseek":
logger.info(f" - deep_thinking: {request.deep_thinking}")
# 构建消息
messages = self._build_messages(request)
logger.info(
f" - messages: {json.dumps(messages, ensure_ascii=False, indent=2)}"
)
# 构建请求参数
kwargs = {
"model": request.model,
"messages": messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"stream": request.stream,
}
# DeepSeek 深度思考支持
extra_body = None
if self._provider_type == "deepseek" and request.deep_thinking:
if self._supports_thinking(request.model):
extra_body = {"thinking": {"type": "enabled"}}
kwargs["extra_body"] = extra_body
logger.info(
f"[{provider_name}] 深度思考已启用: extra_body = {extra_body}"
)
if request.stream:
return self._stream_chat(client, kwargs, extra_body)
else:
return self._sync_chat(client, kwargs, extra_body)
def _supports_thinking(self, model: str) -> bool:
"""检查模型是否支持深度思考"""
return model.lower() in DEEPSEEK_THINKING_MODELS
def _build_messages(self, request: ChatCompletionRequest) -> List[Dict]:
"""构建 OpenAI 格式消息"""
messages = []
for msg in request.messages:
role = msg.get("role", "user")
content = msg.get("content", "")
# OpenAI 直接支持标准格式
if isinstance(content, str):
if content.strip():
messages.append({"role": role, "content": content})
elif isinstance(content, list):
# 多模态内容
openai_content = []
for item in content:
if isinstance(item, dict):
openai_content.append(item)
if openai_content:
messages.append({"role": role, "content": openai_content})
return messages
def _stream_chat(
self, client, kwargs: Dict, extra_body: Optional[Dict] = None
) -> StreamingResponse:
"""流式聊天"""
provider_name = self._provider_type.upper()
logger.info(f"[{provider_name}] 开始流式响应...")
def generator():
from utils.helpers import generate_unique_id, get_current_timestamp
resp = client.chat.completions.create(**kwargs)
full_content = ""
full_reasoning = ""
chunk_count = 0
for chunk in resp:
if chunk.choices:
chunk_count += 1
delta = chunk.choices[0].delta
delta_content = {}
if hasattr(delta, "content") and delta.content:
delta_content["content"] = delta.content
full_content += delta.content
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
delta_content["reasoning_content"] = delta.reasoning_content
full_reasoning += delta.reasoning_content
if delta_content:
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [
{
"index": 0,
"delta": delta_content,
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
finish = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
# 打印流式响应结果
logger.info(f"[{provider_name}] 流式响应完成:")
logger.info(f" - chunks: {chunk_count}")
logger.info(f" - content_length: {len(full_content)} 字符")
if full_reasoning:
logger.info(f" - reasoning_length: {len(full_reasoning)} 字符")
logger.info(
f" - content_preview: {full_content[:200]}..."
if len(full_content) > 200
else f" - content: {full_content}"
)
return StreamingResponse(generator(), media_type="text/event-stream")
def _sync_chat(
self, client, kwargs: Dict, extra_body: Optional[Dict] = None
) -> JSONResponse:
"""非流式聊天"""
from utils.helpers import generate_unique_id, get_current_timestamp
resp = client.chat.completions.create(**kwargs)
message = resp.choices[0].message
content = message.content or ""
response = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion",
"created": get_current_timestamp(),
"model": kwargs["model"],
"choices": [
{
"index": 0,
"message": {
"role": message.role,
"content": content,
},
"finish_reason": resp.choices[0].finish_reason,
}
],
}
# 添加推理内容(如有)
if hasattr(message, "reasoning_content") and message.reasoning_content:
response["choices"][0]["message"][
"reasoning_content"
] = message.reasoning_content
if resp.usage:
response["usage"] = {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens,
}
# 打印响应结果
provider_name = self._provider_type.upper()
logger.info(f"[{provider_name}] 响应结果:")
logger.info(f" - content_length: {len(content)} 字符")
if hasattr(message, "reasoning_content") and message.reasoning_content:
logger.info(f" - reasoning_length: {len(message.reasoning_content)} 字符")
logger.info(
f" - content_preview: {content[:200]}..."
if len(content) > 200
else f" - content: {content}"
)
if resp.usage:
logger.info(f" - usage: {response['usage']}")
return JSONResponse(content=response)
class DeepseekAdapter(OpenAIAdapter):
"""Deepseek 平台适配器(继承 OpenAI 适配器)"""
def __init__(self):
super().__init__(provider_type="deepseek")