468 lines
18 KiB
Python
468 lines
18 KiB
Python
"""
|
||
GLM-4.6V 适配层(基于 zai-sdk)
|
||
SDK:pip install zai-sdk
|
||
模型:glm-4.6v(支持文本/图像/文档/深度思考)
|
||
"""
|
||
|
||
import base64
|
||
import json
|
||
import os
|
||
import sys
|
||
import threading
|
||
from pathlib import Path
|
||
from typing import AsyncGenerator
|
||
|
||
|
||
# ── 自动注入 venv site-packages ───────────────────────────────────────
|
||
def _ensure_venv():
|
||
server_dir = Path(__file__).parent.parent
|
||
for sp in sorted(
|
||
(server_dir / ".venv" / "lib").glob("python*/site-packages"), reverse=True
|
||
):
|
||
if sp.exists() and str(sp) not in sys.path:
|
||
sys.path.insert(0, str(sp))
|
||
print(f"[GLM] venv 注入:{sp}")
|
||
break
|
||
|
||
|
||
# ── 客户端单例 ────────────────────────────────────────────────────────
|
||
_client = None
|
||
|
||
|
||
def get_client():
|
||
global _client
|
||
if _client is None:
|
||
_ensure_venv()
|
||
try:
|
||
from zai import ZhipuAiClient
|
||
except ImportError:
|
||
raise ImportError("GLM 模式需要安装 zai-sdk:.venv/bin/pip install zai-sdk")
|
||
api_key = os.getenv("ZHIPU_API_KEY").strip() or os.getenv("GLM_API_KEY").strip()
|
||
if not api_key:
|
||
raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY")
|
||
_client = ZhipuAiClient(api_key=api_key)
|
||
print("[GLM] ZhipuAiClient 初始化完成(zai-sdk)")
|
||
return _client
|
||
|
||
|
||
# ── 模型映射 ──────────────────────────────────────────────────────────
|
||
DEFAULT_TEXT_MODEL = "glm-4.5-Air" # glm-4.6 文本统一模型
|
||
DEFAULT_VISION_MODEL = "glm-4.6v"
|
||
|
||
MODEL_MAP = {
|
||
"qwen-max": "glm-4.5-Air",
|
||
"qwen-plus": "glm-4.5-Air",
|
||
"qwen-turbo": "glm-4.5-Air",
|
||
"qwen-vl-max": "glm-4.5-Air",
|
||
"qwen-vl-plus": "glm-4.5-Air",
|
||
}
|
||
|
||
|
||
def resolve_model(model: str, has_vision: bool = False) -> str:
|
||
if model.startswith("glm"):
|
||
return model
|
||
mapped = MODEL_MAP.get(model, DEFAULT_TEXT_MODEL)
|
||
# 当消息包含图片时,强制使用视觉模型
|
||
if has_vision and mapped != DEFAULT_VISION_MODEL:
|
||
print(f"[GLM] 检测到图片,模型从 {mapped} 切换为 {DEFAULT_VISION_MODEL}")
|
||
return DEFAULT_VISION_MODEL
|
||
return mapped
|
||
|
||
|
||
# ── 文件上传(含 file_id 缓存)───────────────────────────────────────
|
||
def upload_file_for_extract(local_path: Path) -> str:
|
||
from utils.file_cache import get as cache_get
|
||
from utils.file_cache import set as cache_set
|
||
from utils.file_cache import sha256_of_file
|
||
|
||
file_hash = sha256_of_file(local_path)
|
||
cached = cache_get(file_hash)
|
||
if cached:
|
||
print(f"[GLM] file_id 缓存命中:{local_path.name} → {cached['file_id']}")
|
||
return cached["file_id"]
|
||
|
||
client = get_client()
|
||
mime_map = {
|
||
".pdf": "application/pdf",
|
||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
".doc": "application/msword",
|
||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||
".xls": "application/vnd.ms-excel",
|
||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||
".ppt": "application/vnd.ms-powerpoint",
|
||
}
|
||
mime = mime_map.get(local_path.suffix.lower(), "application/octet-stream")
|
||
print(f"[GLM] 上传文件:{local_path.name}({mime})")
|
||
with open(local_path, "rb") as f:
|
||
file_obj = client.files.create(
|
||
file=(local_path.name, f, mime), purpose="file-extract"
|
||
)
|
||
file_id = file_obj.id
|
||
cache_set(file_hash, file_id)
|
||
print(f"[GLM] 上传成功:file_id={file_id}")
|
||
return file_id
|
||
|
||
|
||
# ── 图像编码 ─────────────────────────────────────────────────────────
|
||
def encode_image(image_source: str) -> dict:
|
||
"""将图像来源统一转为 OpenAI image_url 格式"""
|
||
if image_source.startswith("data:image") or image_source.startswith(
|
||
("http://", "https://")
|
||
):
|
||
return {"type": "image_url", "image_url": {"url": image_source}}
|
||
# 本地路径 → base64
|
||
local = Path(image_source.replace("file://", "").lstrip("/"))
|
||
if not local.exists():
|
||
local = Path.cwd() / local
|
||
ext = local.suffix.lstrip(".")
|
||
with open(local, "rb") as f:
|
||
b64 = base64.b64encode(f.read()).decode()
|
||
return {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}}
|
||
|
||
|
||
# ── 消息格式转换 ──────────────────────────────────────────────────────
|
||
def build_glm_messages(messages: list, files: list | None = None) -> tuple[list, bool]:
|
||
"""
|
||
将 OpenAI 格式的 messages + files 转换为 zai-sdk 所需格式。
|
||
返回 (glm_messages, has_vision)。
|
||
"""
|
||
from urllib.parse import urlparse
|
||
|
||
glm_messages = []
|
||
has_vision = False
|
||
|
||
for msg in messages:
|
||
if not isinstance(msg, dict):
|
||
glm_messages.append({"role": "user", "content": str(msg)})
|
||
continue
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
|
||
if isinstance(content, str):
|
||
glm_messages.append({"role": role, "content": content})
|
||
elif isinstance(content, list):
|
||
new_content = []
|
||
for item in content:
|
||
if not isinstance(item, dict):
|
||
new_content.append({"type": "text", "text": str(item)})
|
||
continue
|
||
t = item.get("type")
|
||
if t == "text":
|
||
new_content.append({"type": "text", "text": item.get("text", "")})
|
||
elif t == "image_url":
|
||
has_vision = True
|
||
img_val = item.get("image_url", "")
|
||
img_src = (
|
||
img_val.get("url", "") if isinstance(img_val, dict) else img_val
|
||
)
|
||
new_content.append(encode_image(img_src))
|
||
else:
|
||
new_content.append({"type": "text", "text": str(item)})
|
||
glm_messages.append({"role": role, "content": new_content})
|
||
else:
|
||
glm_messages.append({"role": role, "content": str(content)})
|
||
|
||
# 处理独立附件列表
|
||
if files:
|
||
doc_exts = {".pdf", ".doc", ".docx", ".xlsx", ".xls", ".pptx", ".ppt"}
|
||
img_exts = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||
inserts = []
|
||
|
||
for file_url in files:
|
||
parsed = urlparse(file_url)
|
||
filename = parsed.path.split("/")[-1]
|
||
suffix = Path(filename).suffix.lower()
|
||
rel = parsed.path.lstrip("/")
|
||
local = Path(rel)
|
||
|
||
if suffix in img_exts:
|
||
has_vision = True
|
||
try:
|
||
inserts.append(encode_image(f"file://{rel}"))
|
||
except Exception as e:
|
||
print(f"[GLM] 图像编码失败:{e}")
|
||
elif suffix in doc_exts:
|
||
has_vision = True
|
||
if local.exists():
|
||
try:
|
||
fid = upload_file_for_extract(local)
|
||
inserts.append({"type": "file", "file": {"file_id": fid}})
|
||
except Exception as e:
|
||
inserts.append(
|
||
{"type": "text", "text": f"[文件上传失败:{filename},{e}]"}
|
||
)
|
||
else:
|
||
inserts.append(
|
||
{"type": "text", "text": f"[附件:{filename},类型:{suffix}]"}
|
||
)
|
||
|
||
if inserts:
|
||
for i in range(len(glm_messages) - 1, -1, -1):
|
||
if glm_messages[i].get("role") == "user":
|
||
old = glm_messages[i]["content"]
|
||
if isinstance(old, str):
|
||
new_content = inserts + [{"type": "text", "text": old}]
|
||
elif isinstance(old, list):
|
||
new_content = inserts + old
|
||
else:
|
||
new_content = inserts
|
||
glm_messages[i] = {"role": "user", "content": new_content}
|
||
break
|
||
|
||
return glm_messages, has_vision
|
||
|
||
|
||
# ── 哨兵对象 ─────────────────────────────────────────────────────────
|
||
_SENTINEL = object()
|
||
|
||
|
||
# ── 流式调用 ────────────────────────────────────────────────────────
|
||
async def glm_stream_generator(
|
||
messages: list,
|
||
model: str,
|
||
temperature: float,
|
||
max_tokens: int,
|
||
files: list | None = None,
|
||
web_search: bool = False,
|
||
deep_thinking: bool = False,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""
|
||
GLM 流式 SSE 生成器。
|
||
使用 queue.Queue + 专用线程(生产者)+ asyncio 消费者模式,
|
||
让 zai-sdk 同步迭代器在单一线程内安全运行。
|
||
"""
|
||
import asyncio
|
||
import queue
|
||
|
||
from utils.helpers import generate_unique_id, get_current_timestamp
|
||
|
||
glm_msgs, has_vision = build_glm_messages(messages, files)
|
||
actual_model = resolve_model(model, has_vision)
|
||
|
||
extra_kwargs: dict = {}
|
||
if web_search:
|
||
extra_kwargs["tools"] = [
|
||
{
|
||
"type": "web_search",
|
||
"web_search": {"enable": True, "search_result": True},
|
||
}
|
||
]
|
||
if not deep_thinking:
|
||
# 智普默认开启思考模式,所以要用非门(不知道“非门”描述是否准确。前端选择开启思考模式,这里不做变动。前端选择关闭思考模式,这里关闭。)
|
||
extra_kwargs["thinking"] = {"type": "disabled"}
|
||
print(
|
||
f"[GLM] 流式请求:model={actual_model} vision={has_vision} "
|
||
f"web_search={web_search} thinking={deep_thinking}"
|
||
)
|
||
# ── 调试:打印发送给 GLM 的完整消息结构 ──
|
||
for i, msg in enumerate(glm_msgs):
|
||
role = msg.get("role", "?")
|
||
content = msg.get("content", "")
|
||
if isinstance(content, list):
|
||
for j, part in enumerate(content):
|
||
if not isinstance(part, dict):
|
||
print(f"[GLM-DEBUG] msg[{i}].content[{j}]: {type(part).__name__}")
|
||
continue
|
||
part_type = part.get("type", "?")
|
||
if part_type == "image_url":
|
||
img_val = part.get("image_url", "")
|
||
img_url = (
|
||
img_val.get("url", "")
|
||
if isinstance(img_val, dict)
|
||
else str(img_val)
|
||
)
|
||
display = img_url[:120] + "..." if len(img_url) > 120 else img_url
|
||
print(
|
||
f"[GLM-DEBUG] msg[{i}].content[{j}]: type=image_url, url={display}"
|
||
)
|
||
elif part_type == "text":
|
||
preview = (part.get("text", "") or "")[:100]
|
||
print(
|
||
f"[GLM-DEBUG] msg[{i}].content[{j}]: type=text, text={preview}"
|
||
)
|
||
else:
|
||
print(f"[GLM-DEBUG] msg[{i}].content[{j}]: {part}")
|
||
else:
|
||
print(f"[GLM-DEBUG] msg[{i}]: role={role}, content={str(content)[:150]}")
|
||
if extra_kwargs:
|
||
print(f"[GLM-DEBUG] extra_kwargs={extra_kwargs}")
|
||
# 原始 JSON 转储(用于排查结构问题)
|
||
import json as _json
|
||
|
||
print(
|
||
f"[GLM-RAW] messages={_json.dumps(glm_msgs, ensure_ascii=False, default=str)[:2000]}"
|
||
)
|
||
|
||
chunk_queue: queue.Queue = queue.Queue(maxsize=128)
|
||
|
||
def _producer():
|
||
try:
|
||
client = get_client()
|
||
resp = client.chat.completions.create(
|
||
model=actual_model,
|
||
messages=glm_msgs,
|
||
stream=True,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**extra_kwargs,
|
||
)
|
||
for chunk in resp:
|
||
chunk_queue.put(chunk)
|
||
except Exception as exc:
|
||
chunk_queue.put(exc)
|
||
finally:
|
||
chunk_queue.put(_SENTINEL)
|
||
|
||
t = threading.Thread(target=_producer, daemon=True)
|
||
t.start()
|
||
|
||
loop = asyncio.get_running_loop()
|
||
|
||
full_reasoning = "" # 累计思考内容(用于判断是否首次)
|
||
full_content = "" # 累计正式回答(用于判断是否首次)
|
||
|
||
while True:
|
||
item = await loop.run_in_executor(None, chunk_queue.get)
|
||
|
||
if item is _SENTINEL:
|
||
break
|
||
|
||
if isinstance(item, Exception):
|
||
print(f"[GLM] 生产者异常:{item}")
|
||
yield f"data: {json.dumps({'error': {'message': str(item), 'type': 'glm_error'}}, ensure_ascii=False)}\n\n"
|
||
break
|
||
|
||
try:
|
||
delta = item.choices[0].delta
|
||
reasoning = getattr(delta, "reasoning_content", "") or ""
|
||
text = getattr(delta, "content", "") or ""
|
||
|
||
delta_str = ""
|
||
|
||
# ── 思考过程(reasoning_content)────────────────────────
|
||
if reasoning:
|
||
if not full_reasoning:
|
||
# 首个思考片段:加 Markdown 引用块标题
|
||
delta_str += "> **💭 深度思考过程:**\n> \n> "
|
||
full_reasoning += reasoning
|
||
# 引用块内换行需在每行前加 `> `
|
||
delta_str += reasoning.replace("\n", "\n> ")
|
||
|
||
# ── 正式回答(content)──────────────────────────────────
|
||
if text:
|
||
if not full_content and full_reasoning:
|
||
# 思考结束后首次出现正式回答:加分隔线
|
||
delta_str += "\n\n---\n\n"
|
||
full_content += text
|
||
delta_str += text
|
||
|
||
if not delta_str:
|
||
continue
|
||
|
||
data = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": actual_model,
|
||
"choices": [
|
||
{"index": 0, "delta": {"content": delta_str}, "finish_reason": None}
|
||
],
|
||
}
|
||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
|
||
except Exception as e:
|
||
print(f"[GLM] chunk 解析异常:{e}")
|
||
|
||
finish = {
|
||
"id": f"chatcmpl-{generate_unique_id()}",
|
||
"object": "chat.completion.chunk",
|
||
"created": get_current_timestamp(),
|
||
"model": actual_model,
|
||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||
}
|
||
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
|
||
|
||
# ── 非流式调用 ────────────────────────────────────────────────────────
|
||
def glm_chat_sync(
|
||
messages: list,
|
||
model: str,
|
||
temperature: float,
|
||
max_tokens: int,
|
||
files: list | None = None,
|
||
web_search: bool = False,
|
||
deep_thinking: bool = False,
|
||
) -> dict:
|
||
glm_msgs, has_vision = build_glm_messages(messages, files)
|
||
actual_model = resolve_model(model, has_vision)
|
||
|
||
extra_kwargs: dict = {}
|
||
if web_search:
|
||
extra_kwargs["tools"] = [
|
||
{
|
||
"type": "web_search",
|
||
"web_search": {"enable": True, "search_result": True},
|
||
}
|
||
]
|
||
if deep_thinking:
|
||
extra_kwargs["thinking"] = {"type": "enabled"}
|
||
|
||
client = get_client()
|
||
print(f"[GLM] 非流式请求:model={actual_model}")
|
||
# ── 调试:打印发送给 GLM 的完整消息结构 ──
|
||
for i, msg in enumerate(glm_msgs):
|
||
role = msg.get("role", "?")
|
||
content = msg.get("content", "")
|
||
if isinstance(content, list):
|
||
for j, part in enumerate(content):
|
||
if not isinstance(part, dict):
|
||
print(f"[GLM-DEBUG] msg[{i}].content[{j}]: {type(part).__name__}")
|
||
continue
|
||
part_type = part.get("type", "?")
|
||
if part_type == "image_url":
|
||
img_val = part.get("image_url", "")
|
||
img_url = (
|
||
img_val.get("url", "")
|
||
if isinstance(img_val, dict)
|
||
else str(img_val)
|
||
)
|
||
display = img_url[:120] + "..." if len(img_url) > 120 else img_url
|
||
print(
|
||
f"[GLM-DEBUG] msg[{i}].content[{j}]: type=image_url, url={display}"
|
||
)
|
||
elif part_type == "text":
|
||
preview = (part.get("text", "") or "")[:100]
|
||
print(
|
||
f"[GLM-DEBUG] msg[{i}].content[{j}]: type=text, text={preview}"
|
||
)
|
||
else:
|
||
print(f"[GLM-DEBUG] msg[{i}].content[{j}]: {part}")
|
||
else:
|
||
print(f"[GLM-DEBUG] msg[{i}]: role={role}, content={str(content)[:150]}")
|
||
if extra_kwargs:
|
||
print(f"[GLM-DEBUG] extra_kwargs={extra_kwargs}")
|
||
# 原始 JSON 转储(用于排查结构问题)
|
||
import json as _json
|
||
|
||
print(
|
||
f"[GLM-RAW] messages={_json.dumps(glm_msgs, ensure_ascii=False, default=str)[:2000]}"
|
||
)
|
||
resp = client.chat.completions.create(
|
||
model=actual_model,
|
||
messages=glm_msgs,
|
||
stream=False,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**extra_kwargs,
|
||
)
|
||
content = resp.choices[0].message.content or ""
|
||
usage = None
|
||
if hasattr(resp, "usage") and resp.usage:
|
||
usage = {
|
||
"promptTokens": resp.usage.prompt_tokens,
|
||
"completionTokens": resp.usage.completion_tokens,
|
||
"totalTokens": resp.usage.total_tokens,
|
||
}
|
||
return {"content": content, "model": actual_model, "usage": usage}
|