refactor: 优化后端目录结构和代码组织
- 重构目录结构,将配置和日志模块分离到独立目录 - 创建 config/ 目录,统一管理平台配置和 API 密钥 - 创建 core/ 目录,集中管理日志系统功能 - 创建 database/ 目录,添加 SQLite 数据库初始化和管理 - 删除不必要的文件:测试文件、缓存文件、重复代码文件 - 更新所有导入路径,确保模块引用正确 主要变更: - config.py → config/settings.py - utils/logger.py → core/logger.py - init_logging.py → core/init.py - 删除 logging.conf(配置已整合到代码中) - 新增 database/__init__.py 提供数据库连接管理 改进点: - 更清晰的模块划分,便于维护和扩展 - 避免命名冲突(logging 模块与 Python 标准库冲突) - 统一的配置和日志管理接口
This commit is contained in:
parent
547ba742b7
commit
d8a6f696e7
271
README.md
271
README.md
|
|
@ -1,6 +1,6 @@
|
|||
# AI-CHAT-UI
|
||||
|
||||
一个基于 Vue 和 markstream-vue 构建的现代化 AI 对话界面,提供丰富的交互功能和精美的视觉体验。
|
||||
一个基于 Vue 和 markstream-vue 构建的现代化 AI 对话界面,提供丰富的交互功能和精美的视觉体验。支持多模型集成,包括智谱 GLM、阿里云百炼、OpenAI/Deepseek 等。
|
||||
|
||||
## 页面展示
|
||||
|
||||
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
## ✨ 核心功能
|
||||
|
||||
### 前端功能
|
||||
|
||||
| 功能 | 详细描述 |
|
||||
|-------|-------------------------------------|
|
||||
| 对话历史 | 支持多对话管理、置顶、重命名、删除 |
|
||||
|
|
@ -36,34 +38,106 @@
|
|||
| 数据管理 | 导入/导出设置、清除数据,保障数据安全 |
|
||||
| 预设提示词 | 快速选择常用角色设定,提升对话效率 |
|
||||
|
||||
### 后端功能
|
||||
|
||||
| 功能 | 详细描述 |
|
||||
|-------|-------------------------------------|
|
||||
| OpenAI 兼容 API | 提供标准 OpenAI 兼容接口,支持多模型路由 |
|
||||
| 多模型支持 | 集成智谱 GLM、阿里云百炼、OpenAI/Deepseek 等平台 |
|
||||
| 会话管理 | 支持对话历史的保存、加载、删除 |
|
||||
| 文件上传 | 支持图片、文档等文件的上传和管理 |
|
||||
| 健康检查 | 提供系统状态和可用模型检查 |
|
||||
| 日志系统 | 详细的请求日志和响应时间记录 |
|
||||
|
||||
## 🛠 技术栈
|
||||
|
||||
- **核心框架**: Vue
|
||||
|
||||
### 前端
|
||||
- **核心框架**: Vue 3
|
||||
- **流式渲染**: markstream-vue
|
||||
|
||||
- **类型系统**: TypeScript
|
||||
- **UI 设计**: 现代化响应式设计,支持暗色主题
|
||||
|
||||
- **交互体验**: 丰富的快捷键系统和消息操作
|
||||
|
||||
### 后端
|
||||
- **核心框架**: FastAPI
|
||||
- **编程语言**: Python 3.12+
|
||||
- **数据库**: SQLite
|
||||
- **API 设计**: RESTful API + OpenAI 兼容接口
|
||||
- **文件存储**: 本地文件系统
|
||||
|
||||
### 模型支持
|
||||
- **智谱 GLM**: glm-4-flash, glm-4, 等
|
||||
- **阿里云百炼**: qwen-turbo, qwen-plus, 等
|
||||
- **OpenAI/Deepseek**: gpt-3.5-turbo, gpt-4, deepseek-chat, 等
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
```Bash
|
||||
### 1. 克隆仓库
|
||||
|
||||
# 克隆仓库
|
||||
```bash
|
||||
git clone https://github.com/zll-it/ai-chat-ui.git
|
||||
|
||||
# 进入项目目录
|
||||
cd ai-chat-ui
|
||||
```
|
||||
|
||||
# 安装依赖
|
||||
### 2. 安装前端依赖
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
# 启动开发服务器
|
||||
### 3. 安装后端依赖
|
||||
|
||||
```bash
|
||||
cd server
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate # Linux/Mac
|
||||
# .venv\Scripts\activate # Windows
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 4. 配置 API 密钥
|
||||
|
||||
在 `server` 目录下创建 `.env` 文件,添加以下配置:
|
||||
|
||||
```env
|
||||
# 智谱 GLM API 配置
|
||||
ZHIPUAI_API_KEY=your_zhipuai_api_key
|
||||
|
||||
# 阿里云百炼 API 配置
|
||||
DASHSCOPE_API_KEY=your_dashscope_api_key
|
||||
|
||||
# OpenAI API 配置(支持 Deepseek 等兼容接口)
|
||||
OPENAI_API_KEY=your_openai_api_key
|
||||
OPENAI_API_BASE=https://api.openai.com/v1 # 或其他兼容接口地址
|
||||
|
||||
# 服务器配置
|
||||
PORT=8000
|
||||
```
|
||||
|
||||
### 5. 启动服务
|
||||
|
||||
#### 启动后端服务器
|
||||
|
||||
```bash
|
||||
# 在 server 目录下
|
||||
python main.py
|
||||
```
|
||||
|
||||
#### 启动前端开发服务器
|
||||
|
||||
```bash
|
||||
# 在项目根目录下
|
||||
npm run dev
|
||||
```
|
||||
|
||||
# 构建生产版本
|
||||
### 6. 构建生产版本
|
||||
|
||||
```bash
|
||||
# 构建前端
|
||||
npm run build
|
||||
|
||||
# 后端服务可以直接运行
|
||||
python server/main.py
|
||||
```
|
||||
|
||||
## 📋 使用说明
|
||||
|
|
@ -71,11 +145,8 @@ npm run build
|
|||
### 基础操作
|
||||
|
||||
- **新建对话**: `Ctrl+N` 快捷键或点击页面右上角 "+" 按钮
|
||||
|
||||
- **切换布局**: 点击页面右下角布局切换按钮
|
||||
|
||||
- **主题切换**: 设置面板中选择浅色/深色/跟随系统
|
||||
|
||||
- **搜索对话**: 使用页面顶部搜索框或快捷键 `Ctrl+K`
|
||||
|
||||
### 快捷键一览
|
||||
|
|
@ -87,7 +158,146 @@ npm run build
|
|||
| 复制当前消息 | Ctrl+C (消息 hover 时) |
|
||||
| 切换布局 | Ctrl+Shift+L |
|
||||
|
||||
## 📄 许可证
|
||||
### 模型选择
|
||||
|
||||
在对话设置中,可以选择不同的模型:
|
||||
|
||||
- **智谱 GLM**: 国内高性能模型,响应速度快
|
||||
- **阿里云百炼**: 支持多语言和多模态
|
||||
- **OpenAI**: 全球领先的 GPT 系列模型
|
||||
- **Deepseek**: 专注于代码和技术领域的模型
|
||||
|
||||
## 📡 API 文档
|
||||
|
||||
### OpenAI 兼容接口
|
||||
|
||||
#### POST /v1/chat/completions
|
||||
|
||||
标准的 OpenAI 兼容聊天接口,支持流式输出。
|
||||
|
||||
**请求示例**:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "glm-4-flash",
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好,请介绍一下你自己"}
|
||||
],
|
||||
"stream": true,
|
||||
"temperature": 0.7
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677858242,
|
||||
"model": "glm-4-flash",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是一个基于智谱 GLM 模型的AI助手..."
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### GET /v1/models
|
||||
|
||||
获取所有可用模型列表。
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "glm-4-flash",
|
||||
"object": "model",
|
||||
"created": 1677825464,
|
||||
"owned_by": "zhipuai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-3.5-turbo",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "openai"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 传统接口
|
||||
|
||||
#### POST /api/chat-ui/chat
|
||||
|
||||
传统聊天接口,保持向后兼容。
|
||||
|
||||
#### GET /api/chat-ui/models
|
||||
|
||||
获取模型列表(聚合所有可用平台)。
|
||||
|
||||
#### GET /api/chat-ui/conversations
|
||||
|
||||
获取所有对话历史。
|
||||
|
||||
#### POST /api/chat-ui/upload
|
||||
|
||||
上传文件(支持图片、文档等)。
|
||||
|
||||
## 🔧 配置说明
|
||||
|
||||
### 前端配置
|
||||
|
||||
前端配置文件位于 `src/config.ts`,主要配置项:
|
||||
|
||||
- **API_BASE_URL**: 后端 API 地址
|
||||
- **DEFAULT_MODEL**: 默认模型
|
||||
- **THEME**: 默认主题
|
||||
- **FONT_SIZE**: 默认字体大小
|
||||
|
||||
### 后端配置
|
||||
|
||||
后端配置通过 `.env` 文件和 `server/config.py` 进行管理:
|
||||
|
||||
- **PORT**: 服务器端口
|
||||
- **API 密钥**: 各平台的 API 密钥
|
||||
- **上传目录**: 文件上传的存储路径
|
||||
- **数据库配置**: SQLite 数据库设置
|
||||
|
||||
## <20> 项目结构
|
||||
|
||||
```
|
||||
ai-chat-ui/
|
||||
├── src/ # 前端源码
|
||||
│ ├── components/ # Vue 组件
|
||||
│ ├── services/ # API 服务
|
||||
│ ├── utils/ # 工具函数
|
||||
│ ├── App.vue # 主应用组件
|
||||
│ └── main.ts # 入口文件
|
||||
├── server/ # 后端源码
|
||||
│ ├── adapters/ # 模型适配器
|
||||
│ ├── api/ # API 路由
|
||||
│ ├── database/ # 数据库操作
|
||||
│ ├── utils/ # 工具函数
|
||||
│ ├── main.py # 主入口
|
||||
│ └── requirements.txt # 依赖文件
|
||||
├── public/ # 静态资源
|
||||
├── screenshots/ # 截图
|
||||
├── package.json # 前端依赖
|
||||
├── tsconfig.json # TypeScript 配置
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## <20>📄 许可证
|
||||
|
||||
本项目采用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件
|
||||
|
||||
|
|
@ -95,10 +305,35 @@ npm run build
|
|||
|
||||
欢迎提交 Issue 和 Pull Request 来帮助改进这个项目!
|
||||
|
||||
### 开发流程
|
||||
|
||||
1. Fork 本仓库
|
||||
2. 创建特性分支 (`git checkout -b feature/amazing-feature`)
|
||||
3. 提交更改 (`git commit -m 'Add amazing feature'`)
|
||||
4. 推送到分支 (`git push origin feature/amazing-feature`)
|
||||
5. 打开 Pull Request
|
||||
|
||||
## 🌟 特性亮点
|
||||
|
||||
- **多模型集成**: 支持多个 AI 平台的模型,灵活切换
|
||||
- **现代化界面**: 美观的 UI 设计,支持深色主题
|
||||
- **流畅体验**: 流式输出和丰富的交互效果
|
||||
- **功能丰富**: 完整的对话管理和设置选项
|
||||
- **易于部署**: 简单的配置和启动流程
|
||||
- **OpenAI 兼容**: 标准的 OpenAI 兼容接口
|
||||
|
||||
## 📞 支持
|
||||
|
||||
如果您有任何问题或建议,欢迎:
|
||||
|
||||
- 提交 Issue
|
||||
- 发送邮件至:contact@example.com
|
||||
- 参与项目讨论
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
<sub>Made with ❤️ using Vue & markstream-vue</sub>
|
||||
<sub>Made with ❤️ using Vue & FastAPI</sub>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -10,7 +10,7 @@ from typing import Dict, List
|
|||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from .base import BaseAdapter, ChatCompletionRequest, ModelInfo
|
||||
from utils.logger import get_logger
|
||||
from core import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ DASHSCOPE_MODELS = [
|
|||
description="最强大的模型",
|
||||
max_tokens=8192,
|
||||
provider="Aliyun",
|
||||
supports_thinking=True,
|
||||
supports_thinking=False,
|
||||
supports_web_search=False,
|
||||
supports_vision=False,
|
||||
supports_files=False,
|
||||
|
|
@ -33,7 +33,7 @@ DASHSCOPE_MODELS = [
|
|||
description="能力均衡",
|
||||
max_tokens=8192,
|
||||
provider="Aliyun",
|
||||
supports_thinking=False,
|
||||
supports_thinking=True,
|
||||
supports_web_search=False,
|
||||
supports_vision=False,
|
||||
supports_files=False,
|
||||
|
|
@ -44,7 +44,7 @@ DASHSCOPE_MODELS = [
|
|||
description="速度更快、成本更低",
|
||||
max_tokens=8192,
|
||||
provider="Aliyun",
|
||||
supports_thinking=False,
|
||||
supports_thinking=True,
|
||||
supports_web_search=False,
|
||||
supports_vision=False,
|
||||
supports_files=False,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import Dict, List, Optional
|
|||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from .base import BaseAdapter, ChatCompletionRequest, ModelInfo
|
||||
from utils.logger import get_logger
|
||||
from core import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Dict, List, Optional
|
|||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from .base import BaseAdapter, ChatCompletionRequest, ModelInfo
|
||||
from utils.logger import get_logger
|
||||
from core import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ sys.path.append(str(Path(__file__).parent.parent))
|
|||
|
||||
from database import get_db
|
||||
from utils.helpers import generate_unique_id
|
||||
from utils.logger import log_error, log_exception, log_info
|
||||
from core import log_error, log_exception, log_info
|
||||
|
||||
# 配置上传目录
|
||||
upload_dir = Path(__file__).parent.parent / "uploads"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from fastapi.responses import JSONResponse
|
|||
|
||||
from adapters import get_adapter, get_provider_from_model
|
||||
from adapters.base import ChatCompletionRequest
|
||||
from utils.logger import get_logger
|
||||
from core import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
"""
|
||||
配置模块
|
||||
|
||||
提供统一的配置管理功能,包括平台配置、API密钥管理等。
|
||||
"""
|
||||
|
||||
from .settings import (
|
||||
ProviderConfig,
|
||||
PROVIDERS,
|
||||
get_provider_config,
|
||||
is_provider_available,
|
||||
get_available_providers,
|
||||
DEFAULT_PROVIDER,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ProviderConfig",
|
||||
"PROVIDERS",
|
||||
"get_provider_config",
|
||||
"is_provider_available",
|
||||
"get_available_providers",
|
||||
"DEFAULT_PROVIDER",
|
||||
]
|
||||
|
|
@ -56,4 +56,4 @@ def get_available_providers() -> list:
|
|||
|
||||
|
||||
# 默认平台
|
||||
DEFAULT_PROVIDER = os.getenv("DEFAULT_PROVIDER", "glm")
|
||||
DEFAULT_PROVIDER = os.getenv("DEFAULT_PROVIDER", "glm")
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
日志模块
|
||||
|
||||
提供统一的日志管理功能,支持结构化日志、文件轮转、多级别日志等。
|
||||
"""
|
||||
|
||||
from .logger import (
|
||||
LoggerSetup,
|
||||
setup_global_logger,
|
||||
get_logger,
|
||||
log_debug,
|
||||
log_info,
|
||||
log_warning,
|
||||
log_error,
|
||||
log_critical,
|
||||
log_exception,
|
||||
log_structured,
|
||||
log_request_info,
|
||||
log_response_info,
|
||||
log_error_detail,
|
||||
log_chat_interaction,
|
||||
log_system_status,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LoggerSetup",
|
||||
"setup_global_logger",
|
||||
"get_logger",
|
||||
"log_debug",
|
||||
"log_info",
|
||||
"log_warning",
|
||||
"log_error",
|
||||
"log_critical",
|
||||
"log_exception",
|
||||
"log_structured",
|
||||
"log_request_info",
|
||||
"log_response_info",
|
||||
"log_error_detail",
|
||||
"log_chat_interaction",
|
||||
"log_system_status",
|
||||
]
|
||||
|
|
@ -1,40 +1,40 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
初始化日志系统
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from utils.logger import setup_global_logger
|
||||
|
||||
|
||||
def init_logging_system():
|
||||
"""
|
||||
初始化日志系统
|
||||
"""
|
||||
# 从环境变量获取日志配置,如果没有则使用默认值
|
||||
log_level = os.getenv("LOG_LEVEL", "INFO")
|
||||
log_dir = os.getenv("LOG_DIR", "logs")
|
||||
|
||||
# 尝试从配置文件读取值
|
||||
try:
|
||||
with open("logging.conf", "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.startswith("LOG_LEVEL="):
|
||||
log_level = line.split("=", 1)[1].strip()
|
||||
elif line.startswith("LOG_DIR="):
|
||||
log_dir = line.split("=", 1)[1].strip()
|
||||
except FileNotFoundError:
|
||||
pass # 如果配置文件不存在,则使用环境变量或默认值
|
||||
|
||||
# 设置全局日志系统
|
||||
logger = setup_global_logger(
|
||||
name="ai-chat-api", log_level=log_level, log_dir=log_dir
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = init_logging_system()
|
||||
logger.info("Logging system initialized successfully")
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
初始化日志系统
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from .logger import setup_global_logger
|
||||
|
||||
|
||||
def init_logging_system():
|
||||
"""
|
||||
初始化日志系统
|
||||
"""
|
||||
# 从环境变量获取日志配置,如果没有则使用默认值
|
||||
log_level = os.getenv("LOG_LEVEL", "INFO")
|
||||
log_dir = os.getenv("LOG_DIR", "logs")
|
||||
|
||||
# 尝试从配置文件读取值
|
||||
try:
|
||||
with open("logging.conf", "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.startswith("LOG_LEVEL="):
|
||||
log_level = line.split("=", 1)[1].strip()
|
||||
elif line.startswith("LOG_DIR="):
|
||||
log_dir = line.split("=", 1)[1].strip()
|
||||
except FileNotFoundError:
|
||||
pass # 如果配置文件不存在,则使用环境变量或默认值
|
||||
|
||||
# 设置全局日志系统
|
||||
logger = setup_global_logger(
|
||||
name="ai-chat-api", log_level=log_level, log_dir=log_dir
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = init_logging_system()
|
||||
logger.info("Logging system initialized successfully")
|
||||
|
|
@ -1,277 +1,277 @@
|
|||
"""
|
||||
统一日志管理系统
|
||||
提供结构化日志记录功能,支持不同日志级别、文件输出、轮转等
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class LoggerSetup:
|
||||
"""日志系统配置类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "ai-chat-server",
|
||||
log_level: str = "INFO",
|
||||
log_dir: str = "logs",
|
||||
max_bytes: int = 10 * 1024 * 1024,
|
||||
backup_count: int = 5,
|
||||
):
|
||||
"""
|
||||
初始化日志系统
|
||||
|
||||
Args:
|
||||
name: 日志记录器名称
|
||||
log_level: 日志级别 ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL')
|
||||
log_dir: 日志文件存储目录
|
||||
max_bytes: 单个日志文件最大大小(字节)
|
||||
backup_count: 保留的备份文件数量
|
||||
"""
|
||||
self.name = name
|
||||
self.log_level = getattr(logging, log_level.upper(), logging.INFO)
|
||||
self.log_dir = Path(log_dir)
|
||||
self.max_bytes = max_bytes
|
||||
self.backup_count = backup_count
|
||||
|
||||
# 创建日志目录
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 设置日志格式(去掉 funcName:lineno,保持人类可读性)
|
||||
self.formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
# 创建logger实例
|
||||
self.logger = self._setup_logger()
|
||||
|
||||
def _setup_logger(self):
|
||||
"""设置logger实例"""
|
||||
logger = logging.getLogger(self.name)
|
||||
logger.setLevel(self.log_level)
|
||||
|
||||
# 避免重复添加处理器
|
||||
if logger.handlers:
|
||||
logger.handlers.clear()
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(self.log_level)
|
||||
console_handler.setFormatter(self.formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器 - 按日期分割
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
log_file = self.log_dir / f"{self.name}_{date_str}.log"
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
str(log_file),
|
||||
maxBytes=self.max_bytes,
|
||||
backupCount=self.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setLevel(self.log_level)
|
||||
file_handler.setFormatter(self.formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
def get_logger(self):
|
||||
"""获取配置好的logger实例"""
|
||||
return self.logger
|
||||
|
||||
|
||||
# 全局日志实例
|
||||
_logger_instance = None
|
||||
|
||||
|
||||
def setup_global_logger(
|
||||
name: str = "ai-chat-server",
|
||||
log_level: str = "INFO",
|
||||
log_dir: str = "logs",
|
||||
max_bytes: int = 10 * 1024 * 1024,
|
||||
backup_count: int = 5,
|
||||
):
|
||||
"""
|
||||
设置全局日志系统
|
||||
|
||||
Args:
|
||||
name: 日志记录器名称
|
||||
log_level: 日志级别
|
||||
log_dir: 日志文件目录
|
||||
max_bytes: 最大文件大小
|
||||
backup_count: 备份文件数
|
||||
"""
|
||||
global _logger_instance
|
||||
logger_setup = LoggerSetup(name, log_level, log_dir, max_bytes, backup_count)
|
||||
_logger_instance = logger_setup.get_logger()
|
||||
return _logger_instance
|
||||
|
||||
|
||||
def get_logger(name: str = None):
|
||||
"""
|
||||
获取日志记录器实例
|
||||
|
||||
Args:
|
||||
name: 如果提供,返回子记录器;否则返回全局记录器
|
||||
"""
|
||||
global _logger_instance
|
||||
if _logger_instance is None:
|
||||
# 如果没有初始化,默认创建一个
|
||||
_logger_instance = setup_global_logger()
|
||||
|
||||
if name and name != _logger_instance.name:
|
||||
return _logger_instance.getChild(name)
|
||||
return _logger_instance
|
||||
|
||||
|
||||
# 便捷的日志记录函数
|
||||
def log_debug(message: str, *args, **kwargs):
|
||||
"""记录DEBUG级别日志"""
|
||||
logger = get_logger()
|
||||
logger.debug(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_info(message: str, *args, **kwargs):
|
||||
"""记录INFO级别日志"""
|
||||
logger = get_logger()
|
||||
logger.info(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_warning(message: str, *args, **kwargs):
|
||||
"""记录WARNING级别日志"""
|
||||
logger = get_logger()
|
||||
logger.warning(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_error(message: str, *args, **kwargs):
|
||||
"""记录ERROR级别日志"""
|
||||
logger = get_logger()
|
||||
logger.error(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_critical(message: str, *args, **kwargs):
|
||||
"""记录CRITICAL级别日志"""
|
||||
logger = get_logger()
|
||||
logger.critical(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_exception(message: str = ""):
|
||||
"""记录异常信息"""
|
||||
logger = get_logger()
|
||||
logger.exception(message)
|
||||
|
||||
|
||||
def log_structured(level: str, message: str, **details):
|
||||
"""
|
||||
记录结构化日志
|
||||
|
||||
Args:
|
||||
level: 日志级别
|
||||
message: 日志消息
|
||||
**details: 额外的结构化数据
|
||||
"""
|
||||
logger = get_logger()
|
||||
# 为了开发时的可读性,不再使用单行 JSON 打印全结构
|
||||
# 转换为更易读的格式
|
||||
detail_str = ", ".join(f"{k}={v}" for k, v in details.items() if v)
|
||||
formatted_msg = f"[{message}] {detail_str}"
|
||||
|
||||
getattr(logger, level.lower())(formatted_msg)
|
||||
|
||||
|
||||
def log_request_info(
|
||||
method: str,
|
||||
path: str,
|
||||
client_ip: str = "unknown",
|
||||
user_agent: str = "",
|
||||
referer: str = "",
|
||||
):
|
||||
"""记录请求信息日志"""
|
||||
log_structured(
|
||||
"info",
|
||||
"API Request",
|
||||
method=method,
|
||||
path=path,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
)
|
||||
|
||||
|
||||
def log_response_info(
|
||||
status_code: int,
|
||||
process_time: float,
|
||||
path: str = "",
|
||||
method: str = "",
|
||||
client_ip: str = "",
|
||||
):
|
||||
"""记录响应信息日志"""
|
||||
log_structured(
|
||||
"info",
|
||||
"API Response",
|
||||
status_code=status_code,
|
||||
process_time_ms=process_time,
|
||||
path=path,
|
||||
method=method,
|
||||
client_ip=client_ip,
|
||||
)
|
||||
|
||||
|
||||
def log_error_detail(
|
||||
error_type: str, error_message: str, traceback_info: str = "", context: dict = None
|
||||
):
|
||||
"""记录详细的错误信息"""
|
||||
log_structured(
|
||||
"error",
|
||||
f"{error_type}: {error_message}",
|
||||
traceback=traceback_info,
|
||||
context=context or {},
|
||||
)
|
||||
|
||||
|
||||
def log_chat_interaction(
|
||||
user_input: str,
|
||||
ai_response: str,
|
||||
model: str = "",
|
||||
conversation_id: str = "",
|
||||
tokens_used: dict = None,
|
||||
):
|
||||
"""记录聊天交互日志"""
|
||||
log_structured(
|
||||
"info",
|
||||
"Chat Interaction",
|
||||
user_input=(
|
||||
user_input[:100] + "..." if len(user_input) > 100 else user_input
|
||||
), # 截断长输入
|
||||
ai_response=(
|
||||
ai_response[:100] + "..." if len(ai_response) > 100 else ai_response
|
||||
),
|
||||
model=model,
|
||||
conversation_id=conversation_id,
|
||||
tokens_used=tokens_used,
|
||||
)
|
||||
|
||||
|
||||
def log_system_status(
|
||||
status: str,
|
||||
uptime: float = 0,
|
||||
cpu_usage: float = 0,
|
||||
memory_usage: float = 0,
|
||||
disk_usage: float = 0,
|
||||
):
|
||||
"""记录系统状态日志"""
|
||||
log_structured(
|
||||
"info",
|
||||
"System Status",
|
||||
status=status,
|
||||
uptime_seconds=uptime,
|
||||
cpu_percent=cpu_usage,
|
||||
memory_percent=memory_usage,
|
||||
disk_percent=disk_usage,
|
||||
)
|
||||
"""
|
||||
统一日志管理系统
|
||||
提供结构化日志记录功能,支持不同日志级别、文件输出、轮转等
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class LoggerSetup:
|
||||
"""日志系统配置类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "ai-chat-server",
|
||||
log_level: str = "INFO",
|
||||
log_dir: str = "logs",
|
||||
max_bytes: int = 10 * 1024 * 1024,
|
||||
backup_count: int = 5,
|
||||
):
|
||||
"""
|
||||
初始化日志系统
|
||||
|
||||
Args:
|
||||
name: 日志记录器名称
|
||||
log_level: 日志级别 ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL')
|
||||
log_dir: 日志文件存储目录
|
||||
max_bytes: 单个日志文件最大大小(字节)
|
||||
backup_count: 保留的备份文件数量
|
||||
"""
|
||||
self.name = name
|
||||
self.log_level = getattr(logging, log_level.upper(), logging.INFO)
|
||||
self.log_dir = Path(log_dir)
|
||||
self.max_bytes = max_bytes
|
||||
self.backup_count = backup_count
|
||||
|
||||
# 创建日志目录
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 设置日志格式(去掉 funcName:lineno,保持人类可读性)
|
||||
self.formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
# 创建logger实例
|
||||
self.logger = self._setup_logger()
|
||||
|
||||
def _setup_logger(self):
|
||||
"""设置logger实例"""
|
||||
logger = logging.getLogger(self.name)
|
||||
logger.setLevel(self.log_level)
|
||||
|
||||
# 避免重复添加处理器
|
||||
if logger.handlers:
|
||||
logger.handlers.clear()
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(self.log_level)
|
||||
console_handler.setFormatter(self.formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器 - 按日期分割
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
log_file = self.log_dir / f"{self.name}_{date_str}.log"
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
str(log_file),
|
||||
maxBytes=self.max_bytes,
|
||||
backupCount=self.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setLevel(self.log_level)
|
||||
file_handler.setFormatter(self.formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
def get_logger(self):
|
||||
"""获取配置好的logger实例"""
|
||||
return self.logger
|
||||
|
||||
|
||||
# 全局日志实例
|
||||
_logger_instance = None
|
||||
|
||||
|
||||
def setup_global_logger(
|
||||
name: str = "ai-chat-server",
|
||||
log_level: str = "INFO",
|
||||
log_dir: str = "logs",
|
||||
max_bytes: int = 10 * 1024 * 1024,
|
||||
backup_count: int = 5,
|
||||
):
|
||||
"""
|
||||
设置全局日志系统
|
||||
|
||||
Args:
|
||||
name: 日志记录器名称
|
||||
log_level: 日志级别
|
||||
log_dir: 日志文件目录
|
||||
max_bytes: 最大文件大小
|
||||
backup_count: 备份文件数
|
||||
"""
|
||||
global _logger_instance
|
||||
logger_setup = LoggerSetup(name, log_level, log_dir, max_bytes, backup_count)
|
||||
_logger_instance = logger_setup.get_logger()
|
||||
return _logger_instance
|
||||
|
||||
|
||||
def get_logger(name: str = None):
|
||||
"""
|
||||
获取日志记录器实例
|
||||
|
||||
Args:
|
||||
name: 如果提供,返回子记录器;否则返回全局记录器
|
||||
"""
|
||||
global _logger_instance
|
||||
if _logger_instance is None:
|
||||
# 如果没有初始化,默认创建一个
|
||||
_logger_instance = setup_global_logger()
|
||||
|
||||
if name and name != _logger_instance.name:
|
||||
return _logger_instance.getChild(name)
|
||||
return _logger_instance
|
||||
|
||||
|
||||
# 便捷的日志记录函数
|
||||
def log_debug(message: str, *args, **kwargs):
|
||||
"""记录DEBUG级别日志"""
|
||||
logger = get_logger()
|
||||
logger.debug(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_info(message: str, *args, **kwargs):
|
||||
"""记录INFO级别日志"""
|
||||
logger = get_logger()
|
||||
logger.info(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_warning(message: str, *args, **kwargs):
|
||||
"""记录WARNING级别日志"""
|
||||
logger = get_logger()
|
||||
logger.warning(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_error(message: str, *args, **kwargs):
|
||||
"""记录ERROR级别日志"""
|
||||
logger = get_logger()
|
||||
logger.error(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_critical(message: str, *args, **kwargs):
|
||||
"""记录CRITICAL级别日志"""
|
||||
logger = get_logger()
|
||||
logger.critical(message, *args, **kwargs)
|
||||
|
||||
|
||||
def log_exception(message: str = ""):
|
||||
"""记录异常信息"""
|
||||
logger = get_logger()
|
||||
logger.exception(message)
|
||||
|
||||
|
||||
def log_structured(level: str, message: str, **details):
|
||||
"""
|
||||
记录结构化日志
|
||||
|
||||
Args:
|
||||
level: 日志级别
|
||||
message: 日志消息
|
||||
**details: 额外的结构化数据
|
||||
"""
|
||||
logger = get_logger()
|
||||
# 为了开发时的可读性,不再使用单行 JSON 打印全结构
|
||||
# 转换为更易读的格式
|
||||
detail_str = ", ".join(f"{k}={v}" for k, v in details.items() if v)
|
||||
formatted_msg = f"[{message}] {detail_str}"
|
||||
|
||||
getattr(logger, level.lower())(formatted_msg)
|
||||
|
||||
|
||||
def log_request_info(
|
||||
method: str,
|
||||
path: str,
|
||||
client_ip: str = "unknown",
|
||||
user_agent: str = "",
|
||||
referer: str = "",
|
||||
):
|
||||
"""记录请求信息日志"""
|
||||
log_structured(
|
||||
"info",
|
||||
"API Request",
|
||||
method=method,
|
||||
path=path,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
)
|
||||
|
||||
|
||||
def log_response_info(
|
||||
status_code: int,
|
||||
process_time: float,
|
||||
path: str = "",
|
||||
method: str = "",
|
||||
client_ip: str = "",
|
||||
):
|
||||
"""记录响应信息日志"""
|
||||
log_structured(
|
||||
"info",
|
||||
"API Response",
|
||||
status_code=status_code,
|
||||
process_time_ms=process_time,
|
||||
path=path,
|
||||
method=method,
|
||||
client_ip=client_ip,
|
||||
)
|
||||
|
||||
|
||||
def log_error_detail(
|
||||
error_type: str, error_message: str, traceback_info: str = "", context: dict = None
|
||||
):
|
||||
"""记录详细的错误信息"""
|
||||
log_structured(
|
||||
"error",
|
||||
f"{error_type}: {error_message}",
|
||||
traceback=traceback_info,
|
||||
context=context or {},
|
||||
)
|
||||
|
||||
|
||||
def log_chat_interaction(
|
||||
user_input: str,
|
||||
ai_response: str,
|
||||
model: str = "",
|
||||
conversation_id: str = "",
|
||||
tokens_used: dict = None,
|
||||
):
|
||||
"""记录聊天交互日志"""
|
||||
log_structured(
|
||||
"info",
|
||||
"Chat Interaction",
|
||||
user_input=(
|
||||
user_input[:100] + "..." if len(user_input) > 100 else user_input
|
||||
), # 截断长输入
|
||||
ai_response=(
|
||||
ai_response[:100] + "..." if len(ai_response) > 100 else ai_response
|
||||
),
|
||||
model=model,
|
||||
conversation_id=conversation_id,
|
||||
tokens_used=tokens_used,
|
||||
)
|
||||
|
||||
|
||||
def log_system_status(
|
||||
status: str,
|
||||
uptime: float = 0,
|
||||
cpu_usage: float = 0,
|
||||
memory_usage: float = 0,
|
||||
disk_usage: float = 0,
|
||||
):
|
||||
"""记录系统状态日志"""
|
||||
log_structured(
|
||||
"info",
|
||||
"System Status",
|
||||
status=status,
|
||||
uptime_seconds=uptime,
|
||||
cpu_percent=cpu_usage,
|
||||
memory_percent=memory_usage,
|
||||
disk_percent=disk_usage,
|
||||
)
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
"""
|
||||
数据库模块
|
||||
|
||||
提供 SQLite 数据库连接和会话管理功能。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
# 默认数据库路径
|
||||
DEFAULT_DB_PATH = Path(__file__).parent.parent / "data" / "chat.db"
|
||||
|
||||
|
||||
def init_db(db_path: Optional[str] = None):
|
||||
"""
|
||||
初始化数据库
|
||||
创建必要的表结构
|
||||
"""
|
||||
if db_path is None:
|
||||
db_path = os.getenv("DB_PATH", str(DEFAULT_DB_PATH))
|
||||
|
||||
# 确保数据目录存在
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建会话表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT,
|
||||
model TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建消息表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT,
|
||||
role TEXT,
|
||||
content TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建文件表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT,
|
||||
filename TEXT,
|
||||
file_path TEXT,
|
||||
file_type TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
print(f"[数据库] 初始化完成: {db_path}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db(db_path: Optional[str] = None):
|
||||
"""
|
||||
获取数据库连接的上下文管理器
|
||||
|
||||
用法:
|
||||
with get_db() as db:
|
||||
cursor = db.execute("SELECT * FROM conversations")
|
||||
rows = cursor.fetchall()
|
||||
"""
|
||||
if db_path is None:
|
||||
db_path = os.getenv("DB_PATH", str(DEFAULT_DB_PATH))
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
# 日志配置文件
|
||||
# 可以在 .env 文件中设置以下环境变量来控制日志行为
|
||||
|
||||
# 日志级别: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# 日志文件目录
|
||||
LOG_DIR=logs
|
||||
|
||||
# 日志文件最大大小 (字节)
|
||||
LOG_MAX_BYTES=10485760 # 10MB
|
||||
|
||||
# 保留的备份日志文件数量
|
||||
LOG_BACKUP_COUNT=5
|
||||
|
|
@ -41,7 +41,7 @@ sys.path.append("/home/mt/project/ai-chat-ui/server")
|
|||
|
||||
# ── 工具/日志(与平台无关)───────────────────────────────────────────
|
||||
from utils.helpers import log_response
|
||||
from utils.logger import get_logger
|
||||
from core import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ async def logging_middleware(request: Request, call_next):
|
|||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
from config import get_available_providers
|
||||
from config.settings import get_available_providers
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
|
|
@ -216,7 +216,7 @@ if __name__ == "__main__":
|
|||
port = int(os.getenv("PORT", 8000))
|
||||
|
||||
# 获取可用平台
|
||||
from config import get_available_providers
|
||||
from config.settings import get_available_providers
|
||||
|
||||
available = get_available_providers()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,67 +0,0 @@
|
|||
"""
|
||||
GLM 文件 ID 缓存(基于磁盘的简单 KV,sha256 → file_id,3天有效期)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
_CACHE_FILE = Path(__file__).parent.parent / "uploads" / ".glm_file_cache.json"
|
||||
_lock = threading.Lock()
|
||||
_TTL = 3 * 24 * 3600 # 3天
|
||||
|
||||
|
||||
def _load() -> dict:
|
||||
try:
|
||||
if _CACHE_FILE.exists():
|
||||
return json.loads(_CACHE_FILE.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save(data: dict) -> None:
|
||||
try:
|
||||
_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_CACHE_FILE.write_text(
|
||||
json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[file_cache] 写入失败:{e}")
|
||||
|
||||
|
||||
def sha256_of_file(file_path: Path) -> str:
|
||||
h = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def get(file_hash: str) -> dict | None:
|
||||
with _lock:
|
||||
data = _load()
|
||||
entry = data.get(file_hash)
|
||||
if not entry:
|
||||
return None
|
||||
if entry.get("expires_at", 0) <= time.time():
|
||||
data.pop(file_hash, None)
|
||||
_save(data)
|
||||
return None
|
||||
return entry
|
||||
|
||||
|
||||
def set(file_hash: str, file_id: str) -> None:
|
||||
with _lock:
|
||||
data = _load()
|
||||
data[file_hash] = {"file_id": file_id, "expires_at": time.time() + _TTL}
|
||||
_save(data)
|
||||
|
||||
|
||||
def delete(file_hash: str) -> None:
|
||||
with _lock:
|
||||
data = _load()
|
||||
data.pop(file_hash, None)
|
||||
_save(data)
|
||||
|
|
@ -1,523 +0,0 @@
|
|||
"""
|
||||
GLM-4.6V 适配层(基于 zai-sdk)
|
||||
SDK:pip install zai-sdk
|
||||
模型:glm-4.6v(支持文本/图像/文档/深度思考)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
|
||||
# ── 自动注入 venv site-packages ───────────────────────────────────────
|
||||
def _ensure_venv():
|
||||
server_dir = Path(__file__).parent.parent
|
||||
for sp in sorted(
|
||||
(server_dir / ".venv" / "lib").glob("python*/site-packages"), reverse=True
|
||||
):
|
||||
if sp.exists() and str(sp) not in sys.path:
|
||||
sys.path.insert(0, str(sp))
|
||||
print(f"[GLM] venv 注入:{sp}")
|
||||
break
|
||||
|
||||
|
||||
# ── 客户端单例 ────────────────────────────────────────────────────────
|
||||
_client = None
|
||||
|
||||
|
||||
def get_client():
|
||||
global _client
|
||||
if _client is None:
|
||||
_ensure_venv()
|
||||
try:
|
||||
from zai import ZhipuAiClient
|
||||
except ImportError:
|
||||
raise ImportError("GLM 模式需要安装 zai-sdk:.venv/bin/pip install zai-sdk")
|
||||
api_key = os.getenv("ZHIPU_API_KEY").strip() or os.getenv("GLM_API_KEY").strip()
|
||||
if not api_key:
|
||||
raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY")
|
||||
_client = ZhipuAiClient(api_key=api_key)
|
||||
print("[GLM] ZhipuAiClient 初始化完成(zai-sdk)")
|
||||
return _client
|
||||
|
||||
|
||||
# ── 模型映射 ──────────────────────────────────────────────────────────
|
||||
DEFAULT_TEXT_MODEL = "glm-4-flash" # 默认文本模型
|
||||
DEFAULT_VISION_MODEL = "glm-4.6v" # 图片/附件识别用 glm-4.6v
|
||||
|
||||
|
||||
def resolve_model(model: str, has_vision: bool = False) -> str:
|
||||
# 当消息包含图片或附件时,使用视觉模型
|
||||
if has_vision:
|
||||
print(f"[GLM] 检测到图片/附件,使用视觉模型:{model} → {DEFAULT_VISION_MODEL}")
|
||||
return DEFAULT_VISION_MODEL
|
||||
# 普通文本对话,保持原模型不变
|
||||
print(f"[GLM] 使用模型:{model}")
|
||||
return model
|
||||
|
||||
|
||||
# ── 文件上传(含 file_id 缓存)───────────────────────────────────────
|
||||
def upload_file_for_extract(local_path: Path) -> str:
|
||||
from utils.file_cache import get as cache_get
|
||||
from utils.file_cache import set as cache_set
|
||||
from utils.file_cache import sha256_of_file
|
||||
|
||||
file_hash = sha256_of_file(local_path)
|
||||
cached = cache_get(file_hash)
|
||||
if cached:
|
||||
print(f"[GLM] file_id 缓存命中:{local_path.name} → {cached['file_id']}")
|
||||
return cached["file_id"]
|
||||
|
||||
client = get_client()
|
||||
mime_map = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc": "application/msword",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".xls": "application/vnd.ms-excel",
|
||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
".ppt": "application/vnd.ms-powerpoint",
|
||||
}
|
||||
mime = mime_map.get(local_path.suffix.lower(), "application/octet-stream")
|
||||
print(f"[GLM] 上传文件:{local_path.name}({mime})")
|
||||
with open(local_path, "rb") as f:
|
||||
file_obj = client.files.create(
|
||||
file=(local_path.name, f, mime), purpose="file-extract"
|
||||
)
|
||||
file_id = file_obj.id
|
||||
cache_set(file_hash, file_id)
|
||||
print(f"[GLM] 上传成功:file_id={file_id}")
|
||||
return file_id
|
||||
|
||||
|
||||
# ── 图像编码 ─────────────────────────────────────────────────────────
|
||||
def encode_image(image_source: str) -> dict:
|
||||
"""将图像来源统一转为 OpenAI image_url 格式"""
|
||||
if image_source.startswith("data:image") or image_source.startswith(
|
||||
("http://", "https://")
|
||||
):
|
||||
return {"type": "image_url", "image_url": {"url": image_source}}
|
||||
# 本地路径 → base64
|
||||
local = Path(image_source.replace("file://", "").lstrip("/"))
|
||||
if not local.exists():
|
||||
local = Path.cwd() / local
|
||||
ext = local.suffix.lstrip(".")
|
||||
with open(local, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode()
|
||||
return {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}}
|
||||
|
||||
|
||||
# ── 消息格式转换 ──────────────────────────────────────────────────────
|
||||
def build_glm_messages(messages: list, files: list | None = None) -> tuple[list, bool]:
|
||||
"""
|
||||
将 OpenAI 格式的 messages + files 转换为 zai-sdk 所需格式。
|
||||
返回 (glm_messages, has_vision)。
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
glm_messages = []
|
||||
has_vision = False
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
glm_messages.append({"role": "user", "content": str(msg)})
|
||||
continue
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
glm_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
new_content.append({"type": "text", "text": str(item)})
|
||||
continue
|
||||
t = item.get("type")
|
||||
if t == "text":
|
||||
new_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif t == "image_url":
|
||||
has_vision = True
|
||||
img_val = item.get("image_url", "")
|
||||
img_src = (
|
||||
img_val.get("url", "") if isinstance(img_val, dict) else img_val
|
||||
)
|
||||
new_content.append(encode_image(img_src))
|
||||
elif t == "file_url":
|
||||
# file_url 类型(PDF/DOCX/TXT 等文档链接)原样透传
|
||||
has_vision = True
|
||||
new_content.append(item)
|
||||
else:
|
||||
new_content.append({"type": "text", "text": str(item)})
|
||||
glm_messages.append({"role": role, "content": new_content})
|
||||
else:
|
||||
glm_messages.append({"role": role, "content": str(content)})
|
||||
|
||||
# 处理独立附件列表
|
||||
if files:
|
||||
doc_exts = {
|
||||
".pdf",
|
||||
".doc",
|
||||
".docx",
|
||||
".xlsx",
|
||||
".xls",
|
||||
".pptx",
|
||||
".ppt",
|
||||
".txt",
|
||||
".md",
|
||||
".csv",
|
||||
".json",
|
||||
".log",
|
||||
}
|
||||
img_exts = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||
inserts = []
|
||||
|
||||
for file_url in files:
|
||||
parsed = urlparse(file_url)
|
||||
filename = parsed.path.split("/")[-1]
|
||||
suffix = Path(filename).suffix.lower()
|
||||
|
||||
# ── 远程 URL(OSS 等)→ 直接透传 ─────────────────
|
||||
if file_url.startswith(("http://", "https://")):
|
||||
has_vision = True
|
||||
if suffix in img_exts:
|
||||
inserts.append(
|
||||
{"type": "image_url", "image_url": {"url": file_url}}
|
||||
)
|
||||
else:
|
||||
# 文档/文本类统一走 file_url
|
||||
inserts.append({"type": "file_url", "file_url": {"url": file_url}})
|
||||
continue
|
||||
|
||||
# ── 本地文件回退逻辑 ──────────────────────────────
|
||||
rel = parsed.path.lstrip("/")
|
||||
local = Path(rel)
|
||||
|
||||
if suffix in img_exts:
|
||||
has_vision = True
|
||||
try:
|
||||
inserts.append(encode_image(f"file://{rel}"))
|
||||
except Exception as e:
|
||||
print(f"[GLM] 图像编码失败:{e}")
|
||||
elif suffix in doc_exts:
|
||||
has_vision = True
|
||||
if local.exists():
|
||||
try:
|
||||
fid = upload_file_for_extract(local)
|
||||
inserts.append({"type": "file", "file": {"file_id": fid}})
|
||||
except Exception as e:
|
||||
inserts.append(
|
||||
{"type": "text", "text": f"[文件上传失败:{filename},{e}]"}
|
||||
)
|
||||
else:
|
||||
inserts.append(
|
||||
{"type": "text", "text": f"[附件:{filename},类型:{suffix}]"}
|
||||
)
|
||||
|
||||
if inserts:
|
||||
for i in range(len(glm_messages) - 1, -1, -1):
|
||||
if glm_messages[i].get("role") == "user":
|
||||
old = glm_messages[i]["content"]
|
||||
if isinstance(old, str):
|
||||
new_content = inserts + [{"type": "text", "text": old}]
|
||||
elif isinstance(old, list):
|
||||
new_content = inserts + old
|
||||
else:
|
||||
new_content = inserts
|
||||
glm_messages[i] = {"role": "user", "content": new_content}
|
||||
break
|
||||
|
||||
return glm_messages, has_vision
|
||||
|
||||
|
||||
# ── 网络搜索 tool 构建 ──────────────────────────────────────────────
|
||||
def _build_web_search_tool(mode: str | bool) -> dict:
|
||||
"""
|
||||
根据搜索模式构建 web_search tool 配置。
|
||||
|
||||
mode:
|
||||
- True / "simple" : 简单搜索(search_std + medium, 10条)
|
||||
- "deep" : 深度搜索(search_pro + high, 20条)
|
||||
"""
|
||||
if mode == "deep":
|
||||
# 深度搜索:高阶搜索引擎 + 详细内容 + 更多结果
|
||||
return {
|
||||
"type": "web_search",
|
||||
"web_search": {
|
||||
"enable": True,
|
||||
"search_result": True,
|
||||
"search_engine": "search_pro",
|
||||
"content_size": "high",
|
||||
"count": 20,
|
||||
},
|
||||
}
|
||||
# 简单搜索(默认):基础搜索引擎 + 摘要内容
|
||||
return {
|
||||
"type": "web_search",
|
||||
"web_search": {
|
||||
"enable": True,
|
||||
"search_result": True,
|
||||
"search_engine": "search_std",
|
||||
"content_size": "medium",
|
||||
"count": 10,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── 哨兵对象 ─────────────────────────────────────────────────────────
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
# ── 流式调用 ────────────────────────────────────────────────────────
|
||||
async def glm_stream_generator(
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
files: list | None = None,
|
||||
web_search: str | bool = False,
|
||||
deep_thinking: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
GLM 流式 SSE 生成器。
|
||||
使用 queue.Queue + 专用线程(生产者)+ asyncio 消费者模式,
|
||||
让 zai-sdk 同步迭代器在单一线程内安全运行。
|
||||
|
||||
web_search:
|
||||
- False / "" : 不启用联网搜索
|
||||
- True / "simple" : 简单搜索(search_std + medium)
|
||||
- "deep" : 深度搜索(search_pro + high + 更多结果)
|
||||
"""
|
||||
import asyncio
|
||||
import queue
|
||||
|
||||
from utils.helpers import generate_unique_id, get_current_timestamp
|
||||
|
||||
glm_msgs, has_vision = build_glm_messages(messages, files)
|
||||
actual_model = resolve_model(model, has_vision)
|
||||
|
||||
extra_kwargs: dict = {}
|
||||
if web_search:
|
||||
extra_kwargs["tools"] = [_build_web_search_tool(web_search)]
|
||||
if not deep_thinking:
|
||||
# 智普默认开启思考模式,所以要用非门(不知道“非门”描述是否准确。前端选择开启思考模式,这里不做变动。前端选择关闭思考模式,这里关闭。)
|
||||
extra_kwargs["thinking"] = {"type": "disabled"}
|
||||
print(
|
||||
f"[GLM] 流式请求:model={actual_model} vision={has_vision} "
|
||||
f"web_search={web_search} thinking={deep_thinking}"
|
||||
)
|
||||
# ── 调试:打印发送给 GLM 的完整消息结构 ──
|
||||
for i, msg in enumerate(glm_msgs):
|
||||
role = msg.get("role", "?")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
for j, part in enumerate(content):
|
||||
if not isinstance(part, dict):
|
||||
print(f"[GLM-DEBUG] msg[{i}].content[{j}]: {type(part).__name__}")
|
||||
continue
|
||||
part_type = part.get("type", "?")
|
||||
if part_type == "image_url":
|
||||
img_val = part.get("image_url", "")
|
||||
img_url = (
|
||||
img_val.get("url", "")
|
||||
if isinstance(img_val, dict)
|
||||
else str(img_val)
|
||||
)
|
||||
display = img_url[:120] + "..." if len(img_url) > 120 else img_url
|
||||
print(
|
||||
f"[GLM-DEBUG] msg[{i}].content[{j}]: type=image_url, url={display}"
|
||||
)
|
||||
elif part_type == "text":
|
||||
preview = (part.get("text", "") or "")[:100]
|
||||
print(
|
||||
f"[GLM-DEBUG] msg[{i}].content[{j}]: type=text, text={preview}"
|
||||
)
|
||||
else:
|
||||
print(f"[GLM-DEBUG] msg[{i}].content[{j}]: {part}")
|
||||
else:
|
||||
print(f"[GLM-DEBUG] msg[{i}]: role={role}, content={str(content)[:150]}")
|
||||
if extra_kwargs:
|
||||
print(f"[GLM-DEBUG] extra_kwargs={extra_kwargs}")
|
||||
# 原始 JSON 转储(用于排查结构问题)
|
||||
import json as _json
|
||||
|
||||
print(
|
||||
f"[GLM-RAW] messages={_json.dumps(glm_msgs, ensure_ascii=False, default=str)[:2000]}"
|
||||
)
|
||||
|
||||
chunk_queue: queue.Queue = queue.Queue(maxsize=128)
|
||||
|
||||
def _producer():
|
||||
try:
|
||||
client = get_client()
|
||||
resp = client.chat.completions.create(
|
||||
model=actual_model,
|
||||
messages=glm_msgs,
|
||||
stream=True,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**extra_kwargs,
|
||||
)
|
||||
for chunk in resp:
|
||||
chunk_queue.put(chunk)
|
||||
except Exception as exc:
|
||||
chunk_queue.put(exc)
|
||||
finally:
|
||||
chunk_queue.put(_SENTINEL)
|
||||
|
||||
t = threading.Thread(target=_producer, daemon=True)
|
||||
t.start()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
full_reasoning = "" # 累计思考内容(用于判断是否首次)
|
||||
full_content = "" # 累计正式回答(用于判断是否首次)
|
||||
|
||||
while True:
|
||||
item = await loop.run_in_executor(None, chunk_queue.get)
|
||||
|
||||
if item is _SENTINEL:
|
||||
break
|
||||
|
||||
if isinstance(item, Exception):
|
||||
print(f"[GLM] 生产者异常:{item}")
|
||||
yield f"data: {json.dumps({'error': {'message': str(item), 'type': 'glm_error'}}, ensure_ascii=False)}\n\n"
|
||||
break
|
||||
|
||||
try:
|
||||
delta = item.choices[0].delta
|
||||
reasoning = getattr(delta, "reasoning_content", "") or ""
|
||||
text = getattr(delta, "content", "") or ""
|
||||
|
||||
delta_str = ""
|
||||
|
||||
# ── 思考过程(reasoning_content)────────────────────────
|
||||
if reasoning:
|
||||
if not full_reasoning:
|
||||
# 首个思考片段:添加 <think> 开始标签
|
||||
delta_str += "<think>"
|
||||
full_reasoning += reasoning
|
||||
delta_str += reasoning
|
||||
|
||||
# ── 正式回答(content)──────────────────────────────────
|
||||
if text:
|
||||
if not full_content and full_reasoning:
|
||||
# 思考结束后首次出现正式回答:关闭 </think> 标签
|
||||
delta_str += "</think>\n\n"
|
||||
full_content += text
|
||||
delta_str += text
|
||||
|
||||
if not delta_str:
|
||||
continue
|
||||
|
||||
data = {
|
||||
"id": f"chatcmpl-{generate_unique_id()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": get_current_timestamp(),
|
||||
"model": actual_model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": delta_str}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
print(f"[GLM] chunk 解析异常:{e}")
|
||||
|
||||
finish = {
|
||||
"id": f"chatcmpl-{generate_unique_id()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": get_current_timestamp(),
|
||||
"model": actual_model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
# ── 非流式调用 ────────────────────────────────────────────────────────
|
||||
def glm_chat_sync(
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
files: list | None = None,
|
||||
web_search: str | bool = False,
|
||||
deep_thinking: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
web_search:
|
||||
- False / "" : 不启用联网搜索
|
||||
- True / "simple" : 简单搜索(search_std + medium)
|
||||
- "deep" : 深度搜索(search_pro + high + 更多结果)
|
||||
"""
|
||||
glm_msgs, has_vision = build_glm_messages(messages, files)
|
||||
actual_model = resolve_model(model, has_vision)
|
||||
|
||||
extra_kwargs: dict = {}
|
||||
if web_search:
|
||||
extra_kwargs["tools"] = [_build_web_search_tool(web_search)]
|
||||
if deep_thinking:
|
||||
extra_kwargs["thinking"] = {"type": "enabled"}
|
||||
|
||||
client = get_client()
|
||||
print(f"[GLM] 非流式请求:model={actual_model}")
|
||||
# ── 调试:打印发送给 GLM 的完整消息结构 ──
|
||||
for i, msg in enumerate(glm_msgs):
|
||||
role = msg.get("role", "?")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
for j, part in enumerate(content):
|
||||
if not isinstance(part, dict):
|
||||
print(f"[GLM-DEBUG] msg[{i}].content[{j}]: {type(part).__name__}")
|
||||
continue
|
||||
part_type = part.get("type", "?")
|
||||
if part_type == "image_url":
|
||||
img_val = part.get("image_url", "")
|
||||
img_url = (
|
||||
img_val.get("url", "")
|
||||
if isinstance(img_val, dict)
|
||||
else str(img_val)
|
||||
)
|
||||
display = img_url[:120] + "..." if len(img_url) > 120 else img_url
|
||||
print(
|
||||
f"[GLM-DEBUG] msg[{i}].content[{j}]: type=image_url, url={display}"
|
||||
)
|
||||
elif part_type == "text":
|
||||
preview = (part.get("text", "") or "")[:100]
|
||||
print(
|
||||
f"[GLM-DEBUG] msg[{i}].content[{j}]: type=text, text={preview}"
|
||||
)
|
||||
else:
|
||||
print(f"[GLM-DEBUG] msg[{i}].content[{j}]: {part}")
|
||||
else:
|
||||
print(f"[GLM-DEBUG] msg[{i}]: role={role}, content={str(content)[:150]}")
|
||||
if extra_kwargs:
|
||||
print(f"[GLM-DEBUG] extra_kwargs={extra_kwargs}")
|
||||
# 原始 JSON 转储(用于排查结构问题)
|
||||
import json as _json
|
||||
|
||||
print(
|
||||
f"[GLM-RAW] messages={_json.dumps(glm_msgs, ensure_ascii=False, default=str)[:2000]}"
|
||||
)
|
||||
resp = client.chat.completions.create(
|
||||
model=actual_model,
|
||||
messages=glm_msgs,
|
||||
stream=False,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**extra_kwargs,
|
||||
)
|
||||
content = resp.choices[0].message.content or ""
|
||||
usage = None
|
||||
if hasattr(resp, "usage") and resp.usage:
|
||||
usage = {
|
||||
"promptTokens": resp.usage.prompt_tokens,
|
||||
"completionTokens": resp.usage.completion_tokens,
|
||||
"totalTokens": resp.usage.total_tokens,
|
||||
}
|
||||
return {"content": content, "model": actual_model, "usage": usage}
|
||||
|
|
@ -8,7 +8,7 @@ import uuid
|
|||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
from .logger import (log_chat_interaction, log_error_detail, log_request_info,
|
||||
from core import (log_chat_interaction, log_error_detail, log_request_info,
|
||||
log_response_info)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,37 +0,0 @@
|
|||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to sys.path
|
||||
root_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(root_dir))
|
||||
|
||||
# Set API key from .env if needed
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from utils.glm_adapter import _ensure_venv, glm_chat_sync, glm_stream_generator
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def test_stream():
|
||||
msgs = [{"role": "user", "content": "今天北京天气怎样?"}]
|
||||
print("Testing stream...")
|
||||
async for chunk in glm_stream_generator(
|
||||
msgs, "glm-4.5-air", 0.7, 1024, web_search=True
|
||||
):
|
||||
print(chunk, end="")
|
||||
|
||||
|
||||
def test_sync():
|
||||
msgs = [{"role": "user", "content": "今天几号?武汉天气怎样?"}]
|
||||
print("Testing sync...")
|
||||
res = glm_chat_sync(msgs, "glm-4.5-air", 0.7, 1024, web_search=True)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_ensure_venv()
|
||||
# test_sync()
|
||||
asyncio.run(test_stream())
|
||||
|
|
@ -1,171 +0,0 @@
|
|||
"""
|
||||
测试脚本:上传 PDF / DOCX / TXT 文件到阿里云 OSS → 获取 URL → 发送给 GLM-4.6V 识别
|
||||
|
||||
支持的文件类型:
|
||||
- .pdf → 上传 OSS 后以 file_url 类型发送 URL 给 GLM
|
||||
- .docx → 上传 OSS 后以 file_url 类型发送 URL 给 GLM
|
||||
- .txt → 上传 OSS 后以 file_url 类型发送 URL 给 GLM
|
||||
|
||||
用法:
|
||||
cd server
|
||||
source ~/.bashrc && source .venv/bin/activate
|
||||
python -m utils.test_oss_doc_glm --file <本地文件路径> [--prompt "请总结这份文件"]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 确保 server 目录在 sys.path
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from utils.oss_uploader import upload_file
|
||||
from utils.glm_adapter import glm_chat_sync
|
||||
|
||||
|
||||
# 文件类型分组
|
||||
DOC_EXTS = {".pdf", ".doc", ".docx", ".xlsx", ".xls", ".pptx", ".ppt"}
|
||||
TXT_EXTS = {".txt", ".md", ".csv", ".json", ".log"}
|
||||
IMG_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||
# 所有可通过 file_url 发送的类型
|
||||
FILE_URL_EXTS = DOC_EXTS | TXT_EXTS
|
||||
|
||||
|
||||
def detect_file_type(suffix: str) -> str:
|
||||
"""根据后缀判断文件类别: 'file_url' / 'image' / 'unknown'"""
|
||||
suffix = suffix.lower()
|
||||
if suffix in FILE_URL_EXTS:
|
||||
return "file_url"
|
||||
elif suffix in IMG_EXTS:
|
||||
return "image"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def build_messages_for_file_url(file_url: str, prompt: str) -> list:
|
||||
"""
|
||||
为文档/文本文件构建消息。
|
||||
使用 file_url 类型,直接传递 OSS 的 URL 给 GLM。
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "file_url",
|
||||
"file_url": {"url": file_url},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def build_messages_for_image(file_url: str, prompt: str) -> list:
|
||||
"""为图片文件构建消息,使用 image_url 类型。"""
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": file_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="上传 PDF/DOCX/TXT 文件到 OSS 并让 GLM-4.6V 识别"
|
||||
)
|
||||
parser.add_argument("--file", required=True, help="要上传的本地文件路径")
|
||||
parser.add_argument(
|
||||
"--prompt", default="请总结这份文件的主要内容", help="发给 GLM 的提示词"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", default="glm-4.6v", help="GLM 模型名称(默认: glm-4.6v)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
file_path = Path(args.file).resolve()
|
||||
if not file_path.exists():
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
sys.exit(1)
|
||||
|
||||
suffix = file_path.suffix.lower()
|
||||
file_type = detect_file_type(suffix)
|
||||
print(f"📂 文件信息:")
|
||||
print(f" 路径: {file_path}")
|
||||
print(f" 后缀: {suffix}")
|
||||
print(f" 类型: {file_type}")
|
||||
print(f" 大小: {file_path.stat().st_size / 1024:.1f} KB")
|
||||
print()
|
||||
|
||||
# ── 第一步:上传文件到 OSS ────────────────────────────────
|
||||
print(f"📤 正在上传文件到阿里云 OSS...")
|
||||
oss_result = upload_file(str(file_path))
|
||||
file_url = oss_result["url"]
|
||||
print(f"✅ OSS 上传成功!")
|
||||
print(f" URL: {file_url}")
|
||||
print(f" ETag: {oss_result['etag']}")
|
||||
print()
|
||||
|
||||
# ── 第二步:根据文件类型构建消息 ──────────────────────────
|
||||
print(f"🔧 正在构建 GLM 消息...")
|
||||
|
||||
if file_type == "file_url":
|
||||
# PDF / DOCX / TXT 等:使用 file_url 类型发送 OSS URL
|
||||
print(f" 策略: 使用 file_url 发送 OSS 链接")
|
||||
messages = build_messages_for_file_url(file_url, args.prompt)
|
||||
elif file_type == "image":
|
||||
# 图片:使用 image_url
|
||||
print(f" 策略: 使用 image_url 发送 OSS 链接")
|
||||
messages = build_messages_for_image(file_url, args.prompt)
|
||||
else:
|
||||
print(f"❌ 不支持的文件类型: {suffix}")
|
||||
print(f" 支持: {', '.join(sorted(FILE_URL_EXTS | IMG_EXTS))}")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
|
||||
# ── 第三步:发送给 GLM 识别 ──────────────────────────────
|
||||
print(f"🤖 正在请求 GLM ({args.model}) 识别文件...")
|
||||
print(f" 提示词: {args.prompt}")
|
||||
print()
|
||||
|
||||
try:
|
||||
result = glm_chat_sync(
|
||||
messages=messages,
|
||||
model=args.model,
|
||||
temperature=0.7,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
print("─" * 60)
|
||||
print("📝 GLM 回复:")
|
||||
print("─" * 60)
|
||||
print(result["content"])
|
||||
print("─" * 60)
|
||||
|
||||
if result.get("usage"):
|
||||
usage = result["usage"]
|
||||
print(
|
||||
f"\n📊 Token 用量: 输入 {usage['promptTokens']} | "
|
||||
f"输出 {usage['completionTokens']} | "
|
||||
f"总计 {usage['totalTokens']}"
|
||||
)
|
||||
|
||||
print(f"\n✅ 测试完成! 使用模型: {result.get('model', args.model)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ GLM 请求失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
测试脚本:上传文件到 OSS → 获取 URL → 发送给 GLM 进行识别
|
||||
|
||||
用法:
|
||||
cd server
|
||||
source ~/.bashrc && source .venv/bin/activate
|
||||
python -m utils.test_oss_glm --file <本地文件路径> [--prompt "描述一下这张图片"]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 确保 server 目录在 sys.path
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from utils.oss_uploader import upload_file
|
||||
from utils.glm_adapter import glm_chat_sync
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="上传文件到 OSS 并让 GLM 读取")
|
||||
parser.add_argument("--file", required=True, help="要上传的本地文件路径")
|
||||
parser.add_argument(
|
||||
"--prompt", default="请描述一下这张图片的内容", help="发给 GLM 的提示词"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", default="glm-4.6v", help="GLM 模型名称(默认: glm-4.6v)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# ── 第一步:上传文件到 OSS ────────────────────────────────
|
||||
file_path = args.file
|
||||
if not Path(file_path).exists():
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"📤 正在上传文件: {file_path}")
|
||||
oss_result = upload_file(file_path)
|
||||
file_url = oss_result["url"]
|
||||
print(f"✅ 上传成功!")
|
||||
print(f" URL: {file_url}")
|
||||
print()
|
||||
|
||||
# ── 第二步:构建消息,把 URL 发给 GLM ──────────────────────
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": file_url},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": args.prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
print(f"🤖 正在请求 GLM ({args.model}) 识别图片...")
|
||||
print(f" 提示词: {args.prompt}")
|
||||
print()
|
||||
|
||||
result = glm_chat_sync(
|
||||
messages=messages,
|
||||
model=args.model,
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
print("─" * 60)
|
||||
print("📝 GLM 回复:")
|
||||
print("─" * 60)
|
||||
print(result["content"])
|
||||
print("─" * 60)
|
||||
|
||||
if result.get("usage"):
|
||||
usage = result["usage"]
|
||||
print(
|
||||
f"\n📊 Token 用量: 输入 {usage['promptTokens']} | "
|
||||
f"输出 {usage['completionTokens']} | "
|
||||
f"总计 {usage['totalTokens']}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue