650 lines
29 KiB
Python
650 lines
29 KiB
Python
"""
|
||
API 路由定义
|
||
"""
|
||
import os
|
||
import json
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Dict, List
|
||
from pathlib import Path
|
||
from fastapi import HTTPException, File, UploadFile
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
import dashscope
|
||
from dashscope import Generation, MultiModalConversation
|
||
|
||
# 导入模型和工具函数(使用绝对路径)
|
||
import sys
|
||
from pathlib import Path
|
||
sys.path.append(str(Path(__file__).parent.parent))
|
||
|
||
from models.chat_models import ChatRequest, ModelInfo
|
||
from utils.helpers import (
|
||
get_current_timestamp,
|
||
generate_unique_id,
|
||
format_api_response,
|
||
extract_delta_content
|
||
)
|
||
|
||
|
||
# 模拟数据库 - 实际应用中应使用持久化存储
|
||
conversations_db: Dict[str, dict] = {}
|
||
|
||
# 配置上传目录
|
||
upload_dir = Path("uploads")
|
||
upload_dir.mkdir(exist_ok=True)
|
||
|
||
|
||
async def chat_endpoint_handler(body: dict):
|
||
"""
|
||
聊天接口处理器 - 与阿里云百炼API兼容的接口
|
||
这个端点会接收前端的聊天请求并转发到阿里云百炼API
|
||
"""
|
||
try:
|
||
# 确保 body 是字典类型
|
||
if not isinstance(body, dict):
|
||
print(f"[ERROR] Request body is not a dictionary: {type(body)}")
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Request body must be a JSON object, got {type(body).__name__}: {body}"
|
||
)
|
||
|
||
# 检查请求格式并适配
|
||
# 如果是OpenAI兼容格式 (来自streamChat)
|
||
if 'messages' in body:
|
||
messages = body.get('messages', [])
|
||
model = body.get('model', 'qwen-plus')
|
||
stream = body.get('stream', True)
|
||
temperature = body.get('temperature', 0.7)
|
||
max_tokens = body.get('max_tokens', 2000)
|
||
else:
|
||
# 否则是前端简化格式 (来自chat函数)
|
||
# 需要将其转换为OpenAI兼容格式
|
||
message_text = body.get('message', '')
|
||
|
||
# 检查message是否已经是格式化的列表(带图片的情况)
|
||
if isinstance(message_text, list):
|
||
user_content = message_text
|
||
else:
|
||
# 如果是字符串,转换为标准格式
|
||
user_content = [{"type": "text", "text": message_text}]
|
||
|
||
messages = [
|
||
{"role": "system", "content": body.get('systemPrompt', '你是一个支持视觉理解的助手。')},
|
||
{"role": "user", "content": user_content}
|
||
]
|
||
model = body.get('model', 'qwen-plus')
|
||
stream = body.get('stream', False) # 默认为非流式
|
||
temperature = body.get('temperature', 0.7)
|
||
max_tokens = body.get('maxTokens', 2000)
|
||
|
||
# 检查是否包含图像内容,如果是多模态请求,使用MultiModalConversation
|
||
has_images = any(
|
||
isinstance(msg, dict) and
|
||
isinstance(msg.get('content'), list) and
|
||
any(isinstance(item, dict) and item.get('type') == 'image_url' for item in msg.get('content', []))
|
||
for msg in messages if isinstance(msg, dict)
|
||
)
|
||
|
||
if has_images:
|
||
# 使用多模态API处理图像
|
||
return await multimodal_chat_handler(messages, model, stream, temperature, max_tokens)
|
||
else:
|
||
# 使用常规聊天API
|
||
if stream:
|
||
# 流式响应
|
||
async def event_generator():
|
||
try:
|
||
responses = Generation.call(
|
||
model=model,
|
||
messages=messages,
|
||
stream=True,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature
|
||
)
|
||
|
||
full_content = "" # 用于累计完整内容
|
||
|
||
for idx, response in enumerate(responses):
|
||
if response.status_code == 200:
|
||
# 检查响应是否包含预期的内容
|
||
# DashScope API的响应结构可能是 output.choices 或 output.text
|
||
content = None
|
||
|
||
# 尝试从 output.choices 获取内容
|
||
if (hasattr(response, 'output') and
|
||
response.output and
|
||
hasattr(response.output, 'choices') and
|
||
response.output.choices is not None and
|
||
len(response.output.choices) > 0 and
|
||
'message' in response.output.choices[0] and
|
||
'content' in response.output.choices[0]['message']):
|
||
|
||
content = response.output.choices[0]['message']['content']
|
||
|
||
# 只有当内容发生变化时才发送增量
|
||
if len(content) > len(full_content):
|
||
delta_content = extract_delta_content(content, full_content)
|
||
full_content = content
|
||
|
||
if delta_content.strip(): # 只有当有非空白新内容时才发送
|
||
# 构建 SSE 数据块
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": {"content": delta_content},
|
||
"finish_reason": None
|
||
}
|
||
]
|
||
}
|
||
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
# 否则尝试从 output.text 获取内容(DashScope特定格式)
|
||
elif (hasattr(response, 'output') and
|
||
response.output and
|
||
'text' in response.output):
|
||
|
||
content = response.output.get('text')
|
||
|
||
# 只有当内容发生变化时才发送增量
|
||
if len(content) > len(full_content):
|
||
delta_content = extract_delta_content(content, full_content)
|
||
full_content = content
|
||
|
||
if delta_content.strip(): # 只有当有非空白新内容时才发送
|
||
# 构建 SSE 数据块
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": {"content": delta_content},
|
||
"finish_reason": None
|
||
}
|
||
]
|
||
}
|
||
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
else:
|
||
# 错误处理
|
||
error_data = {
|
||
"error": {
|
||
"message": f"API Error: {response.code} - {response.message}",
|
||
"type": "api_error",
|
||
"param": None,
|
||
"code": response.code
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
break
|
||
|
||
# 发送结束信号
|
||
finish_data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": "stop"
|
||
}
|
||
]
|
||
}
|
||
yield f"data: {json.dumps(finish_data, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
except Exception as e:
|
||
error_data = {
|
||
"error": {
|
||
"message": str(e),
|
||
"type": "server_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||
else:
|
||
# 非流式响应
|
||
response = Generation.call(
|
||
model=model,
|
||
messages=messages,
|
||
stream=False,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
# 检查响应是否包含预期的内容
|
||
# DashScope API的响应结构可能是 output.choices 或 output.text
|
||
content = None
|
||
|
||
# 尝试从 output.choices 获取内容
|
||
if (hasattr(response, 'output') and
|
||
response.output and
|
||
hasattr(response.output, 'choices') and
|
||
response.output.choices is not None and
|
||
len(response.output.choices) > 0 and
|
||
'message' in response.output.choices[0] and
|
||
'content' in response.output.choices[0]['message']):
|
||
|
||
content = response.output.choices[0]['message']['content']
|
||
# 否则尝试从 output.text 获取内容(DashScope特定格式)
|
||
elif (hasattr(response, 'output') and
|
||
response.output and
|
||
'text' in response.output):
|
||
|
||
content = response.output.get('text')
|
||
|
||
if content:
|
||
# 构建前端期望的响应格式
|
||
chat_response = format_api_response(
|
||
content=content,
|
||
conversation_id=body.get('conversationId'),
|
||
model=model
|
||
)
|
||
|
||
if hasattr(response, 'usage') and response.usage:
|
||
chat_response["usage"] = {
|
||
"promptTokens": response.usage.input_tokens,
|
||
"completionTokens": response.usage.output_tokens,
|
||
"totalTokens": response.usage.total_tokens
|
||
}
|
||
|
||
return JSONResponse(content=chat_response, ensure_ascii=False)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail="API Response does not contain expected content"
|
||
)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"API Error: {response.code} - {response.message}"
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Error in chat endpoint: {str(e)}")
|
||
import traceback
|
||
print(f"[ERROR] Traceback: {traceback.format_exc()}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
async def multimodal_chat_handler(messages, model, stream, temperature, max_tokens):
|
||
"""
|
||
多模态聊天处理器 - 处理包含图像的消息
|
||
"""
|
||
try:
|
||
# 将OpenAI格式的消息转换为DashScope MultiModalConversation格式
|
||
dashscope_messages = []
|
||
for i, msg in enumerate(messages):
|
||
# 验证 msg 是否为字典类型,如果不是则跳过或处理为字符串
|
||
if not isinstance(msg, dict):
|
||
# 如果消息不是字典,将其作为纯文本处理
|
||
dashscope_content = [
|
||
{'text': str(msg)}
|
||
]
|
||
dashscope_messages.append({
|
||
'role': 'user',
|
||
'content': dashscope_content
|
||
})
|
||
continue
|
||
|
||
role = msg.get('role', 'user')
|
||
content = msg.get('content', '')
|
||
|
||
if isinstance(content, str):
|
||
# 纯文本内容
|
||
dashscope_content = [
|
||
{'text': content}
|
||
]
|
||
elif isinstance(content, list):
|
||
# 包含图像和文本的内容
|
||
dashscope_content = []
|
||
for j, item in enumerate(content):
|
||
if isinstance(item, dict):
|
||
if item.get('type') == 'text':
|
||
dashscope_content.append({'text': item.get('text', '')})
|
||
elif item.get('type') == 'image_url':
|
||
# 处理 image_url 可能是字符串或字典两种情况
|
||
image_url_value = item.get('image_url', '')
|
||
if isinstance(image_url_value, str):
|
||
# 如果 image_url 是字符串,直接使用
|
||
img_url = image_url_value
|
||
elif isinstance(image_url_value, dict) and 'url' in image_url_value:
|
||
# 如果 image_url 是字典,从中获取 url
|
||
img_url = image_url_value.get('url', '')
|
||
else:
|
||
# 其他情况视为错误或空值
|
||
img_url = ''
|
||
|
||
# 如果URL是http格式,提取文件名并转换为file://格式
|
||
if img_url.startswith('http://') or img_url.startswith('https://'):
|
||
# 提取URL中的文件名部分 (例如从 http://localhost:8000/uploads/filename.jpg 提取 uploads/filename.jpg)
|
||
from urllib.parse import urlparse
|
||
parsed_url = urlparse(img_url)
|
||
path_parts = parsed_url.path.split('/')
|
||
|
||
# 从路径中找到uploads部分及后面的文件名
|
||
try:
|
||
uploads_index = path_parts.index('uploads')
|
||
filename = '/'.join(path_parts[uploads_index:]) # 例如: uploads/filename.jpg
|
||
img_url = f"file://{filename}"
|
||
except ValueError:
|
||
# 如果路径中没有uploads部分,使用原始路径
|
||
img_url = f"file://{parsed_url.path.lstrip('/')}"
|
||
elif not img_url.startswith('file://'):
|
||
# 如果既不是网络URL也不是file://协议,假设是相对路径
|
||
img_url = f"file://{img_url}"
|
||
|
||
if img_url.startswith('file://'):
|
||
# 确保本地文件存在
|
||
import os
|
||
local_path = img_url[7:] # 移除 "file://" 前缀
|
||
if not os.path.exists(local_path):
|
||
print(f"[WARNING] Image file does not exist: {local_path}")
|
||
|
||
dashscope_content.append({'image': img_url})
|
||
else:
|
||
# 将非字典内容转换为文本
|
||
dashscope_content.append({'text': str(item)})
|
||
else:
|
||
# 其他情况转换为文本
|
||
dashscope_content = [
|
||
{'text': str(content)}
|
||
]
|
||
|
||
dashscope_messages.append({
|
||
'role': role,
|
||
'content': dashscope_content
|
||
})
|
||
|
||
if stream:
|
||
# 多模态流式响应
|
||
async def multimodal_event_generator():
|
||
try:
|
||
responses = MultiModalConversation.call(
|
||
model=model.replace('qwen-', 'qwen-vl-') if 'qwen-' in model else 'qwen-vl-max',
|
||
messages=dashscope_messages,
|
||
stream=True,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature
|
||
)
|
||
|
||
full_content = ""
|
||
|
||
for response in responses:
|
||
if response.status_code == 200:
|
||
content = None
|
||
|
||
# 从多模态响应中提取内容
|
||
if (hasattr(response, 'output') and
|
||
response.output and
|
||
hasattr(response.output, 'choices') and
|
||
response.output.choices is not None and
|
||
len(response.output.choices) > 0 and
|
||
'message' in response.output.choices[0]):
|
||
|
||
message = response.output.choices[0]['message']
|
||
if 'content' in message:
|
||
content_items = message['content']
|
||
|
||
# 从内容项中提取文本
|
||
extracted_text = ""
|
||
for item in content_items:
|
||
if isinstance(item, dict) and 'text' in item:
|
||
extracted_text += item['text']
|
||
|
||
content = extracted_text
|
||
|
||
# 只有当内容发生变化时才发送增量
|
||
if len(content) > len(full_content):
|
||
delta_content = extract_delta_content(content, full_content)
|
||
full_content = content
|
||
|
||
if delta_content.strip():
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": {"content": delta_content},
|
||
"finish_reason": None
|
||
}
|
||
]
|
||
}
|
||
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
else:
|
||
error_data = {
|
||
"error": {
|
||
"message": f"Multimodal API Error: {response.code} - {response.message}",
|
||
"type": "api_error",
|
||
"param": None,
|
||
"code": response.code
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
break
|
||
|
||
finish_data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": "stop"
|
||
}
|
||
]
|
||
}
|
||
yield f"data: {json.dumps(finish_data, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
except Exception as e:
|
||
error_data = {
|
||
"error": {
|
||
"message": str(e),
|
||
"type": "server_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
return StreamingResponse(multimodal_event_generator(), media_type="text/event-stream")
|
||
else:
|
||
# 多模态非流式响应
|
||
response = MultiModalConversation.call(
|
||
model=model.replace('qwen-', 'qwen-vl-') if 'qwen-' in model else 'qwen-vl-max',
|
||
messages=dashscope_messages,
|
||
stream=False,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
content = None
|
||
|
||
if (hasattr(response, 'output') and
|
||
response.output and
|
||
hasattr(response.output, 'choices') and
|
||
response.output.choices is not None and
|
||
len(response.output.choices) > 0 and
|
||
'message' in response.output.choices[0]):
|
||
|
||
message = response.output.choices[0]['message']
|
||
if 'content' in message:
|
||
content_items = message['content']
|
||
|
||
# 从内容项中提取文本
|
||
extracted_text = ""
|
||
for item in content_items:
|
||
if isinstance(item, dict) and 'text' in item:
|
||
extracted_text += item['text']
|
||
|
||
content = extracted_text
|
||
|
||
if content:
|
||
return JSONResponse(content={"result": content}, ensure_ascii=False)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail="Multimodal API Response does not contain expected content"
|
||
)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Multimodal API Error: {response.code} - {response.message}"
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Error in multimodal chat handler: {str(e)}")
|
||
import traceback
|
||
print(f"[ERROR] Traceback: {traceback.format_exc()}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
async def get_models_handler():
|
||
"""获取模型列表处理器"""
|
||
models = [
|
||
ModelInfo(
|
||
id="qwen-max",
|
||
name="通义千问 Max",
|
||
description="最强大的模型",
|
||
maxTokens=8192,
|
||
provider="Aliyun"
|
||
),
|
||
ModelInfo(
|
||
id="qwen-plus",
|
||
name="通义千问 Plus",
|
||
description="能力均衡",
|
||
maxTokens=8192,
|
||
provider="Aliyun"
|
||
),
|
||
ModelInfo(
|
||
id="qwen-turbo",
|
||
name="通义千问 Turbo",
|
||
description="速度更快、成本更低",
|
||
maxTokens=8192,
|
||
provider="Aliyun"
|
||
),
|
||
ModelInfo(
|
||
id="qwen-vl-max",
|
||
name="通义万相 VL-Max",
|
||
description="支持视觉理解的多模态模型",
|
||
maxTokens=8192,
|
||
provider="Aliyun"
|
||
),
|
||
ModelInfo(
|
||
id="qwen-vl-plus",
|
||
name="通义万相 VL-Plus",
|
||
description="支持视觉理解的多模态模型",
|
||
maxTokens=8192,
|
||
provider="Aliyun"
|
||
)
|
||
]
|
||
return [model.dict() for model in models]
|
||
|
||
|
||
async def get_conversations_handler():
|
||
"""获取所有对话处理器"""
|
||
return list(conversations_db.values())
|
||
|
||
|
||
async def get_conversation_handler(conversation_id: str):
|
||
"""获取特定对话处理器"""
|
||
conversation = conversations_db.get(conversation_id)
|
||
if not conversation:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
return conversation
|
||
|
||
|
||
async def save_conversation_handler(data: dict):
|
||
"""保存或更新对话处理器"""
|
||
try:
|
||
conversation_id = data.get('id') or generate_unique_id()
|
||
|
||
conversation = {
|
||
"id": conversation_id,
|
||
"title": data.get('title', '新对话'),
|
||
"messages": data.get('messages', []),
|
||
"updatedAt": datetime.utcnow().isoformat(),
|
||
"createdAt": data.get('createdAt', datetime.utcnow().isoformat())
|
||
}
|
||
|
||
conversations_db[conversation_id] = conversation
|
||
return conversation
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Error saving conversation: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
async def delete_conversation_handler(conversation_id: str):
|
||
"""删除对话处理器"""
|
||
if conversation_id in conversations_db:
|
||
del conversations_db[conversation_id]
|
||
return {"success": True, "message": "删除成功"}
|
||
else:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
|
||
async def upload_file_handler(file: UploadFile = File(...)):
|
||
"""文件上传处理器"""
|
||
try:
|
||
# 检查文件类型
|
||
allowed_types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'text/plain', 'application/pdf']
|
||
if file.content_type not in allowed_types:
|
||
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {file.content_type}")
|
||
|
||
# 生成唯一文件名
|
||
file_extension = Path(file.filename).suffix.lower()
|
||
unique_filename = f"{int(datetime.utcnow().timestamp())}_{generate_unique_id()}{file_extension}"
|
||
file_path = upload_dir / unique_filename
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as f:
|
||
content = await file.read()
|
||
f.write(content)
|
||
|
||
# 返回文件信息
|
||
file_url = f"http://localhost:8000/uploads/{unique_filename}"
|
||
result = {
|
||
"url": file_url,
|
||
"name": file.filename,
|
||
"size": len(content),
|
||
"mimeType": file.content_type
|
||
}
|
||
|
||
print(f"[INFO] File uploaded: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Upload error: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||
|
||
|
||
def serve_upload_handler(filename: str):
|
||
"""提供上传文件访问处理器"""
|
||
file_path = upload_dir / filename
|
||
if not file_path.exists():
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
from fastapi.responses import FileResponse
|
||
return FileResponse(str(file_path))
|
||
|
||
|
||
async def stop_generation_handler(message_id: str = None):
|
||
"""停止生成处理器"""
|
||
message = f"已发出停止指令,消息ID: {message_id}" if message_id else "已发出停止指令"
|
||
return {"success": True, "message": message} |