diff --git a/.gitignore b/.gitignore index ba00c0d..5b9e276 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,13 @@ uploads .venv __pycache__ .claude +<<<<<<< HEAD *.db +======= +.trae +.agent +.agents +>>>>>>> feat/database # Editor directories and files .vscode/* @@ -28,6 +34,7 @@ __pycache__ *.njsproj *.sln *.sw? +<<<<<<< HEAD # Skills .skills @@ -35,3 +42,9 @@ __pycache__ .agents .trae skills-lock.json +======= +*.db +server/data/*.db + +skills-lock.json +>>>>>>> feat/database diff --git a/server/adapters/base.py b/server/adapters/base.py index a7ea570..13f2b6b 100644 --- a/server/adapters/base.py +++ b/server/adapters/base.py @@ -30,7 +30,7 @@ class ModelInfo: "maxTokens": self.max_tokens, "provider": self.provider, "supports_thinking": self.supports_thinking, - "supports_web_Search": self.supports_web_search, + "supports_web_search": self.supports_web_search, "supports_vision": self.supports_vision, "supports_files": self.supports_files, } diff --git a/server/adapters/glm_adapter.py b/server/adapters/glm_adapter.py index 59cc659..d39b8b9 100644 --- a/server/adapters/glm_adapter.py +++ b/server/adapters/glm_adapter.py @@ -138,6 +138,11 @@ class GLMAdapter(BaseAdapter): logger.info( f"[GLM] 深度思考已启用: extra_kwargs['thinking'] = {extra_kwargs['thinking']}" ) + else: + extra_kwargs["thinking"] = {"type": "disabled"} + logger.info( + f"[GLM] 深度思考已禁用: extra_kwargs['thinking'] = {extra_kwargs['thinking']}" + ) if extra_kwargs: logger.info( diff --git a/server/api/conversation_routes.py b/server/api/conversation_routes.py index 60ae8ef..5319831 100644 --- a/server/api/conversation_routes.py +++ b/server/api/conversation_routes.py @@ -28,10 +28,10 @@ upload_dir.mkdir(exist_ok=True) # ── 会话管理 ───────────────────────────────────────────────────── -async def get_conversations_handler(): +async def get_conversations_handler(user_id: str = "default"): """获取所有对话处理器""" db = get_db() - return db.list_conversations() + return db.list_conversations(user_id) async def get_conversation_handler(conversation_id: str): @@ -65,8 +65,38 @@ async def save_conversation_handler(data: dict): async def delete_conversation_handler(conversation_id: str): - """删除对话处理器""" + """删除对话处理器(同时删除关联的 OSS 文件)""" db = get_db() + + # 先获取会话数据,提取 OSS 文件 URL + conversation = db.get_conversation(conversation_id) + if not conversation: + raise HTTPException(status_code=404, detail="对话不存在") + + # 提取所有 OSS 文件 URL + oss_urls = _extract_oss_urls_from_conversation(conversation) + + # 删除 OSS 文件 + if oss_urls: + try: + from utils.oss_uploader import delete_files, extract_object_key_from_url + + object_keys = [] + for url in oss_urls: + key = extract_object_key_from_url(url) + if key: + object_keys.append(key) + + if object_keys: + result = delete_files(object_keys) + log_info(f"[删除会话] OSS 文件清理结果: 删除 {len(result['deleted'])} 个, 失败 {len(result['failed'])} 个") + if result['failed']: + log_error(f"[删除会话] OSS 文件删除失败: {result['failed']}") + except Exception as e: + log_error(f"[删除会话] OSS 文件删除异常: {e}") + # 继续删除会话,即使 OSS 删除失败 + + # 删除数据库记录 success = db.delete_conversation(conversation_id) if success: return {"success": True, "message": "删除成功"} @@ -74,6 +104,85 @@ async def delete_conversation_handler(conversation_id: str): raise HTTPException(status_code=404, detail="对话不存在") +def _extract_oss_urls_from_conversation(conversation: dict) -> list: + """ + 从会话消息中提取所有 OSS 文件 URL + + 消息结构: + - content.images: 图片附件列表 + - content.files: 文件附件列表 + 每个附件包含 url 字段 + """ + urls = [] + messages = conversation.get("messages", []) + + for message in messages: + content = message.get("content") + if not content: + continue + + # content 可能是字符串(需要解析)或已解析的字典 + if isinstance(content, str): + try: + content = json.loads(content) + except json.JSONDecodeError: + continue + + # 提取图片附件 + images = content.get("images", []) + for img in images: + url = img.get("url") + if url and url not in urls: + urls.append(url) + + # 提取文件附件 + files = content.get("files", []) + for f in files: + url = f.get("url") + if url and url not in urls: + urls.append(url) + + return urls + + +async def update_conversation_handler(conversation_id: str, data: dict): + """部分更新对话处理器""" + db = get_db() + result = db.update_conversation(conversation_id, data) + if result: + return result + else: + raise HTTPException(status_code=404, detail="对话不存在") + + +# ── 消息管理 ───────────────────────────────────────────────────── + + +async def add_message_handler(conversation_id: str, message: dict): + """添加消息到对话处理器""" + db = get_db() + # 检查对话是否存在 + existing = db.get_conversation(conversation_id) + if not existing: + raise HTTPException(status_code=404, detail="对话不存在") + return db.add_message(conversation_id, message) + + +async def update_message_handler(conversation_id: str, message_id: str, data: dict): + """更新消息处理器""" + db = get_db() + # 检查对话是否存在 + existing = db.get_conversation(conversation_id) + if not existing: + raise HTTPException(status_code=404, detail="对话不存在") + + result = db.update_message(message_id, data) + if result: + return result + else: + raise HTTPException(status_code=404, detail="消息不存在") + + # ── 文件上传 ───────────────────────────────────────────────────── diff --git a/server/database/__init__.py b/server/database/__init__.py index db014e9..9aa8dbe 100644 --- a/server/database/__init__.py +++ b/server/database/__init__.py @@ -1,91 +1,3 @@ -""" -数据库模块 +from .db import Database, get_db, init_db -提供 SQLite 数据库连接和会话管理功能。 -""" - -import os -import sqlite3 -from pathlib import Path -from contextlib import contextmanager -from typing import Optional - -# 默认数据库路径 -DEFAULT_DB_PATH = Path(__file__).parent.parent / "data" / "chat.db" - - -def init_db(db_path: Optional[str] = None): - """ - 初始化数据库 - 创建必要的表结构 - """ - if db_path is None: - db_path = os.getenv("DB_PATH", str(DEFAULT_DB_PATH)) - - # 确保数据目录存在 - Path(db_path).parent.mkdir(parents=True, exist_ok=True) - - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - # 创建会话表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS conversations ( - id TEXT PRIMARY KEY, - title TEXT, - model TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - - # 创建消息表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY, - conversation_id TEXT, - role TEXT, - content TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE - ) - """) - - # 创建文件表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS files ( - id TEXT PRIMARY KEY, - conversation_id TEXT, - filename TEXT, - file_path TEXT, - file_type TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE - ) - """) - - conn.commit() - conn.close() - - print(f"[数据库] 初始化完成: {db_path}") - - -@contextmanager -def get_db(db_path: Optional[str] = None): - """ - 获取数据库连接的上下文管理器 - - 用法: - with get_db() as db: - cursor = db.execute("SELECT * FROM conversations") - rows = cursor.fetchall() - """ - if db_path is None: - db_path = os.getenv("DB_PATH", str(DEFAULT_DB_PATH)) - - conn = sqlite3.connect(db_path) - conn.row_factory = sqlite3.Row - try: - yield conn - finally: - conn.close() +__all__ = ["Database", "get_db", "init_db"] \ No newline at end of file diff --git a/server/database/db.py b/server/database/db.py new file mode 100644 index 0000000..7915ae1 --- /dev/null +++ b/server/database/db.py @@ -0,0 +1,401 @@ +""" +SQLite 数据库模块 - 会话持久化存储 + +提供会话和消息的 CRUD 操作,支持多用户(预留 user_id 字段)。 +""" + +import json +import sqlite3 +import threading +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional + +# 数据库文件路径 +DB_PATH = Path(__file__).parent.parent / "data" / "chat.db" + +# 线程本地存储,确保每个线程使用独立的连接 +_thread_local = threading.local() + +# 全局数据库实例 +_db_instance: Optional["Database"] = None + + +class Database: + """SQLite 数据库管理类""" + + def __init__(self, db_path: Path): + self.db_path = db_path + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_tables() + + def _get_connection(self) -> sqlite3.Connection: + """获取当前线程的数据库连接""" + if not hasattr(_thread_local, "connection"): + conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + conn.row_factory = sqlite3.Row + # 启用外键约束 + conn.execute("PRAGMA foreign_keys = ON") + _thread_local.connection = conn + return _thread_local.connection + + def _init_tables(self): + """初始化数据库表结构""" + conn = self._get_connection() + cursor = conn.cursor() + + # 创建会话表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + user_id TEXT DEFAULT 'default', + title TEXT DEFAULT '新对话', + created_at INTEGER, + updated_at INTEGER, + pinned INTEGER DEFAULT 0, + archived INTEGER DEFAULT 0, + settings TEXT + ) + """) + + # 创建消息表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp INTEGER, + feedback TEXT, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE + ) + """) + + # 创建索引 + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_conversation + ON messages(conversation_id) + """) + + # 检查并添加缺失的列(迁移旧数据库 - conversations 表) + cursor.execute("PRAGMA table_info(conversations)") + conv_columns = [col[1] for col in cursor.fetchall()] + + conv_migrations = [ + ('user_id', "TEXT DEFAULT 'default'"), + ('pinned', "INTEGER DEFAULT 0"), + ('archived', "INTEGER DEFAULT 0"), + ('settings', "TEXT"), + ] + + for col_name, col_def in conv_migrations: + if col_name not in conv_columns: + cursor.execute(f"ALTER TABLE conversations ADD COLUMN {col_name} {col_def}") + print(f"[数据库] conversations 表已添加 {col_name} 列") + + # 检查并添加缺失的列(迁移旧数据库 - messages 表) + cursor.execute("PRAGMA table_info(messages)") + msg_columns = [col[1] for col in cursor.fetchall()] + + msg_migrations = [ + ('timestamp', "INTEGER"), + ('feedback', "TEXT"), + ] + + for col_name, col_def in msg_migrations: + if col_name not in msg_columns: + cursor.execute(f"ALTER TABLE messages ADD COLUMN {col_name} {col_def}") + print(f"[数据库] messages 表已添加 {col_name} 列") + + # 创建 user_id 索引(在确保列存在后) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_conversations_user + ON conversations(user_id) + """) + + conn.commit() + + # ── 会话 CRUD ───────────────────────────────────────────────────── + + def create_conversation(self, data: Dict[str, Any]) -> Dict[str, Any]: + """创建新会话""" + conn = self._get_connection() + cursor = conn.cursor() + + now = int(datetime.now(timezone.utc).timestamp() * 1000) + conv_id = data.get("id") or self._generate_id() + + cursor.execute( + """ + INSERT INTO conversations (id, user_id, title, created_at, updated_at, pinned, archived, settings) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + conv_id, + data.get("user_id", "default"), + data.get("title", "新对话"), + data.get("createdAt", now), + now, + 1 if data.get("pinned") else 0, + 1 if data.get("archived") else 0, + json.dumps(data.get("settings")) if data.get("settings") else None, + ), + ) + + # 插入消息(如果有) + messages = data.get("messages", []) + for msg in messages: + self._insert_message(cursor, conv_id, msg) + + conn.commit() + return self.get_conversation(conv_id) + + def get_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]: + """获取单个会话(包含消息)""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + "SELECT * FROM conversations WHERE id = ?", (conversation_id,) + ) + row = cursor.fetchone() + + if not row: + return None + + return self._row_to_conversation(row, cursor) + + def list_conversations(self, user_id: str = "default") -> List[Dict[str, Any]]: + """列出用户的所有会话""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + SELECT * FROM conversations + WHERE user_id = ? + ORDER BY updated_at DESC + """, + (user_id,), + ) + + conversations = [] + for row in cursor.fetchall(): + conv = self._row_to_conversation(row, cursor, include_messages=False) + conversations.append(conv) + + return conversations + + def update_conversation( + self, conversation_id: str, data: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """更新会话""" + conn = self._get_connection() + cursor = conn.cursor() + + # 检查会话是否存在 + cursor.execute( + "SELECT id FROM conversations WHERE id = ?", (conversation_id,) + ) + if not cursor.fetchone(): + return None + + now = int(datetime.now(timezone.utc).timestamp() * 1000) + + # 更新会话字段 + update_fields = ["updated_at = ?"] + update_values = [now] + + if "title" in data: + update_fields.append("title = ?") + update_values.append(data["title"]) + if "pinned" in data: + update_fields.append("pinned = ?") + update_values.append(1 if data["pinned"] else 0) + if "archived" in data: + update_fields.append("archived = ?") + update_values.append(1 if data["archived"] else 0) + if "settings" in data: + update_fields.append("settings = ?") + update_values.append(json.dumps(data["settings"])) + + update_values.append(conversation_id) + + cursor.execute( + f"UPDATE conversations SET {', '.join(update_fields)} WHERE id = ?", + update_values, + ) + + # 更新消息(如果提供了 messages 字段) + if "messages" in data: + # 删除旧消息 + cursor.execute( + "DELETE FROM messages WHERE conversation_id = ?", (conversation_id,) + ) + # 插入新消息 + for msg in data["messages"]: + self._insert_message(cursor, conversation_id, msg) + + conn.commit() + return self.get_conversation(conversation_id) + + def delete_conversation(self, conversation_id: str) -> bool: + """删除会话(级联删除消息)""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + "DELETE FROM conversations WHERE id = ?", (conversation_id,) + ) + + conn.commit() + return cursor.rowcount > 0 + + # ── 消息操作 ─────────────────────────────────────────────────────── + + def add_message(self, conversation_id: str, message: Dict[str, Any]) -> Dict[str, Any]: + """添加消息到会话""" + conn = self._get_connection() + cursor = conn.cursor() + + msg = self._insert_message(cursor, conversation_id, message) + + # 更新会话的 updated_at + now = int(datetime.now(timezone.utc).timestamp() * 1000) + cursor.execute( + "UPDATE conversations SET updated_at = ? WHERE id = ?", + (now, conversation_id), + ) + + conn.commit() + return msg + + def update_message(self, message_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """更新消息""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT id FROM messages WHERE id = ?", (message_id,)) + if not cursor.fetchone(): + return None + + update_fields = [] + update_values = [] + + if "content" in data: + update_fields.append("content = ?") + update_values.append(json.dumps(data["content"])) + if "feedback" in data: + update_fields.append("feedback = ?") + update_values.append(json.dumps(data["feedback"])) + + if not update_fields: + return self._get_message_by_id(message_id, cursor) + + update_values.append(message_id) + cursor.execute( + f"UPDATE messages SET {', '.join(update_fields)} WHERE id = ?", + update_values, + ) + + conn.commit() + return self._get_message_by_id(message_id, cursor) + + # ── 内部方法 ─────────────────────────────────────────────────────── + + def _insert_message( + self, cursor: sqlite3.Cursor, conversation_id: str, message: Dict[str, Any] + ) -> Dict[str, Any]: + """插入消息(内部方法)""" + msg_id = message.get("id") or self._generate_id() + timestamp = message.get("timestamp") or int( + datetime.now(timezone.utc).timestamp() * 1000 + ) + + cursor.execute( + """ + INSERT INTO messages (id, conversation_id, role, content, timestamp, feedback) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + msg_id, + conversation_id, + message.get("role", "user"), + json.dumps(message.get("content", "")), + timestamp, + json.dumps(message.get("feedback")) if message.get("feedback") else None, + ), + ) + + return { + "id": msg_id, + "role": message.get("role", "user"), + "content": message.get("content", ""), + "timestamp": timestamp, + "feedback": message.get("feedback"), + } + + def _get_message_by_id( + self, message_id: str, cursor: sqlite3.Cursor + ) -> Optional[Dict[str, Any]]: + """根据 ID 获取消息""" + cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,)) + row = cursor.fetchone() + return self._row_to_message(row) if row else None + + def _row_to_conversation( + self, row: sqlite3.Row, cursor: sqlite3.Cursor, include_messages: bool = True + ) -> Dict[str, Any]: + """将数据库行转换为会话字典""" + conv = { + "id": row["id"], + "userId": row["user_id"], + "title": row["title"], + "createdAt": row["created_at"], + "updatedAt": row["updated_at"], + "pinned": bool(row["pinned"]), + "archived": bool(row["archived"]), + "settings": json.loads(row["settings"]) if row["settings"] else None, + } + + if include_messages: + cursor.execute( + "SELECT * FROM messages WHERE conversation_id = ? ORDER BY timestamp", + (row["id"],), + ) + conv["messages"] = [ + self._row_to_message(msg_row) for msg_row in cursor.fetchall() + ] + + return conv + + def _row_to_message(self, row: sqlite3.Row) -> Dict[str, Any]: + """将数据库行转换为消息字典""" + return { + "id": row["id"], + "role": row["role"], + "content": json.loads(row["content"]), + "timestamp": row["timestamp"], + "feedback": json.loads(row["feedback"]) if row["feedback"] else None, + } + + def _generate_id(self) -> str: + """生成唯一 ID""" + import uuid + return str(uuid.uuid4()) + + +def init_db(): + """初始化数据库(应用启动时调用)""" + global _db_instance + if _db_instance is None: + _db_instance = Database(DB_PATH) + print(f"[数据库] SQLite 初始化完成: {DB_PATH}") + + +def get_db() -> Database: + """获取数据库实例""" + global _db_instance + if _db_instance is None: + init_db() + return _db_instance \ No newline at end of file diff --git a/server/main.py b/server/main.py index f2513ee..39cf28f 100644 --- a/server/main.py +++ b/server/main.py @@ -54,12 +54,15 @@ init_db() load_dotenv() # ── 会话管理路由处理器 ──────────────────────────────────────────────── -from api.conversation_routes import (delete_conversation_handler, +from api.conversation_routes import (add_message_handler, + delete_conversation_handler, get_conversation_handler, get_conversations_handler, save_conversation_handler, serve_upload_handler, stop_generation_handler, + update_conversation_handler, + update_message_handler, upload_file_handler) # ── OpenAI 兼容网关初始化 ─────────────────────────────────────────────── @@ -170,8 +173,8 @@ async def get_models(): @app.get("/api/chat-ui/conversations") -async def get_conversations(): - return await get_conversations_handler() +async def get_conversations(user_id: str = "default"): + return await get_conversations_handler(user_id) @app.get("/api/chat-ui/conversations/{conversation_id}") @@ -189,6 +192,21 @@ async def delete_conversation(conversation_id: str): return await delete_conversation_handler(conversation_id) +@app.put("/api/chat-ui/conversations/{conversation_id}") +async def update_conversation(conversation_id: str, request: Request): + return await update_conversation_handler(conversation_id, await request.json()) + + +@app.post("/api/chat-ui/conversations/{conversation_id}/messages") +async def add_message(conversation_id: str, request: Request): + return await add_message_handler(conversation_id, await request.json()) + + +@app.put("/api/chat-ui/conversations/{conversation_id}/messages/{message_id}") +async def update_message(conversation_id: str, message_id: str, request: Request): + return await update_message_handler(conversation_id, message_id, await request.json()) + + @app.post("/api/chat-ui/upload") async def upload_file(file: UploadFile = File(...)): return await upload_file_handler(file=file) diff --git a/server/middleware/auth.py b/server/middleware/auth.py new file mode 100644 index 0000000..190bdda --- /dev/null +++ b/server/middleware/auth.py @@ -0,0 +1,45 @@ +""" +认证中间件 - 预留接口 + +当前返回默认用户,未来可集成 JWT、OAuth 等认证系统。 +""" + +from typing import Optional + + +def get_current_user_id(request) -> str: + """ + 从请求中获取当前用户 ID(预留) + + 当前返回默认用户 'default' + 未来可集成 JWT、OAuth 等 + + Args: + request: FastAPI Request 对象 + + Returns: + 用户 ID 字符串 + """ + # TODO: 实现 token 验证逻辑 + # 示例: + # auth_header = request.headers.get("Authorization") + # if auth_header and auth_header.startswith("Bearer "): + # token = auth_header[7:] + # user_id = verify_token(token) + # return user_id + + return "default" + + +def get_current_user(request) -> dict: + """ + 获取当前用户完整信息(预留) + + Returns: + 用户信息字典 + """ + return { + "id": get_current_user_id(request), + "name": None, + "email": None + } \ No newline at end of file diff --git a/server/utils/oss_uploader.py b/server/utils/oss_uploader.py index e47706f..e8e552f 100644 --- a/server/utils/oss_uploader.py +++ b/server/utils/oss_uploader.py @@ -57,11 +57,12 @@ def _get_client() -> oss.Client: return oss.Client(cfg) -def _generate_object_key(filename: str, prefix: str = "uploads") -> str: +def _generate_object_key(filename: str, prefix: str = "chat-ui") -> str: """ 根据文件名生成唯一的 OSS 对象 Key 格式: {prefix}/{日期}/{uuid}_{原始文件名} """ + # TODO: 需要按用户ID分目录 date_str = datetime.now().strftime("%Y%m%d") unique_id = uuid.uuid4().hex[:8] safe_name = Path(filename).name # 只取文件名,去掉路径 @@ -80,7 +81,7 @@ def _build_url(object_key: str) -> str: def upload_file( file_path: str, object_key: Optional[str] = None, - prefix: str = "uploads", + prefix: str = "chat-ui", ) -> dict: """ 上传本地文件到 OSS @@ -204,6 +205,99 @@ def upload_fileobj( ) +def delete_file(object_key: str) -> bool: + """ + 删除 OSS 上的单个文件 + + 参数: + object_key: OSS 对象路径(如 "uploads/20240301/abc123_file.jpg") + + 返回: + True 表示删除成功,False 表示失败 + """ + try: + client = _get_client() + result = client.delete_object( + oss.DeleteObjectRequest( + bucket=OSS_BUCKET_NAME, + key=object_key, + ) + ) + return result.status_code == 204 + except Exception as e: + print(f"[OSS] 删除文件失败: {object_key}, 错误: {e}") + return False + + +def delete_files(object_keys: list) -> dict: + """ + 批量删除 OSS 上的文件 + + 参数: + object_keys: OSS 对象路径列表 + + 返回: + { + "deleted": ["成功删除的 object_key 列表"], + "failed": ["删除失败的 object_key 列表"], + } + """ + deleted = [] + failed = [] + + for key in object_keys: + if delete_file(key): + deleted.append(key) + else: + failed.append(key) + + return {"deleted": deleted, "failed": failed} + + +def extract_object_key_from_url(url: str) -> Optional[str]: + """ + 从 OSS URL 中提取 object_key + + 参数: + url: OSS 文件的完整 URL + + 返回: + object_key 或 None(如果不是有效的 OSS URL) + """ + if not url: + return None + + # 支持两种 URL 格式: + # 1. 自定义域名: OSS_URL_PREFIX/object_key + # 2. 默认域名: https://bucket.endpoint/object_key + + try: + # 移除查询参数 + url_path = url.split("?")[0] + + if OSS_URL_PREFIX: + # 自定义域名格式 + prefix = OSS_URL_PREFIX.rstrip("/") + if url_path.startswith(prefix): + return url_path[len(prefix) + 1:] # +1 去掉开头的 / + + # 默认域名格式: https://bucket.endpoint/object_key + endpoint = OSS_ENDPOINT.replace("https://", "").replace("http://", "") + default_prefix = f"https://{OSS_BUCKET_NAME}.{endpoint}/" + + if url_path.startswith(default_prefix): + return url_path[len(default_prefix):] + + # 也尝试匹配 http 版本 + http_prefix = f"http://{OSS_BUCKET_NAME}.{endpoint}/" + if url_path.startswith(http_prefix): + return url_path[len(http_prefix):] + + return None + except Exception: + return None + + # ──────────────────────────────────────────────────────────────── # 命令行入口:python -m utils.oss_uploader --file <路径> # ──────────────────────────────────────────────────────────────── diff --git a/src/App.vue b/src/App.vue index c33840c..f0605b3 100644 --- a/src/App.vue +++ b/src/App.vue @@ -44,7 +44,8 @@ import ShortcutsModal from "@/components/modals/ShortcutsModal.vue"; import SettingsModal from "@/components/modals/SettingsModal.vue"; import ConversationSettingsModal from "@/components/modals/ConversationSettingsModal.vue"; import { Check, AlertCircle, Info } from "@/components/icons"; - +import { useAuthStore } from "./stores/auth"; +const authStore = useAuthStore(); // Stores const chatStore = useChatStore(); const settingsStore = useSettingsStore(); @@ -126,10 +127,13 @@ useKeyboard( // 初始化 onMounted(() => { - // 如果没有对话,创建一个 - if (chatStore.conversations.length === 0) { - chatStore.createConversation(); - } + authStore.init(); + console.log(authStore.token); + + // // 如果没有对话,创建一个 + // if (chatStore.conversations.length === 0) { + // chatStore.createConversation(); + // } }); // 暴露给全局使用 diff --git a/src/components/chat/ChatMain.vue b/src/components/chat/ChatMain.vue index 8d6f5f7..a48f475 100644 --- a/src/components/chat/ChatMain.vue +++ b/src/components/chat/ChatMain.vue @@ -52,6 +52,7 @@ import { ref, computed, watch, nextTick, onMounted } from "vue"; import { storeToRefs } from "pinia"; import { useChatStore } from "@/stores/chat"; import { useSettingsStore } from "@/stores/settings"; +import { useAuthStore } from "@/stores/auth"; import ChatHeader from "./ChatHeader.vue"; import MessageList from "./MessageList.vue"; import ChatInput from "@/components/input/ChatInput.vue"; @@ -65,6 +66,7 @@ defineEmits<{ const chatStore = useChatStore(); const settingsStore = useSettingsStore(); +const authStore = useAuthStore(); const { currentConversation, isStreaming } = storeToRefs(chatStore); const { settings, sidebarCollapsed } = storeToRefs(settingsStore); @@ -164,6 +166,12 @@ async function handleSend( systemPrompt?: string; }, ) { + // 检查认证状态 + if (!authStore.isAuthenticated) { + window.$toast?.('请先登录', 'error'); + return; + } + console.log("handleSend", text, attachments, options); // 检查是否还有正在上传的附件 const uploadingAttachments = attachments.filter((a) => a.uploading); @@ -196,7 +204,7 @@ async function handleSend( // 如果没有当前对话,创建新对话 if (!currentConversation.value) { - chatStore.createConversation(); + await chatStore.createConversation(); } // 从当前会话中提取历史消息(用于上下文记忆),在添加新消息之前提取 @@ -212,7 +220,7 @@ async function handleSend( .map((m: any) => ({ role: m.role, content: m.content.text })); // 添加用户消息 - chatStore.addMessage(MessageRole.USER, { + await chatStore.addMessage(MessageRole.USER, { type: MessageType.TEXT, text, images: attachments.filter((a) => a.type === "image"), @@ -220,7 +228,7 @@ async function handleSend( }); // 添加 AI 消息占位符 - const aiMessage = chatStore.addMessage(MessageRole.ASSISTANT, { + const aiMessage = await chatStore.addMessage(MessageRole.ASSISTANT, { type: MessageType.TEXT, text: "", }); @@ -337,6 +345,12 @@ function handleStop() { // 重试 async function handleRetry(messageId: string) { + // 检查认证状态 + if (!authStore.isAuthenticated) { + window.$toast?.('请先登录', 'error'); + return; + } + const message = messages.value.find((m: any) => m.id === messageId); if (!message || message.role !== MessageRole.ASSISTANT) return; diff --git a/src/components/input/ChatInput.vue b/src/components/input/ChatInput.vue index 9418a52..cac180f 100644 --- a/src/components/input/ChatInput.vue +++ b/src/components/input/ChatInput.vue @@ -1,734 +1,740 @@ -