ai-chat-ui/server/main.py

246 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Chat API Server — 主入口(纯基础设施层)
职责:
- 注入运行时依赖venv site-packages
- 支持 OpenAI 兼容 API 网关(/v1/*)和多平台路由
- 保留向后兼容的 /api/chat-ui/* 路由
平台适配器位置:
- adapters/glm_adapter.py → 智谱 GLM
- adapters/dashscope_adapter.py → 阿里云百炼
- adapters/openai_adapter.py → OpenAI / Deepseek
API 端点:
- POST /v1/chat/completions → OpenAI 兼容网关(根据 model 自动路由)
- GET /v1/models → 所有可用模型列表
- POST /api/chat-ui/chat → 传统聊天接口(保持兼容)
"""
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
# ── 注入 venv site-packages兼容 start.sh 用系统 python3 启动)────────
# 必须在所有第三方 import 前执行
_venv_lib = Path(__file__).parent / ".venv" / "lib"
if _venv_lib.exists():
for _sp in sorted(_venv_lib.glob("python*/site-packages"), reverse=True):
if _sp.exists() and str(_sp) not in sys.path:
sys.path.insert(0, str(_sp))
print(f"[启动] venv 注入:{_sp}")
break
# ── 第三方导入 ────────────────────────────────────────────────────────
from dotenv import load_dotenv
from fastapi import FastAPI, File, Request, UploadFile
from fastapi.responses import JSONResponse
sys.path.append("/home/mt/project/ai-chat-ui/server")
# ── 工具/日志(与平台无关)───────────────────────────────────────────
from utils.helpers import log_response
from utils.logger import get_logger
logger = get_logger()
# ── 加载环境变量 ──────────────────────────────────────────────────────
load_dotenv()
LLM_BACKEND = os.getenv("LLM_BACKEND", "dashscope").lower().strip()
if LLM_BACKEND not in {"dashscope", "glm"}:
logger.warning(f"未知的 LLM_BACKEND='{LLM_BACKEND}',回退到 dashscope")
LLM_BACKEND = "dashscope"
# ── 动态加载平台模块 ──────────────────────────────────────────────────
if LLM_BACKEND == "glm":
import api.chat_routes_glm as _platform
else:
import api.chat_routes as _platform
_platform.init() # 各平台自行完成初始化API Key 校验等)
# 通用路由处理器(文件上传、会话管理等,与平台无关,统一用百炼路由中的实现)
from api.chat_routes import (delete_conversation_handler,
get_conversation_handler,
get_conversations_handler,
save_conversation_handler, serve_upload_handler,
stop_generation_handler, upload_file_handler)
# ── OpenAI 兼容网关初始化 ───────────────────────────────────────────────
from api.openai_gateway import init_adapters, router as openai_router
init_adapters()
# ── FastAPI 应用 ──────────────────────────────────────────────────────
app = FastAPI(
title="AI Chat API Gateway",
version="4.0.0",
description="统一 OpenAI 兼容 API 网关,支持多平台模型",
)
# 注册 OpenAI 兼容路由
app.include_router(openai_router, prefix="/v1")
@app.middleware("http")
async def logging_middleware(request: Request, call_next):
start_time = datetime.now(timezone.utc)
client_ip = request.client.host if request.client else "unknown"
logger.info(f"{request.method} {request.url.path} | IP: {client_ip}")
response = await call_next(request)
ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
icon = "" if response.status_code < 400 else ""
logger.info(
f"{icon} {request.method} {request.url.path} | 状态: {response.status_code} | 耗时: {ms:.0f}ms"
)
log_response(response.status_code, ms)
response.headers["X-Process-Time"] = f"{ms:.2f}ms"
return response
# ── 路由注册 ──────────────────────────────────────────────────────────
@app.get("/health")
async def health_check():
from config import get_available_providers
return {
"status": "healthy",
"version": "4.0.0",
"default_backend": LLM_BACKEND,
"available_providers": get_available_providers(),
"endpoints": {
"openai_compatible": "/v1/chat/completions",
"legacy": "/api/chat-ui/chat",
"models": "/v1/models",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
}
@app.post("/api/chat-ui/chat")
async def chat_endpoint(request: Request):
"""聊天接口(根据 model 自动路由到对应平台)"""
from adapters import get_adapter, get_provider_from_model
from adapters.base import ChatCompletionRequest
try:
body = await request.json()
except Exception:
return JSONResponse({"error": "Invalid JSON body"}, status_code=400)
# 创建请求对象
chat_request = ChatCompletionRequest.from_dict(body)
model = chat_request.model
# 根据模型名称确定平台
provider = get_provider_from_model(model)
logger.info(f"[Legacy API] model={model} → provider={provider}")
# 获取对应平台的适配器
adapter = get_adapter(provider)
if adapter is None:
return JSONResponse(
{"error": f"Unsupported model: {model} (provider: {provider})"},
status_code=400,
)
# 检查适配器是否可用
if not adapter.is_available():
return JSONResponse(
{"error": f"Provider '{provider}' is not available (API key not configured)"},
status_code=503,
)
# 调用适配器处理请求
return await adapter.chat(chat_request)
@app.get("/api/chat-ui/models")
async def get_models():
"""模型列表(聚合所有可用平台的模型)"""
from adapters import get_all_adapters
all_models = []
for provider, adapter in get_all_adapters().items():
if adapter.is_available():
models = adapter.list_models()
all_models.extend([m.to_dict() for m in models])
return {"object": "list", "data": all_models}
# ── 通用路由(与平台无关)────────────────────────────────────────────
@app.get("/api/chat-ui/conversations")
async def get_conversations():
return await get_conversations_handler()
@app.get("/api/chat-ui/conversations/{conversation_id}")
async def get_conversation(conversation_id: str):
return await get_conversation_handler(conversation_id)
@app.post("/api/chat-ui/conversations")
async def save_conversation(request: Request):
return await save_conversation_handler(await request.json())
@app.delete("/api/chat-ui/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str):
return await delete_conversation_handler(conversation_id)
@app.post("/api/chat-ui/upload")
async def upload_file(file: UploadFile = File(...)):
return await upload_file_handler(file=file)
@app.get("/uploads/{filename}")
async def serve_upload(filename: str):
return serve_upload_handler(filename)
@app.post("/api/chat-ui/stop")
async def stop_generation():
return await stop_generation_handler()
@app.post("/api/chat-ui/stop/{message_id}")
async def stop_generation_by_id(message_id: str):
return await stop_generation_handler(message_id)
# ── 程序入口 ──────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 8000))
# 获取可用平台
from config import get_available_providers
available = get_available_providers()
print("=" * 60)
print(" AI Chat API Gateway v4.0")
print("=" * 60)
print(f" OpenAI 兼容端点: http://localhost:{port}/v1/chat/completions")
print(f" 模型列表 : http://localhost:{port}/v1/models")
print("-" * 60)
print(f" 可用平台 : {', '.join(available) or '无(请配置 API Key'}")
print(f" 默认平台 : {LLM_BACKEND} (向后兼容模式)")
print("-" * 60)
print(" 使用方法:")
print(" curl -X POST http://localhost:8000/v1/chat/completions \\")
print(' -H "Content-Type: application/json" \\')
print(' -d \'{"model":"glm-4-flash","messages":[{"role":"user","content":"hi"}]}\'')
print("=" * 60)
uvicorn.run(app, host="0.0.0.0", port=port)