feat: 图片识别mvp

This commit is contained in:
肖应宇 2026-03-03 16:32:08 +08:00
parent c32b50584d
commit 467f38645d
1 changed files with 372 additions and 98 deletions

View File

@ -10,7 +10,7 @@ from pathlib import Path
from fastapi import HTTPException, File, UploadFile from fastapi import HTTPException, File, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
import dashscope import dashscope
from dashscope import Generation from dashscope import Generation, MultiModalConversation
# 导入模型和工具函数(使用绝对路径) # 导入模型和工具函数(使用绝对路径)
import sys import sys
@ -40,6 +40,14 @@ async def chat_endpoint_handler(body: dict):
这个端点会接收前端的聊天请求并转发到阿里云百炼API 这个端点会接收前端的聊天请求并转发到阿里云百炼API
""" """
try: try:
# 确保 body 是字典类型
if not isinstance(body, dict):
print(f"[ERROR] Request body is not a dictionary: {type(body)}")
raise HTTPException(
status_code=400,
detail=f"Request body must be a JSON object, got {type(body).__name__}: {body}"
)
# 检查请求格式并适配 # 检查请求格式并适配
# 如果是OpenAI兼容格式 (来自streamChat) # 如果是OpenAI兼容格式 (来自streamChat)
if 'messages' in body: if 'messages' in body:
@ -69,6 +77,19 @@ async def chat_endpoint_handler(body: dict):
temperature = body.get('temperature', 0.7) temperature = body.get('temperature', 0.7)
max_tokens = body.get('maxTokens', 2000) max_tokens = body.get('maxTokens', 2000)
# 检查是否包含图像内容如果是多模态请求使用MultiModalConversation
has_images = any(
isinstance(msg, dict) and
isinstance(msg.get('content'), list) and
any(isinstance(item, dict) and item.get('type') == 'image_url' for item in msg.get('content', []))
for msg in messages if isinstance(msg, dict)
)
if has_images:
# 使用多模态API处理图像
return await multimodal_chat_handler(messages, model, stream, temperature, max_tokens)
else:
# 使用常规聊天API
if stream: if stream:
# 流式响应 # 流式响应
async def event_generator(): async def event_generator():
@ -251,6 +272,245 @@ async def chat_endpoint_handler(body: dict):
except Exception as e: except Exception as e:
print(f"[ERROR] Error in chat endpoint: {str(e)}") print(f"[ERROR] Error in chat endpoint: {str(e)}")
import traceback
print(f"[ERROR] Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))
async def multimodal_chat_handler(messages, model, stream, temperature, max_tokens):
"""
多模态聊天处理器 - 处理包含图像的消息
"""
try:
# 将OpenAI格式的消息转换为DashScope MultiModalConversation格式
dashscope_messages = []
for i, msg in enumerate(messages):
# 验证 msg 是否为字典类型,如果不是则跳过或处理为字符串
if not isinstance(msg, dict):
# 如果消息不是字典,将其作为纯文本处理
dashscope_content = [
{'text': str(msg)}
]
dashscope_messages.append({
'role': 'user',
'content': dashscope_content
})
continue
role = msg.get('role', 'user')
content = msg.get('content', '')
if isinstance(content, str):
# 纯文本内容
dashscope_content = [
{'text': content}
]
elif isinstance(content, list):
# 包含图像和文本的内容
dashscope_content = []
for j, item in enumerate(content):
if isinstance(item, dict):
if item.get('type') == 'text':
dashscope_content.append({'text': item.get('text', '')})
elif item.get('type') == 'image_url':
# 处理 image_url 可能是字符串或字典两种情况
image_url_value = item.get('image_url', '')
if isinstance(image_url_value, str):
# 如果 image_url 是字符串,直接使用
img_url = image_url_value
elif isinstance(image_url_value, dict) and 'url' in image_url_value:
# 如果 image_url 是字典,从中获取 url
img_url = image_url_value.get('url', '')
else:
# 其他情况视为错误或空值
img_url = ''
# 如果URL是http格式提取文件名并转换为file://格式
if img_url.startswith('http://') or img_url.startswith('https://'):
# 提取URL中的文件名部分 (例如从 http://localhost:8000/uploads/filename.jpg 提取 uploads/filename.jpg)
from urllib.parse import urlparse
parsed_url = urlparse(img_url)
path_parts = parsed_url.path.split('/')
# 从路径中找到uploads部分及后面的文件名
try:
uploads_index = path_parts.index('uploads')
filename = '/'.join(path_parts[uploads_index:]) # 例如: uploads/filename.jpg
img_url = f"file://{filename}"
except ValueError:
# 如果路径中没有uploads部分使用原始路径
img_url = f"file://{parsed_url.path.lstrip('/')}"
elif not img_url.startswith('file://'):
# 如果既不是网络URL也不是file://协议,假设是相对路径
img_url = f"file://{img_url}"
if img_url.startswith('file://'):
# 确保本地文件存在
import os
local_path = img_url[7:] # 移除 "file://" 前缀
if not os.path.exists(local_path):
print(f"[WARNING] Image file does not exist: {local_path}")
dashscope_content.append({'image': img_url})
else:
# 将非字典内容转换为文本
dashscope_content.append({'text': str(item)})
else:
# 其他情况转换为文本
dashscope_content = [
{'text': str(content)}
]
dashscope_messages.append({
'role': role,
'content': dashscope_content
})
if stream:
# 多模态流式响应
async def multimodal_event_generator():
try:
responses = MultiModalConversation.call(
model=model.replace('qwen-', 'qwen-vl-') if 'qwen-' in model else 'qwen-vl-max',
messages=dashscope_messages,
stream=True,
max_tokens=max_tokens,
temperature=temperature
)
full_content = ""
for response in responses:
if response.status_code == 200:
content = None
# 从多模态响应中提取内容
if (hasattr(response, 'output') and
response.output and
hasattr(response.output, 'choices') and
response.output.choices is not None and
len(response.output.choices) > 0 and
'message' in response.output.choices[0]):
message = response.output.choices[0]['message']
if 'content' in message:
content_items = message['content']
# 从内容项中提取文本
extracted_text = ""
for item in content_items:
if isinstance(item, dict) and 'text' in item:
extracted_text += item['text']
content = extracted_text
# 只有当内容发生变化时才发送增量
if len(content) > len(full_content):
delta_content = extract_delta_content(content, full_content)
full_content = content
if delta_content.strip():
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": delta_content},
"finish_reason": None
}
]
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
else:
error_data = {
"error": {
"message": f"Multimodal API Error: {response.code} - {response.message}",
"type": "api_error",
"param": None,
"code": response.code
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
break
finish_data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": model,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(finish_data, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
error_data = {
"error": {
"message": str(e),
"type": "server_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
return StreamingResponse(multimodal_event_generator(), media_type="text/event-stream")
else:
# 多模态非流式响应
response = MultiModalConversation.call(
model=model.replace('qwen-', 'qwen-vl-') if 'qwen-' in model else 'qwen-vl-max',
messages=dashscope_messages,
stream=False,
max_tokens=max_tokens,
temperature=temperature
)
if response.status_code == 200:
content = None
if (hasattr(response, 'output') and
response.output and
hasattr(response.output, 'choices') and
response.output.choices is not None and
len(response.output.choices) > 0 and
'message' in response.output.choices[0]):
message = response.output.choices[0]['message']
if 'content' in message:
content_items = message['content']
# 从内容项中提取文本
extracted_text = ""
for item in content_items:
if isinstance(item, dict) and 'text' in item:
extracted_text += item['text']
content = extracted_text
if content:
return JSONResponse(content={"result": content}, ensure_ascii=False)
else:
raise HTTPException(
status_code=500,
detail="Multimodal API Response does not contain expected content"
)
else:
raise HTTPException(
status_code=500,
detail=f"Multimodal API Error: {response.code} - {response.message}"
)
except Exception as e:
print(f"[ERROR] Error in multimodal chat handler: {str(e)}")
import traceback
print(f"[ERROR] Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -277,6 +537,20 @@ async def get_models_handler():
description="速度更快、成本更低", description="速度更快、成本更低",
maxTokens=8192, maxTokens=8192,
provider="Aliyun" provider="Aliyun"
),
ModelInfo(
id="qwen-vl-max",
name="通义万相 VL-Max",
description="支持视觉理解的多模态模型",
maxTokens=8192,
provider="Aliyun"
),
ModelInfo(
id="qwen-vl-plus",
name="通义万相 VL-Plus",
description="支持视觉理解的多模态模型",
maxTokens=8192,
provider="Aliyun"
) )
] ]
return [model.dict() for model in models] return [model.dict() for model in models]