119 lines
3.4 KiB
Python
119 lines
3.4 KiB
Python
"""
|
|
OpenAI 兼容 API 网关
|
|
提供统一的 /v1/chat/completions 和 /v1/models 端点
|
|
"""
|
|
|
|
from typing import Any, Dict
|
|
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from adapters import get_adapter, get_provider_from_model
|
|
from adapters.base import ChatCompletionRequest
|
|
from core import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
router = APIRouter(tags=["OpenAI Compatible API"])
|
|
|
|
|
|
@router.post("/chat/completions")
|
|
async def chat_completions(request: Request):
|
|
"""
|
|
OpenAI 兼容的聊天补全接口
|
|
|
|
根据请求中的 model 字段自动路由到对应的平台适配器:
|
|
- glm-* → 智谱 GLM
|
|
- qwen-* → 阿里云百炼
|
|
- gpt-* / o1-* / o3-* → OpenAI
|
|
- deepseek-* → Deepseek
|
|
"""
|
|
try:
|
|
body = await request.json()
|
|
except Exception:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
|
|
|
# 创建请求对象
|
|
chat_request = ChatCompletionRequest.from_dict(body)
|
|
model = chat_request.model
|
|
|
|
# 根据模型名称确定平台
|
|
provider = get_provider_from_model(model)
|
|
logger.info(f"[Gateway] model={model} → provider={provider}")
|
|
|
|
# 获取对应平台的适配器
|
|
adapter = get_adapter(provider)
|
|
if adapter is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported model: {model} (provider: {provider})",
|
|
)
|
|
|
|
# 检查适配器是否可用
|
|
if not adapter.is_available():
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=f"Provider '{provider}' is not available (API key not configured)",
|
|
)
|
|
|
|
# 调用适配器处理请求
|
|
return await adapter.chat(chat_request)
|
|
|
|
|
|
@router.get("/models")
|
|
async def list_models():
|
|
"""
|
|
返回所有可用平台的模型列表
|
|
|
|
聚合所有已配置 API Key 的平台模型
|
|
"""
|
|
from adapters import get_all_adapters
|
|
|
|
all_models = []
|
|
|
|
for provider, adapter in get_all_adapters().items():
|
|
if adapter.is_available():
|
|
models = adapter.list_models()
|
|
all_models.extend([m.to_dict() for m in models])
|
|
|
|
return {
|
|
"object": "list",
|
|
"data": all_models,
|
|
}
|
|
|
|
|
|
@router.get("/models/{model_id}")
|
|
async def get_model(model_id: str):
|
|
"""
|
|
获取特定模型信息
|
|
"""
|
|
from adapters import get_all_adapters
|
|
|
|
for provider, adapter in get_all_adapters().items():
|
|
if adapter.is_available():
|
|
for model in adapter.list_models():
|
|
if model.id == model_id:
|
|
return {
|
|
"object": "model",
|
|
"id": model.id,
|
|
"owned_by": model.provider,
|
|
"data": model.to_dict(),
|
|
}
|
|
|
|
raise HTTPException(status_code=404, detail=f"Model not found: {model_id}")
|
|
|
|
|
|
# 初始化时注册适配器
|
|
def init_adapters():
|
|
"""注册所有适配器"""
|
|
from adapters import register_adapter
|
|
from adapters.dashscope_adapter import DashScopeAdapter
|
|
from adapters.glm_adapter import GLMAdapter
|
|
from adapters.openai_adapter import DeepseekAdapter, OpenAIAdapter
|
|
|
|
register_adapter("glm", GLMAdapter)
|
|
register_adapter("dashscope", DashScopeAdapter)
|
|
register_adapter("openai", OpenAIAdapter)
|
|
register_adapter("deepseek", DeepseekAdapter)
|
|
|
|
logger.info("[Gateway] Adapters registered: glm, dashscope, openai, deepseek") |