format: 项目格式化
This commit is contained in:
parent
b878011a2c
commit
89b02c4c93
|
|
@ -1,35 +1,38 @@
|
||||||
"""
|
"""
|
||||||
包初始化文件
|
包初始化文件
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .api.chat_routes import (chat_endpoint_handler,
|
||||||
|
delete_conversation_handler,
|
||||||
|
get_conversation_handler,
|
||||||
|
get_conversations_handler, get_models_handler,
|
||||||
|
save_conversation_handler, serve_upload_handler,
|
||||||
|
stop_generation_handler, upload_file_handler)
|
||||||
from .models.chat_models import ChatMessage, ChatRequest, ModelInfo
|
from .models.chat_models import ChatMessage, ChatRequest, ModelInfo
|
||||||
from .utils.helpers import (
|
from .utils.helpers import (extract_delta_content, format_api_response,
|
||||||
get_current_timestamp,
|
generate_unique_id, get_current_timestamp,
|
||||||
generate_unique_id,
|
log_request, log_response)
|
||||||
format_api_response,
|
|
||||||
log_request,
|
|
||||||
log_response,
|
|
||||||
extract_delta_content
|
|
||||||
)
|
|
||||||
from .api.chat_routes import (
|
|
||||||
chat_endpoint_handler,
|
|
||||||
get_models_handler,
|
|
||||||
get_conversations_handler,
|
|
||||||
get_conversation_handler,
|
|
||||||
save_conversation_handler,
|
|
||||||
delete_conversation_handler,
|
|
||||||
upload_file_handler,
|
|
||||||
serve_upload_handler,
|
|
||||||
stop_generation_handler
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Models
|
# Models
|
||||||
'ChatMessage', 'ChatRequest', 'ModelInfo',
|
"ChatMessage",
|
||||||
|
"ChatRequest",
|
||||||
|
"ModelInfo",
|
||||||
# Utils
|
# Utils
|
||||||
'get_current_timestamp', 'generate_unique_id', 'format_api_response',
|
"get_current_timestamp",
|
||||||
'log_request', 'log_response', 'extract_delta_content',
|
"generate_unique_id",
|
||||||
|
"format_api_response",
|
||||||
|
"log_request",
|
||||||
|
"log_response",
|
||||||
|
"extract_delta_content",
|
||||||
# API Handlers
|
# API Handlers
|
||||||
'chat_endpoint_handler', 'get_models_handler', 'get_conversations_handler',
|
"chat_endpoint_handler",
|
||||||
'get_conversation_handler', 'save_conversation_handler', 'delete_conversation_handler',
|
"get_models_handler",
|
||||||
'upload_file_handler', 'serve_upload_handler', 'stop_generation_handler'
|
"get_conversations_handler",
|
||||||
|
"get_conversation_handler",
|
||||||
|
"save_conversation_handler",
|
||||||
|
"delete_conversation_handler",
|
||||||
|
"upload_file_handler",
|
||||||
|
"serve_upload_handler",
|
||||||
|
"stop_generation_handler",
|
||||||
]
|
]
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -2,13 +2,16 @@
|
||||||
GLM-4.6V 平台路由处理器(zai-sdk)
|
GLM-4.6V 平台路由处理器(zai-sdk)
|
||||||
所有智谱 GLM 相关逻辑均集中在此文件,main.py 无感知任何平台细节。
|
所有智谱 GLM 相关逻辑均集中在此文件,main.py 无感知任何平台细节。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from utils.helpers import get_current_timestamp, generate_unique_id
|
|
||||||
|
from utils.helpers import generate_unique_id, get_current_timestamp
|
||||||
from utils.logger import log_info
|
from utils.logger import log_info
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,7 +22,9 @@ def init():
|
||||||
"""
|
"""
|
||||||
api_key = os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY")
|
api_key = os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY(在 https://open.bigmodel.cn 申请)")
|
raise ValueError(
|
||||||
|
"GLM 模式需要设置环境变量 ZHIPU_API_KEY(在 https://open.bigmodel.cn 申请)"
|
||||||
|
)
|
||||||
log_info(f"[GLM] 初始化完成,ZHIPU_API_KEY 已配置")
|
log_info(f"[GLM] 初始化完成,ZHIPU_API_KEY 已配置")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,47 +33,61 @@ async def chat_handler(body: dict):
|
||||||
GLM 聊天处理器(对外接口与百炼 chat_endpoint_handler 完全兼容)。
|
GLM 聊天处理器(对外接口与百炼 chat_endpoint_handler 完全兼容)。
|
||||||
流式/非流式自动适配,支持图像、文档附件、联网搜索、深度思考。
|
流式/非流式自动适配,支持图像、文档附件、联网搜索、深度思考。
|
||||||
"""
|
"""
|
||||||
from utils.glm_adapter import glm_stream_generator, glm_chat_sync
|
from utils.glm_adapter import glm_chat_sync, glm_stream_generator
|
||||||
|
|
||||||
if not isinstance(body, dict):
|
if not isinstance(body, dict):
|
||||||
raise HTTPException(status_code=400, detail="请求体必须是 JSON 对象")
|
raise HTTPException(status_code=400, detail="请求体必须是 JSON 对象")
|
||||||
|
|
||||||
messages = body.get("messages", [])
|
messages = body.get("messages", [])
|
||||||
model = body.get("model", "glm-4.6v")
|
model = body.get("model", "glm-4.6v")
|
||||||
stream = body.get("stream", True)
|
stream = body.get("stream", True)
|
||||||
temperature = body.get("temperature", 0.7)
|
temperature = body.get("temperature", 0.7)
|
||||||
max_tokens = body.get("max_tokens", body.get("maxTokens", 2000))
|
max_tokens = body.get("max_tokens", body.get("maxTokens", 2000))
|
||||||
web_search = body.get("webSearch", False) or body.get("deepSearch", False)
|
web_search = body.get("webSearch", False) or body.get("deepSearch", False)
|
||||||
deep_think = body.get("deepThinking", False)
|
deep_think = body.get("deepThinking", False)
|
||||||
files = body.get("files", [])
|
files = body.get("files", [])
|
||||||
|
|
||||||
# 兼容前端简化格式(非 messages 结构)
|
# 兼容前端简化格式(非 messages 结构)
|
||||||
if not messages:
|
if not messages:
|
||||||
msg_text = body.get("message", "")
|
msg_text = body.get("message", "")
|
||||||
sys_prompt = body.get("systemPrompt", "你是一个智能助手。")
|
sys_prompt = body.get("systemPrompt", "你是一个智能助手。")
|
||||||
user_content = msg_text if isinstance(msg_text, list) else [{"type": "text", "text": msg_text}]
|
user_content = (
|
||||||
|
msg_text
|
||||||
|
if isinstance(msg_text, list)
|
||||||
|
else [{"type": "text", "text": msg_text}]
|
||||||
|
)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": sys_prompt},
|
{"role": "system", "content": sys_prompt},
|
||||||
{"role": "user", "content": user_content},
|
{"role": "user", "content": user_content},
|
||||||
]
|
]
|
||||||
|
|
||||||
log_info(f"[GLM] model={model} stream={stream} web_search={web_search} "
|
log_info(
|
||||||
f"thinking={deep_think} files={len(files)} msgs={len(messages)}")
|
f"[GLM] model={model} stream={stream} web_search={web_search} "
|
||||||
|
f"thinking={deep_think} files={len(files)} msgs={len(messages)}"
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
glm_stream_generator(
|
glm_stream_generator(
|
||||||
messages=messages, model=model, temperature=temperature,
|
messages=messages,
|
||||||
max_tokens=max_tokens, files=files or None,
|
model=model,
|
||||||
web_search=web_search, deep_thinking=deep_think,
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
files=files or None,
|
||||||
|
web_search=web_search,
|
||||||
|
deep_thinking=deep_think,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = glm_chat_sync(
|
result = glm_chat_sync(
|
||||||
messages=messages, model=model, temperature=temperature,
|
messages=messages,
|
||||||
max_tokens=max_tokens, files=files or None,
|
model=model,
|
||||||
web_search=web_search, deep_thinking=deep_think,
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
files=files or None,
|
||||||
|
web_search=web_search,
|
||||||
|
deep_thinking=deep_think,
|
||||||
)
|
)
|
||||||
resp = {
|
resp = {
|
||||||
"id": f"chatcmpl-{generate_unique_id()}",
|
"id": f"chatcmpl-{generate_unique_id()}",
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,12 @@
|
||||||
"""
|
"""
|
||||||
初始化日志系统
|
初始化日志系统
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from utils.logger import setup_global_logger
|
from utils.logger import setup_global_logger
|
||||||
|
|
||||||
|
|
||||||
def init_logging_system():
|
def init_logging_system():
|
||||||
"""
|
"""
|
||||||
初始化日志系统
|
初始化日志系统
|
||||||
|
|
@ -26,13 +29,12 @@ def init_logging_system():
|
||||||
|
|
||||||
# 设置全局日志系统
|
# 设置全局日志系统
|
||||||
logger = setup_global_logger(
|
logger = setup_global_logger(
|
||||||
name="ai-chat-api",
|
name="ai-chat-api", log_level=log_level, log_dir=log_dir
|
||||||
log_level=log_level,
|
|
||||||
log_dir=log_dir
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger = init_logging_system()
|
logger = init_logging_system()
|
||||||
logger.info("Logging system initialized successfully")
|
logger.info("Logging system initialized successfully")
|
||||||
|
|
@ -10,6 +10,7 @@ AI Chat API Server — 主入口(纯基础设施层)
|
||||||
- 百炼 DashScope → api/chat_routes.py
|
- 百炼 DashScope → api/chat_routes.py
|
||||||
- 智谱 GLM-4.6V → api/chat_routes_glm.py + utils/glm_adapter.py
|
- 智谱 GLM-4.6V → api/chat_routes_glm.py + utils/glm_adapter.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
@ -27,10 +28,10 @@ if _venv_lib.exists():
|
||||||
|
|
||||||
# ── 第三方导入 ────────────────────────────────────────────────────────
|
# ── 第三方导入 ────────────────────────────────────────────────────────
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import FastAPI, File, UploadFile, Request
|
from fastapi import FastAPI, File, Request, UploadFile
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
sys.path.append('/home/mt/project/ai-chat-ui/server')
|
sys.path.append("/home/mt/project/ai-chat-ui/server")
|
||||||
|
|
||||||
# ── 工具/日志(与平台无关)───────────────────────────────────────────
|
# ── 工具/日志(与平台无关)───────────────────────────────────────────
|
||||||
from utils.helpers import log_response
|
from utils.helpers import log_response
|
||||||
|
|
@ -55,15 +56,11 @@ else:
|
||||||
_platform.init() # 各平台自行完成初始化(API Key 校验等)
|
_platform.init() # 各平台自行完成初始化(API Key 校验等)
|
||||||
|
|
||||||
# 通用路由处理器(文件上传、会话管理等,与平台无关,统一用百炼路由中的实现)
|
# 通用路由处理器(文件上传、会话管理等,与平台无关,统一用百炼路由中的实现)
|
||||||
from api.chat_routes import (
|
from api.chat_routes import (delete_conversation_handler,
|
||||||
get_conversations_handler,
|
get_conversation_handler,
|
||||||
get_conversation_handler,
|
get_conversations_handler,
|
||||||
save_conversation_handler,
|
save_conversation_handler, serve_upload_handler,
|
||||||
delete_conversation_handler,
|
stop_generation_handler, upload_file_handler)
|
||||||
upload_file_handler,
|
|
||||||
serve_upload_handler,
|
|
||||||
stop_generation_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── FastAPI 应用 ──────────────────────────────────────────────────────
|
# ── FastAPI 应用 ──────────────────────────────────────────────────────
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
|
|
@ -80,7 +77,9 @@ async def logging_middleware(request: Request, call_next):
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
|
ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
|
||||||
icon = "✅" if response.status_code < 400 else "❌"
|
icon = "✅" if response.status_code < 400 else "❌"
|
||||||
logger.info(f"{icon} {request.method} {request.url.path} | 状态: {response.status_code} | 耗时: {ms:.0f}ms")
|
logger.info(
|
||||||
|
f"{icon} {request.method} {request.url.path} | 状态: {response.status_code} | 耗时: {ms:.0f}ms"
|
||||||
|
)
|
||||||
log_response(response.status_code, ms)
|
log_response(response.status_code, ms)
|
||||||
response.headers["X-Process-Time"] = f"{ms:.2f}ms"
|
response.headers["X-Process-Time"] = f"{ms:.2f}ms"
|
||||||
return response
|
return response
|
||||||
|
|
@ -88,6 +87,7 @@ async def logging_middleware(request: Request, call_next):
|
||||||
|
|
||||||
# ── 路由注册 ──────────────────────────────────────────────────────────
|
# ── 路由注册 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
return {
|
return {
|
||||||
|
|
@ -115,6 +115,7 @@ async def get_models():
|
||||||
|
|
||||||
# ── 通用路由(与平台无关)────────────────────────────────────────────
|
# ── 通用路由(与平台无关)────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/chat-ui/conversations")
|
@app.get("/api/chat-ui/conversations")
|
||||||
async def get_conversations():
|
async def get_conversations():
|
||||||
return await get_conversations_handler()
|
return await get_conversations_handler()
|
||||||
|
|
@ -158,6 +159,7 @@ async def stop_generation_by_id(message_id: str):
|
||||||
# ── 程序入口 ──────────────────────────────────────────────────────────
|
# ── 程序入口 ──────────────────────────────────────────────────────────
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
port = int(os.getenv("PORT", 8000))
|
port = int(os.getenv("PORT", 8000))
|
||||||
print("=" * 55)
|
print("=" * 55)
|
||||||
print(f" AI Chat Server v3.0 启动中...")
|
print(f" AI Chat Server v3.0 启动中...")
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,18 @@
|
||||||
"""
|
"""
|
||||||
数据模型定义
|
数据模型定义
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Dict, List, Optional, Any, Union
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageContentItem(BaseModel):
|
class ChatMessageContentItem(BaseModel):
|
||||||
type: str # "text" or "image_url"
|
type: str # "text" or "image_url"
|
||||||
text: Optional[str] = None
|
text: Optional[str] = None
|
||||||
image_url: Optional[Dict[str, str]] = None # {"url": "...", "detail": "auto|low|high"}
|
image_url: Optional[Dict[str, str]] = (
|
||||||
|
None # {"url": "...", "detail": "auto|low|high"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
"""
|
"""
|
||||||
GLM 文件 ID 缓存(基于磁盘的简单 KV,sha256 → file_id,3天有效期)
|
GLM 文件 ID 缓存(基于磁盘的简单 KV,sha256 → file_id,3天有效期)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
_CACHE_FILE = Path(__file__).parent.parent / "uploads" / ".glm_file_cache.json"
|
_CACHE_FILE = Path(__file__).parent.parent / "uploads" / ".glm_file_cache.json"
|
||||||
|
|
@ -24,7 +25,9 @@ def _load() -> dict:
|
||||||
def _save(data: dict) -> None:
|
def _save(data: dict) -> None:
|
||||||
try:
|
try:
|
||||||
_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
_CACHE_FILE.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
_CACHE_FILE.write_text(
|
||||||
|
json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[file_cache] 写入失败:{e}")
|
print(f"[file_cache] 写入失败:{e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,11 @@ GLM-4.6V 适配层(基于 zai-sdk)
|
||||||
SDK:pip install zai-sdk
|
SDK:pip install zai-sdk
|
||||||
模型:glm-4.6v(支持文本/图像/文档/深度思考)
|
模型:glm-4.6v(支持文本/图像/文档/深度思考)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
|
||||||
import base64
|
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
@ -15,7 +16,9 @@ from typing import AsyncGenerator
|
||||||
# ── 自动注入 venv site-packages ───────────────────────────────────────
|
# ── 自动注入 venv site-packages ───────────────────────────────────────
|
||||||
def _ensure_venv():
|
def _ensure_venv():
|
||||||
server_dir = Path(__file__).parent.parent
|
server_dir = Path(__file__).parent.parent
|
||||||
for sp in sorted((server_dir / ".venv" / "lib").glob("python*/site-packages"), reverse=True):
|
for sp in sorted(
|
||||||
|
(server_dir / ".venv" / "lib").glob("python*/site-packages"), reverse=True
|
||||||
|
):
|
||||||
if sp.exists() and str(sp) not in sys.path:
|
if sp.exists() and str(sp) not in sys.path:
|
||||||
sys.path.insert(0, str(sp))
|
sys.path.insert(0, str(sp))
|
||||||
print(f"[GLM] venv 注入:{sp}")
|
print(f"[GLM] venv 注入:{sp}")
|
||||||
|
|
@ -34,7 +37,7 @@ def get_client():
|
||||||
from zai import ZhipuAiClient
|
from zai import ZhipuAiClient
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("GLM 模式需要安装 zai-sdk:.venv/bin/pip install zai-sdk")
|
raise ImportError("GLM 模式需要安装 zai-sdk:.venv/bin/pip install zai-sdk")
|
||||||
api_key = os.getenv("ZHIPU_API_KEY").strip() or os.getenv("GLM_API_KEY").strip()
|
api_key = os.getenv("ZHIPU_API_KEY").strip() or os.getenv("GLM_API_KEY").strip()
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY")
|
raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY")
|
||||||
_client = ZhipuAiClient(api_key=api_key)
|
_client = ZhipuAiClient(api_key=api_key)
|
||||||
|
|
@ -43,15 +46,15 @@ def get_client():
|
||||||
|
|
||||||
|
|
||||||
# ── 模型映射 ──────────────────────────────────────────────────────────
|
# ── 模型映射 ──────────────────────────────────────────────────────────
|
||||||
DEFAULT_TEXT_MODEL = "glm-4.5-Air" # glm-4.6 文本统一模型
|
DEFAULT_TEXT_MODEL = "glm-4.5-Air" # glm-4.6 文本统一模型
|
||||||
DEFAULT_VISION_MODEL = "glm-4.5-Air"
|
DEFAULT_VISION_MODEL = "glm-4.5-Air"
|
||||||
|
|
||||||
MODEL_MAP = {
|
MODEL_MAP = {
|
||||||
"qwen-max": "glm-4.5-Air",
|
"qwen-max": "glm-4.5-Air",
|
||||||
"qwen-plus": "glm-4.5-Air",
|
"qwen-plus": "glm-4.5-Air",
|
||||||
"qwen-turbo": "glm-4.5-Air",
|
"qwen-turbo": "glm-4.5-Air",
|
||||||
"qwen-vl-max": "glm-4.5-Air",
|
"qwen-vl-max": "glm-4.5-Air",
|
||||||
"qwen-vl-plus": "glm-4.5-Air",
|
"qwen-vl-plus": "glm-4.5-Air",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,7 +66,9 @@ def resolve_model(model: str, has_vision: bool = False) -> str:
|
||||||
|
|
||||||
# ── 文件上传(含 file_id 缓存)───────────────────────────────────────
|
# ── 文件上传(含 file_id 缓存)───────────────────────────────────────
|
||||||
def upload_file_for_extract(local_path: Path) -> str:
|
def upload_file_for_extract(local_path: Path) -> str:
|
||||||
from utils.file_cache import sha256_of_file, get as cache_get, set as cache_set
|
from utils.file_cache import get as cache_get
|
||||||
|
from utils.file_cache import set as cache_set
|
||||||
|
from utils.file_cache import sha256_of_file
|
||||||
|
|
||||||
file_hash = sha256_of_file(local_path)
|
file_hash = sha256_of_file(local_path)
|
||||||
cached = cache_get(file_hash)
|
cached = cache_get(file_hash)
|
||||||
|
|
@ -73,18 +78,20 @@ def upload_file_for_extract(local_path: Path) -> str:
|
||||||
|
|
||||||
client = get_client()
|
client = get_client()
|
||||||
mime_map = {
|
mime_map = {
|
||||||
".pdf": "application/pdf",
|
".pdf": "application/pdf",
|
||||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
".doc": "application/msword",
|
".doc": "application/msword",
|
||||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
".xls": "application/vnd.ms-excel",
|
".xls": "application/vnd.ms-excel",
|
||||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
".ppt": "application/vnd.ms-powerpoint",
|
".ppt": "application/vnd.ms-powerpoint",
|
||||||
}
|
}
|
||||||
mime = mime_map.get(local_path.suffix.lower(), "application/octet-stream")
|
mime = mime_map.get(local_path.suffix.lower(), "application/octet-stream")
|
||||||
print(f"[GLM] 上传文件:{local_path.name}({mime})")
|
print(f"[GLM] 上传文件:{local_path.name}({mime})")
|
||||||
with open(local_path, "rb") as f:
|
with open(local_path, "rb") as f:
|
||||||
file_obj = client.files.create(file=(local_path.name, f, mime), purpose="file-extract")
|
file_obj = client.files.create(
|
||||||
|
file=(local_path.name, f, mime), purpose="file-extract"
|
||||||
|
)
|
||||||
file_id = file_obj.id
|
file_id = file_obj.id
|
||||||
cache_set(file_hash, file_id)
|
cache_set(file_hash, file_id)
|
||||||
print(f"[GLM] 上传成功:file_id={file_id}")
|
print(f"[GLM] 上传成功:file_id={file_id}")
|
||||||
|
|
@ -94,7 +101,9 @@ def upload_file_for_extract(local_path: Path) -> str:
|
||||||
# ── 图像编码 ─────────────────────────────────────────────────────────
|
# ── 图像编码 ─────────────────────────────────────────────────────────
|
||||||
def encode_image(image_source: str) -> dict:
|
def encode_image(image_source: str) -> dict:
|
||||||
"""将图像来源统一转为 OpenAI image_url 格式"""
|
"""将图像来源统一转为 OpenAI image_url 格式"""
|
||||||
if image_source.startswith("data:image") or image_source.startswith(("http://", "https://")):
|
if image_source.startswith("data:image") or image_source.startswith(
|
||||||
|
("http://", "https://")
|
||||||
|
):
|
||||||
return {"type": "image_url", "image_url": {"url": image_source}}
|
return {"type": "image_url", "image_url": {"url": image_source}}
|
||||||
# 本地路径 → base64
|
# 本地路径 → base64
|
||||||
local = Path(image_source.replace("file://", "").lstrip("/"))
|
local = Path(image_source.replace("file://", "").lstrip("/"))
|
||||||
|
|
@ -138,7 +147,9 @@ def build_glm_messages(messages: list, files: list | None = None) -> tuple[list,
|
||||||
elif t == "image_url":
|
elif t == "image_url":
|
||||||
has_vision = True
|
has_vision = True
|
||||||
img_val = item.get("image_url", "")
|
img_val = item.get("image_url", "")
|
||||||
img_src = img_val.get("url", "") if isinstance(img_val, dict) else img_val
|
img_src = (
|
||||||
|
img_val.get("url", "") if isinstance(img_val, dict) else img_val
|
||||||
|
)
|
||||||
new_content.append(encode_image(img_src))
|
new_content.append(encode_image(img_src))
|
||||||
else:
|
else:
|
||||||
new_content.append({"type": "text", "text": str(item)})
|
new_content.append({"type": "text", "text": str(item)})
|
||||||
|
|
@ -172,9 +183,13 @@ def build_glm_messages(messages: list, files: list | None = None) -> tuple[list,
|
||||||
fid = upload_file_for_extract(local)
|
fid = upload_file_for_extract(local)
|
||||||
inserts.append({"type": "file", "file": {"file_id": fid}})
|
inserts.append({"type": "file", "file": {"file_id": fid}})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
inserts.append({"type": "text", "text": f"[文件上传失败:{filename},{e}]"})
|
inserts.append(
|
||||||
|
{"type": "text", "text": f"[文件上传失败:{filename},{e}]"}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
inserts.append({"type": "text", "text": f"[附件:{filename},类型:{suffix}]"})
|
inserts.append(
|
||||||
|
{"type": "text", "text": f"[附件:{filename},类型:{suffix}]"}
|
||||||
|
)
|
||||||
|
|
||||||
if inserts:
|
if inserts:
|
||||||
for i in range(len(glm_messages) - 1, -1, -1):
|
for i in range(len(glm_messages) - 1, -1, -1):
|
||||||
|
|
@ -195,6 +210,7 @@ def build_glm_messages(messages: list, files: list | None = None) -> tuple[list,
|
||||||
# ── 哨兵对象 ─────────────────────────────────────────────────────────
|
# ── 哨兵对象 ─────────────────────────────────────────────────────────
|
||||||
_SENTINEL = object()
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
# ── 流式调用 ────────────────────────────────────────────────────────
|
# ── 流式调用 ────────────────────────────────────────────────────────
|
||||||
async def glm_stream_generator(
|
async def glm_stream_generator(
|
||||||
messages: list,
|
messages: list,
|
||||||
|
|
@ -213,7 +229,7 @@ async def glm_stream_generator(
|
||||||
import asyncio
|
import asyncio
|
||||||
import queue
|
import queue
|
||||||
|
|
||||||
from utils.helpers import get_current_timestamp, generate_unique_id
|
from utils.helpers import generate_unique_id, get_current_timestamp
|
||||||
|
|
||||||
glm_msgs, has_vision = build_glm_messages(messages, files)
|
glm_msgs, has_vision = build_glm_messages(messages, files)
|
||||||
actual_model = resolve_model(model, has_vision)
|
actual_model = resolve_model(model, has_vision)
|
||||||
|
|
@ -221,13 +237,18 @@ async def glm_stream_generator(
|
||||||
extra_kwargs: dict = {}
|
extra_kwargs: dict = {}
|
||||||
if web_search:
|
if web_search:
|
||||||
extra_kwargs["tools"] = [
|
extra_kwargs["tools"] = [
|
||||||
{"type": "web_search", "web_search": {"enable":True,"search_result": True}}
|
{
|
||||||
|
"type": "web_search",
|
||||||
|
"web_search": {"enable": True, "search_result": True},
|
||||||
|
}
|
||||||
]
|
]
|
||||||
if not deep_thinking:
|
if not deep_thinking:
|
||||||
# 智普默认开启思考模式,所以要用非门(不知道“非门”描述是否准确。前端选择开启思考模式,这里不做变动。前端选择关闭思考模式,这里关闭。)
|
# 智普默认开启思考模式,所以要用非门(不知道“非门”描述是否准确。前端选择开启思考模式,这里不做变动。前端选择关闭思考模式,这里关闭。)
|
||||||
extra_kwargs["thinking"] = {"type": "disabled"}
|
extra_kwargs["thinking"] = {"type": "disabled"}
|
||||||
print(f"[GLM] 流式请求:model={actual_model} vision={has_vision} "
|
print(
|
||||||
f"web_search={web_search} thinking={deep_thinking}")
|
f"[GLM] 流式请求:model={actual_model} vision={has_vision} "
|
||||||
|
f"web_search={web_search} thinking={deep_thinking}"
|
||||||
|
)
|
||||||
|
|
||||||
chunk_queue: queue.Queue = queue.Queue(maxsize=128)
|
chunk_queue: queue.Queue = queue.Queue(maxsize=128)
|
||||||
|
|
||||||
|
|
@ -254,8 +275,8 @@ async def glm_stream_generator(
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
full_reasoning = "" # 累计思考内容(用于判断是否首次)
|
full_reasoning = "" # 累计思考内容(用于判断是否首次)
|
||||||
full_content = "" # 累计正式回答(用于判断是否首次)
|
full_content = "" # 累计正式回答(用于判断是否首次)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
item = await loop.run_in_executor(None, chunk_queue.get)
|
item = await loop.run_in_executor(None, chunk_queue.get)
|
||||||
|
|
@ -271,7 +292,7 @@ async def glm_stream_generator(
|
||||||
try:
|
try:
|
||||||
delta = item.choices[0].delta
|
delta = item.choices[0].delta
|
||||||
reasoning = getattr(delta, "reasoning_content", "") or ""
|
reasoning = getattr(delta, "reasoning_content", "") or ""
|
||||||
text = getattr(delta, "content", "") or ""
|
text = getattr(delta, "content", "") or ""
|
||||||
|
|
||||||
delta_str = ""
|
delta_str = ""
|
||||||
|
|
||||||
|
|
@ -300,7 +321,9 @@ async def glm_stream_generator(
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"created": get_current_timestamp(),
|
"created": get_current_timestamp(),
|
||||||
"model": actual_model,
|
"model": actual_model,
|
||||||
"choices": [{"index": 0, "delta": {"content": delta_str}, "finish_reason": None}],
|
"choices": [
|
||||||
|
{"index": 0, "delta": {"content": delta_str}, "finish_reason": None}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
@ -318,7 +341,6 @@ async def glm_stream_generator(
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ── 非流式调用 ────────────────────────────────────────────────────────
|
# ── 非流式调用 ────────────────────────────────────────────────────────
|
||||||
def glm_chat_sync(
|
def glm_chat_sync(
|
||||||
messages: list,
|
messages: list,
|
||||||
|
|
@ -334,13 +356,12 @@ def glm_chat_sync(
|
||||||
|
|
||||||
extra_kwargs: dict = {}
|
extra_kwargs: dict = {}
|
||||||
if web_search:
|
if web_search:
|
||||||
extra_kwargs["tools"] = [{
|
extra_kwargs["tools"] = [
|
||||||
"type": "web_search",
|
{
|
||||||
"web_search": {
|
"type": "web_search",
|
||||||
"enable": True,
|
"web_search": {"enable": True, "search_result": True},
|
||||||
"search_result": True
|
}
|
||||||
}
|
]
|
||||||
}]
|
|
||||||
if deep_thinking:
|
if deep_thinking:
|
||||||
extra_kwargs["thinking"] = {"type": "enabled"}
|
extra_kwargs["thinking"] = {"type": "enabled"}
|
||||||
|
|
||||||
|
|
@ -358,8 +379,8 @@ def glm_chat_sync(
|
||||||
usage = None
|
usage = None
|
||||||
if hasattr(resp, "usage") and resp.usage:
|
if hasattr(resp, "usage") and resp.usage:
|
||||||
usage = {
|
usage = {
|
||||||
"promptTokens": resp.usage.prompt_tokens,
|
"promptTokens": resp.usage.prompt_tokens,
|
||||||
"completionTokens": resp.usage.completion_tokens,
|
"completionTokens": resp.usage.completion_tokens,
|
||||||
"totalTokens": resp.usage.total_tokens,
|
"totalTokens": resp.usage.total_tokens,
|
||||||
}
|
}
|
||||||
return {"content": content, "model": actual_model, "usage": usage}
|
return {"content": content, "model": actual_model, "usage": usage}
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,15 @@
|
||||||
"""
|
"""
|
||||||
通用工具函数
|
通用工具函数
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from .logger import log_request_info, log_response_info, log_error_detail, log_chat_interaction
|
from .logger import (log_chat_interaction, log_error_detail, log_request_info,
|
||||||
|
log_response_info)
|
||||||
|
|
||||||
|
|
||||||
def get_current_timestamp():
|
def get_current_timestamp():
|
||||||
|
|
@ -20,14 +22,16 @@ def generate_unique_id():
|
||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
def format_api_response(content: str, conversation_id: str = None, model: str = "qwen-plus"):
|
def format_api_response(
|
||||||
|
content: str, conversation_id: str = None, model: str = "qwen-plus"
|
||||||
|
):
|
||||||
"""格式化API响应"""
|
"""格式化API响应"""
|
||||||
return {
|
return {
|
||||||
"id": generate_unique_id(),
|
"id": generate_unique_id(),
|
||||||
"conversationId": conversation_id or generate_unique_id(),
|
"conversationId": conversation_id or generate_unique_id(),
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": model,
|
"model": model,
|
||||||
"createdAt": get_current_timestamp()
|
"createdAt": get_current_timestamp(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -44,5 +48,5 @@ def log_response(status_code: int, process_time: float):
|
||||||
def extract_delta_content(full_content: str, previous_content: str) -> str:
|
def extract_delta_content(full_content: str, previous_content: str) -> str:
|
||||||
"""提取增量内容"""
|
"""提取增量内容"""
|
||||||
if len(full_content) > len(previous_content):
|
if len(full_content) > len(previous_content):
|
||||||
return full_content[len(previous_content):]
|
return full_content[len(previous_content) :]
|
||||||
return ""
|
return ""
|
||||||
|
|
@ -2,20 +2,27 @@
|
||||||
统一日志管理系统
|
统一日志管理系统
|
||||||
提供结构化日志记录功能,支持不同日志级别、文件输出、轮转等
|
提供结构化日志记录功能,支持不同日志级别、文件输出、轮转等
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
import json
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class LoggerSetup:
|
class LoggerSetup:
|
||||||
"""日志系统配置类"""
|
"""日志系统配置类"""
|
||||||
|
|
||||||
def __init__(self, name: str = "ai-chat-server", log_level: str = "INFO",
|
def __init__(
|
||||||
log_dir: str = "logs", max_bytes: int = 10 * 1024 * 1024, backup_count: int = 5):
|
self,
|
||||||
|
name: str = "ai-chat-server",
|
||||||
|
log_level: str = "INFO",
|
||||||
|
log_dir: str = "logs",
|
||||||
|
max_bytes: int = 10 * 1024 * 1024,
|
||||||
|
backup_count: int = 5,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
初始化日志系统
|
初始化日志系统
|
||||||
|
|
||||||
|
|
@ -37,7 +44,7 @@ class LoggerSetup:
|
||||||
|
|
||||||
# 设置日志格式(去掉 funcName:lineno,保持人类可读性)
|
# 设置日志格式(去掉 funcName:lineno,保持人类可读性)
|
||||||
self.formatter = logging.Formatter(
|
self.formatter = logging.Formatter(
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建logger实例
|
# 创建logger实例
|
||||||
|
|
@ -66,7 +73,7 @@ class LoggerSetup:
|
||||||
str(log_file),
|
str(log_file),
|
||||||
maxBytes=self.max_bytes,
|
maxBytes=self.max_bytes,
|
||||||
backupCount=self.backup_count,
|
backupCount=self.backup_count,
|
||||||
encoding='utf-8'
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
file_handler.setLevel(self.log_level)
|
file_handler.setLevel(self.log_level)
|
||||||
file_handler.setFormatter(self.formatter)
|
file_handler.setFormatter(self.formatter)
|
||||||
|
|
@ -83,9 +90,13 @@ class LoggerSetup:
|
||||||
_logger_instance = None
|
_logger_instance = None
|
||||||
|
|
||||||
|
|
||||||
def setup_global_logger(name: str = "ai-chat-server", log_level: str = "INFO",
|
def setup_global_logger(
|
||||||
log_dir: str = "logs", max_bytes: int = 10 * 1024 * 1024,
|
name: str = "ai-chat-server",
|
||||||
backup_count: int = 5):
|
log_level: str = "INFO",
|
||||||
|
log_dir: str = "logs",
|
||||||
|
max_bytes: int = 10 * 1024 * 1024,
|
||||||
|
backup_count: int = 5,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
设置全局日志系统
|
设置全局日志系统
|
||||||
|
|
||||||
|
|
@ -174,8 +185,13 @@ def log_structured(level: str, message: str, **details):
|
||||||
getattr(logger, level.lower())(formatted_msg)
|
getattr(logger, level.lower())(formatted_msg)
|
||||||
|
|
||||||
|
|
||||||
def log_request_info(method: str, path: str, client_ip: str = "unknown",
|
def log_request_info(
|
||||||
user_agent: str = "", referer: str = ""):
|
method: str,
|
||||||
|
path: str,
|
||||||
|
client_ip: str = "unknown",
|
||||||
|
user_agent: str = "",
|
||||||
|
referer: str = "",
|
||||||
|
):
|
||||||
"""记录请求信息日志"""
|
"""记录请求信息日志"""
|
||||||
log_structured(
|
log_structured(
|
||||||
"info",
|
"info",
|
||||||
|
|
@ -184,12 +200,17 @@ def log_request_info(method: str, path: str, client_ip: str = "unknown",
|
||||||
path=path,
|
path=path,
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
referer=referer
|
referer=referer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_response_info(status_code: int, process_time: float, path: str = "",
|
def log_response_info(
|
||||||
method: str = "", client_ip: str = ""):
|
status_code: int,
|
||||||
|
process_time: float,
|
||||||
|
path: str = "",
|
||||||
|
method: str = "",
|
||||||
|
client_ip: str = "",
|
||||||
|
):
|
||||||
"""记录响应信息日志"""
|
"""记录响应信息日志"""
|
||||||
log_structured(
|
log_structured(
|
||||||
"info",
|
"info",
|
||||||
|
|
@ -198,37 +219,52 @@ def log_response_info(status_code: int, process_time: float, path: str = "",
|
||||||
process_time_ms=process_time,
|
process_time_ms=process_time,
|
||||||
path=path,
|
path=path,
|
||||||
method=method,
|
method=method,
|
||||||
client_ip=client_ip
|
client_ip=client_ip,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_error_detail(error_type: str, error_message: str, traceback_info: str = "",
|
def log_error_detail(
|
||||||
context: dict = None):
|
error_type: str, error_message: str, traceback_info: str = "", context: dict = None
|
||||||
|
):
|
||||||
"""记录详细的错误信息"""
|
"""记录详细的错误信息"""
|
||||||
log_structured(
|
log_structured(
|
||||||
"error",
|
"error",
|
||||||
f"{error_type}: {error_message}",
|
f"{error_type}: {error_message}",
|
||||||
traceback=traceback_info,
|
traceback=traceback_info,
|
||||||
context=context or {}
|
context=context or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_chat_interaction(user_input: str, ai_response: str, model: str = "",
|
def log_chat_interaction(
|
||||||
conversation_id: str = "", tokens_used: dict = None):
|
user_input: str,
|
||||||
|
ai_response: str,
|
||||||
|
model: str = "",
|
||||||
|
conversation_id: str = "",
|
||||||
|
tokens_used: dict = None,
|
||||||
|
):
|
||||||
"""记录聊天交互日志"""
|
"""记录聊天交互日志"""
|
||||||
log_structured(
|
log_structured(
|
||||||
"info",
|
"info",
|
||||||
"Chat Interaction",
|
"Chat Interaction",
|
||||||
user_input=user_input[:100] + "..." if len(user_input) > 100 else user_input, # 截断长输入
|
user_input=(
|
||||||
ai_response=ai_response[:100] + "..." if len(ai_response) > 100 else ai_response,
|
user_input[:100] + "..." if len(user_input) > 100 else user_input
|
||||||
|
), # 截断长输入
|
||||||
|
ai_response=(
|
||||||
|
ai_response[:100] + "..." if len(ai_response) > 100 else ai_response
|
||||||
|
),
|
||||||
model=model,
|
model=model,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
tokens_used=tokens_used
|
tokens_used=tokens_used,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_system_status(status: str, uptime: float = 0, cpu_usage: float = 0,
|
def log_system_status(
|
||||||
memory_usage: float = 0, disk_usage: float = 0):
|
status: str,
|
||||||
|
uptime: float = 0,
|
||||||
|
cpu_usage: float = 0,
|
||||||
|
memory_usage: float = 0,
|
||||||
|
disk_usage: float = 0,
|
||||||
|
):
|
||||||
"""记录系统状态日志"""
|
"""记录系统状态日志"""
|
||||||
log_structured(
|
log_structured(
|
||||||
"info",
|
"info",
|
||||||
|
|
@ -237,5 +273,5 @@ def log_system_status(status: str, uptime: float = 0, cpu_usage: float = 0,
|
||||||
uptime_seconds=uptime,
|
uptime_seconds=uptime,
|
||||||
cpu_percent=cpu_usage,
|
cpu_percent=cpu_usage,
|
||||||
memory_percent=memory_usage,
|
memory_percent=memory_usage,
|
||||||
disk_percent=disk_usage
|
disk_percent=disk_usage,
|
||||||
)
|
)
|
||||||
|
|
@ -1,30 +1,36 @@
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Add project root to sys.path
|
# Add project root to sys.path
|
||||||
root_dir = Path(__file__).parent
|
root_dir = Path(__file__).parent
|
||||||
sys.path.insert(0, str(root_dir))
|
sys.path.insert(0, str(root_dir))
|
||||||
|
|
||||||
from utils.glm_adapter import glm_stream_generator, _ensure_venv, glm_chat_sync
|
|
||||||
|
|
||||||
# Set API key from .env if needed
|
# Set API key from .env if needed
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from utils.glm_adapter import _ensure_venv, glm_chat_sync, glm_stream_generator
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
async def test_stream():
|
async def test_stream():
|
||||||
msgs = [{"role": "user", "content": "今天北京天气怎样?"}]
|
msgs = [{"role": "user", "content": "今天北京天气怎样?"}]
|
||||||
print("Testing stream...")
|
print("Testing stream...")
|
||||||
async for chunk in glm_stream_generator(msgs, "glm-4.5-air", 0.7, 1024, web_search=True):
|
async for chunk in glm_stream_generator(
|
||||||
|
msgs, "glm-4.5-air", 0.7, 1024, web_search=True
|
||||||
|
):
|
||||||
print(chunk, end="")
|
print(chunk, end="")
|
||||||
|
|
||||||
|
|
||||||
def test_sync():
|
def test_sync():
|
||||||
msgs = [{"role": "user", "content": "今天几号?武汉天气怎样?"}]
|
msgs = [{"role": "user", "content": "今天几号?武汉天气怎样?"}]
|
||||||
print("Testing sync...")
|
print("Testing sync...")
|
||||||
res = glm_chat_sync(msgs, "glm-4.5-air", 0.7, 1024, web_search=True)
|
res = glm_chat_sync(msgs, "glm-4.5-air", 0.7, 1024, web_search=True)
|
||||||
print(res)
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_ensure_venv()
|
_ensure_venv()
|
||||||
# test_sync()
|
# test_sync()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue