ai-chat-ui/server/api/chat_routes.py

650 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API 路由定义
"""
import os
import json
import uuid
from datetime import datetime
from typing import Dict, List
from pathlib import Path
from fastapi import HTTPException, File, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
import dashscope
from dashscope import Generation, MultiModalConversation
# 导入模型和工具函数(使用绝对路径)
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from models.chat_models import ChatRequest, ModelInfo
from utils.helpers import (
get_current_timestamp,
generate_unique_id,
format_api_response,
extract_delta_content
)
# 模拟数据库 - 实际应用中应使用持久化存储
conversations_db: Dict[str, dict] = {}
# 配置上传目录
upload_dir = Path("uploads")
upload_dir.mkdir(exist_ok=True)
async def chat_endpoint_handler(body: dict):
"""
聊天接口处理器 - 与阿里云百炼API兼容的接口
这个端点会接收前端的聊天请求并转发到阿里云百炼API
"""
try:
# 确保 body 是字典类型
if not isinstance(body, dict):
print(f"[ERROR] Request body is not a dictionary: {type(body)}")
raise HTTPException(
status_code=400,
detail=f"Request body must be a JSON object, got {type(body).__name__}: {body}"
)
# 检查请求格式并适配
# 如果是OpenAI兼容格式 (来自streamChat)
if 'messages' in body:
messages = body.get('messages', [])
model = body.get('model', 'qwen-plus')
stream = body.get('stream', True)
temperature = body.get('temperature', 0.7)
max_tokens = body.get('max_tokens', 2000)
else:
# 否则是前端简化格式 (来自chat函数)
# 需要将其转换为OpenAI兼容格式
message_text = body.get('message', '')
# 检查message是否已经是格式化的列表带图片的情况
if isinstance(message_text, list):
user_content = message_text
else:
# 如果是字符串,转换为标准格式
user_content = [{"type": "text", "text": message_text}]
messages = [
{"role": "system", "content": body.get('systemPrompt', '你是一个支持视觉理解的助手。')},
{"role": "user", "content": user_content}
]
model = body.get('model', 'qwen-plus')
stream = body.get('stream', False) # 默认为非流式
temperature = body.get('temperature', 0.7)
max_tokens = body.get('maxTokens', 2000)
# 检查是否包含图像内容如果是多模态请求使用MultiModalConversation
has_images = any(
isinstance(msg, dict) and
isinstance(msg.get('content'), list) and
any(isinstance(item, dict) and item.get('type') == 'image_url' for item in msg.get('content', []))
for msg in messages if isinstance(msg, dict)
)
if has_images:
# 使用多模态API处理图像
return await multimodal_chat_handler(messages, model, stream, temperature, max_tokens)
else:
# 使用常规聊天API
if stream:
# 流式响应
async def event_generator():
try:
responses = Generation.call(
model=model,
messages=messages,
stream=True,
max_tokens=max_tokens,
temperature=temperature
)
full_content = "" # 用于累计完整内容
for idx, response in enumerate(responses):
if response.status_code == 200:
# 检查响应是否包含预期的内容
# DashScope API的响应结构可能是 output.choices 或 output.text
content = None
# 尝试从 output.choices 获取内容
if (hasattr(response, 'output') and
response.output and
hasattr(response.output, 'choices') and
response.output.choices is not None and
len(response.output.choices) > 0 and
'message' in response.output.choices[0] and
'content' in response.output.choices[0]['message']):
content = response.output.choices[0]['message']['content']
# 只有当内容发生变化时才发送增量
if len(content) > len(full_content):
delta_content = extract_delta_content(content, full_content)
full_content = content
if delta_content.strip(): # 只有当有非空白新内容时才发送
# 构建 SSE 数据块
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": delta_content},
"finish_reason": None
}
]
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
# 否则尝试从 output.text 获取内容DashScope特定格式
elif (hasattr(response, 'output') and
response.output and
'text' in response.output):
content = response.output.get('text')
# 只有当内容发生变化时才发送增量
if len(content) > len(full_content):
delta_content = extract_delta_content(content, full_content)
full_content = content
if delta_content.strip(): # 只有当有非空白新内容时才发送
# 构建 SSE 数据块
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": delta_content},
"finish_reason": None
}
]
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
else:
# 错误处理
error_data = {
"error": {
"message": f"API Error: {response.code} - {response.message}",
"type": "api_error",
"param": None,
"code": response.code
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
break
# 发送结束信号
finish_data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": model,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(finish_data, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
error_data = {
"error": {
"message": str(e),
"type": "server_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
else:
# 非流式响应
response = Generation.call(
model=model,
messages=messages,
stream=False,
max_tokens=max_tokens,
temperature=temperature
)
if response.status_code == 200:
# 检查响应是否包含预期的内容
# DashScope API的响应结构可能是 output.choices 或 output.text
content = None
# 尝试从 output.choices 获取内容
if (hasattr(response, 'output') and
response.output and
hasattr(response.output, 'choices') and
response.output.choices is not None and
len(response.output.choices) > 0 and
'message' in response.output.choices[0] and
'content' in response.output.choices[0]['message']):
content = response.output.choices[0]['message']['content']
# 否则尝试从 output.text 获取内容DashScope特定格式
elif (hasattr(response, 'output') and
response.output and
'text' in response.output):
content = response.output.get('text')
if content:
# 构建前端期望的响应格式
chat_response = format_api_response(
content=content,
conversation_id=body.get('conversationId'),
model=model
)
if hasattr(response, 'usage') and response.usage:
chat_response["usage"] = {
"promptTokens": response.usage.input_tokens,
"completionTokens": response.usage.output_tokens,
"totalTokens": response.usage.total_tokens
}
return JSONResponse(content=chat_response, ensure_ascii=False)
else:
raise HTTPException(
status_code=500,
detail="API Response does not contain expected content"
)
else:
raise HTTPException(
status_code=500,
detail=f"API Error: {response.code} - {response.message}"
)
except Exception as e:
print(f"[ERROR] Error in chat endpoint: {str(e)}")
import traceback
print(f"[ERROR] Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))
async def multimodal_chat_handler(messages, model, stream, temperature, max_tokens):
"""
多模态聊天处理器 - 处理包含图像的消息
"""
try:
# 将OpenAI格式的消息转换为DashScope MultiModalConversation格式
dashscope_messages = []
for i, msg in enumerate(messages):
# 验证 msg 是否为字典类型,如果不是则跳过或处理为字符串
if not isinstance(msg, dict):
# 如果消息不是字典,将其作为纯文本处理
dashscope_content = [
{'text': str(msg)}
]
dashscope_messages.append({
'role': 'user',
'content': dashscope_content
})
continue
role = msg.get('role', 'user')
content = msg.get('content', '')
if isinstance(content, str):
# 纯文本内容
dashscope_content = [
{'text': content}
]
elif isinstance(content, list):
# 包含图像和文本的内容
dashscope_content = []
for j, item in enumerate(content):
if isinstance(item, dict):
if item.get('type') == 'text':
dashscope_content.append({'text': item.get('text', '')})
elif item.get('type') == 'image_url':
# 处理 image_url 可能是字符串或字典两种情况
image_url_value = item.get('image_url', '')
if isinstance(image_url_value, str):
# 如果 image_url 是字符串,直接使用
img_url = image_url_value
elif isinstance(image_url_value, dict) and 'url' in image_url_value:
# 如果 image_url 是字典,从中获取 url
img_url = image_url_value.get('url', '')
else:
# 其他情况视为错误或空值
img_url = ''
# 如果URL是http格式提取文件名并转换为file://格式
if img_url.startswith('http://') or img_url.startswith('https://'):
# 提取URL中的文件名部分 (例如从 http://localhost:8000/uploads/filename.jpg 提取 uploads/filename.jpg)
from urllib.parse import urlparse
parsed_url = urlparse(img_url)
path_parts = parsed_url.path.split('/')
# 从路径中找到uploads部分及后面的文件名
try:
uploads_index = path_parts.index('uploads')
filename = '/'.join(path_parts[uploads_index:]) # 例如: uploads/filename.jpg
img_url = f"file://{filename}"
except ValueError:
# 如果路径中没有uploads部分使用原始路径
img_url = f"file://{parsed_url.path.lstrip('/')}"
elif not img_url.startswith('file://'):
# 如果既不是网络URL也不是file://协议,假设是相对路径
img_url = f"file://{img_url}"
if img_url.startswith('file://'):
# 确保本地文件存在
import os
local_path = img_url[7:] # 移除 "file://" 前缀
if not os.path.exists(local_path):
print(f"[WARNING] Image file does not exist: {local_path}")
dashscope_content.append({'image': img_url})
else:
# 将非字典内容转换为文本
dashscope_content.append({'text': str(item)})
else:
# 其他情况转换为文本
dashscope_content = [
{'text': str(content)}
]
dashscope_messages.append({
'role': role,
'content': dashscope_content
})
if stream:
# 多模态流式响应
async def multimodal_event_generator():
try:
responses = MultiModalConversation.call(
model=model.replace('qwen-', 'qwen-vl-') if 'qwen-' in model else 'qwen-vl-max',
messages=dashscope_messages,
stream=True,
max_tokens=max_tokens,
temperature=temperature
)
full_content = ""
for response in responses:
if response.status_code == 200:
content = None
# 从多模态响应中提取内容
if (hasattr(response, 'output') and
response.output and
hasattr(response.output, 'choices') and
response.output.choices is not None and
len(response.output.choices) > 0 and
'message' in response.output.choices[0]):
message = response.output.choices[0]['message']
if 'content' in message:
content_items = message['content']
# 从内容项中提取文本
extracted_text = ""
for item in content_items:
if isinstance(item, dict) and 'text' in item:
extracted_text += item['text']
content = extracted_text
# 只有当内容发生变化时才发送增量
if len(content) > len(full_content):
delta_content = extract_delta_content(content, full_content)
full_content = content
if delta_content.strip():
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": delta_content},
"finish_reason": None
}
]
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
else:
error_data = {
"error": {
"message": f"Multimodal API Error: {response.code} - {response.message}",
"type": "api_error",
"param": None,
"code": response.code
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
break
finish_data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": model,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(finish_data, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
error_data = {
"error": {
"message": str(e),
"type": "server_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
return StreamingResponse(multimodal_event_generator(), media_type="text/event-stream")
else:
# 多模态非流式响应
response = MultiModalConversation.call(
model=model.replace('qwen-', 'qwen-vl-') if 'qwen-' in model else 'qwen-vl-max',
messages=dashscope_messages,
stream=False,
max_tokens=max_tokens,
temperature=temperature
)
if response.status_code == 200:
content = None
if (hasattr(response, 'output') and
response.output and
hasattr(response.output, 'choices') and
response.output.choices is not None and
len(response.output.choices) > 0 and
'message' in response.output.choices[0]):
message = response.output.choices[0]['message']
if 'content' in message:
content_items = message['content']
# 从内容项中提取文本
extracted_text = ""
for item in content_items:
if isinstance(item, dict) and 'text' in item:
extracted_text += item['text']
content = extracted_text
if content:
return JSONResponse(content={"result": content}, ensure_ascii=False)
else:
raise HTTPException(
status_code=500,
detail="Multimodal API Response does not contain expected content"
)
else:
raise HTTPException(
status_code=500,
detail=f"Multimodal API Error: {response.code} - {response.message}"
)
except Exception as e:
print(f"[ERROR] Error in multimodal chat handler: {str(e)}")
import traceback
print(f"[ERROR] Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))
async def get_models_handler():
"""获取模型列表处理器"""
models = [
ModelInfo(
id="qwen-max",
name="通义千问 Max",
description="最强大的模型",
maxTokens=8192,
provider="Aliyun"
),
ModelInfo(
id="qwen-plus",
name="通义千问 Plus",
description="能力均衡",
maxTokens=8192,
provider="Aliyun"
),
ModelInfo(
id="qwen-turbo",
name="通义千问 Turbo",
description="速度更快、成本更低",
maxTokens=8192,
provider="Aliyun"
),
ModelInfo(
id="qwen-vl-max",
name="通义万相 VL-Max",
description="支持视觉理解的多模态模型",
maxTokens=8192,
provider="Aliyun"
),
ModelInfo(
id="qwen-vl-plus",
name="通义万相 VL-Plus",
description="支持视觉理解的多模态模型",
maxTokens=8192,
provider="Aliyun"
)
]
return [model.dict() for model in models]
async def get_conversations_handler():
"""获取所有对话处理器"""
return list(conversations_db.values())
async def get_conversation_handler(conversation_id: str):
"""获取特定对话处理器"""
conversation = conversations_db.get(conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="对话不存在")
return conversation
async def save_conversation_handler(data: dict):
"""保存或更新对话处理器"""
try:
conversation_id = data.get('id') or generate_unique_id()
conversation = {
"id": conversation_id,
"title": data.get('title', '新对话'),
"messages": data.get('messages', []),
"updatedAt": datetime.utcnow().isoformat(),
"createdAt": data.get('createdAt', datetime.utcnow().isoformat())
}
conversations_db[conversation_id] = conversation
return conversation
except Exception as e:
print(f"[ERROR] Error saving conversation: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def delete_conversation_handler(conversation_id: str):
"""删除对话处理器"""
if conversation_id in conversations_db:
del conversations_db[conversation_id]
return {"success": True, "message": "删除成功"}
else:
raise HTTPException(status_code=404, detail="对话不存在")
async def upload_file_handler(file: UploadFile = File(...)):
"""文件上传处理器"""
try:
# 检查文件类型
allowed_types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'text/plain', 'application/pdf']
if file.content_type not in allowed_types:
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {file.content_type}")
# 生成唯一文件名
file_extension = Path(file.filename).suffix.lower()
unique_filename = f"{int(datetime.utcnow().timestamp())}_{generate_unique_id()}{file_extension}"
file_path = upload_dir / unique_filename
# 保存文件
with open(file_path, "wb") as f:
content = await file.read()
f.write(content)
# 返回文件信息
file_url = f"http://localhost:8000/uploads/{unique_filename}"
result = {
"url": file_url,
"name": file.filename,
"size": len(content),
"mimeType": file.content_type
}
print(f"[INFO] File uploaded: {result}")
return result
except Exception as e:
print(f"[ERROR] Upload error: {str(e)}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
def serve_upload_handler(filename: str):
"""提供上传文件访问处理器"""
file_path = upload_dir / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="文件不存在")
from fastapi.responses import FileResponse
return FileResponse(str(file_path))
async def stop_generation_handler(message_id: str = None):
"""停止生成处理器"""
message = f"已发出停止指令消息ID: {message_id}" if message_id else "已发出停止指令"
return {"success": True, "message": message}