ai-chat-ui/server/utils/glm_adapter.py

362 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
GLM-4.6V 适配层(基于 zai-sdk
SDKpip install zai-sdk
模型glm-4.6v(支持文本/图像/文档/深度思考)
"""
import os
import sys
import json
import base64
import threading
from pathlib import Path
from typing import AsyncGenerator
# ── 自动注入 venv site-packages ───────────────────────────────────────
def _ensure_venv():
server_dir = Path(__file__).parent.parent
for sp in sorted((server_dir / ".venv" / "lib").glob("python*/site-packages"), reverse=True):
if sp.exists() and str(sp) not in sys.path:
sys.path.insert(0, str(sp))
print(f"[GLM] venv 注入:{sp}")
break
# ── 客户端单例 ────────────────────────────────────────────────────────
_client = None
def get_client():
global _client
if _client is None:
_ensure_venv()
try:
from zai import ZhipuAiClient
except ImportError:
raise ImportError("GLM 模式需要安装 zai-sdk.venv/bin/pip install zai-sdk")
api_key = os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY")
if not api_key:
raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY")
_client = ZhipuAiClient(api_key=api_key)
print("[GLM] ZhipuAiClient 初始化完成zai-sdk")
return _client
# ── 模型映射 ──────────────────────────────────────────────────────────
DEFAULT_TEXT_MODEL = "glm-4.6v" # glm-4.6v 文本+视觉统一模型
DEFAULT_VISION_MODEL = "glm-4.6v"
MODEL_MAP = {
"qwen-max": "glm-4.6v",
"qwen-plus": "glm-4.6v",
"qwen-turbo": "glm-4.6v",
"qwen-vl-max": "glm-4.6v",
"qwen-vl-plus": "glm-4.6v",
}
def resolve_model(model: str, has_vision: bool = False) -> str:
if model.startswith("glm"):
return model
return MODEL_MAP.get(model, DEFAULT_TEXT_MODEL)
# ── 文件上传(含 file_id 缓存)───────────────────────────────────────
def upload_file_for_extract(local_path: Path) -> str:
from utils.file_cache import sha256_of_file, get as cache_get, set as cache_set
file_hash = sha256_of_file(local_path)
cached = cache_get(file_hash)
if cached:
print(f"[GLM] file_id 缓存命中:{local_path.name}{cached['file_id']}")
return cached["file_id"]
client = get_client()
mime_map = {
".pdf": "application/pdf",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".doc": "application/msword",
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
".xls": "application/vnd.ms-excel",
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
".ppt": "application/vnd.ms-powerpoint",
}
mime = mime_map.get(local_path.suffix.lower(), "application/octet-stream")
print(f"[GLM] 上传文件:{local_path.name}{mime}")
with open(local_path, "rb") as f:
file_obj = client.files.create(file=(local_path.name, f, mime), purpose="file-extract")
file_id = file_obj.id
cache_set(file_hash, file_id)
print(f"[GLM] 上传成功file_id={file_id}")
return file_id
# ── 图像编码 ─────────────────────────────────────────────────────────
def encode_image(image_source: str) -> dict:
"""将图像来源统一转为 OpenAI image_url 格式"""
if image_source.startswith("data:image") or image_source.startswith(("http://", "https://")):
return {"type": "image_url", "image_url": {"url": image_source}}
# 本地路径 → base64
local = Path(image_source.replace("file://", "").lstrip("/"))
if not local.exists():
local = Path.cwd() / local
ext = local.suffix.lstrip(".")
with open(local, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
return {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}}
# ── 消息格式转换 ──────────────────────────────────────────────────────
def build_glm_messages(messages: list, files: list | None = None) -> tuple[list, bool]:
"""
将 OpenAI 格式的 messages + files 转换为 zai-sdk 所需格式。
返回 (glm_messages, has_vision)。
"""
from urllib.parse import urlparse
glm_messages = []
has_vision = False
for msg in messages:
if not isinstance(msg, dict):
glm_messages.append({"role": "user", "content": str(msg)})
continue
role = msg.get("role", "user")
content = msg.get("content", "")
if isinstance(content, str):
glm_messages.append({"role": role, "content": content})
elif isinstance(content, list):
new_content = []
for item in content:
if not isinstance(item, dict):
new_content.append({"type": "text", "text": str(item)})
continue
t = item.get("type")
if t == "text":
new_content.append({"type": "text", "text": item.get("text", "")})
elif t == "image_url":
has_vision = True
img_val = item.get("image_url", "")
img_src = img_val.get("url", "") if isinstance(img_val, dict) else img_val
new_content.append(encode_image(img_src))
else:
new_content.append({"type": "text", "text": str(item)})
glm_messages.append({"role": role, "content": new_content})
else:
glm_messages.append({"role": role, "content": str(content)})
# 处理独立附件列表
if files:
doc_exts = {".pdf", ".doc", ".docx", ".xlsx", ".xls", ".pptx", ".ppt"}
img_exts = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
inserts = []
for file_url in files:
parsed = urlparse(file_url)
filename = parsed.path.split("/")[-1]
suffix = Path(filename).suffix.lower()
rel = parsed.path.lstrip("/")
local = Path(rel)
if suffix in img_exts:
has_vision = True
try:
inserts.append(encode_image(f"file://{rel}"))
except Exception as e:
print(f"[GLM] 图像编码失败:{e}")
elif suffix in doc_exts:
has_vision = True
if local.exists():
try:
fid = upload_file_for_extract(local)
inserts.append({"type": "file", "file": {"file_id": fid}})
except Exception as e:
inserts.append({"type": "text", "text": f"[文件上传失败:{filename}{e}]"})
else:
inserts.append({"type": "text", "text": f"[附件:{filename},类型:{suffix}]"})
if inserts:
for i in range(len(glm_messages) - 1, -1, -1):
if glm_messages[i].get("role") == "user":
old = glm_messages[i]["content"]
if isinstance(old, str):
new_content = inserts + [{"type": "text", "text": old}]
elif isinstance(old, list):
new_content = inserts + old
else:
new_content = inserts
glm_messages[i] = {"role": "user", "content": new_content}
break
return glm_messages, has_vision
# ── 哨兵对象 ─────────────────────────────────────────────────────────
_SENTINEL = object()
async def glm_stream_generator(
messages: list,
model: str,
temperature: float,
max_tokens: int,
files: list | None = None,
web_search: bool = False,
deep_thinking: bool = False,
) -> AsyncGenerator[str, None]:
"""
GLM 流式 SSE 生成器。
使用 queue.Queue + 专用线程(生产者)+ asyncio 消费者模式,
让 zai-sdk 同步迭代器在单一线程内安全运行。
"""
import asyncio
import queue
from utils.helpers import get_current_timestamp, generate_unique_id
glm_msgs, has_vision = build_glm_messages(messages, files)
actual_model = resolve_model(model, has_vision)
extra_kwargs: dict = {}
if web_search:
extra_kwargs["tools"] = [
{"type": "web_search", "web_search": {"search_result": True}}
]
if deep_thinking:
extra_kwargs["thinking"] = {"type": "enabled"}
print(f"[GLM] 流式请求model={actual_model} vision={has_vision} "
f"web_search={web_search} thinking={deep_thinking}")
chunk_queue: queue.Queue = queue.Queue(maxsize=128)
def _producer():
try:
client = get_client()
resp = client.chat.completions.create(
model=actual_model,
messages=glm_msgs,
stream=True,
temperature=temperature,
max_tokens=max_tokens,
**extra_kwargs,
)
for chunk in resp:
chunk_queue.put(chunk)
except Exception as exc:
chunk_queue.put(exc)
finally:
chunk_queue.put(_SENTINEL)
t = threading.Thread(target=_producer, daemon=True)
t.start()
loop = asyncio.get_running_loop()
full_reasoning = "" # 累计思考内容(用于判断是否首次)
full_content = "" # 累计正式回答(用于判断是否首次)
while True:
item = await loop.run_in_executor(None, chunk_queue.get)
if item is _SENTINEL:
break
if isinstance(item, Exception):
print(f"[GLM] 生产者异常:{item}")
yield f"data: {json.dumps({'error': {'message': str(item), 'type': 'glm_error'}}, ensure_ascii=False)}\n\n"
break
try:
delta = item.choices[0].delta
reasoning = getattr(delta, "reasoning_content", "") or ""
text = getattr(delta, "content", "") or ""
delta_str = ""
# ── 思考过程reasoning_content────────────────────────
if reasoning:
if not full_reasoning:
# 首个思考片段:加 Markdown 引用块标题
delta_str += "> **💭 深度思考过程:**\n> \n> "
full_reasoning += reasoning
# 引用块内换行需在每行前加 `> `
delta_str += reasoning.replace("\n", "\n> ")
# ── 正式回答content──────────────────────────────────
if text:
if not full_content and full_reasoning:
# 思考结束后首次出现正式回答:加分隔线
delta_str += "\n\n---\n\n"
full_content += text
delta_str += text
if not delta_str:
continue
data = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": actual_model,
"choices": [{"index": 0, "delta": {"content": delta_str}, "finish_reason": None}],
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
except Exception as e:
print(f"[GLM] chunk 解析异常:{e}")
finish = {
"id": f"chatcmpl-{generate_unique_id()}",
"object": "chat.completion.chunk",
"created": get_current_timestamp(),
"model": actual_model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
# ── 非流式调用 ────────────────────────────────────────────────────────
def glm_chat_sync(
messages: list,
model: str,
temperature: float,
max_tokens: int,
files: list | None = None,
web_search: bool = False,
deep_thinking: bool = False,
) -> dict:
glm_msgs, has_vision = build_glm_messages(messages, files)
actual_model = resolve_model(model, has_vision)
extra_kwargs: dict = {}
if web_search:
extra_kwargs["tools"] = [
{"type": "web_search", "web_search": {"search_result": True}}
]
if deep_thinking:
extra_kwargs["thinking"] = {"type": "enabled"}
client = get_client()
print(f"[GLM] 非流式请求model={actual_model}")
resp = client.chat.completions.create(
model=actual_model,
messages=glm_msgs,
stream=False,
temperature=temperature,
max_tokens=max_tokens,
**extra_kwargs,
)
content = resp.choices[0].message.content or ""
usage = None
if hasattr(resp, "usage") and resp.usage:
usage = {
"promptTokens": resp.usage.prompt_tokens,
"completionTokens": resp.usage.completion_tokens,
"totalTokens": resp.usage.total_tokens,
}
return {"content": content, "model": actual_model, "usage": usage}