feat: GLM普通对话功能测试成功
This commit is contained in:
parent
4ba245706d
commit
7f8043003c
|
|
@ -1,5 +1,6 @@
|
|||
"""
|
||||
API 路由定义
|
||||
API 路由定义(阿里云百炼 / DashScope 平台)
|
||||
所有 DashScope 相关逻辑均集中在此文件,main.py 无感知任何平台细节。
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
|
|
@ -12,6 +13,19 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
import dashscope
|
||||
from dashscope import Generation, MultiModalConversation
|
||||
|
||||
|
||||
def init():
|
||||
"""
|
||||
初始化百炼后端:设置 DashScope API Key。
|
||||
由 main.py 在启动时调用(若 LLM_BACKEND=dashscope)。
|
||||
"""
|
||||
api_key = os.getenv("ALIYUN_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("dashscope 模式需要设置环境变量 ALIYUN_API_KEY")
|
||||
dashscope.api_key = api_key
|
||||
print(f"[DashScope] 初始化完成,ALIYUN_API_KEY 已配置")
|
||||
|
||||
|
||||
# 导入模型和工具函数(使用绝对路径)
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
|
@ -936,4 +950,11 @@ def serve_upload_handler(filename: str):
|
|||
async def stop_generation_handler(message_id: str = None):
|
||||
"""停止生成处理器"""
|
||||
message = f"已发出停止指令,消息ID: {message_id}" if message_id else "已发出停止指令"
|
||||
return {"success": True, "message": message}
|
||||
return {"success": True, "message": message}
|
||||
|
||||
|
||||
# ── 平台统一接口别名(供 main.py 的 _platform 动态调用)─────────────
|
||||
# main.py 通过 _platform.chat_handler / _platform.models_handler 调用,
|
||||
# 各平台模块需暴露相同名称的函数。
|
||||
chat_handler = chat_endpoint_handler # 聊天接口别名
|
||||
models_handler = get_models_handler # 模型列表别名
|
||||
|
|
|
|||
|
|
@ -0,0 +1,125 @@
|
|||
"""
|
||||
GLM-4.6V 平台路由处理器(zai-sdk)
|
||||
所有智谱 GLM 相关逻辑均集中在此文件,main.py 无感知任何平台细节。
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from utils.helpers import get_current_timestamp, generate_unique_id
|
||||
from utils.logger import log_info
|
||||
|
||||
|
||||
def init():
|
||||
"""
|
||||
初始化 GLM 后端:验证 API Key 是否配置。
|
||||
由 main.py 在启动时调用(若 LLM_BACKEND=glm)。
|
||||
"""
|
||||
api_key = os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY(在 https://open.bigmodel.cn 申请)")
|
||||
log_info(f"[GLM] 初始化完成,ZHIPU_API_KEY 已配置")
|
||||
|
||||
|
||||
async def chat_handler(body: dict):
|
||||
"""
|
||||
GLM 聊天处理器(对外接口与百炼 chat_endpoint_handler 完全兼容)。
|
||||
流式/非流式自动适配,支持图像、文档附件、联网搜索、深度思考。
|
||||
"""
|
||||
from utils.glm_adapter import glm_stream_generator, glm_chat_sync
|
||||
|
||||
if not isinstance(body, dict):
|
||||
raise HTTPException(status_code=400, detail="请求体必须是 JSON 对象")
|
||||
|
||||
messages = body.get("messages", [])
|
||||
model = body.get("model", "glm-4.6v")
|
||||
stream = body.get("stream", True)
|
||||
temperature = body.get("temperature", 0.7)
|
||||
max_tokens = body.get("max_tokens", body.get("maxTokens", 2000))
|
||||
web_search = body.get("webSearch", False) or body.get("deepSearch", False)
|
||||
deep_think = body.get("deepThinking", False)
|
||||
files = body.get("files", [])
|
||||
|
||||
# 兼容前端简化格式(非 messages 结构)
|
||||
if not messages:
|
||||
msg_text = body.get("message", "")
|
||||
sys_prompt = body.get("systemPrompt", "你是一个智能助手。")
|
||||
user_content = msg_text if isinstance(msg_text, list) else [{"type": "text", "text": msg_text}]
|
||||
messages = [
|
||||
{"role": "system", "content": sys_prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
log_info(f"[GLM] model={model} stream={stream} web_search={web_search} "
|
||||
f"thinking={deep_think} files={len(files)} msgs={len(messages)}")
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
glm_stream_generator(
|
||||
messages=messages, model=model, temperature=temperature,
|
||||
max_tokens=max_tokens, files=files or None,
|
||||
web_search=web_search, deep_thinking=deep_think,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
result = glm_chat_sync(
|
||||
messages=messages, model=model, temperature=temperature,
|
||||
max_tokens=max_tokens, files=files or None,
|
||||
web_search=web_search, deep_thinking=deep_think,
|
||||
)
|
||||
resp = {
|
||||
"id": f"chatcmpl-{generate_unique_id()}",
|
||||
"object": "chat.completion",
|
||||
"created": get_current_timestamp(),
|
||||
"model": result["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": result["content"]},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
if result.get("usage"):
|
||||
resp["usage"] = result["usage"]
|
||||
return JSONResponse(content=resp)
|
||||
|
||||
|
||||
def models_handler():
|
||||
"""返回 GLM 可用模型列表"""
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "glm-4.6v",
|
||||
"name": "GLM-4.6V(推荐)",
|
||||
"description": "最新旗舰模型,支持文本/图像/文档/深度思考",
|
||||
"maxTokens": 128000,
|
||||
"provider": "ZhipuAI",
|
||||
},
|
||||
{
|
||||
"id": "glm-4-flash",
|
||||
"name": "GLM-4 Flash",
|
||||
"description": "高性价比文本模型(0.2元/千token)",
|
||||
"maxTokens": 128000,
|
||||
"provider": "ZhipuAI",
|
||||
},
|
||||
{
|
||||
"id": "glm-4v-plus-0111",
|
||||
"name": "GLM-4V Plus",
|
||||
"description": "图像 + PDF/DOCX 原生多模态",
|
||||
"maxTokens": 128000,
|
||||
"provider": "ZhipuAI",
|
||||
},
|
||||
{
|
||||
"id": "glm-z1-flash",
|
||||
"name": "GLM-Z1 Flash",
|
||||
"description": "深度思考推理模型",
|
||||
"maxTokens": 128000,
|
||||
"provider": "ZhipuAI",
|
||||
},
|
||||
],
|
||||
"object": "list",
|
||||
}
|
||||
158
server/main.py
158
server/main.py
|
|
@ -1,158 +1,168 @@
|
|||
"""
|
||||
改进版Python FastAPI服务器实现,使用DashScope Python SDK连接阿里云百炼平台API
|
||||
拆分模块版本
|
||||
AI Chat API Server — 主入口(纯基础设施层)
|
||||
|
||||
职责:
|
||||
- 注入运行时依赖(venv site-packages)
|
||||
- 读取 LLM_BACKEND 环境变量,动态加载对应平台模块
|
||||
- 注册 FastAPI 路由和中间件
|
||||
|
||||
平台代码位置(main.py 中不包含任何平台逻辑):
|
||||
- 百炼 DashScope → api/chat_routes.py
|
||||
- 智谱 GLM-4.6V → api/chat_routes_glm.py + utils/glm_adapter.py
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
import dashscope
|
||||
|
||||
# ── 注入 venv site-packages(兼容 start.sh 用系统 python3 启动)────────
|
||||
# 必须在所有第三方 import 前执行
|
||||
_venv_lib = Path(__file__).parent / ".venv" / "lib"
|
||||
if _venv_lib.exists():
|
||||
for _sp in sorted(_venv_lib.glob("python*/site-packages"), reverse=True):
|
||||
if _sp.exists() and str(_sp) not in sys.path:
|
||||
sys.path.insert(0, str(_sp))
|
||||
print(f"[启动] venv 注入:{_sp}")
|
||||
break
|
||||
|
||||
# ── 第三方导入 ────────────────────────────────────────────────────────
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Request
|
||||
from fastapi import FastAPI, File, UploadFile, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
# 导入模块
|
||||
import sys
|
||||
sys.path.append('/home/mt/project/ai-chat-ui/server')
|
||||
|
||||
# ── 工具/日志(与平台无关)───────────────────────────────────────────
|
||||
from utils.helpers import log_response
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# ── 加载环境变量 ──────────────────────────────────────────────────────
|
||||
load_dotenv()
|
||||
|
||||
LLM_BACKEND = os.getenv("LLM_BACKEND", "dashscope").lower().strip()
|
||||
if LLM_BACKEND not in {"dashscope", "glm"}:
|
||||
logger.warning(f"未知的 LLM_BACKEND='{LLM_BACKEND}',回退到 dashscope")
|
||||
LLM_BACKEND = "dashscope"
|
||||
|
||||
# ── 动态加载平台模块 ──────────────────────────────────────────────────
|
||||
if LLM_BACKEND == "glm":
|
||||
import api.chat_routes_glm as _platform
|
||||
else:
|
||||
import api.chat_routes as _platform
|
||||
|
||||
_platform.init() # 各平台自行完成初始化(API Key 校验等)
|
||||
|
||||
# 通用路由处理器(文件上传、会话管理等,与平台无关,统一用百炼路由中的实现)
|
||||
from api.chat_routes import (
|
||||
chat_endpoint_handler,
|
||||
get_models_handler,
|
||||
get_conversations_handler,
|
||||
get_conversation_handler,
|
||||
save_conversation_handler,
|
||||
delete_conversation_handler,
|
||||
upload_file_handler,
|
||||
serve_upload_handler,
|
||||
stop_generation_handler
|
||||
stop_generation_handler,
|
||||
)
|
||||
from models.chat_models import ChatRequest, ModelInfo
|
||||
from utils.helpers import log_request, log_response
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 设置 DashScope API 密钥
|
||||
api_key = os.getenv("ALIYUN_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("请在环境变量中设置 ALIYUN_API_KEY")
|
||||
|
||||
dashscope.api_key = api_key
|
||||
|
||||
# 创建 FastAPI 应用
|
||||
app = FastAPI(title="AI Chat API Server (Python)", version="2.0.0")
|
||||
# ── FastAPI 应用 ──────────────────────────────────────────────────────
|
||||
app = FastAPI(
|
||||
title=f"AI Chat API(LLM_BACKEND={LLM_BACKEND})",
|
||||
version="3.0.0",
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def logging_middleware(request: Request, call_next):
|
||||
"""中间件:记录请求日志并美化输出"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
client_ip = request.client.host if request.client else 'unknown'
|
||||
|
||||
# 请求日志
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
logger.info(f"→ {request.method} {request.url.path} | IP: {client_ip}")
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
process_time = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
|
||||
status_emoji = "✅" if response.status_code < 400 else "❌"
|
||||
|
||||
# 响应日志(包含端点、状态码、耗时)
|
||||
logger.info(f"{status_emoji} {request.method} {request.url.path} | 状态: {response.status_code} | 耗时: {process_time:.0f}ms")
|
||||
|
||||
# 记录结构化日志(写入日志文件)
|
||||
log_response(response.status_code, process_time)
|
||||
|
||||
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
|
||||
ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
|
||||
icon = "✅" if response.status_code < 400 else "❌"
|
||||
logger.info(f"{icon} {request.method} {request.url.path} | 状态: {response.status_code} | 耗时: {ms:.0f}ms")
|
||||
log_response(response.status_code, ms)
|
||||
response.headers["X-Process-Time"] = f"{ms:.2f}ms"
|
||||
return response
|
||||
|
||||
|
||||
# ── 路由注册 ──────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
return {
|
||||
"status": "healthy",
|
||||
"backend": LLM_BACKEND,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/chat-ui/chat")
|
||||
async def chat_endpoint(request: Request):
|
||||
"""聊天接口 - 与阿里云百炼API兼容的接口"""
|
||||
body = await request.json()
|
||||
return await chat_endpoint_handler(body)
|
||||
"""聊天接口(自动路由到当前平台)"""
|
||||
return await _platform.chat_handler(await request.json())
|
||||
|
||||
|
||||
@app.get("/api/chat-ui/models")
|
||||
async def get_models():
|
||||
"""获取模型列表"""
|
||||
return await get_models_handler()
|
||||
"""模型列表(由当前平台返回)"""
|
||||
result = _platform.models_handler()
|
||||
# 支持同步和异步两种返回
|
||||
if hasattr(result, "__await__"):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
# ── 通用路由(与平台无关)────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/chat-ui/conversations")
|
||||
async def get_conversations():
|
||||
"""获取所有对话"""
|
||||
return await get_conversations_handler()
|
||||
|
||||
|
||||
@app.get("/api/chat-ui/conversations/{conversation_id}")
|
||||
async def get_conversation(conversation_id: str):
|
||||
"""获取特定对话"""
|
||||
return await get_conversation_handler(conversation_id)
|
||||
|
||||
|
||||
@app.post("/api/chat-ui/conversations")
|
||||
async def save_conversation(request: Request):
|
||||
"""保存或更新对话"""
|
||||
data = await request.json()
|
||||
return await save_conversation_handler(data)
|
||||
return await save_conversation_handler(await request.json())
|
||||
|
||||
|
||||
@app.delete("/api/chat-ui/conversations/{conversation_id}")
|
||||
async def delete_conversation(conversation_id: str):
|
||||
"""删除对话"""
|
||||
return await delete_conversation_handler(conversation_id)
|
||||
|
||||
|
||||
@app.post("/api/chat-ui/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
"""文件上传接口"""
|
||||
return await upload_file_handler(file=file)
|
||||
|
||||
|
||||
@app.get("/uploads/{filename}")
|
||||
async def serve_upload(filename: str):
|
||||
"""提供上传文件的访问"""
|
||||
return serve_upload_handler(filename)
|
||||
|
||||
|
||||
@app.post("/api/chat-ui/stop")
|
||||
async def stop_generation():
|
||||
"""停止生成接口"""
|
||||
return await stop_generation_handler()
|
||||
|
||||
|
||||
@app.post("/api/chat-ui/stop/{message_id}")
|
||||
async def stop_generation_by_id(message_id: str):
|
||||
"""根据消息ID停止生成"""
|
||||
return await stop_generation_handler(message_id)
|
||||
|
||||
|
||||
# ── 程序入口 ──────────────────────────────────────────────────────────
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
port = int(os.getenv("PORT", 8000))
|
||||
print("="*50)
|
||||
print(f"Python AI Chat Server 启动中...")
|
||||
print(f"监听端口: {port}")
|
||||
print(f"API Key 状态: {'已配置' if api_key else '未配置'}")
|
||||
print("="*50)
|
||||
|
||||
if not api_key:
|
||||
print("警告: 未在环境变量中检测到 ALIYUN_API_KEY!")
|
||||
print("请在 .env 文件中添加您的百炼 API Key。")
|
||||
else:
|
||||
print("API Key 已检测到。")
|
||||
|
||||
print("=" * 55)
|
||||
print(f" AI Chat Server v3.0 启动中...")
|
||||
print(f" 后端平台 : {LLM_BACKEND.upper()} [LLM_BACKEND={LLM_BACKEND}]")
|
||||
print(f" 监听端口 : {port}")
|
||||
print(f" 切换平台 : 修改 .env 中 LLM_BACKEND=glm|dashscope,重启")
|
||||
print("=" * 55)
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
"""
|
||||
GLM 文件 ID 缓存(基于磁盘的简单 KV,sha256 → file_id,3天有效期)
|
||||
"""
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
_CACHE_FILE = Path(__file__).parent.parent / "uploads" / ".glm_file_cache.json"
|
||||
_lock = threading.Lock()
|
||||
_TTL = 3 * 24 * 3600 # 3天
|
||||
|
||||
|
||||
def _load() -> dict:
|
||||
try:
|
||||
if _CACHE_FILE.exists():
|
||||
return json.loads(_CACHE_FILE.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save(data: dict) -> None:
|
||||
try:
|
||||
_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_CACHE_FILE.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
except Exception as e:
|
||||
print(f"[file_cache] 写入失败:{e}")
|
||||
|
||||
|
||||
def sha256_of_file(file_path: Path) -> str:
|
||||
h = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def get(file_hash: str) -> dict | None:
|
||||
with _lock:
|
||||
data = _load()
|
||||
entry = data.get(file_hash)
|
||||
if not entry:
|
||||
return None
|
||||
if entry.get("expires_at", 0) <= time.time():
|
||||
data.pop(file_hash, None)
|
||||
_save(data)
|
||||
return None
|
||||
return entry
|
||||
|
||||
|
||||
def set(file_hash: str, file_id: str) -> None:
|
||||
with _lock:
|
||||
data = _load()
|
||||
data[file_hash] = {"file_id": file_id, "expires_at": time.time() + _TTL}
|
||||
_save(data)
|
||||
|
||||
|
||||
def delete(file_hash: str) -> None:
|
||||
with _lock:
|
||||
data = _load()
|
||||
data.pop(file_hash, None)
|
||||
_save(data)
|
||||
|
|
@ -0,0 +1,335 @@
|
|||
"""
|
||||
GLM-4.6V 适配层(基于 zai-sdk)
|
||||
SDK:pip install zai-sdk
|
||||
模型:glm-4.6v(支持文本/图像/文档/深度思考)
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import base64
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
|
||||
# ── 自动注入 venv site-packages ───────────────────────────────────────
|
||||
def _ensure_venv():
|
||||
server_dir = Path(__file__).parent.parent
|
||||
for sp in sorted((server_dir / ".venv" / "lib").glob("python*/site-packages"), reverse=True):
|
||||
if sp.exists() and str(sp) not in sys.path:
|
||||
sys.path.insert(0, str(sp))
|
||||
print(f"[GLM] venv 注入:{sp}")
|
||||
break
|
||||
|
||||
|
||||
# ── 客户端单例 ────────────────────────────────────────────────────────
|
||||
_client = None
|
||||
|
||||
|
||||
def get_client():
|
||||
global _client
|
||||
if _client is None:
|
||||
_ensure_venv()
|
||||
try:
|
||||
from zai import ZhipuAiClient
|
||||
except ImportError:
|
||||
raise ImportError("GLM 模式需要安装 zai-sdk:.venv/bin/pip install zai-sdk")
|
||||
api_key = os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY")
|
||||
_client = ZhipuAiClient(api_key=api_key)
|
||||
print("[GLM] ZhipuAiClient 初始化完成(zai-sdk)")
|
||||
return _client
|
||||
|
||||
|
||||
# ── 模型映射 ──────────────────────────────────────────────────────────
|
||||
DEFAULT_TEXT_MODEL = "glm-4.6v" # glm-4.6v 文本+视觉统一模型
|
||||
DEFAULT_VISION_MODEL = "glm-4.6v"
|
||||
|
||||
MODEL_MAP = {
|
||||
"qwen-max": "glm-4.6v",
|
||||
"qwen-plus": "glm-4.6v",
|
||||
"qwen-turbo": "glm-4.6v",
|
||||
"qwen-vl-max": "glm-4.6v",
|
||||
"qwen-vl-plus": "glm-4.6v",
|
||||
}
|
||||
|
||||
|
||||
def resolve_model(model: str, has_vision: bool = False) -> str:
|
||||
if model.startswith("glm"):
|
||||
return model
|
||||
return MODEL_MAP.get(model, DEFAULT_TEXT_MODEL)
|
||||
|
||||
|
||||
# ── 文件上传(含 file_id 缓存)───────────────────────────────────────
|
||||
def upload_file_for_extract(local_path: Path) -> str:
|
||||
from utils.file_cache import sha256_of_file, get as cache_get, set as cache_set
|
||||
|
||||
file_hash = sha256_of_file(local_path)
|
||||
cached = cache_get(file_hash)
|
||||
if cached:
|
||||
print(f"[GLM] file_id 缓存命中:{local_path.name} → {cached['file_id']}")
|
||||
return cached["file_id"]
|
||||
|
||||
client = get_client()
|
||||
mime_map = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc": "application/msword",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".xls": "application/vnd.ms-excel",
|
||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
".ppt": "application/vnd.ms-powerpoint",
|
||||
}
|
||||
mime = mime_map.get(local_path.suffix.lower(), "application/octet-stream")
|
||||
print(f"[GLM] 上传文件:{local_path.name}({mime})")
|
||||
with open(local_path, "rb") as f:
|
||||
file_obj = client.files.create(file=(local_path.name, f, mime), purpose="file-extract")
|
||||
file_id = file_obj.id
|
||||
cache_set(file_hash, file_id)
|
||||
print(f"[GLM] 上传成功:file_id={file_id}")
|
||||
return file_id
|
||||
|
||||
|
||||
# ── 图像编码 ─────────────────────────────────────────────────────────
|
||||
def encode_image(image_source: str) -> dict:
|
||||
"""将图像来源统一转为 OpenAI image_url 格式"""
|
||||
if image_source.startswith("data:image") or image_source.startswith(("http://", "https://")):
|
||||
return {"type": "image_url", "image_url": {"url": image_source}}
|
||||
# 本地路径 → base64
|
||||
local = Path(image_source.replace("file://", "").lstrip("/"))
|
||||
if not local.exists():
|
||||
local = Path.cwd() / local
|
||||
ext = local.suffix.lstrip(".")
|
||||
with open(local, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode()
|
||||
return {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}}
|
||||
|
||||
|
||||
# ── 消息格式转换 ──────────────────────────────────────────────────────
|
||||
def build_glm_messages(messages: list, files: list | None = None) -> tuple[list, bool]:
|
||||
"""
|
||||
将 OpenAI 格式的 messages + files 转换为 zai-sdk 所需格式。
|
||||
返回 (glm_messages, has_vision)。
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
glm_messages = []
|
||||
has_vision = False
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
glm_messages.append({"role": "user", "content": str(msg)})
|
||||
continue
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
glm_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
new_content.append({"type": "text", "text": str(item)})
|
||||
continue
|
||||
t = item.get("type")
|
||||
if t == "text":
|
||||
new_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif t == "image_url":
|
||||
has_vision = True
|
||||
img_val = item.get("image_url", "")
|
||||
img_src = img_val.get("url", "") if isinstance(img_val, dict) else img_val
|
||||
new_content.append(encode_image(img_src))
|
||||
else:
|
||||
new_content.append({"type": "text", "text": str(item)})
|
||||
glm_messages.append({"role": role, "content": new_content})
|
||||
else:
|
||||
glm_messages.append({"role": role, "content": str(content)})
|
||||
|
||||
# 处理独立附件列表
|
||||
if files:
|
||||
doc_exts = {".pdf", ".doc", ".docx", ".xlsx", ".xls", ".pptx", ".ppt"}
|
||||
img_exts = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||
inserts = []
|
||||
|
||||
for file_url in files:
|
||||
parsed = urlparse(file_url)
|
||||
filename = parsed.path.split("/")[-1]
|
||||
suffix = Path(filename).suffix.lower()
|
||||
rel = parsed.path.lstrip("/")
|
||||
local = Path(rel)
|
||||
|
||||
if suffix in img_exts:
|
||||
has_vision = True
|
||||
try:
|
||||
inserts.append(encode_image(f"file://{rel}"))
|
||||
except Exception as e:
|
||||
print(f"[GLM] 图像编码失败:{e}")
|
||||
elif suffix in doc_exts:
|
||||
has_vision = True
|
||||
if local.exists():
|
||||
try:
|
||||
fid = upload_file_for_extract(local)
|
||||
inserts.append({"type": "file", "file": {"file_id": fid}})
|
||||
except Exception as e:
|
||||
inserts.append({"type": "text", "text": f"[文件上传失败:{filename},{e}]"})
|
||||
else:
|
||||
inserts.append({"type": "text", "text": f"[附件:{filename},类型:{suffix}]"})
|
||||
|
||||
if inserts:
|
||||
for i in range(len(glm_messages) - 1, -1, -1):
|
||||
if glm_messages[i].get("role") == "user":
|
||||
old = glm_messages[i]["content"]
|
||||
if isinstance(old, str):
|
||||
new_content = inserts + [{"type": "text", "text": old}]
|
||||
elif isinstance(old, list):
|
||||
new_content = inserts + old
|
||||
else:
|
||||
new_content = inserts
|
||||
glm_messages[i] = {"role": "user", "content": new_content}
|
||||
break
|
||||
|
||||
return glm_messages, has_vision
|
||||
|
||||
|
||||
# ── 哨兵对象 ─────────────────────────────────────────────────────────
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
# ── 流式 SSE 生成器 ───────────────────────────────────────────────────
|
||||
async def glm_stream_generator(
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
files: list | None = None,
|
||||
web_search: bool = False,
|
||||
deep_thinking: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
GLM 流式 SSE 生成器。
|
||||
使用 queue.Queue + 专用线程(生产者)+ asyncio 消费者模式,
|
||||
让 zai-sdk 同步迭代器在单一线程内安全运行。
|
||||
"""
|
||||
import asyncio
|
||||
import queue
|
||||
|
||||
from utils.helpers import get_current_timestamp, generate_unique_id
|
||||
|
||||
glm_msgs, has_vision = build_glm_messages(messages, files)
|
||||
actual_model = resolve_model(model, has_vision)
|
||||
|
||||
extra_kwargs: dict = {}
|
||||
if web_search:
|
||||
extra_kwargs["tools"] = [
|
||||
{"type": "web_search", "web_search": {"search_result": True}}
|
||||
]
|
||||
if deep_thinking:
|
||||
extra_kwargs["thinking"] = {"type": "enabled"}
|
||||
|
||||
print(f"[GLM] 流式请求:model={actual_model} vision={has_vision} "
|
||||
f"web_search={web_search} thinking={deep_thinking}")
|
||||
|
||||
chunk_queue: queue.Queue = queue.Queue(maxsize=128)
|
||||
|
||||
def _producer():
|
||||
try:
|
||||
client = get_client()
|
||||
resp = client.chat.completions.create(
|
||||
model=actual_model,
|
||||
messages=glm_msgs,
|
||||
stream=True,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**extra_kwargs,
|
||||
)
|
||||
for chunk in resp:
|
||||
chunk_queue.put(chunk)
|
||||
except Exception as exc:
|
||||
chunk_queue.put(exc)
|
||||
finally:
|
||||
chunk_queue.put(_SENTINEL)
|
||||
|
||||
t = threading.Thread(target=_producer, daemon=True)
|
||||
t.start()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
while True:
|
||||
item = await loop.run_in_executor(None, chunk_queue.get)
|
||||
|
||||
if item is _SENTINEL:
|
||||
break
|
||||
|
||||
if isinstance(item, Exception):
|
||||
print(f"[GLM] 生产者异常:{item}")
|
||||
yield f"data: {json.dumps({'error': {'message': str(item), 'type': 'glm_error'}}, ensure_ascii=False)}\n\n"
|
||||
break
|
||||
|
||||
try:
|
||||
delta = item.choices[0].delta
|
||||
text = getattr(delta, "content", "") or ""
|
||||
if not text:
|
||||
continue
|
||||
data = {
|
||||
"id": f"chatcmpl-{generate_unique_id()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": get_current_timestamp(),
|
||||
"model": actual_model,
|
||||
"choices": [{"index": 0, "delta": {"content": text}, "finish_reason": None}],
|
||||
}
|
||||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
print(f"[GLM] chunk 解析异常:{e}")
|
||||
|
||||
finish = {
|
||||
"id": f"chatcmpl-{generate_unique_id()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": get_current_timestamp(),
|
||||
"model": actual_model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
# ── 非流式调用 ────────────────────────────────────────────────────────
|
||||
def glm_chat_sync(
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
files: list | None = None,
|
||||
web_search: bool = False,
|
||||
deep_thinking: bool = False,
|
||||
) -> dict:
|
||||
glm_msgs, has_vision = build_glm_messages(messages, files)
|
||||
actual_model = resolve_model(model, has_vision)
|
||||
|
||||
extra_kwargs: dict = {}
|
||||
if web_search:
|
||||
extra_kwargs["tools"] = [
|
||||
{"type": "web_search", "web_search": {"search_result": True}}
|
||||
]
|
||||
if deep_thinking:
|
||||
extra_kwargs["thinking"] = {"type": "enabled"}
|
||||
|
||||
client = get_client()
|
||||
print(f"[GLM] 非流式请求:model={actual_model}")
|
||||
resp = client.chat.completions.create(
|
||||
model=actual_model,
|
||||
messages=glm_msgs,
|
||||
stream=False,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**extra_kwargs,
|
||||
)
|
||||
content = resp.choices[0].message.content or ""
|
||||
usage = None
|
||||
if hasattr(resp, "usage") and resp.usage:
|
||||
usage = {
|
||||
"promptTokens": resp.usage.prompt_tokens,
|
||||
"completionTokens": resp.usage.completion_tokens,
|
||||
"totalTokens": resp.usage.total_tokens,
|
||||
}
|
||||
return {"content": content, "model": actual_model, "usage": usage}
|
||||
Loading…
Reference in New Issue