1082 lines
45 KiB
Python
1082 lines
45 KiB
Python
"""
|
||
API 路由定义(阿里云百炼 / DashScope 平台)
|
||
所有 DashScope 相关逻辑均集中在此文件,main.py 无感知任何平台细节。
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import uuid
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Dict, List
|
||
|
||
import dashscope
|
||
from dashscope import Generation, MultiModalConversation
|
||
from fastapi import File, HTTPException, UploadFile
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
||
|
||
def init():
|
||
"""
|
||
初始化百炼后端:设置 DashScope API Key。
|
||
由 main.py 在启动时调用(若 LLM_BACKEND=dashscope)。
|
||
"""
|
||
api_key = os.getenv("ALIYUN_API_KEY")
|
||
if not api_key:
|
||
raise ValueError("dashscope 模式需要设置环境变量 ALIYUN_API_KEY")
|
||
dashscope.api_key = api_key
|
||
print(f"[DashScope] 初始化完成,ALIYUN_API_KEY 已配置")
|
||
|
||
|
||
# 导入模型和工具函数(使用绝对路径)
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
sys.path.append(str(Path(__file__).parent.parent))
|
||
|
||
from models.chat_models import ChatRequest, ModelInfo
|
||
from utils.helpers import (
|
||
extract_delta_content,
|
||
format_api_response,
|
||
generate_unique_id,
|
||
get_current_timestamp,
|
||
)
|
||
from utils.logger import log_error, log_exception, log_info
|
||
|
||
# 模拟数据库 - 实际应用中应使用持久化存储
|
||
conversations_db: Dict[str, dict] = {}
|
||
|
||
# 配置上传目录
|
||
upload_dir = Path("uploads")
|
||
upload_dir.mkdir(exist_ok=True)
|
||
|
||
|
||
def _extract_text_from_docmind(obj, depth: int = 0) -> str:
|
||
"""
|
||
递归提取 DocMind JSON 结构中的可读文本。
|
||
DashScopeParse 返回的 text 字段是 JSON 字符串,内部为文档智能解析的树形结构。
|
||
"""
|
||
if depth > 15:
|
||
return ""
|
||
|
||
if isinstance(obj, str):
|
||
s = obj.strip()
|
||
# 过滤极短、URL、base64 等非正文字符串
|
||
if len(s) > 3 and not s.startswith(("http://", "https://", "data:", "oss://")):
|
||
return s
|
||
return ""
|
||
|
||
if isinstance(obj, list):
|
||
parts = [_extract_text_from_docmind(item, depth + 1) for item in obj]
|
||
return "\n".join(p for p in parts if p)
|
||
|
||
if isinstance(obj, dict):
|
||
# 优先处理文本相关字段
|
||
priority_keys = ["content", "text", "paragraph", "caption", "value", "title"]
|
||
# 跳过纯元数据字段
|
||
skip_keys = {
|
||
"backlink",
|
||
"pos",
|
||
"index",
|
||
"style",
|
||
"font",
|
||
"color",
|
||
"size",
|
||
"hash",
|
||
"id_",
|
||
"id",
|
||
"layouts",
|
||
"type",
|
||
"link",
|
||
}
|
||
parts = []
|
||
for key in priority_keys:
|
||
if key in obj:
|
||
t = _extract_text_from_docmind(obj[key], depth + 1)
|
||
if t:
|
||
parts.append(t)
|
||
for key, val in obj.items():
|
||
if key not in priority_keys and key not in skip_keys:
|
||
t = _extract_text_from_docmind(val, depth + 1)
|
||
if t:
|
||
parts.append(t)
|
||
return "\n".join(parts)
|
||
|
||
return ""
|
||
|
||
|
||
def _read_file_content(file_url: str):
|
||
"""
|
||
【路线一:本地文本提取】对纯文本格式文件,直接读取内容注入消息。
|
||
【路线二:DashScopeParse】对 doc/docx/pdf 文件,返回 (local_path, suffix) 供异步调用。
|
||
|
||
返回值:
|
||
- str:文本内容(路线一成功)
|
||
- tuple(Path, str):(本地路径, 扩展名),需异步调用 DashScopeParse(路线二)
|
||
- None:不支持的文件类型
|
||
"""
|
||
try:
|
||
from urllib.parse import urlparse
|
||
|
||
parsed = urlparse(file_url)
|
||
relative_path = parsed.path.lstrip("/")
|
||
local_path = Path(relative_path)
|
||
|
||
if not local_path.exists():
|
||
return f"[文件不存在: {local_path}]"
|
||
|
||
suffix = local_path.suffix.lower()
|
||
|
||
# 路线一:纯文本格式直接读取
|
||
text_extensions = {
|
||
".txt",
|
||
".md",
|
||
".csv",
|
||
".json",
|
||
".xml",
|
||
".yaml",
|
||
".yml",
|
||
".log",
|
||
".py",
|
||
".js",
|
||
".ts",
|
||
".html",
|
||
".css",
|
||
}
|
||
if suffix in text_extensions:
|
||
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
|
||
content = f.read()
|
||
max_len = 8000
|
||
if len(content) > max_len:
|
||
content = (
|
||
content[:max_len]
|
||
+ f"\n\n[...文件内容过长,已截断,共 {len(content)} 字符]"
|
||
)
|
||
return content
|
||
|
||
# 路线二:doc/docx/pdf 使用 DashScopeParse 云端解析
|
||
dashscope_extensions = {".doc", ".docx", ".pdf"}
|
||
if suffix in dashscope_extensions:
|
||
return (local_path, suffix) # 交给异步函数处理
|
||
|
||
# 其余格式(xlsx、pptx 等)暂不支持内容读取
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"[WARNING] 读取文件内容失败: {e}")
|
||
return f"[文件读取失败: {str(e)}]"
|
||
|
||
|
||
async def _parse_with_dashscope(local_path: Path) -> str:
|
||
"""
|
||
【路线二:DashScopeParse】使用阿里云文档智能解析 doc/docx/pdf 文件。
|
||
在线程池中运行(避免阻塞 FastAPI 事件循环)。
|
||
仅支持 .doc/.docx/.pdf,文件大小 ≤100MB,页数 ≤1000 页。
|
||
"""
|
||
import asyncio
|
||
|
||
def _sync_parse():
|
||
try:
|
||
import json
|
||
|
||
from llama_index.readers.dashscope.base import DashScopeParse
|
||
from llama_index.readers.dashscope.utils import ResultType
|
||
|
||
api_key = os.getenv("ALIYUN_API_KEY")
|
||
parser = DashScopeParse(
|
||
result_type=ResultType.DASHSCOPE_DOCMIND,
|
||
api_key=api_key,
|
||
num_workers=1,
|
||
)
|
||
print(f"[INFO] DashScopeParse: 开始解析 {local_path.name} ...")
|
||
documents = parser.load_data(file_path=[str(local_path)])
|
||
|
||
if not documents:
|
||
return f"[DashScopeParse: {local_path.name} 解析结果为空]"
|
||
|
||
texts = []
|
||
for doc in documents:
|
||
try:
|
||
content = json.loads(doc.text)
|
||
extracted = _extract_text_from_docmind(content)
|
||
texts.append(extracted if extracted else doc.text[:6000])
|
||
except Exception:
|
||
texts.append(doc.text[:6000] if doc.text else "")
|
||
|
||
result = "\n\n".join(t for t in texts if t)
|
||
print(
|
||
f"[INFO] DashScopeParse: {local_path.name} 解析完成,提取 {len(result)} 字符"
|
||
)
|
||
return result or f"[DashScopeParse: {local_path.name} 未能提取到文本内容]"
|
||
|
||
except ImportError:
|
||
return "[错误:DashScopeParse 未安装,请运行: pip install llama-index-core llama-index-readers-dashscope]"
|
||
except Exception as e:
|
||
print(f"[ERROR] DashScopeParse 解析失败: {e}")
|
||
return f"[DashScopeParse 解析失败: {str(e)}]"
|
||
|
||
loop = asyncio.get_event_loop()
|
||
return await loop.run_in_executor(None, _sync_parse)
|
||
|
||
|
||
async def _inject_files_into_messages(messages: list, files: list) -> list:
|
||
"""
|
||
将文件内容异步注入到消息列表中。
|
||
- 文本类文件(路线一):读取内容并追加到最后一条 user 消息中
|
||
- doc/docx/pdf(路线二):调用 DashScopeParse 云端解析后注入
|
||
- 其他二进制文件:仅告知 AI 文件名和类型
|
||
"""
|
||
if not files:
|
||
return messages
|
||
|
||
file_context_parts = []
|
||
for file_url in files:
|
||
from urllib.parse import urlparse
|
||
|
||
parsed = urlparse(file_url)
|
||
filename = parsed.path.split("/")[-1]
|
||
suffix = Path(filename).suffix.lower()
|
||
|
||
result = _read_file_content(file_url)
|
||
|
||
if isinstance(result, str):
|
||
# 路线一:文本内容,直接嵌入
|
||
file_context_parts.append(
|
||
f"--- 附件文件内容({filename})---\n{result}\n--- 附件结束 ---"
|
||
)
|
||
elif isinstance(result, tuple):
|
||
# 路线二:doc/docx/pdf → 调用 DashScopeParse
|
||
local_path, _ = result
|
||
print(f"[INFO] 路线二:调用 DashScopeParse 解析 {filename}")
|
||
parsed_text = await _parse_with_dashscope(local_path)
|
||
file_context_parts.append(
|
||
f"--- 附件文件内容({filename},阿里云文档智能解析)---\n"
|
||
f"{parsed_text}\n--- 附件结束 ---"
|
||
)
|
||
else:
|
||
# 其他不支持的格式:仅告知文件信息
|
||
file_context_parts.append(
|
||
f"[用户上传了一个文件: {filename},类型: {suffix},暂不支持自动读取内容,请告知用户。]"
|
||
)
|
||
|
||
if not file_context_parts:
|
||
return messages
|
||
|
||
file_context_text = "\n\n" + "\n\n".join(file_context_parts)
|
||
|
||
# 把文件内容追加到最后一条 user 消息
|
||
messages = list(messages) # 复制,避免修改原始列表
|
||
for i in range(len(messages) - 1, -1, -1):
|
||
msg = messages[i]
|
||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||
content = msg.get("content", "")
|
||
if isinstance(content, str):
|
||
messages[i] = dict(msg, content=content + file_context_text)
|
||
elif isinstance(content, list):
|
||
# 找到现有的 text 项,追加内容
|
||
new_content = list(content)
|
||
appended = False
|
||
for j, item in enumerate(new_content):
|
||
if isinstance(item, dict) and item.get("type") == "text":
|
||
new_content[j] = dict(
|
||
item, text=item["text"] + file_context_text
|
||
)
|
||
appended = True
|
||
break
|
||
if not appended:
|
||
new_content.append({"type": "text", "text": file_context_text})
|
||
messages[i] = dict(msg, content=new_content)
|
||
break
|
||
|
||
return messages
|
||
|
||
|
||
async def chat_endpoint_handler(body: dict):
|
||
"""
|
||
聊天接口处理器 - 与阿里云百炼API兼容的接口
|
||
这个端点会接收前端的聊天请求并转发到阿里云百炼API
|
||
"""
|
||
try:
|
||
# 确保 body 是字典类型
|
||
if not isinstance(body, dict):
|
||
print(f"[ERROR] Request body is not a dictionary: {type(body)}")
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Request body must be a JSON object, got {type(body).__name__}: {body}",
|
||
)
|
||
|
||
# 检查请求格式并适配
|
||
# 如果是OpenAI兼容格式 (来自streamChat)
|
||
if "messages" in body:
|
||
messages = body.get("messages", [])
|
||
model = body.get("model", "qwen-plus")
|
||
stream = body.get("stream", True)
|
||
temperature = body.get("temperature", 0.7)
|
||
max_tokens = body.get("max_tokens", 2000)
|
||
deepSearch = body.get("deepSearch", False)
|
||
webSearch = body.get("webSearch", False)
|
||
deepThinking = body.get("deepThinking", False)
|
||
|
||
log_info(
|
||
f"POST /api/chat-ui/chat | 模型: {model} | 流式: {stream} | 联网搜索: {webSearch} | 深度搜索: {deepSearch} | 深度思考: {deepThinking}"
|
||
)
|
||
|
||
# 处理 files 附件:将文件内容注入到最后一条 user 消息中
|
||
files = body.get("files", [])
|
||
if files:
|
||
messages = await _inject_files_into_messages(messages, files)
|
||
# 调试:打印注入后最后一条 user 消息的内容(截断显示 500 字)
|
||
for msg in reversed(messages):
|
||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||
content_preview = str(msg.get("content", ""))[:500]
|
||
print(
|
||
f"[DEBUG] 注入文件后 user 消息内容预览: {content_preview}"
|
||
)
|
||
break
|
||
else:
|
||
# 否则是前端简化格式 (来自chat函数)
|
||
message_text = body.get("message", "")
|
||
|
||
# 检查message是否已经是格式化的列表(带图片的情况)
|
||
if isinstance(message_text, list):
|
||
user_content = message_text
|
||
else:
|
||
user_content = [{"type": "text", "text": message_text}]
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": body.get(
|
||
"systemPrompt",
|
||
"你是一个智能助手,可以分析用户发送的文本和文件内容。",
|
||
),
|
||
},
|
||
{"role": "user", "content": user_content},
|
||
]
|
||
model = body.get("model", "qwen-plus")
|
||
stream = body.get("stream", False)
|
||
temperature = body.get("temperature", 0.7)
|
||
max_tokens = body.get("maxTokens", 2000)
|
||
deepSearch = body.get("deepSearch", False)
|
||
webSearch = body.get("webSearch", False)
|
||
deepThinking = body.get("deepThinking", False)
|
||
|
||
# 检查是否包含图像内容,如果是多模态请求,使用MultiModalConversation
|
||
has_images = any(
|
||
isinstance(msg, dict)
|
||
and isinstance(msg.get("content"), list)
|
||
and any(
|
||
isinstance(item, dict) and item.get("type") == "image_url"
|
||
for item in msg.get("content", [])
|
||
)
|
||
for msg in messages
|
||
if isinstance(msg, dict)
|
||
)
|
||
|
||
if has_images:
|
||
# 使用多模态API处理图像
|
||
return await multimodal_chat_handler(
|
||
messages, model, stream, temperature, max_tokens
|
||
)
|
||
else:
|
||
# 构建 DashScope 额外参数
|
||
dashscope_kwargs = {}
|
||
if deepSearch:
|
||
dashscope_kwargs["enable_search"] = True
|
||
dashscope_kwargs["search_options"] = {"search_strategy": "max"}
|
||
# 只有特定的思考模型版本支持部分高级 agent,但目前我们保持使用基础模型 + max 策略
|
||
elif webSearch:
|
||
dashscope_kwargs["enable_search"] = True
|
||
dashscope_kwargs["search_options"] = {"search_strategy": "turbo"}
|
||
|
||
if deepThinking:
|
||
dashscope_kwargs["enable_thinking"] = True
|
||
dashscope_kwargs["result_format"] = (
|
||
"message" # enable_thinking 必须配合 result_format=message
|
||
)
|
||
dashscope_kwargs["incremental_output"] = (
|
||
True # 流式模式下 enable_thinking 还必须配合 incremental_output=True
|
||
)
|
||
|
||
# 使用常规聊天API
|
||
if stream:
|
||
# 流式响应
|
||
async def event_generator():
|
||
try:
|
||
responses = Generation.call(
|
||
model=model,
|
||
messages=messages,
|
||
stream=True,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature,
|
||
**dashscope_kwargs,
|
||
)
|
||
|
||
full_content = "" # 用于累计完整内容
|
||
full_reasoning_content = "" # 用于累计完整思考内容
|
||
|
||
for idx, response in enumerate(responses):
|
||
if response.status_code == 200:
|
||
content = None
|
||
|
||
# 尝试从 output.choices 获取内容
|
||
if (
|
||
hasattr(response, "output")
|
||
and response.output
|
||
and hasattr(response.output, "choices")
|
||
and response.output.choices is not None
|
||
and len(response.output.choices) > 0
|
||
and "message" in response.output.choices[0]
|
||
):
|
||
|
||
msg_dict = response.output.choices[0]["message"]
|
||
# incremental_output=True 时,每个 chunk 的 content/reasoning_content 已是增量片段
|
||
# 直接使用,无需与 full_* 做对比
|
||
content = msg_dict.get("content") or ""
|
||
reasoning_content = (
|
||
msg_dict.get("reasoning_content") or ""
|
||
)
|
||
|
||
delta_str = ""
|
||
|
||
# 处理思考过程片段
|
||
if reasoning_content:
|
||
if not full_reasoning_content:
|
||
# 第一个思考片段,添加 <think> 开始标签
|
||
delta_str += "<think>"
|
||
full_reasoning_content += reasoning_content
|
||
delta_str += reasoning_content
|
||
|
||
# 处理正式回复片段
|
||
if content:
|
||
if not full_content and full_reasoning_content:
|
||
# 思考结束后首个正式回复,关闭 </think> 标签
|
||
delta_str += "</think>\n\n"
|
||
full_content += content
|
||
delta_str += content
|
||
|
||
if delta_str:
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": {"content": delta_str},
|
||
"finish_reason": None,
|
||
}
|
||
],
|
||
}
|
||
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
# 否则尝试从 output.text 获取内容(DashScope特定格式)
|
||
elif (
|
||
hasattr(response, "output")
|
||
and response.output
|
||
and "text" in response.output
|
||
):
|
||
|
||
content = response.output.get("text")
|
||
|
||
# 只有当内容发生变化时才发送增量
|
||
if len(content) > len(full_content):
|
||
delta_content = extract_delta_content(
|
||
content, full_content
|
||
)
|
||
full_content = content
|
||
|
||
if (
|
||
delta_content.strip()
|
||
): # 只有当有非空白新内容时才发送
|
||
# 构建 SSE 数据块
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": delta_content
|
||
},
|
||
"finish_reason": None,
|
||
}
|
||
],
|
||
}
|
||
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
else:
|
||
# 错误处理:写入 logger,方便排查
|
||
log_error(
|
||
f"DashScope API 返回错误: chunk status={response.status_code}, code={response.code}, msg={response.message}"
|
||
)
|
||
error_data = {
|
||
"error": {
|
||
"message": f"API Error: {response.code} - {response.message}",
|
||
"type": "api_error",
|
||
"param": None,
|
||
"code": response.code,
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
break
|
||
|
||
# 发送结束信号
|
||
finish_data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [
|
||
{"index": 0, "delta": {}, "finish_reason": "stop"}
|
||
],
|
||
}
|
||
yield f"data: {json.dumps(finish_data, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
except Exception as e:
|
||
log_exception(f"流式生成器异常: {e}")
|
||
error_data = {
|
||
"error": {"message": str(e), "type": "server_error"}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
return StreamingResponse(
|
||
event_generator(), media_type="text/event-stream"
|
||
)
|
||
else:
|
||
# 非流式响应
|
||
response = Generation.call(
|
||
model=model,
|
||
messages=messages,
|
||
stream=False,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature,
|
||
**dashscope_kwargs,
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
# 检查响应是否包含预期的内容
|
||
# DashScope API的响应结构可能是 output.choices 或 output.text
|
||
content = None
|
||
|
||
# 尝试从 output.choices 获取内容
|
||
if (
|
||
hasattr(response, "output")
|
||
and response.output
|
||
and hasattr(response.output, "choices")
|
||
and response.output.choices is not None
|
||
and len(response.output.choices) > 0
|
||
and "message" in response.output.choices[0]
|
||
):
|
||
|
||
msg_dict = response.output.choices[0]["message"]
|
||
content = msg_dict.get("content", "")
|
||
rc = msg_dict.get("reasoning_content", "")
|
||
if rc:
|
||
content = f"<think>{rc}</think>\n\n{content}"
|
||
# 否则尝试从 output.text 获取内容(DashScope特定格式)
|
||
elif (
|
||
hasattr(response, "output")
|
||
and response.output
|
||
and "text" in response.output
|
||
):
|
||
|
||
content = response.output.get("text")
|
||
|
||
if content:
|
||
# 构建前端期望的响应格式
|
||
chat_response = format_api_response(
|
||
content=content,
|
||
conversation_id=body.get("conversationId"),
|
||
model=model,
|
||
)
|
||
|
||
if hasattr(response, "usage") and response.usage:
|
||
chat_response["usage"] = {
|
||
"promptTokens": response.usage.input_tokens,
|
||
"completionTokens": response.usage.output_tokens,
|
||
"totalTokens": response.usage.total_tokens,
|
||
}
|
||
|
||
return JSONResponse(content=chat_response, ensure_ascii=False)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail="API Response does not contain expected content",
|
||
)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"API Error: {response.code} - {response.message}",
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Error in chat endpoint: {str(e)}")
|
||
import traceback
|
||
|
||
print(f"[ERROR] Traceback: {traceback.format_exc()}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
async def multimodal_chat_handler(messages, model, stream, temperature, max_tokens):
|
||
"""
|
||
多模态聊天处理器 - 处理包含图像的消息
|
||
"""
|
||
try:
|
||
# 将OpenAI格式的消息转换为DashScope MultiModalConversation格式
|
||
dashscope_messages = []
|
||
for i, msg in enumerate(messages):
|
||
# 验证 msg 是否为字典类型,如果不是则跳过或处理为字符串
|
||
if not isinstance(msg, dict):
|
||
# 如果消息不是字典,将其作为纯文本处理
|
||
dashscope_content = [{"text": str(msg)}]
|
||
dashscope_messages.append(
|
||
{"role": "user", "content": dashscope_content}
|
||
)
|
||
continue
|
||
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
|
||
if isinstance(content, str):
|
||
# 纯文本内容
|
||
dashscope_content = [{"text": content}]
|
||
elif isinstance(content, list):
|
||
# 包含图像和文本的内容
|
||
dashscope_content = []
|
||
for j, item in enumerate(content):
|
||
if isinstance(item, dict):
|
||
if item.get("type") == "text":
|
||
dashscope_content.append({"text": item.get("text", "")})
|
||
elif item.get("type") == "image_url":
|
||
# 处理 image_url 可能是字符串或字典两种情况
|
||
image_url_value = item.get("image_url", "")
|
||
if isinstance(image_url_value, str):
|
||
# 如果 image_url 是字符串,直接使用
|
||
img_url = image_url_value
|
||
elif (
|
||
isinstance(image_url_value, dict)
|
||
and "url" in image_url_value
|
||
):
|
||
# 如果 image_url 是字典,从中获取 url
|
||
img_url = image_url_value.get("url", "")
|
||
else:
|
||
# 其他情况视为错误或空值
|
||
img_url = ""
|
||
|
||
# 如果URL是http格式,提取文件名并转换为file://格式
|
||
if img_url.startswith("http://") or img_url.startswith(
|
||
"https://"
|
||
):
|
||
# 提取URL中的文件名部分 (例如从 http://localhost:8000/uploads/filename.jpg 提取 uploads/filename.jpg)
|
||
from urllib.parse import urlparse
|
||
|
||
parsed_url = urlparse(img_url)
|
||
path_parts = parsed_url.path.split("/")
|
||
|
||
# 从路径中找到uploads部分及后面的文件名
|
||
try:
|
||
uploads_index = path_parts.index("uploads")
|
||
filename = "/".join(
|
||
path_parts[uploads_index:]
|
||
) # 例如: uploads/filename.jpg
|
||
img_url = f"file://{filename}"
|
||
except ValueError:
|
||
# 如果路径中没有uploads部分,使用原始路径
|
||
img_url = f"file://{parsed_url.path.lstrip('/')}"
|
||
elif not img_url.startswith("file://"):
|
||
# 如果既不是网络URL也不是file://协议,假设是相对路径
|
||
img_url = f"file://{img_url}"
|
||
|
||
if img_url.startswith("file://"):
|
||
# 确保本地文件存在
|
||
import os
|
||
|
||
local_path = img_url[7:] # 移除 "file://" 前缀
|
||
if not os.path.exists(local_path):
|
||
print(
|
||
f"[WARNING] Image file does not exist: {local_path}"
|
||
)
|
||
|
||
dashscope_content.append({"image": img_url})
|
||
else:
|
||
# 将非字典内容转换为文本
|
||
dashscope_content.append({"text": str(item)})
|
||
else:
|
||
# 其他情况转换为文本
|
||
dashscope_content = [{"text": str(content)}]
|
||
|
||
dashscope_messages.append({"role": role, "content": dashscope_content})
|
||
|
||
if stream:
|
||
# 多模态流式响应
|
||
async def multimodal_event_generator():
|
||
try:
|
||
responses = MultiModalConversation.call(
|
||
model=(
|
||
model.replace("qwen-", "qwen-vl-")
|
||
if "qwen-" in model
|
||
else "qwen-vl-max"
|
||
),
|
||
messages=dashscope_messages,
|
||
stream=True,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature,
|
||
)
|
||
|
||
full_content = ""
|
||
|
||
for response in responses:
|
||
if response.status_code == 200:
|
||
content = None
|
||
|
||
# 从多模态响应中提取内容
|
||
if (
|
||
hasattr(response, "output")
|
||
and response.output
|
||
and hasattr(response.output, "choices")
|
||
and response.output.choices is not None
|
||
and len(response.output.choices) > 0
|
||
and "message" in response.output.choices[0]
|
||
):
|
||
|
||
message = response.output.choices[0]["message"]
|
||
if "content" in message:
|
||
content_items = message["content"]
|
||
|
||
# 从内容项中提取文本
|
||
extracted_text = ""
|
||
for item in content_items:
|
||
if isinstance(item, dict) and "text" in item:
|
||
extracted_text += item["text"]
|
||
|
||
content = extracted_text
|
||
|
||
# 只有当内容发生变化时才发送增量
|
||
if len(content) > len(full_content):
|
||
delta_content = extract_delta_content(
|
||
content, full_content
|
||
)
|
||
full_content = content
|
||
|
||
if delta_content.strip():
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": delta_content
|
||
},
|
||
"finish_reason": None,
|
||
}
|
||
],
|
||
}
|
||
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
else:
|
||
error_data = {
|
||
"error": {
|
||
"message": f"Multimodal API Error: {response.code} - {response.message}",
|
||
"type": "api_error",
|
||
"param": None,
|
||
"code": response.code,
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
break
|
||
|
||
finish_data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": model,
|
||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||
}
|
||
yield f"data: {json.dumps(finish_data, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
except Exception as e:
|
||
error_data = {"error": {"message": str(e), "type": "server_error"}}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
return StreamingResponse(
|
||
multimodal_event_generator(), media_type="text/event-stream"
|
||
)
|
||
else:
|
||
# 多模态非流式响应
|
||
response = MultiModalConversation.call(
|
||
model=(
|
||
model.replace("qwen-", "qwen-vl-")
|
||
if "qwen-" in model
|
||
else "qwen-vl-max"
|
||
),
|
||
messages=dashscope_messages,
|
||
stream=False,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature,
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
content = None
|
||
|
||
if (
|
||
hasattr(response, "output")
|
||
and response.output
|
||
and hasattr(response.output, "choices")
|
||
and response.output.choices is not None
|
||
and len(response.output.choices) > 0
|
||
and "message" in response.output.choices[0]
|
||
):
|
||
|
||
message = response.output.choices[0]["message"]
|
||
if "content" in message:
|
||
content_items = message["content"]
|
||
|
||
# 从内容项中提取文本
|
||
extracted_text = ""
|
||
for item in content_items:
|
||
if isinstance(item, dict) and "text" in item:
|
||
extracted_text += item["text"]
|
||
|
||
content = extracted_text
|
||
|
||
if content:
|
||
return JSONResponse(content={"result": content}, ensure_ascii=False)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail="Multimodal API Response does not contain expected content",
|
||
)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Multimodal API Error: {response.code} - {response.message}",
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Error in multimodal chat handler: {str(e)}")
|
||
import traceback
|
||
|
||
print(f"[ERROR] Traceback: {traceback.format_exc()}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
async def get_models_handler():
|
||
"""获取模型列表处理器"""
|
||
models = [
|
||
ModelInfo(
|
||
id="qwen-max",
|
||
name="通义千问 Max",
|
||
description="最强大的模型",
|
||
maxTokens=8192,
|
||
provider="Aliyun",
|
||
),
|
||
ModelInfo(
|
||
id="qwen-plus",
|
||
name="通义千问 Plus",
|
||
description="能力均衡",
|
||
maxTokens=8192,
|
||
provider="Aliyun",
|
||
),
|
||
ModelInfo(
|
||
id="qwen-turbo",
|
||
name="通义千问 Turbo",
|
||
description="速度更快、成本更低",
|
||
maxTokens=8192,
|
||
provider="Aliyun",
|
||
),
|
||
ModelInfo(
|
||
id="qwen-vl-max",
|
||
name="通义万相 VL-Max",
|
||
description="支持视觉理解的多模态模型",
|
||
maxTokens=8192,
|
||
provider="Aliyun",
|
||
),
|
||
ModelInfo(
|
||
id="qwen-vl-plus",
|
||
name="通义万相 VL-Plus",
|
||
description="支持视觉理解的多模态模型",
|
||
maxTokens=8192,
|
||
provider="Aliyun",
|
||
),
|
||
]
|
||
return [model.dict() for model in models]
|
||
|
||
|
||
async def get_conversations_handler():
|
||
"""获取所有对话处理器"""
|
||
return list(conversations_db.values())
|
||
|
||
|
||
async def get_conversation_handler(conversation_id: str):
|
||
"""获取特定对话处理器"""
|
||
conversation = conversations_db.get(conversation_id)
|
||
if not conversation:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
return conversation
|
||
|
||
|
||
async def save_conversation_handler(data: dict):
|
||
"""保存或更新对话处理器"""
|
||
try:
|
||
conversation_id = data.get("id") or generate_unique_id()
|
||
|
||
conversation = {
|
||
"id": conversation_id,
|
||
"title": data.get("title", "新对话"),
|
||
"messages": data.get("messages", []),
|
||
"updatedAt": datetime.utcnow().isoformat(),
|
||
"createdAt": data.get("createdAt", datetime.utcnow().isoformat()),
|
||
}
|
||
|
||
conversations_db[conversation_id] = conversation
|
||
return conversation
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Error saving conversation: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
async def delete_conversation_handler(conversation_id: str):
|
||
"""删除对话处理器"""
|
||
if conversation_id in conversations_db:
|
||
del conversations_db[conversation_id]
|
||
return {"success": True, "message": "删除成功"}
|
||
else:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
|
||
async def upload_file_handler(file: UploadFile = File(...)):
|
||
"""文件上传处理器"""
|
||
try:
|
||
# 允许的 MIME 类型(宽松策略)
|
||
allowed_types = {
|
||
# 图片
|
||
"image/jpeg",
|
||
"image/png",
|
||
"image/gif",
|
||
"image/webp",
|
||
"image/bmp",
|
||
"image/svg+xml",
|
||
# 文本类
|
||
"text/plain",
|
||
"text/csv",
|
||
"text/markdown",
|
||
"text/html",
|
||
"text/xml",
|
||
"application/json",
|
||
"application/xml",
|
||
# PDF
|
||
"application/pdf",
|
||
# Office 文档
|
||
"application/msword",
|
||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
"application/vnd.ms-excel",
|
||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||
"application/vnd.ms-powerpoint",
|
||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||
}
|
||
|
||
# 允许的扩展名(兜底:MIME 类型可能被浏览器误判)
|
||
allowed_extensions = {
|
||
".jpg",
|
||
".jpeg",
|
||
".png",
|
||
".gif",
|
||
".webp",
|
||
".bmp",
|
||
".txt",
|
||
".md",
|
||
".csv",
|
||
".json",
|
||
".xml",
|
||
".yaml",
|
||
".yml",
|
||
".log",
|
||
".pdf",
|
||
".doc",
|
||
".docx",
|
||
".xls",
|
||
".xlsx",
|
||
".ppt",
|
||
".pptx",
|
||
".py",
|
||
".js",
|
||
".ts",
|
||
".html",
|
||
".css",
|
||
}
|
||
|
||
file_extension = Path(file.filename).suffix.lower()
|
||
|
||
if (
|
||
file.content_type not in allowed_types
|
||
and file_extension not in allowed_extensions
|
||
):
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"不支持的文件类型: {file.content_type}({file_extension})",
|
||
)
|
||
|
||
# 生成唯一文件名
|
||
unique_filename = f"{int(datetime.utcnow().timestamp())}_{generate_unique_id()}{file_extension}"
|
||
file_path = upload_dir / unique_filename
|
||
|
||
# 保存文件到本地(临时缓存)
|
||
content = await file.read()
|
||
with open(file_path, "wb") as f:
|
||
f.write(content)
|
||
|
||
# 文件关闭后再上传到 OSS
|
||
from utils.oss_uploader import upload_file as oss_upload
|
||
|
||
oss_result = oss_upload(str(file_path))
|
||
file_url = oss_result["url"]
|
||
|
||
# 返回文件信息
|
||
result = {
|
||
"url": file_url,
|
||
"name": file.filename,
|
||
"size": len(content),
|
||
"mimeType": file.content_type,
|
||
}
|
||
|
||
print(f"[INFO] File uploaded: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Upload error: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||
|
||
|
||
def serve_upload_handler(filename: str):
|
||
"""提供上传文件访问处理器"""
|
||
file_path = upload_dir / filename
|
||
if not file_path.exists():
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
from fastapi.responses import FileResponse
|
||
|
||
return FileResponse(str(file_path))
|
||
|
||
|
||
async def stop_generation_handler(message_id: str = None):
|
||
"""停止生成处理器"""
|
||
message = (
|
||
f"已发出停止指令,消息ID: {message_id}" if message_id else "已发出停止指令"
|
||
)
|
||
return {"success": True, "message": message}
|
||
|
||
|
||
# ── 平台统一接口别名(供 main.py 的 _platform 动态调用)─────────────
|
||
# main.py 通过 _platform.chat_handler / _platform.models_handler 调用,
|
||
# 各平台模块需暴露相同名称的函数。
|
||
chat_handler = chat_endpoint_handler # 聊天接口别名
|
||
models_handler = get_models_handler # 模型列表别名
|