343 lines
11 KiB
Python
343 lines
11 KiB
Python
"""
|
||
Python Flask/FastAPI 服务器实现,用于替代 Node.js 服务器
|
||
使用 DashScope Python SDK 连接阿里云百炼平台 API
|
||
"""
|
||
import os
|
||
import json
|
||
import uuid
|
||
import asyncio
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional
|
||
from pathlib import Path
|
||
|
||
import dashscope
|
||
from dashscope import Generation
|
||
from dotenv import load_dotenv
|
||
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
from pydantic import BaseModel, Field
|
||
import uvicorn
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
# 设置 DashScope API 密钥
|
||
dashscope.api_key = os.getenv("ALIYUN_API_KEY")
|
||
|
||
# 创建 FastAPI 应用
|
||
app = FastAPI(title="AI Chat API Server", version="1.0.0")
|
||
|
||
# 数据模型定义
|
||
class ChatMessage(BaseModel):
|
||
role: str
|
||
content: str
|
||
images: Optional[List[str]] = None
|
||
files: Optional[List[str]] = None
|
||
|
||
class ChatRequest(BaseModel):
|
||
conversationId: Optional[str] = None
|
||
message: str
|
||
images: Optional[List[str]] = None
|
||
files: Optional[List[str]] = None
|
||
model: Optional[str] = "qwen-plus"
|
||
temperature: Optional[float] = 0.7
|
||
maxTokens: Optional[int] = 2000
|
||
systemPrompt: Optional[str] = "你是一个支持视觉理解的助手。"
|
||
stream: Optional[bool] = True
|
||
# 扩展选项
|
||
deepSearch: Optional[bool] = False
|
||
webSearch: Optional[bool] = False
|
||
deepThinking: Optional[bool] = False
|
||
|
||
class ModelInfo(BaseModel):
|
||
id: str
|
||
name: str
|
||
description: str
|
||
maxTokens: int
|
||
provider: str
|
||
|
||
# 模拟数据库 - 实际应用中应使用持久化存储
|
||
conversations_db: Dict[str, dict] = {}
|
||
|
||
# 配置上传目录
|
||
upload_dir = Path("uploads")
|
||
upload_dir.mkdir(exist_ok=True)
|
||
|
||
@app.middleware("http")
|
||
async def add_process_time_header(request, call_next):
|
||
"""中间件:记录请求处理时间"""
|
||
start_time = datetime.utcnow()
|
||
response = await call_next(request)
|
||
|
||
# 计算处理时间
|
||
process_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||
|
||
# 在响应头中添加处理时间
|
||
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
|
||
|
||
# 记录请求信息
|
||
print(f"HTTP {request.method} {request.url.path} {response.status_code} {process_time:.2f}ms")
|
||
|
||
return response
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查端点"""
|
||
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
|
||
|
||
@app.post("/api/chat-ui/chat")
|
||
async def chat_endpoint(request: ChatRequest):
|
||
"""聊天接口 - 处理普通请求"""
|
||
try:
|
||
# 构建消息数组,考虑是否包含图片
|
||
user_content = []
|
||
|
||
# 添加用户消息文本
|
||
user_content.append({"type": "text", "text": request.message})
|
||
|
||
# 如果有图片,则添加到内容中
|
||
if request.images and len(request.images) > 0:
|
||
for image_url in request.images:
|
||
user_content.append({
|
||
"type": "image_url",
|
||
"image_url": image_url
|
||
})
|
||
|
||
# 构建请求给百炼的消息列表
|
||
messages = [
|
||
{"role": "system", "content": request.systemPrompt},
|
||
{"role": "user", "content": user_content}
|
||
]
|
||
|
||
# 调用 DashScope API
|
||
response = Generation.call(
|
||
model=request.model,
|
||
messages=messages,
|
||
stream=False, # 非流式响应
|
||
max_tokens=request.maxTokens,
|
||
temperature=request.temperature
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
content = response.output.choices[0]['message']['content']
|
||
|
||
# 构建响应
|
||
result = {
|
||
"id": str(uuid.uuid4()),
|
||
"conversationId": request.conversationId or str(uuid.uuid4()),
|
||
"content": content,
|
||
"model": request.model,
|
||
"createdAt": int(datetime.utcnow().timestamp())
|
||
}
|
||
|
||
if hasattr(response, 'usage'):
|
||
result["usage"] = {
|
||
"promptTokens": response.usage.input_tokens,
|
||
"completionTokens": response.usage.output_tokens,
|
||
"totalTokens": response.usage.total_tokens
|
||
}
|
||
|
||
return JSONResponse(content=result)
|
||
else:
|
||
raise HTTPException(status_code=500, detail=f"API Error: {response.code} - {response.message}")
|
||
|
||
except Exception as e:
|
||
print(f"Error in chat endpoint: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@app.post("/api/chat-ui/chat/stream")
|
||
async def chat_stream_endpoint(request: ChatRequest):
|
||
"""流式聊天接口 - 处理流式请求"""
|
||
async def event_generator():
|
||
try:
|
||
# 构建消息数组,考虑是否包含图片
|
||
user_content = []
|
||
|
||
# 添加用户消息文本
|
||
user_content.append({"type": "text", "text": request.message})
|
||
|
||
# 如果有图片,则添加到内容中
|
||
if request.images and len(request.images) > 0:
|
||
for image_url in request.images:
|
||
user_content.append({
|
||
"type": "image_url",
|
||
"image_url": image_url
|
||
})
|
||
|
||
# 构建请求给百炼的消息列表
|
||
messages = [
|
||
{"role": "system", "content": request.systemPrompt},
|
||
{"role": "user", "content": user_content}
|
||
]
|
||
|
||
# 调用 DashScope API(流式)
|
||
responses = Generation.call(
|
||
model=request.model,
|
||
messages=messages,
|
||
stream=True, # 流式响应
|
||
max_tokens=request.maxTokens,
|
||
temperature=request.temperature
|
||
)
|
||
|
||
for response in responses:
|
||
if response.status_code == 200:
|
||
content = response.output.choices[0]['message']['content']
|
||
|
||
if content:
|
||
# 发送流式数据
|
||
data = {
|
||
"choices": [
|
||
{
|
||
"delta": {"content": content},
|
||
"index": 0,
|
||
"finish_reason": None
|
||
}
|
||
]
|
||
}
|
||
|
||
yield f"data: {json.dumps(data)}\n\n"
|
||
else:
|
||
error_data = {
|
||
"error": {
|
||
"message": f"API Error: {response.code} - {response.message}",
|
||
"type": "api_error",
|
||
"param": None,
|
||
"code": response.code
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data)}\n\n"
|
||
break
|
||
|
||
# 发送结束信号
|
||
yield "data: [DONE]\n\n"
|
||
|
||
except Exception as e:
|
||
error_data = {
|
||
"error": {
|
||
"message": str(e),
|
||
"type": "server_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data)}\n\n"
|
||
|
||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||
|
||
@app.get("/api/chat-ui/models")
|
||
async def get_models():
|
||
"""获取模型列表"""
|
||
models = [
|
||
ModelInfo(
|
||
id="qwen-max",
|
||
name="通义千问 Max",
|
||
description="最强大的模型",
|
||
maxTokens=8192,
|
||
provider="Aliyun"
|
||
),
|
||
ModelInfo(
|
||
id="qwen-plus",
|
||
name="通义千问 Plus",
|
||
description="能力均衡",
|
||
maxTokens=8192,
|
||
provider="Aliyun"
|
||
)
|
||
]
|
||
return [model.dict() for model in models]
|
||
|
||
@app.get("/api/chat-ui/conversations")
|
||
async def get_conversations():
|
||
"""获取所有对话"""
|
||
return list(conversations_db.values())
|
||
|
||
@app.get("/api/chat-ui/conversations/{conversation_id}")
|
||
async def get_conversation(conversation_id: str):
|
||
"""获取特定对话"""
|
||
conversation = conversations_db.get(conversation_id)
|
||
if not conversation:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
return conversation
|
||
|
||
@app.post("/api/chat-ui/conversations")
|
||
async def save_conversation(
|
||
id: str = Form(None),
|
||
title: str = Form(...),
|
||
messages: str = Form(...)
|
||
):
|
||
"""保存或更新对话"""
|
||
# 解析 messages JSON 字符串
|
||
try:
|
||
parsed_messages = json.loads(messages)
|
||
except json.JSONDecodeError:
|
||
raise HTTPException(status_code=400, detail="Invalid messages JSON")
|
||
|
||
conversation_id = id or str(uuid.uuid4())
|
||
conversation = {
|
||
"id": conversation_id,
|
||
"title": title,
|
||
"messages": parsed_messages,
|
||
"updatedAt": datetime.utcnow().isoformat()
|
||
}
|
||
|
||
conversations_db[conversation_id] = conversation
|
||
return conversation
|
||
|
||
@app.delete("/api/chat-ui/conversations/{conversation_id}")
|
||
async def delete_conversation(conversation_id: str):
|
||
"""删除对话"""
|
||
if conversation_id in conversations_db:
|
||
del conversations_db[conversation_id]
|
||
return {"success": True, "message": "删除成功"}
|
||
else:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
@app.post("/api/chat-ui/upload")
|
||
async def upload_file(file: UploadFile = File(...)):
|
||
"""文件上传接口"""
|
||
try:
|
||
# 生成唯一文件名
|
||
file_extension = Path(file.filename).suffix
|
||
unique_filename = f"{int(datetime.utcnow().timestamp())}-{uuid.uuid4()}{file_extension}"
|
||
file_path = upload_dir / unique_filename
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as f:
|
||
content = await file.read()
|
||
f.write(content)
|
||
|
||
# 返回文件信息
|
||
file_url = f"http://localhost:8000/uploads/{unique_filename}"
|
||
return {
|
||
"url": file_url,
|
||
"name": file.filename,
|
||
"size": len(content),
|
||
"mimeType": file.content_type
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Upload error: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||
|
||
@app.get("/uploads/{filename}")
|
||
async def serve_upload(filename: str):
|
||
"""提供上传文件的访问"""
|
||
file_path = upload_dir / filename
|
||
if not file_path.exists():
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
from fastapi.responses import FileResponse
|
||
return FileResponse(file_path)
|
||
|
||
@app.post("/api/chat-ui/stop")
|
||
async def stop_generation():
|
||
"""停止生成接口"""
|
||
# 在实际实现中,这里可能需要维护正在运行的任务ID列表
|
||
# 目前只是返回成功消息
|
||
return {"success": True, "message": "已发出停止指令"}
|
||
|
||
@app.post("/api/chat-ui/stop/{message_id}")
|
||
async def stop_generation_by_id(message_id: str):
|
||
"""根据消息ID停止生成"""
|
||
return {"success": True, "message": "已发出停止指令"}
|
||
|
||
if __name__ == "__main__":
|
||
port = int(os.getenv("PORT", 8000))
|
||
uvicorn.run(app, host="0.0.0.0", port=port) |