ai-chat-ui/server_python/main.py

343 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Python Flask/FastAPI 服务器实现,用于替代 Node.js 服务器
使用 DashScope Python SDK 连接阿里云百炼平台 API
"""
import os
import json
import uuid
import asyncio
from datetime import datetime
from typing import Dict, List, Optional
from pathlib import Path
import dashscope
from dashscope import Generation
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
import uvicorn
# 加载环境变量
load_dotenv()
# 设置 DashScope API 密钥
dashscope.api_key = os.getenv("ALIYUN_API_KEY")
# 创建 FastAPI 应用
app = FastAPI(title="AI Chat API Server", version="1.0.0")
# 数据模型定义
class ChatMessage(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
files: Optional[List[str]] = None
class ChatRequest(BaseModel):
conversationId: Optional[str] = None
message: str
images: Optional[List[str]] = None
files: Optional[List[str]] = None
model: Optional[str] = "qwen-plus"
temperature: Optional[float] = 0.7
maxTokens: Optional[int] = 2000
systemPrompt: Optional[str] = "你是一个支持视觉理解的助手。"
stream: Optional[bool] = True
# 扩展选项
deepSearch: Optional[bool] = False
webSearch: Optional[bool] = False
deepThinking: Optional[bool] = False
class ModelInfo(BaseModel):
id: str
name: str
description: str
maxTokens: int
provider: str
# 模拟数据库 - 实际应用中应使用持久化存储
conversations_db: Dict[str, dict] = {}
# 配置上传目录
upload_dir = Path("uploads")
upload_dir.mkdir(exist_ok=True)
@app.middleware("http")
async def add_process_time_header(request, call_next):
"""中间件:记录请求处理时间"""
start_time = datetime.utcnow()
response = await call_next(request)
# 计算处理时间
process_time = (datetime.utcnow() - start_time).total_seconds() * 1000
# 在响应头中添加处理时间
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
# 记录请求信息
print(f"HTTP {request.method} {request.url.path} {response.status_code} {process_time:.2f}ms")
return response
@app.get("/health")
async def health_check():
"""健康检查端点"""
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
@app.post("/api/chat-ui/chat")
async def chat_endpoint(request: ChatRequest):
"""聊天接口 - 处理普通请求"""
try:
# 构建消息数组,考虑是否包含图片
user_content = []
# 添加用户消息文本
user_content.append({"type": "text", "text": request.message})
# 如果有图片,则添加到内容中
if request.images and len(request.images) > 0:
for image_url in request.images:
user_content.append({
"type": "image_url",
"image_url": image_url
})
# 构建请求给百炼的消息列表
messages = [
{"role": "system", "content": request.systemPrompt},
{"role": "user", "content": user_content}
]
# 调用 DashScope API
response = Generation.call(
model=request.model,
messages=messages,
stream=False, # 非流式响应
max_tokens=request.maxTokens,
temperature=request.temperature
)
if response.status_code == 200:
content = response.output.choices[0]['message']['content']
# 构建响应
result = {
"id": str(uuid.uuid4()),
"conversationId": request.conversationId or str(uuid.uuid4()),
"content": content,
"model": request.model,
"createdAt": int(datetime.utcnow().timestamp())
}
if hasattr(response, 'usage'):
result["usage"] = {
"promptTokens": response.usage.input_tokens,
"completionTokens": response.usage.output_tokens,
"totalTokens": response.usage.total_tokens
}
return JSONResponse(content=result)
else:
raise HTTPException(status_code=500, detail=f"API Error: {response.code} - {response.message}")
except Exception as e:
print(f"Error in chat endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/chat-ui/chat/stream")
async def chat_stream_endpoint(request: ChatRequest):
"""流式聊天接口 - 处理流式请求"""
async def event_generator():
try:
# 构建消息数组,考虑是否包含图片
user_content = []
# 添加用户消息文本
user_content.append({"type": "text", "text": request.message})
# 如果有图片,则添加到内容中
if request.images and len(request.images) > 0:
for image_url in request.images:
user_content.append({
"type": "image_url",
"image_url": image_url
})
# 构建请求给百炼的消息列表
messages = [
{"role": "system", "content": request.systemPrompt},
{"role": "user", "content": user_content}
]
# 调用 DashScope API流式
responses = Generation.call(
model=request.model,
messages=messages,
stream=True, # 流式响应
max_tokens=request.maxTokens,
temperature=request.temperature
)
for response in responses:
if response.status_code == 200:
content = response.output.choices[0]['message']['content']
if content:
# 发送流式数据
data = {
"choices": [
{
"delta": {"content": content},
"index": 0,
"finish_reason": None
}
]
}
yield f"data: {json.dumps(data)}\n\n"
else:
error_data = {
"error": {
"message": f"API Error: {response.code} - {response.message}",
"type": "api_error",
"param": None,
"code": response.code
}
}
yield f"data: {json.dumps(error_data)}\n\n"
break
# 发送结束信号
yield "data: [DONE]\n\n"
except Exception as e:
error_data = {
"error": {
"message": str(e),
"type": "server_error"
}
}
yield f"data: {json.dumps(error_data)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@app.get("/api/chat-ui/models")
async def get_models():
"""获取模型列表"""
models = [
ModelInfo(
id="qwen-max",
name="通义千问 Max",
description="最强大的模型",
maxTokens=8192,
provider="Aliyun"
),
ModelInfo(
id="qwen-plus",
name="通义千问 Plus",
description="能力均衡",
maxTokens=8192,
provider="Aliyun"
)
]
return [model.dict() for model in models]
@app.get("/api/chat-ui/conversations")
async def get_conversations():
"""获取所有对话"""
return list(conversations_db.values())
@app.get("/api/chat-ui/conversations/{conversation_id}")
async def get_conversation(conversation_id: str):
"""获取特定对话"""
conversation = conversations_db.get(conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="对话不存在")
return conversation
@app.post("/api/chat-ui/conversations")
async def save_conversation(
id: str = Form(None),
title: str = Form(...),
messages: str = Form(...)
):
"""保存或更新对话"""
# 解析 messages JSON 字符串
try:
parsed_messages = json.loads(messages)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid messages JSON")
conversation_id = id or str(uuid.uuid4())
conversation = {
"id": conversation_id,
"title": title,
"messages": parsed_messages,
"updatedAt": datetime.utcnow().isoformat()
}
conversations_db[conversation_id] = conversation
return conversation
@app.delete("/api/chat-ui/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str):
"""删除对话"""
if conversation_id in conversations_db:
del conversations_db[conversation_id]
return {"success": True, "message": "删除成功"}
else:
raise HTTPException(status_code=404, detail="对话不存在")
@app.post("/api/chat-ui/upload")
async def upload_file(file: UploadFile = File(...)):
"""文件上传接口"""
try:
# 生成唯一文件名
file_extension = Path(file.filename).suffix
unique_filename = f"{int(datetime.utcnow().timestamp())}-{uuid.uuid4()}{file_extension}"
file_path = upload_dir / unique_filename
# 保存文件
with open(file_path, "wb") as f:
content = await file.read()
f.write(content)
# 返回文件信息
file_url = f"http://localhost:8000/uploads/{unique_filename}"
return {
"url": file_url,
"name": file.filename,
"size": len(content),
"mimeType": file.content_type
}
except Exception as e:
print(f"Upload error: {str(e)}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@app.get("/uploads/{filename}")
async def serve_upload(filename: str):
"""提供上传文件的访问"""
file_path = upload_dir / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="文件不存在")
from fastapi.responses import FileResponse
return FileResponse(file_path)
@app.post("/api/chat-ui/stop")
async def stop_generation():
"""停止生成接口"""
# 在实际实现中这里可能需要维护正在运行的任务ID列表
# 目前只是返回成功消息
return {"success": True, "message": "已发出停止指令"}
@app.post("/api/chat-ui/stop/{message_id}")
async def stop_generation_by_id(message_id: str):
"""根据消息ID停止生成"""
return {"success": True, "message": "已发出停止指令"}
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)