From 89b02c4c93beafbc3cb7fb1d59888a6d62e25f63 Mon Sep 17 00:00:00 2001 From: MT-Fire <798521692@qq.com> Date: Wed, 4 Mar 2026 16:25:16 +0800 Subject: [PATCH] =?UTF-8?q?format:=20=E9=A1=B9=E7=9B=AE=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/__init__.py | 55 +-- server/api/__init__.py | 2 +- server/api/chat_routes.py | 582 +++++++++++++++++++------------- server/api/chat_routes_glm.py | 63 ++-- server/init_logging.py | 10 +- server/main.py | 28 +- server/models/__init__.py | 2 +- server/models/chat_models.py | 10 +- server/utils/__init__.py | 2 +- server/utils/file_cache.py | 7 +- server/utils/glm_adapter.py | 97 +++--- server/utils/helpers.py | 16 +- server/utils/logger.py | 92 +++-- server/utils/test_glm_search.py | 14 +- 14 files changed, 599 insertions(+), 381 deletions(-) diff --git a/server/__init__.py b/server/__init__.py index 7ad0dc9..7475162 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -1,35 +1,38 @@ """ 包初始化文件 """ + +from .api.chat_routes import (chat_endpoint_handler, + delete_conversation_handler, + get_conversation_handler, + get_conversations_handler, get_models_handler, + save_conversation_handler, serve_upload_handler, + stop_generation_handler, upload_file_handler) from .models.chat_models import ChatMessage, ChatRequest, ModelInfo -from .utils.helpers import ( - get_current_timestamp, - generate_unique_id, - format_api_response, - log_request, - log_response, - extract_delta_content -) -from .api.chat_routes import ( - chat_endpoint_handler, - get_models_handler, - get_conversations_handler, - get_conversation_handler, - save_conversation_handler, - delete_conversation_handler, - upload_file_handler, - serve_upload_handler, - stop_generation_handler -) +from .utils.helpers import (extract_delta_content, format_api_response, + generate_unique_id, get_current_timestamp, + log_request, log_response) __all__ = [ # Models - 'ChatMessage', 'ChatRequest', 'ModelInfo', + "ChatMessage", + "ChatRequest", + "ModelInfo", # Utils - 'get_current_timestamp', 'generate_unique_id', 'format_api_response', - 'log_request', 'log_response', 'extract_delta_content', + "get_current_timestamp", + "generate_unique_id", + "format_api_response", + "log_request", + "log_response", + "extract_delta_content", # API Handlers - 'chat_endpoint_handler', 'get_models_handler', 'get_conversations_handler', - 'get_conversation_handler', 'save_conversation_handler', 'delete_conversation_handler', - 'upload_file_handler', 'serve_upload_handler', 'stop_generation_handler' -] \ No newline at end of file + "chat_endpoint_handler", + "get_models_handler", + "get_conversations_handler", + "get_conversation_handler", + "save_conversation_handler", + "delete_conversation_handler", + "upload_file_handler", + "serve_upload_handler", + "stop_generation_handler", +] diff --git a/server/api/__init__.py b/server/api/__init__.py index 5dc4203..2f82817 100644 --- a/server/api/__init__.py +++ b/server/api/__init__.py @@ -1 +1 @@ -# api/__init__.py \ No newline at end of file +# api/__init__.py diff --git a/server/api/chat_routes.py b/server/api/chat_routes.py index 0899440..ca91c7f 100644 --- a/server/api/chat_routes.py +++ b/server/api/chat_routes.py @@ -2,16 +2,18 @@ API 路由定义(阿里云百炼 / DashScope 平台) 所有 DashScope 相关逻辑均集中在此文件,main.py 无感知任何平台细节。 """ -import os + import json +import os import uuid from datetime import datetime -from typing import Dict, List from pathlib import Path -from fastapi import HTTPException, File, UploadFile -from fastapi.responses import JSONResponse, StreamingResponse +from typing import Dict, List + import dashscope from dashscope import Generation, MultiModalConversation +from fastapi import File, HTTPException, UploadFile +from fastapi.responses import JSONResponse, StreamingResponse def init(): @@ -29,18 +31,14 @@ def init(): # 导入模型和工具函数(使用绝对路径) import sys from pathlib import Path + sys.path.append(str(Path(__file__).parent.parent)) from models.chat_models import ChatRequest, ModelInfo -from utils.helpers import ( - get_current_timestamp, - generate_unique_id, - format_api_response, - extract_delta_content -) +from utils.helpers import (extract_delta_content, format_api_response, + generate_unique_id, get_current_timestamp) from utils.logger import log_error, log_exception, log_info - # 模拟数据库 - 实际应用中应使用持久化存储 conversations_db: Dict[str, dict] = {} @@ -60,7 +58,7 @@ def _extract_text_from_docmind(obj, depth: int = 0) -> str: if isinstance(obj, str): s = obj.strip() # 过滤极短、URL、base64 等非正文字符串 - if len(s) > 3 and not s.startswith(('http://', 'https://', 'data:', 'oss://')): + if len(s) > 3 and not s.startswith(("http://", "https://", "data:", "oss://")): return s return "" @@ -70,10 +68,23 @@ def _extract_text_from_docmind(obj, depth: int = 0) -> str: if isinstance(obj, dict): # 优先处理文本相关字段 - priority_keys = ['content', 'text', 'paragraph', 'caption', 'value', 'title'] + priority_keys = ["content", "text", "paragraph", "caption", "value", "title"] # 跳过纯元数据字段 - skip_keys = {'backlink', 'pos', 'index', 'style', 'font', 'color', - 'size', 'hash', 'id_', 'id', 'layouts', 'type', 'link'} + skip_keys = { + "backlink", + "pos", + "index", + "style", + "font", + "color", + "size", + "hash", + "id_", + "id", + "layouts", + "type", + "link", + } parts = [] for key in priority_keys: if key in obj: @@ -102,8 +113,9 @@ def _read_file_content(file_url: str): """ try: from urllib.parse import urlparse + parsed = urlparse(file_url) - relative_path = parsed.path.lstrip('/') + relative_path = parsed.path.lstrip("/") local_path = Path(relative_path) if not local_path.exists(): @@ -112,18 +124,34 @@ def _read_file_content(file_url: str): suffix = local_path.suffix.lower() # 路线一:纯文本格式直接读取 - text_extensions = {'.txt', '.md', '.csv', '.json', '.xml', - '.yaml', '.yml', '.log', '.py', '.js', '.ts', '.html', '.css'} + text_extensions = { + ".txt", + ".md", + ".csv", + ".json", + ".xml", + ".yaml", + ".yml", + ".log", + ".py", + ".js", + ".ts", + ".html", + ".css", + } if suffix in text_extensions: - with open(local_path, 'r', encoding='utf-8', errors='replace') as f: + with open(local_path, "r", encoding="utf-8", errors="replace") as f: content = f.read() max_len = 8000 if len(content) > max_len: - content = content[:max_len] + f"\n\n[...文件内容过长,已截断,共 {len(content)} 字符]" + content = ( + content[:max_len] + + f"\n\n[...文件内容过长,已截断,共 {len(content)} 字符]" + ) return content # 路线二:doc/docx/pdf 使用 DashScopeParse 云端解析 - dashscope_extensions = {'.doc', '.docx', '.pdf'} + dashscope_extensions = {".doc", ".docx", ".pdf"} if suffix in dashscope_extensions: return (local_path, suffix) # 交给异步函数处理 @@ -135,7 +163,6 @@ def _read_file_content(file_url: str): return f"[文件读取失败: {str(e)}]" - async def _parse_with_dashscope(local_path: Path) -> str: """ 【路线二:DashScopeParse】使用阿里云文档智能解析 doc/docx/pdf 文件。 @@ -146,9 +173,10 @@ async def _parse_with_dashscope(local_path: Path) -> str: def _sync_parse(): try: + import json + from llama_index.readers.dashscope.base import DashScopeParse from llama_index.readers.dashscope.utils import ResultType - import json api_key = os.getenv("ALIYUN_API_KEY") parser = DashScopeParse( @@ -172,7 +200,9 @@ async def _parse_with_dashscope(local_path: Path) -> str: texts.append(doc.text[:6000] if doc.text else "") result = "\n\n".join(t for t in texts if t) - print(f"[INFO] DashScopeParse: {local_path.name} 解析完成,提取 {len(result)} 字符") + print( + f"[INFO] DashScopeParse: {local_path.name} 解析完成,提取 {len(result)} 字符" + ) return result or f"[DashScopeParse: {local_path.name} 未能提取到文本内容]" except ImportError: @@ -198,8 +228,9 @@ async def _inject_files_into_messages(messages: list, files: list) -> list: file_context_parts = [] for file_url in files: from urllib.parse import urlparse + parsed = urlparse(file_url) - filename = parsed.path.split('/')[-1] + filename = parsed.path.split("/")[-1] suffix = Path(filename).suffix.lower() result = _read_file_content(file_url) @@ -233,8 +264,8 @@ async def _inject_files_into_messages(messages: list, files: list) -> list: messages = list(messages) # 复制,避免修改原始列表 for i in range(len(messages) - 1, -1, -1): msg = messages[i] - if isinstance(msg, dict) and msg.get('role') == 'user': - content = msg.get('content', '') + if isinstance(msg, dict) and msg.get("role") == "user": + content = msg.get("content", "") if isinstance(content, str): messages[i] = dict(msg, content=content + file_context_text) elif isinstance(content, list): @@ -242,12 +273,14 @@ async def _inject_files_into_messages(messages: list, files: list) -> list: new_content = list(content) appended = False for j, item in enumerate(new_content): - if isinstance(item, dict) and item.get('type') == 'text': - new_content[j] = dict(item, text=item['text'] + file_context_text) + if isinstance(item, dict) and item.get("type") == "text": + new_content[j] = dict( + item, text=item["text"] + file_context_text + ) appended = True break if not appended: - new_content.append({'type': 'text', 'text': file_context_text}) + new_content.append({"type": "text", "text": file_context_text}) messages[i] = dict(msg, content=new_content) break @@ -265,36 +298,40 @@ async def chat_endpoint_handler(body: dict): print(f"[ERROR] Request body is not a dictionary: {type(body)}") raise HTTPException( status_code=400, - detail=f"Request body must be a JSON object, got {type(body).__name__}: {body}" + detail=f"Request body must be a JSON object, got {type(body).__name__}: {body}", ) # 检查请求格式并适配 # 如果是OpenAI兼容格式 (来自streamChat) - if 'messages' in body: - messages = body.get('messages', []) - model = body.get('model', 'qwen-plus') - stream = body.get('stream', True) - temperature = body.get('temperature', 0.7) - max_tokens = body.get('max_tokens', 2000) - deepSearch = body.get('deepSearch', False) - webSearch = body.get('webSearch', False) - deepThinking = body.get('deepThinking', False) + if "messages" in body: + messages = body.get("messages", []) + model = body.get("model", "qwen-plus") + stream = body.get("stream", True) + temperature = body.get("temperature", 0.7) + max_tokens = body.get("max_tokens", 2000) + deepSearch = body.get("deepSearch", False) + webSearch = body.get("webSearch", False) + deepThinking = body.get("deepThinking", False) - log_info(f"POST /api/chat-ui/chat | 模型: {model} | 流式: {stream} | 联网搜索: {webSearch} | 深度搜索: {deepSearch} | 深度思考: {deepThinking}") + log_info( + f"POST /api/chat-ui/chat | 模型: {model} | 流式: {stream} | 联网搜索: {webSearch} | 深度搜索: {deepSearch} | 深度思考: {deepThinking}" + ) # 处理 files 附件:将文件内容注入到最后一条 user 消息中 - files = body.get('files', []) + files = body.get("files", []) if files: messages = await _inject_files_into_messages(messages, files) # 调试:打印注入后最后一条 user 消息的内容(截断显示 500 字) for msg in reversed(messages): - if isinstance(msg, dict) and msg.get('role') == 'user': - content_preview = str(msg.get('content', ''))[:500] - print(f"[DEBUG] 注入文件后 user 消息内容预览: {content_preview}") + if isinstance(msg, dict) and msg.get("role") == "user": + content_preview = str(msg.get("content", ""))[:500] + print( + f"[DEBUG] 注入文件后 user 消息内容预览: {content_preview}" + ) break else: # 否则是前端简化格式 (来自chat函数) - message_text = body.get('message', '') + message_text = body.get("message", "") # 检查message是否已经是格式化的列表(带图片的情况) if isinstance(message_text, list): @@ -303,28 +340,40 @@ async def chat_endpoint_handler(body: dict): user_content = [{"type": "text", "text": message_text}] messages = [ - {"role": "system", "content": body.get('systemPrompt', '你是一个智能助手,可以分析用户发送的文本和文件内容。')}, - {"role": "user", "content": user_content} + { + "role": "system", + "content": body.get( + "systemPrompt", + "你是一个智能助手,可以分析用户发送的文本和文件内容。", + ), + }, + {"role": "user", "content": user_content}, ] - model = body.get('model', 'qwen-plus') - stream = body.get('stream', False) - temperature = body.get('temperature', 0.7) - max_tokens = body.get('maxTokens', 2000) - deepSearch = body.get('deepSearch', False) - webSearch = body.get('webSearch', False) - deepThinking = body.get('deepThinking', False) + model = body.get("model", "qwen-plus") + stream = body.get("stream", False) + temperature = body.get("temperature", 0.7) + max_tokens = body.get("maxTokens", 2000) + deepSearch = body.get("deepSearch", False) + webSearch = body.get("webSearch", False) + deepThinking = body.get("deepThinking", False) # 检查是否包含图像内容,如果是多模态请求,使用MultiModalConversation has_images = any( - isinstance(msg, dict) and - isinstance(msg.get('content'), list) and - any(isinstance(item, dict) and item.get('type') == 'image_url' for item in msg.get('content', [])) - for msg in messages if isinstance(msg, dict) + isinstance(msg, dict) + and isinstance(msg.get("content"), list) + and any( + isinstance(item, dict) and item.get("type") == "image_url" + for item in msg.get("content", []) + ) + for msg in messages + if isinstance(msg, dict) ) if has_images: # 使用多模态API处理图像 - return await multimodal_chat_handler(messages, model, stream, temperature, max_tokens) + return await multimodal_chat_handler( + messages, model, stream, temperature, max_tokens + ) else: # 构建 DashScope 额外参数 dashscope_kwargs = {} @@ -335,11 +384,15 @@ async def chat_endpoint_handler(body: dict): elif webSearch: dashscope_kwargs["enable_search"] = True dashscope_kwargs["search_options"] = {"search_strategy": "turbo"} - + if deepThinking: dashscope_kwargs["enable_thinking"] = True - dashscope_kwargs["result_format"] = "message" # enable_thinking 必须配合 result_format=message - dashscope_kwargs["incremental_output"] = True # 流式模式下 enable_thinking 还必须配合 incremental_output=True + dashscope_kwargs["result_format"] = ( + "message" # enable_thinking 必须配合 result_format=message + ) + dashscope_kwargs["incremental_output"] = ( + True # 流式模式下 enable_thinking 还必须配合 incremental_output=True + ) # 使用常规聊天API if stream: @@ -352,7 +405,7 @@ async def chat_endpoint_handler(body: dict): stream=True, max_tokens=max_tokens, temperature=temperature, - **dashscope_kwargs + **dashscope_kwargs, ) full_content = "" # 用于累计完整内容 @@ -363,18 +416,22 @@ async def chat_endpoint_handler(body: dict): content = None # 尝试从 output.choices 获取内容 - if (hasattr(response, 'output') and - response.output and - hasattr(response.output, 'choices') and - response.output.choices is not None and - len(response.output.choices) > 0 and - 'message' in response.output.choices[0]): + if ( + hasattr(response, "output") + and response.output + and hasattr(response.output, "choices") + and response.output.choices is not None + and len(response.output.choices) > 0 + and "message" in response.output.choices[0] + ): - msg_dict = response.output.choices[0]['message'] + msg_dict = response.output.choices[0]["message"] # incremental_output=True 时,每个 chunk 的 content/reasoning_content 已是增量片段 # 直接使用,无需与 full_* 做对比 - content = msg_dict.get('content') or '' - reasoning_content = msg_dict.get('reasoning_content') or '' + content = msg_dict.get("content") or "" + reasoning_content = ( + msg_dict.get("reasoning_content") or "" + ) delta_str = "" @@ -382,10 +439,14 @@ async def chat_endpoint_handler(body: dict): if reasoning_content: if not full_reasoning_content: # 第一个思考片段,加标题前缀 - delta_str += "> **💭 深度思考过程:**\n> \n> " + delta_str += ( + "> **💭 深度思考过程:**\n> \n> " + ) full_reasoning_content += reasoning_content # markdown 引用块内换行需加 > - delta_str += reasoning_content.replace("\n", "\n> ") + delta_str += reasoning_content.replace( + "\n", "\n> " + ) # 处理正式回复片段 if content: @@ -405,25 +466,31 @@ async def chat_endpoint_handler(body: dict): { "index": 0, "delta": {"content": delta_str}, - "finish_reason": None + "finish_reason": None, } - ] + ], } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" # 否则尝试从 output.text 获取内容(DashScope特定格式) - elif (hasattr(response, 'output') and - response.output and - 'text' in response.output): + elif ( + hasattr(response, "output") + and response.output + and "text" in response.output + ): - content = response.output.get('text') + content = response.output.get("text") # 只有当内容发生变化时才发送增量 if len(content) > len(full_content): - delta_content = extract_delta_content(content, full_content) + delta_content = extract_delta_content( + content, full_content + ) full_content = content - if delta_content.strip(): # 只有当有非空白新内容时才发送 + if ( + delta_content.strip() + ): # 只有当有非空白新内容时才发送 # 构建 SSE 数据块 data = { "id": f"chatcmpl-{generate_unique_id()}", @@ -433,22 +500,26 @@ async def chat_endpoint_handler(body: dict): "choices": [ { "index": 0, - "delta": {"content": delta_content}, - "finish_reason": None + "delta": { + "content": delta_content + }, + "finish_reason": None, } - ] + ], } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" else: # 错误处理:写入 logger,方便排查 - log_error(f"DashScope API 返回错误: chunk status={response.status_code}, code={response.code}, msg={response.message}") + log_error( + f"DashScope API 返回错误: chunk status={response.status_code}, code={response.code}, msg={response.message}" + ) error_data = { "error": { "message": f"API Error: {response.code} - {response.message}", "type": "api_error", "param": None, - "code": response.code + "code": response.code, } } yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" @@ -461,26 +532,21 @@ async def chat_endpoint_handler(body: dict): "created": get_current_timestamp(), "model": model, "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop" - } - ] + {"index": 0, "delta": {}, "finish_reason": "stop"} + ], } yield f"data: {json.dumps(finish_data, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" except Exception as e: log_exception(f"流式生成器异常: {e}") error_data = { - "error": { - "message": str(e), - "type": "server_error" - } + "error": {"message": str(e), "type": "server_error"} } yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" - return StreamingResponse(event_generator(), media_type="text/event-stream") + return StreamingResponse( + event_generator(), media_type="text/event-stream" + ) else: # 非流式响应 response = Generation.call( @@ -489,7 +555,7 @@ async def chat_endpoint_handler(body: dict): stream=False, max_tokens=max_tokens, temperature=temperature, - **dashscope_kwargs + **dashscope_kwargs, ) if response.status_code == 200: @@ -498,56 +564,61 @@ async def chat_endpoint_handler(body: dict): content = None # 尝试从 output.choices 获取内容 - if (hasattr(response, 'output') and - response.output and - hasattr(response.output, 'choices') and - response.output.choices is not None and - len(response.output.choices) > 0 and - 'message' in response.output.choices[0]): + if ( + hasattr(response, "output") + and response.output + and hasattr(response.output, "choices") + and response.output.choices is not None + and len(response.output.choices) > 0 + and "message" in response.output.choices[0] + ): - msg_dict = response.output.choices[0]['message'] - content = msg_dict.get('content', '') - rc = msg_dict.get('reasoning_content', '') + msg_dict = response.output.choices[0]["message"] + content = msg_dict.get("content", "") + rc = msg_dict.get("reasoning_content", "") if rc: - rc_formatted = rc.replace('\n', '\n> ') + rc_formatted = rc.replace("\n", "\n> ") content = f"> **💭 深度思考过程:**\n> \n> {rc_formatted}\n\n---\n\n{content}" # 否则尝试从 output.text 获取内容(DashScope特定格式) - elif (hasattr(response, 'output') and - response.output and - 'text' in response.output): + elif ( + hasattr(response, "output") + and response.output + and "text" in response.output + ): - content = response.output.get('text') + content = response.output.get("text") if content: # 构建前端期望的响应格式 chat_response = format_api_response( content=content, - conversation_id=body.get('conversationId'), - model=model + conversation_id=body.get("conversationId"), + model=model, ) - if hasattr(response, 'usage') and response.usage: + if hasattr(response, "usage") and response.usage: chat_response["usage"] = { "promptTokens": response.usage.input_tokens, "completionTokens": response.usage.output_tokens, - "totalTokens": response.usage.total_tokens + "totalTokens": response.usage.total_tokens, } return JSONResponse(content=chat_response, ensure_ascii=False) else: raise HTTPException( status_code=500, - detail="API Response does not contain expected content" + detail="API Response does not contain expected content", ) else: raise HTTPException( status_code=500, - detail=f"API Error: {response.code} - {response.message}" + detail=f"API Error: {response.code} - {response.message}", ) except Exception as e: print(f"[ERROR] Error in chat endpoint: {str(e)}") import traceback + print(f"[ERROR] Traceback: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(e)) @@ -563,94 +634,99 @@ async def multimodal_chat_handler(messages, model, stream, temperature, max_toke # 验证 msg 是否为字典类型,如果不是则跳过或处理为字符串 if not isinstance(msg, dict): # 如果消息不是字典,将其作为纯文本处理 - dashscope_content = [ - {'text': str(msg)} - ] - dashscope_messages.append({ - 'role': 'user', - 'content': dashscope_content - }) + dashscope_content = [{"text": str(msg)}] + dashscope_messages.append( + {"role": "user", "content": dashscope_content} + ) continue - role = msg.get('role', 'user') - content = msg.get('content', '') + role = msg.get("role", "user") + content = msg.get("content", "") if isinstance(content, str): # 纯文本内容 - dashscope_content = [ - {'text': content} - ] + dashscope_content = [{"text": content}] elif isinstance(content, list): # 包含图像和文本的内容 dashscope_content = [] for j, item in enumerate(content): if isinstance(item, dict): - if item.get('type') == 'text': - dashscope_content.append({'text': item.get('text', '')}) - elif item.get('type') == 'image_url': + if item.get("type") == "text": + dashscope_content.append({"text": item.get("text", "")}) + elif item.get("type") == "image_url": # 处理 image_url 可能是字符串或字典两种情况 - image_url_value = item.get('image_url', '') + image_url_value = item.get("image_url", "") if isinstance(image_url_value, str): # 如果 image_url 是字符串,直接使用 img_url = image_url_value - elif isinstance(image_url_value, dict) and 'url' in image_url_value: + elif ( + isinstance(image_url_value, dict) + and "url" in image_url_value + ): # 如果 image_url 是字典,从中获取 url - img_url = image_url_value.get('url', '') + img_url = image_url_value.get("url", "") else: # 其他情况视为错误或空值 - img_url = '' + img_url = "" # 如果URL是http格式,提取文件名并转换为file://格式 - if img_url.startswith('http://') or img_url.startswith('https://'): + if img_url.startswith("http://") or img_url.startswith( + "https://" + ): # 提取URL中的文件名部分 (例如从 http://localhost:8000/uploads/filename.jpg 提取 uploads/filename.jpg) from urllib.parse import urlparse + parsed_url = urlparse(img_url) - path_parts = parsed_url.path.split('/') + path_parts = parsed_url.path.split("/") # 从路径中找到uploads部分及后面的文件名 try: - uploads_index = path_parts.index('uploads') - filename = '/'.join(path_parts[uploads_index:]) # 例如: uploads/filename.jpg + uploads_index = path_parts.index("uploads") + filename = "/".join( + path_parts[uploads_index:] + ) # 例如: uploads/filename.jpg img_url = f"file://{filename}" except ValueError: # 如果路径中没有uploads部分,使用原始路径 img_url = f"file://{parsed_url.path.lstrip('/')}" - elif not img_url.startswith('file://'): + elif not img_url.startswith("file://"): # 如果既不是网络URL也不是file://协议,假设是相对路径 img_url = f"file://{img_url}" - if img_url.startswith('file://'): + if img_url.startswith("file://"): # 确保本地文件存在 import os + local_path = img_url[7:] # 移除 "file://" 前缀 if not os.path.exists(local_path): - print(f"[WARNING] Image file does not exist: {local_path}") + print( + f"[WARNING] Image file does not exist: {local_path}" + ) - dashscope_content.append({'image': img_url}) + dashscope_content.append({"image": img_url}) else: # 将非字典内容转换为文本 - dashscope_content.append({'text': str(item)}) + dashscope_content.append({"text": str(item)}) else: # 其他情况转换为文本 - dashscope_content = [ - {'text': str(content)} - ] + dashscope_content = [{"text": str(content)}] - dashscope_messages.append({ - 'role': role, - 'content': dashscope_content - }) + dashscope_messages.append({"role": role, "content": dashscope_content}) if stream: # 多模态流式响应 async def multimodal_event_generator(): try: responses = MultiModalConversation.call( - model=model.replace('qwen-', 'qwen-vl-') if 'qwen-' in model else 'qwen-vl-max', + model=( + model.replace("qwen-", "qwen-vl-") + if "qwen-" in model + else "qwen-vl-max" + ), messages=dashscope_messages, stream=True, max_tokens=max_tokens, - temperature=temperature + temperature=temperature, ) full_content = "" @@ -660,28 +736,32 @@ async def multimodal_chat_handler(messages, model, stream, temperature, max_toke content = None # 从多模态响应中提取内容 - if (hasattr(response, 'output') and - response.output and - hasattr(response.output, 'choices') and - response.output.choices is not None and - len(response.output.choices) > 0 and - 'message' in response.output.choices[0]): + if ( + hasattr(response, "output") + and response.output + and hasattr(response.output, "choices") + and response.output.choices is not None + and len(response.output.choices) > 0 + and "message" in response.output.choices[0] + ): - message = response.output.choices[0]['message'] - if 'content' in message: - content_items = message['content'] + message = response.output.choices[0]["message"] + if "content" in message: + content_items = message["content"] # 从内容项中提取文本 extracted_text = "" for item in content_items: - if isinstance(item, dict) and 'text' in item: - extracted_text += item['text'] + if isinstance(item, dict) and "text" in item: + extracted_text += item["text"] content = extracted_text # 只有当内容发生变化时才发送增量 if len(content) > len(full_content): - delta_content = extract_delta_content(content, full_content) + delta_content = extract_delta_content( + content, full_content + ) full_content = content if delta_content.strip(): @@ -693,10 +773,12 @@ async def multimodal_chat_handler(messages, model, stream, temperature, max_toke "choices": [ { "index": 0, - "delta": {"content": delta_content}, - "finish_reason": None + "delta": { + "content": delta_content + }, + "finish_reason": None, } - ] + ], } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" @@ -706,7 +788,7 @@ async def multimodal_chat_handler(messages, model, stream, temperature, max_toke "message": f"Multimodal API Error: {response.code} - {response.message}", "type": "api_error", "param": None, - "code": response.code + "code": response.code, } } yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" @@ -717,55 +799,52 @@ async def multimodal_chat_handler(messages, model, stream, temperature, max_toke "object": "chat.completion.chunk", "created": get_current_timestamp(), "model": model, - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop" - } - ] + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } yield f"data: {json.dumps(finish_data, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" except Exception as e: - error_data = { - "error": { - "message": str(e), - "type": "server_error" - } - } + error_data = {"error": {"message": str(e), "type": "server_error"}} yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" - return StreamingResponse(multimodal_event_generator(), media_type="text/event-stream") + return StreamingResponse( + multimodal_event_generator(), media_type="text/event-stream" + ) else: # 多模态非流式响应 response = MultiModalConversation.call( - model=model.replace('qwen-', 'qwen-vl-') if 'qwen-' in model else 'qwen-vl-max', + model=( + model.replace("qwen-", "qwen-vl-") + if "qwen-" in model + else "qwen-vl-max" + ), messages=dashscope_messages, stream=False, max_tokens=max_tokens, - temperature=temperature + temperature=temperature, ) if response.status_code == 200: content = None - if (hasattr(response, 'output') and - response.output and - hasattr(response.output, 'choices') and - response.output.choices is not None and - len(response.output.choices) > 0 and - 'message' in response.output.choices[0]): + if ( + hasattr(response, "output") + and response.output + and hasattr(response.output, "choices") + and response.output.choices is not None + and len(response.output.choices) > 0 + and "message" in response.output.choices[0] + ): - message = response.output.choices[0]['message'] - if 'content' in message: - content_items = message['content'] + message = response.output.choices[0]["message"] + if "content" in message: + content_items = message["content"] # 从内容项中提取文本 extracted_text = "" for item in content_items: - if isinstance(item, dict) and 'text' in item: - extracted_text += item['text'] + if isinstance(item, dict) and "text" in item: + extracted_text += item["text"] content = extracted_text @@ -774,17 +853,18 @@ async def multimodal_chat_handler(messages, model, stream, temperature, max_toke else: raise HTTPException( status_code=500, - detail="Multimodal API Response does not contain expected content" + detail="Multimodal API Response does not contain expected content", ) else: raise HTTPException( status_code=500, - detail=f"Multimodal API Error: {response.code} - {response.message}" + detail=f"Multimodal API Error: {response.code} - {response.message}", ) except Exception as e: print(f"[ERROR] Error in multimodal chat handler: {str(e)}") import traceback + print(f"[ERROR] Traceback: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(e)) @@ -797,36 +877,36 @@ async def get_models_handler(): name="通义千问 Max", description="最强大的模型", maxTokens=8192, - provider="Aliyun" + provider="Aliyun", ), ModelInfo( id="qwen-plus", name="通义千问 Plus", description="能力均衡", maxTokens=8192, - provider="Aliyun" + provider="Aliyun", ), ModelInfo( id="qwen-turbo", name="通义千问 Turbo", description="速度更快、成本更低", maxTokens=8192, - provider="Aliyun" + provider="Aliyun", ), ModelInfo( id="qwen-vl-max", name="通义万相 VL-Max", description="支持视觉理解的多模态模型", maxTokens=8192, - provider="Aliyun" + provider="Aliyun", ), ModelInfo( id="qwen-vl-plus", name="通义万相 VL-Plus", description="支持视觉理解的多模态模型", maxTokens=8192, - provider="Aliyun" - ) + provider="Aliyun", + ), ] return [model.dict() for model in models] @@ -847,14 +927,14 @@ async def get_conversation_handler(conversation_id: str): async def save_conversation_handler(data: dict): """保存或更新对话处理器""" try: - conversation_id = data.get('id') or generate_unique_id() + conversation_id = data.get("id") or generate_unique_id() conversation = { "id": conversation_id, - "title": data.get('title', '新对话'), - "messages": data.get('messages', []), + "title": data.get("title", "新对话"), + "messages": data.get("messages", []), "updatedAt": datetime.utcnow().isoformat(), - "createdAt": data.get('createdAt', datetime.utcnow().isoformat()) + "createdAt": data.get("createdAt", datetime.utcnow().isoformat()), } conversations_db[conversation_id] = conversation @@ -880,35 +960,70 @@ async def upload_file_handler(file: UploadFile = File(...)): # 允许的 MIME 类型(宽松策略) allowed_types = { # 图片 - 'image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/bmp', 'image/svg+xml', + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + "image/bmp", + "image/svg+xml", # 文本类 - 'text/plain', 'text/csv', 'text/markdown', 'text/html', 'text/xml', - 'application/json', 'application/xml', + "text/plain", + "text/csv", + "text/markdown", + "text/html", + "text/xml", + "application/json", + "application/xml", # PDF - 'application/pdf', + "application/pdf", # Office 文档 - 'application/msword', - 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - 'application/vnd.ms-excel', - 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - 'application/vnd.ms-powerpoint', - 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.ms-powerpoint", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", } # 允许的扩展名(兜底:MIME 类型可能被浏览器误判) allowed_extensions = { - '.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', - '.txt', '.md', '.csv', '.json', '.xml', '.yaml', '.yml', '.log', - '.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', - '.py', '.js', '.ts', '.html', '.css' + ".jpg", + ".jpeg", + ".png", + ".gif", + ".webp", + ".bmp", + ".txt", + ".md", + ".csv", + ".json", + ".xml", + ".yaml", + ".yml", + ".log", + ".pdf", + ".doc", + ".docx", + ".xls", + ".xlsx", + ".ppt", + ".pptx", + ".py", + ".js", + ".ts", + ".html", + ".css", } file_extension = Path(file.filename).suffix.lower() - if file.content_type not in allowed_types and file_extension not in allowed_extensions: + if ( + file.content_type not in allowed_types + and file_extension not in allowed_extensions + ): raise HTTPException( status_code=400, - detail=f"不支持的文件类型: {file.content_type}({file_extension})" + detail=f"不支持的文件类型: {file.content_type}({file_extension})", ) # 生成唯一文件名 @@ -926,7 +1041,7 @@ async def upload_file_handler(file: UploadFile = File(...)): "url": file_url, "name": file.filename, "size": len(content), - "mimeType": file.content_type + "mimeType": file.content_type, } print(f"[INFO] File uploaded: {result}") @@ -944,17 +1059,20 @@ def serve_upload_handler(filename: str): raise HTTPException(status_code=404, detail="文件不存在") from fastapi.responses import FileResponse + return FileResponse(str(file_path)) async def stop_generation_handler(message_id: str = None): """停止生成处理器""" - message = f"已发出停止指令,消息ID: {message_id}" if message_id else "已发出停止指令" + message = ( + f"已发出停止指令,消息ID: {message_id}" if message_id else "已发出停止指令" + ) return {"success": True, "message": message} # ── 平台统一接口别名(供 main.py 的 _platform 动态调用)───────────── # main.py 通过 _platform.chat_handler / _platform.models_handler 调用, # 各平台模块需暴露相同名称的函数。 -chat_handler = chat_endpoint_handler # 聊天接口别名 -models_handler = get_models_handler # 模型列表别名 +chat_handler = chat_endpoint_handler # 聊天接口别名 +models_handler = get_models_handler # 模型列表别名 diff --git a/server/api/chat_routes_glm.py b/server/api/chat_routes_glm.py index 57cb1b9..c0c10ca 100644 --- a/server/api/chat_routes_glm.py +++ b/server/api/chat_routes_glm.py @@ -2,13 +2,16 @@ GLM-4.6V 平台路由处理器(zai-sdk) 所有智谱 GLM 相关逻辑均集中在此文件,main.py 无感知任何平台细节。 """ + +import json import os import sys -import json from pathlib import Path + from fastapi import HTTPException from fastapi.responses import JSONResponse, StreamingResponse -from utils.helpers import get_current_timestamp, generate_unique_id + +from utils.helpers import generate_unique_id, get_current_timestamp from utils.logger import log_info @@ -19,7 +22,9 @@ def init(): """ api_key = os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY") if not api_key: - raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY(在 https://open.bigmodel.cn 申请)") + raise ValueError( + "GLM 模式需要设置环境变量 ZHIPU_API_KEY(在 https://open.bigmodel.cn 申请)" + ) log_info(f"[GLM] 初始化完成,ZHIPU_API_KEY 已配置") @@ -28,47 +33,61 @@ async def chat_handler(body: dict): GLM 聊天处理器(对外接口与百炼 chat_endpoint_handler 完全兼容)。 流式/非流式自动适配,支持图像、文档附件、联网搜索、深度思考。 """ - from utils.glm_adapter import glm_stream_generator, glm_chat_sync + from utils.glm_adapter import glm_chat_sync, glm_stream_generator if not isinstance(body, dict): raise HTTPException(status_code=400, detail="请求体必须是 JSON 对象") - messages = body.get("messages", []) - model = body.get("model", "glm-4.6v") - stream = body.get("stream", True) + messages = body.get("messages", []) + model = body.get("model", "glm-4.6v") + stream = body.get("stream", True) temperature = body.get("temperature", 0.7) - max_tokens = body.get("max_tokens", body.get("maxTokens", 2000)) - web_search = body.get("webSearch", False) or body.get("deepSearch", False) - deep_think = body.get("deepThinking", False) - files = body.get("files", []) + max_tokens = body.get("max_tokens", body.get("maxTokens", 2000)) + web_search = body.get("webSearch", False) or body.get("deepSearch", False) + deep_think = body.get("deepThinking", False) + files = body.get("files", []) # 兼容前端简化格式(非 messages 结构) if not messages: - msg_text = body.get("message", "") + msg_text = body.get("message", "") sys_prompt = body.get("systemPrompt", "你是一个智能助手。") - user_content = msg_text if isinstance(msg_text, list) else [{"type": "text", "text": msg_text}] + user_content = ( + msg_text + if isinstance(msg_text, list) + else [{"type": "text", "text": msg_text}] + ) messages = [ {"role": "system", "content": sys_prompt}, - {"role": "user", "content": user_content}, + {"role": "user", "content": user_content}, ] - log_info(f"[GLM] model={model} stream={stream} web_search={web_search} " - f"thinking={deep_think} files={len(files)} msgs={len(messages)}") + log_info( + f"[GLM] model={model} stream={stream} web_search={web_search} " + f"thinking={deep_think} files={len(files)} msgs={len(messages)}" + ) if stream: return StreamingResponse( glm_stream_generator( - messages=messages, model=model, temperature=temperature, - max_tokens=max_tokens, files=files or None, - web_search=web_search, deep_thinking=deep_think, + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + files=files or None, + web_search=web_search, + deep_thinking=deep_think, ), media_type="text/event-stream", ) result = glm_chat_sync( - messages=messages, model=model, temperature=temperature, - max_tokens=max_tokens, files=files or None, - web_search=web_search, deep_thinking=deep_think, + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + files=files or None, + web_search=web_search, + deep_thinking=deep_think, ) resp = { "id": f"chatcmpl-{generate_unique_id()}", diff --git a/server/init_logging.py b/server/init_logging.py index 06ae23d..6b71ff8 100644 --- a/server/init_logging.py +++ b/server/init_logging.py @@ -2,9 +2,12 @@ """ 初始化日志系统 """ + import os + from utils.logger import setup_global_logger + def init_logging_system(): """ 初始化日志系统 @@ -26,13 +29,12 @@ def init_logging_system(): # 设置全局日志系统 logger = setup_global_logger( - name="ai-chat-api", - log_level=log_level, - log_dir=log_dir + name="ai-chat-api", log_level=log_level, log_dir=log_dir ) return logger + if __name__ == "__main__": logger = init_logging_system() - logger.info("Logging system initialized successfully") \ No newline at end of file + logger.info("Logging system initialized successfully") diff --git a/server/main.py b/server/main.py index f94379e..e56880c 100644 --- a/server/main.py +++ b/server/main.py @@ -10,6 +10,7 @@ AI Chat API Server — 主入口(纯基础设施层) - 百炼 DashScope → api/chat_routes.py - 智谱 GLM-4.6V → api/chat_routes_glm.py + utils/glm_adapter.py """ + import os import sys from datetime import datetime, timezone @@ -27,10 +28,10 @@ if _venv_lib.exists(): # ── 第三方导入 ──────────────────────────────────────────────────────── from dotenv import load_dotenv -from fastapi import FastAPI, File, UploadFile, Request +from fastapi import FastAPI, File, Request, UploadFile from fastapi.responses import JSONResponse -sys.path.append('/home/mt/project/ai-chat-ui/server') +sys.path.append("/home/mt/project/ai-chat-ui/server") # ── 工具/日志(与平台无关)─────────────────────────────────────────── from utils.helpers import log_response @@ -55,15 +56,11 @@ else: _platform.init() # 各平台自行完成初始化(API Key 校验等) # 通用路由处理器(文件上传、会话管理等,与平台无关,统一用百炼路由中的实现) -from api.chat_routes import ( - get_conversations_handler, - get_conversation_handler, - save_conversation_handler, - delete_conversation_handler, - upload_file_handler, - serve_upload_handler, - stop_generation_handler, -) +from api.chat_routes import (delete_conversation_handler, + get_conversation_handler, + get_conversations_handler, + save_conversation_handler, serve_upload_handler, + stop_generation_handler, upload_file_handler) # ── FastAPI 应用 ────────────────────────────────────────────────────── app = FastAPI( @@ -80,7 +77,9 @@ async def logging_middleware(request: Request, call_next): response = await call_next(request) ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000 icon = "✅" if response.status_code < 400 else "❌" - logger.info(f"{icon} {request.method} {request.url.path} | 状态: {response.status_code} | 耗时: {ms:.0f}ms") + logger.info( + f"{icon} {request.method} {request.url.path} | 状态: {response.status_code} | 耗时: {ms:.0f}ms" + ) log_response(response.status_code, ms) response.headers["X-Process-Time"] = f"{ms:.2f}ms" return response @@ -88,6 +87,7 @@ async def logging_middleware(request: Request, call_next): # ── 路由注册 ────────────────────────────────────────────────────────── + @app.get("/health") async def health_check(): return { @@ -115,6 +115,7 @@ async def get_models(): # ── 通用路由(与平台无关)──────────────────────────────────────────── + @app.get("/api/chat-ui/conversations") async def get_conversations(): return await get_conversations_handler() @@ -158,6 +159,7 @@ async def stop_generation_by_id(message_id: str): # ── 程序入口 ────────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn + port = int(os.getenv("PORT", 8000)) print("=" * 55) print(f" AI Chat Server v3.0 启动中...") @@ -165,4 +167,4 @@ if __name__ == "__main__": print(f" 监听端口 : {port}") print(f" 切换平台 : 修改 .env 中 LLM_BACKEND=glm|dashscope,重启") print("=" * 55) - uvicorn.run(app, host="0.0.0.0", port=port) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/server/models/__init__.py b/server/models/__init__.py index a3bef1d..c09de5a 100644 --- a/server/models/__init__.py +++ b/server/models/__init__.py @@ -1 +1 @@ -# models/__init__.py \ No newline at end of file +# models/__init__.py diff --git a/server/models/chat_models.py b/server/models/chat_models.py index 5bf81bc..b0a21e8 100644 --- a/server/models/chat_models.py +++ b/server/models/chat_models.py @@ -1,14 +1,18 @@ """ 数据模型定义 """ + +from typing import Any, Dict, List, Optional, Union + from pydantic import BaseModel -from typing import Dict, List, Optional, Any, Union class ChatMessageContentItem(BaseModel): type: str # "text" or "image_url" text: Optional[str] = None - image_url: Optional[Dict[str, str]] = None # {"url": "...", "detail": "auto|low|high"} + image_url: Optional[Dict[str, str]] = ( + None # {"url": "...", "detail": "auto|low|high"} + ) class ChatMessage(BaseModel): @@ -35,4 +39,4 @@ class ModelInfo(BaseModel): name: str description: str maxTokens: int - provider: str \ No newline at end of file + provider: str diff --git a/server/utils/__init__.py b/server/utils/__init__.py index 3c74ea6..083ac2d 100644 --- a/server/utils/__init__.py +++ b/server/utils/__init__.py @@ -1 +1 @@ -# utils/__init__.py \ No newline at end of file +# utils/__init__.py diff --git a/server/utils/file_cache.py b/server/utils/file_cache.py index 619035e..390a163 100644 --- a/server/utils/file_cache.py +++ b/server/utils/file_cache.py @@ -1,10 +1,11 @@ """ GLM 文件 ID 缓存(基于磁盘的简单 KV,sha256 → file_id,3天有效期) """ + import hashlib import json -import time import threading +import time from pathlib import Path _CACHE_FILE = Path(__file__).parent.parent / "uploads" / ".glm_file_cache.json" @@ -24,7 +25,9 @@ def _load() -> dict: def _save(data: dict) -> None: try: _CACHE_FILE.parent.mkdir(parents=True, exist_ok=True) - _CACHE_FILE.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + _CACHE_FILE.write_text( + json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8" + ) except Exception as e: print(f"[file_cache] 写入失败:{e}") diff --git a/server/utils/glm_adapter.py b/server/utils/glm_adapter.py index cbfa796..3040cc3 100644 --- a/server/utils/glm_adapter.py +++ b/server/utils/glm_adapter.py @@ -3,10 +3,11 @@ GLM-4.6V 适配层(基于 zai-sdk) SDK:pip install zai-sdk 模型:glm-4.6v(支持文本/图像/文档/深度思考) """ + +import base64 +import json import os import sys -import json -import base64 import threading from pathlib import Path from typing import AsyncGenerator @@ -15,7 +16,9 @@ from typing import AsyncGenerator # ── 自动注入 venv site-packages ─────────────────────────────────────── def _ensure_venv(): server_dir = Path(__file__).parent.parent - for sp in sorted((server_dir / ".venv" / "lib").glob("python*/site-packages"), reverse=True): + for sp in sorted( + (server_dir / ".venv" / "lib").glob("python*/site-packages"), reverse=True + ): if sp.exists() and str(sp) not in sys.path: sys.path.insert(0, str(sp)) print(f"[GLM] venv 注入:{sp}") @@ -34,7 +37,7 @@ def get_client(): from zai import ZhipuAiClient except ImportError: raise ImportError("GLM 模式需要安装 zai-sdk:.venv/bin/pip install zai-sdk") - api_key = os.getenv("ZHIPU_API_KEY").strip() or os.getenv("GLM_API_KEY").strip() + api_key = os.getenv("ZHIPU_API_KEY").strip() or os.getenv("GLM_API_KEY").strip() if not api_key: raise ValueError("GLM 模式需要设置环境变量 ZHIPU_API_KEY") _client = ZhipuAiClient(api_key=api_key) @@ -43,15 +46,15 @@ def get_client(): # ── 模型映射 ────────────────────────────────────────────────────────── -DEFAULT_TEXT_MODEL = "glm-4.5-Air" # glm-4.6 文本统一模型 +DEFAULT_TEXT_MODEL = "glm-4.5-Air" # glm-4.6 文本统一模型 DEFAULT_VISION_MODEL = "glm-4.5-Air" MODEL_MAP = { - "qwen-max": "glm-4.5-Air", - "qwen-plus": "glm-4.5-Air", - "qwen-turbo": "glm-4.5-Air", - "qwen-vl-max": "glm-4.5-Air", - "qwen-vl-plus": "glm-4.5-Air", + "qwen-max": "glm-4.5-Air", + "qwen-plus": "glm-4.5-Air", + "qwen-turbo": "glm-4.5-Air", + "qwen-vl-max": "glm-4.5-Air", + "qwen-vl-plus": "glm-4.5-Air", } @@ -63,7 +66,9 @@ def resolve_model(model: str, has_vision: bool = False) -> str: # ── 文件上传(含 file_id 缓存)─────────────────────────────────────── def upload_file_for_extract(local_path: Path) -> str: - from utils.file_cache import sha256_of_file, get as cache_get, set as cache_set + from utils.file_cache import get as cache_get + from utils.file_cache import set as cache_set + from utils.file_cache import sha256_of_file file_hash = sha256_of_file(local_path) cached = cache_get(file_hash) @@ -73,18 +78,20 @@ def upload_file_for_extract(local_path: Path) -> str: client = get_client() mime_map = { - ".pdf": "application/pdf", + ".pdf": "application/pdf", ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - ".doc": "application/msword", + ".doc": "application/msword", ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - ".xls": "application/vnd.ms-excel", + ".xls": "application/vnd.ms-excel", ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", - ".ppt": "application/vnd.ms-powerpoint", + ".ppt": "application/vnd.ms-powerpoint", } mime = mime_map.get(local_path.suffix.lower(), "application/octet-stream") print(f"[GLM] 上传文件:{local_path.name}({mime})") with open(local_path, "rb") as f: - file_obj = client.files.create(file=(local_path.name, f, mime), purpose="file-extract") + file_obj = client.files.create( + file=(local_path.name, f, mime), purpose="file-extract" + ) file_id = file_obj.id cache_set(file_hash, file_id) print(f"[GLM] 上传成功:file_id={file_id}") @@ -94,7 +101,9 @@ def upload_file_for_extract(local_path: Path) -> str: # ── 图像编码 ───────────────────────────────────────────────────────── def encode_image(image_source: str) -> dict: """将图像来源统一转为 OpenAI image_url 格式""" - if image_source.startswith("data:image") or image_source.startswith(("http://", "https://")): + if image_source.startswith("data:image") or image_source.startswith( + ("http://", "https://") + ): return {"type": "image_url", "image_url": {"url": image_source}} # 本地路径 → base64 local = Path(image_source.replace("file://", "").lstrip("/")) @@ -138,7 +147,9 @@ def build_glm_messages(messages: list, files: list | None = None) -> tuple[list, elif t == "image_url": has_vision = True img_val = item.get("image_url", "") - img_src = img_val.get("url", "") if isinstance(img_val, dict) else img_val + img_src = ( + img_val.get("url", "") if isinstance(img_val, dict) else img_val + ) new_content.append(encode_image(img_src)) else: new_content.append({"type": "text", "text": str(item)}) @@ -172,9 +183,13 @@ def build_glm_messages(messages: list, files: list | None = None) -> tuple[list, fid = upload_file_for_extract(local) inserts.append({"type": "file", "file": {"file_id": fid}}) except Exception as e: - inserts.append({"type": "text", "text": f"[文件上传失败:{filename},{e}]"}) + inserts.append( + {"type": "text", "text": f"[文件上传失败:{filename},{e}]"} + ) else: - inserts.append({"type": "text", "text": f"[附件:{filename},类型:{suffix}]"}) + inserts.append( + {"type": "text", "text": f"[附件:{filename},类型:{suffix}]"} + ) if inserts: for i in range(len(glm_messages) - 1, -1, -1): @@ -195,6 +210,7 @@ def build_glm_messages(messages: list, files: list | None = None) -> tuple[list, # ── 哨兵对象 ───────────────────────────────────────────────────────── _SENTINEL = object() + # ── 流式调用 ──────────────────────────────────────────────────────── async def glm_stream_generator( messages: list, @@ -213,7 +229,7 @@ async def glm_stream_generator( import asyncio import queue - from utils.helpers import get_current_timestamp, generate_unique_id + from utils.helpers import generate_unique_id, get_current_timestamp glm_msgs, has_vision = build_glm_messages(messages, files) actual_model = resolve_model(model, has_vision) @@ -221,13 +237,18 @@ async def glm_stream_generator( extra_kwargs: dict = {} if web_search: extra_kwargs["tools"] = [ - {"type": "web_search", "web_search": {"enable":True,"search_result": True}} + { + "type": "web_search", + "web_search": {"enable": True, "search_result": True}, + } ] if not deep_thinking: # 智普默认开启思考模式,所以要用非门(不知道“非门”描述是否准确。前端选择开启思考模式,这里不做变动。前端选择关闭思考模式,这里关闭。) extra_kwargs["thinking"] = {"type": "disabled"} - print(f"[GLM] 流式请求:model={actual_model} vision={has_vision} " - f"web_search={web_search} thinking={deep_thinking}") + print( + f"[GLM] 流式请求:model={actual_model} vision={has_vision} " + f"web_search={web_search} thinking={deep_thinking}" + ) chunk_queue: queue.Queue = queue.Queue(maxsize=128) @@ -254,8 +275,8 @@ async def glm_stream_generator( loop = asyncio.get_running_loop() - full_reasoning = "" # 累计思考内容(用于判断是否首次) - full_content = "" # 累计正式回答(用于判断是否首次) + full_reasoning = "" # 累计思考内容(用于判断是否首次) + full_content = "" # 累计正式回答(用于判断是否首次) while True: item = await loop.run_in_executor(None, chunk_queue.get) @@ -271,7 +292,7 @@ async def glm_stream_generator( try: delta = item.choices[0].delta reasoning = getattr(delta, "reasoning_content", "") or "" - text = getattr(delta, "content", "") or "" + text = getattr(delta, "content", "") or "" delta_str = "" @@ -300,7 +321,9 @@ async def glm_stream_generator( "object": "chat.completion.chunk", "created": get_current_timestamp(), "model": actual_model, - "choices": [{"index": 0, "delta": {"content": delta_str}, "finish_reason": None}], + "choices": [ + {"index": 0, "delta": {"content": delta_str}, "finish_reason": None} + ], } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" @@ -318,7 +341,6 @@ async def glm_stream_generator( yield "data: [DONE]\n\n" - # ── 非流式调用 ──────────────────────────────────────────────────────── def glm_chat_sync( messages: list, @@ -334,13 +356,12 @@ def glm_chat_sync( extra_kwargs: dict = {} if web_search: - extra_kwargs["tools"] = [{ - "type": "web_search", - "web_search": { - "enable": True, - "search_result": True - } -}] + extra_kwargs["tools"] = [ + { + "type": "web_search", + "web_search": {"enable": True, "search_result": True}, + } + ] if deep_thinking: extra_kwargs["thinking"] = {"type": "enabled"} @@ -358,8 +379,8 @@ def glm_chat_sync( usage = None if hasattr(resp, "usage") and resp.usage: usage = { - "promptTokens": resp.usage.prompt_tokens, + "promptTokens": resp.usage.prompt_tokens, "completionTokens": resp.usage.completion_tokens, - "totalTokens": resp.usage.total_tokens, + "totalTokens": resp.usage.total_tokens, } return {"content": content, "model": actual_model, "usage": usage} diff --git a/server/utils/helpers.py b/server/utils/helpers.py index c890048..b136eb1 100644 --- a/server/utils/helpers.py +++ b/server/utils/helpers.py @@ -1,13 +1,15 @@ """ 通用工具函数 """ -import os + import json +import os import uuid from datetime import datetime from typing import Dict -from .logger import log_request_info, log_response_info, log_error_detail, log_chat_interaction +from .logger import (log_chat_interaction, log_error_detail, log_request_info, + log_response_info) def get_current_timestamp(): @@ -20,14 +22,16 @@ def generate_unique_id(): return str(uuid.uuid4()) -def format_api_response(content: str, conversation_id: str = None, model: str = "qwen-plus"): +def format_api_response( + content: str, conversation_id: str = None, model: str = "qwen-plus" +): """格式化API响应""" return { "id": generate_unique_id(), "conversationId": conversation_id or generate_unique_id(), "content": content, "model": model, - "createdAt": get_current_timestamp() + "createdAt": get_current_timestamp(), } @@ -44,5 +48,5 @@ def log_response(status_code: int, process_time: float): def extract_delta_content(full_content: str, previous_content: str) -> str: """提取增量内容""" if len(full_content) > len(previous_content): - return full_content[len(previous_content):] - return "" \ No newline at end of file + return full_content[len(previous_content) :] + return "" diff --git a/server/utils/logger.py b/server/utils/logger.py index 09a3a68..9e68d00 100644 --- a/server/utils/logger.py +++ b/server/utils/logger.py @@ -2,20 +2,27 @@ 统一日志管理系统 提供结构化日志记录功能,支持不同日志级别、文件输出、轮转等 """ + +import json import logging import os import sys from datetime import datetime -from pathlib import Path from logging.handlers import RotatingFileHandler -import json +from pathlib import Path class LoggerSetup: """日志系统配置类""" - def __init__(self, name: str = "ai-chat-server", log_level: str = "INFO", - log_dir: str = "logs", max_bytes: int = 10 * 1024 * 1024, backup_count: int = 5): + def __init__( + self, + name: str = "ai-chat-server", + log_level: str = "INFO", + log_dir: str = "logs", + max_bytes: int = 10 * 1024 * 1024, + backup_count: int = 5, + ): """ 初始化日志系统 @@ -37,7 +44,7 @@ class LoggerSetup: # 设置日志格式(去掉 funcName:lineno,保持人类可读性) self.formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) # 创建logger实例 @@ -66,7 +73,7 @@ class LoggerSetup: str(log_file), maxBytes=self.max_bytes, backupCount=self.backup_count, - encoding='utf-8' + encoding="utf-8", ) file_handler.setLevel(self.log_level) file_handler.setFormatter(self.formatter) @@ -83,9 +90,13 @@ class LoggerSetup: _logger_instance = None -def setup_global_logger(name: str = "ai-chat-server", log_level: str = "INFO", - log_dir: str = "logs", max_bytes: int = 10 * 1024 * 1024, - backup_count: int = 5): +def setup_global_logger( + name: str = "ai-chat-server", + log_level: str = "INFO", + log_dir: str = "logs", + max_bytes: int = 10 * 1024 * 1024, + backup_count: int = 5, +): """ 设置全局日志系统 @@ -170,12 +181,17 @@ def log_structured(level: str, message: str, **details): # 转换为更易读的格式 detail_str = ", ".join(f"{k}={v}" for k, v in details.items() if v) formatted_msg = f"[{message}] {detail_str}" - + getattr(logger, level.lower())(formatted_msg) -def log_request_info(method: str, path: str, client_ip: str = "unknown", - user_agent: str = "", referer: str = ""): +def log_request_info( + method: str, + path: str, + client_ip: str = "unknown", + user_agent: str = "", + referer: str = "", +): """记录请求信息日志""" log_structured( "info", @@ -184,12 +200,17 @@ def log_request_info(method: str, path: str, client_ip: str = "unknown", path=path, client_ip=client_ip, user_agent=user_agent, - referer=referer + referer=referer, ) -def log_response_info(status_code: int, process_time: float, path: str = "", - method: str = "", client_ip: str = ""): +def log_response_info( + status_code: int, + process_time: float, + path: str = "", + method: str = "", + client_ip: str = "", +): """记录响应信息日志""" log_structured( "info", @@ -198,37 +219,52 @@ def log_response_info(status_code: int, process_time: float, path: str = "", process_time_ms=process_time, path=path, method=method, - client_ip=client_ip + client_ip=client_ip, ) -def log_error_detail(error_type: str, error_message: str, traceback_info: str = "", - context: dict = None): +def log_error_detail( + error_type: str, error_message: str, traceback_info: str = "", context: dict = None +): """记录详细的错误信息""" log_structured( "error", f"{error_type}: {error_message}", traceback=traceback_info, - context=context or {} + context=context or {}, ) -def log_chat_interaction(user_input: str, ai_response: str, model: str = "", - conversation_id: str = "", tokens_used: dict = None): +def log_chat_interaction( + user_input: str, + ai_response: str, + model: str = "", + conversation_id: str = "", + tokens_used: dict = None, +): """记录聊天交互日志""" log_structured( "info", "Chat Interaction", - user_input=user_input[:100] + "..." if len(user_input) > 100 else user_input, # 截断长输入 - ai_response=ai_response[:100] + "..." if len(ai_response) > 100 else ai_response, + user_input=( + user_input[:100] + "..." if len(user_input) > 100 else user_input + ), # 截断长输入 + ai_response=( + ai_response[:100] + "..." if len(ai_response) > 100 else ai_response + ), model=model, conversation_id=conversation_id, - tokens_used=tokens_used + tokens_used=tokens_used, ) -def log_system_status(status: str, uptime: float = 0, cpu_usage: float = 0, - memory_usage: float = 0, disk_usage: float = 0): +def log_system_status( + status: str, + uptime: float = 0, + cpu_usage: float = 0, + memory_usage: float = 0, + disk_usage: float = 0, +): """记录系统状态日志""" log_structured( "info", @@ -237,5 +273,5 @@ def log_system_status(status: str, uptime: float = 0, cpu_usage: float = 0, uptime_seconds=uptime, cpu_percent=cpu_usage, memory_percent=memory_usage, - disk_percent=disk_usage - ) \ No newline at end of file + disk_percent=disk_usage, + ) diff --git a/server/utils/test_glm_search.py b/server/utils/test_glm_search.py index d6f8729..7acb47b 100644 --- a/server/utils/test_glm_search.py +++ b/server/utils/test_glm_search.py @@ -1,30 +1,36 @@ +import asyncio import os import sys -import asyncio from pathlib import Path # Add project root to sys.path root_dir = Path(__file__).parent sys.path.insert(0, str(root_dir)) -from utils.glm_adapter import glm_stream_generator, _ensure_venv, glm_chat_sync - # Set API key from .env if needed from dotenv import load_dotenv + +from utils.glm_adapter import _ensure_venv, glm_chat_sync, glm_stream_generator + load_dotenv() + async def test_stream(): msgs = [{"role": "user", "content": "今天北京天气怎样?"}] print("Testing stream...") - async for chunk in glm_stream_generator(msgs, "glm-4.5-air", 0.7, 1024, web_search=True): + async for chunk in glm_stream_generator( + msgs, "glm-4.5-air", 0.7, 1024, web_search=True + ): print(chunk, end="") + def test_sync(): msgs = [{"role": "user", "content": "今天几号?武汉天气怎样?"}] print("Testing sync...") res = glm_chat_sync(msgs, "glm-4.5-air", 0.7, 1024, web_search=True) print(res) + if __name__ == "__main__": _ensure_venv() # test_sync()