diff --git a/.gitignore b/.gitignore index ba00c0d..5b9e276 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,13 @@ uploads .venv __pycache__ .claude +<<<<<<< HEAD *.db +======= +.trae +.agent +.agents +>>>>>>> feat/database # Editor directories and files .vscode/* @@ -28,6 +34,7 @@ __pycache__ *.njsproj *.sln *.sw? +<<<<<<< HEAD # Skills .skills @@ -35,3 +42,9 @@ __pycache__ .agents .trae skills-lock.json +======= +*.db +server/data/*.db + +skills-lock.json +>>>>>>> feat/database diff --git a/server/adapters/base.py b/server/adapters/base.py index a7ea570..13f2b6b 100644 --- a/server/adapters/base.py +++ b/server/adapters/base.py @@ -30,7 +30,7 @@ class ModelInfo: "maxTokens": self.max_tokens, "provider": self.provider, "supports_thinking": self.supports_thinking, - "supports_web_Search": self.supports_web_search, + "supports_web_search": self.supports_web_search, "supports_vision": self.supports_vision, "supports_files": self.supports_files, } diff --git a/server/adapters/glm_adapter.py b/server/adapters/glm_adapter.py index 59cc659..d39b8b9 100644 --- a/server/adapters/glm_adapter.py +++ b/server/adapters/glm_adapter.py @@ -138,6 +138,11 @@ class GLMAdapter(BaseAdapter): logger.info( f"[GLM] 深度思考已启用: extra_kwargs['thinking'] = {extra_kwargs['thinking']}" ) + else: + extra_kwargs["thinking"] = {"type": "disabled"} + logger.info( + f"[GLM] 深度思考已禁用: extra_kwargs['thinking'] = {extra_kwargs['thinking']}" + ) if extra_kwargs: logger.info( diff --git a/server/api/conversation_routes.py b/server/api/conversation_routes.py index 60ae8ef..5319831 100644 --- a/server/api/conversation_routes.py +++ b/server/api/conversation_routes.py @@ -28,10 +28,10 @@ upload_dir.mkdir(exist_ok=True) # ── 会话管理 ───────────────────────────────────────────────────── -async def get_conversations_handler(): +async def get_conversations_handler(user_id: str = "default"): """获取所有对话处理器""" db = get_db() - return db.list_conversations() + return db.list_conversations(user_id) async def get_conversation_handler(conversation_id: str): @@ -65,8 +65,38 @@ async def save_conversation_handler(data: dict): async def delete_conversation_handler(conversation_id: str): - """删除对话处理器""" + """删除对话处理器(同时删除关联的 OSS 文件)""" db = get_db() + + # 先获取会话数据,提取 OSS 文件 URL + conversation = db.get_conversation(conversation_id) + if not conversation: + raise HTTPException(status_code=404, detail="对话不存在") + + # 提取所有 OSS 文件 URL + oss_urls = _extract_oss_urls_from_conversation(conversation) + + # 删除 OSS 文件 + if oss_urls: + try: + from utils.oss_uploader import delete_files, extract_object_key_from_url + + object_keys = [] + for url in oss_urls: + key = extract_object_key_from_url(url) + if key: + object_keys.append(key) + + if object_keys: + result = delete_files(object_keys) + log_info(f"[删除会话] OSS 文件清理结果: 删除 {len(result['deleted'])} 个, 失败 {len(result['failed'])} 个") + if result['failed']: + log_error(f"[删除会话] OSS 文件删除失败: {result['failed']}") + except Exception as e: + log_error(f"[删除会话] OSS 文件删除异常: {e}") + # 继续删除会话,即使 OSS 删除失败 + + # 删除数据库记录 success = db.delete_conversation(conversation_id) if success: return {"success": True, "message": "删除成功"} @@ -74,6 +104,85 @@ async def delete_conversation_handler(conversation_id: str): raise HTTPException(status_code=404, detail="对话不存在") +def _extract_oss_urls_from_conversation(conversation: dict) -> list: + """ + 从会话消息中提取所有 OSS 文件 URL + + 消息结构: + - content.images: 图片附件列表 + - content.files: 文件附件列表 + 每个附件包含 url 字段 + """ + urls = [] + messages = conversation.get("messages", []) + + for message in messages: + content = message.get("content") + if not content: + continue + + # content 可能是字符串(需要解析)或已解析的字典 + if isinstance(content, str): + try: + content = json.loads(content) + except json.JSONDecodeError: + continue + + # 提取图片附件 + images = content.get("images", []) + for img in images: + url = img.get("url") + if url and url not in urls: + urls.append(url) + + # 提取文件附件 + files = content.get("files", []) + for f in files: + url = f.get("url") + if url and url not in urls: + urls.append(url) + + return urls + + +async def update_conversation_handler(conversation_id: str, data: dict): + """部分更新对话处理器""" + db = get_db() + result = db.update_conversation(conversation_id, data) + if result: + return result + else: + raise HTTPException(status_code=404, detail="对话不存在") + + +# ── 消息管理 ───────────────────────────────────────────────────── + + +async def add_message_handler(conversation_id: str, message: dict): + """添加消息到对话处理器""" + db = get_db() + # 检查对话是否存在 + existing = db.get_conversation(conversation_id) + if not existing: + raise HTTPException(status_code=404, detail="对话不存在") + return db.add_message(conversation_id, message) + + +async def update_message_handler(conversation_id: str, message_id: str, data: dict): + """更新消息处理器""" + db = get_db() + # 检查对话是否存在 + existing = db.get_conversation(conversation_id) + if not existing: + raise HTTPException(status_code=404, detail="对话不存在") + + result = db.update_message(message_id, data) + if result: + return result + else: + raise HTTPException(status_code=404, detail="消息不存在") + + # ── 文件上传 ───────────────────────────────────────────────────── diff --git a/server/database/__init__.py b/server/database/__init__.py index db014e9..9aa8dbe 100644 --- a/server/database/__init__.py +++ b/server/database/__init__.py @@ -1,91 +1,3 @@ -""" -数据库模块 +from .db import Database, get_db, init_db -提供 SQLite 数据库连接和会话管理功能。 -""" - -import os -import sqlite3 -from pathlib import Path -from contextlib import contextmanager -from typing import Optional - -# 默认数据库路径 -DEFAULT_DB_PATH = Path(__file__).parent.parent / "data" / "chat.db" - - -def init_db(db_path: Optional[str] = None): - """ - 初始化数据库 - 创建必要的表结构 - """ - if db_path is None: - db_path = os.getenv("DB_PATH", str(DEFAULT_DB_PATH)) - - # 确保数据目录存在 - Path(db_path).parent.mkdir(parents=True, exist_ok=True) - - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - # 创建会话表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS conversations ( - id TEXT PRIMARY KEY, - title TEXT, - model TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - - # 创建消息表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY, - conversation_id TEXT, - role TEXT, - content TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE - ) - """) - - # 创建文件表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS files ( - id TEXT PRIMARY KEY, - conversation_id TEXT, - filename TEXT, - file_path TEXT, - file_type TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE - ) - """) - - conn.commit() - conn.close() - - print(f"[数据库] 初始化完成: {db_path}") - - -@contextmanager -def get_db(db_path: Optional[str] = None): - """ - 获取数据库连接的上下文管理器 - - 用法: - with get_db() as db: - cursor = db.execute("SELECT * FROM conversations") - rows = cursor.fetchall() - """ - if db_path is None: - db_path = os.getenv("DB_PATH", str(DEFAULT_DB_PATH)) - - conn = sqlite3.connect(db_path) - conn.row_factory = sqlite3.Row - try: - yield conn - finally: - conn.close() +__all__ = ["Database", "get_db", "init_db"] \ No newline at end of file diff --git a/server/database/db.py b/server/database/db.py new file mode 100644 index 0000000..7915ae1 --- /dev/null +++ b/server/database/db.py @@ -0,0 +1,401 @@ +""" +SQLite 数据库模块 - 会话持久化存储 + +提供会话和消息的 CRUD 操作,支持多用户(预留 user_id 字段)。 +""" + +import json +import sqlite3 +import threading +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional + +# 数据库文件路径 +DB_PATH = Path(__file__).parent.parent / "data" / "chat.db" + +# 线程本地存储,确保每个线程使用独立的连接 +_thread_local = threading.local() + +# 全局数据库实例 +_db_instance: Optional["Database"] = None + + +class Database: + """SQLite 数据库管理类""" + + def __init__(self, db_path: Path): + self.db_path = db_path + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_tables() + + def _get_connection(self) -> sqlite3.Connection: + """获取当前线程的数据库连接""" + if not hasattr(_thread_local, "connection"): + conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + conn.row_factory = sqlite3.Row + # 启用外键约束 + conn.execute("PRAGMA foreign_keys = ON") + _thread_local.connection = conn + return _thread_local.connection + + def _init_tables(self): + """初始化数据库表结构""" + conn = self._get_connection() + cursor = conn.cursor() + + # 创建会话表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + user_id TEXT DEFAULT 'default', + title TEXT DEFAULT '新对话', + created_at INTEGER, + updated_at INTEGER, + pinned INTEGER DEFAULT 0, + archived INTEGER DEFAULT 0, + settings TEXT + ) + """) + + # 创建消息表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp INTEGER, + feedback TEXT, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE + ) + """) + + # 创建索引 + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_conversation + ON messages(conversation_id) + """) + + # 检查并添加缺失的列(迁移旧数据库 - conversations 表) + cursor.execute("PRAGMA table_info(conversations)") + conv_columns = [col[1] for col in cursor.fetchall()] + + conv_migrations = [ + ('user_id', "TEXT DEFAULT 'default'"), + ('pinned', "INTEGER DEFAULT 0"), + ('archived', "INTEGER DEFAULT 0"), + ('settings', "TEXT"), + ] + + for col_name, col_def in conv_migrations: + if col_name not in conv_columns: + cursor.execute(f"ALTER TABLE conversations ADD COLUMN {col_name} {col_def}") + print(f"[数据库] conversations 表已添加 {col_name} 列") + + # 检查并添加缺失的列(迁移旧数据库 - messages 表) + cursor.execute("PRAGMA table_info(messages)") + msg_columns = [col[1] for col in cursor.fetchall()] + + msg_migrations = [ + ('timestamp', "INTEGER"), + ('feedback', "TEXT"), + ] + + for col_name, col_def in msg_migrations: + if col_name not in msg_columns: + cursor.execute(f"ALTER TABLE messages ADD COLUMN {col_name} {col_def}") + print(f"[数据库] messages 表已添加 {col_name} 列") + + # 创建 user_id 索引(在确保列存在后) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_conversations_user + ON conversations(user_id) + """) + + conn.commit() + + # ── 会话 CRUD ───────────────────────────────────────────────────── + + def create_conversation(self, data: Dict[str, Any]) -> Dict[str, Any]: + """创建新会话""" + conn = self._get_connection() + cursor = conn.cursor() + + now = int(datetime.now(timezone.utc).timestamp() * 1000) + conv_id = data.get("id") or self._generate_id() + + cursor.execute( + """ + INSERT INTO conversations (id, user_id, title, created_at, updated_at, pinned, archived, settings) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + conv_id, + data.get("user_id", "default"), + data.get("title", "新对话"), + data.get("createdAt", now), + now, + 1 if data.get("pinned") else 0, + 1 if data.get("archived") else 0, + json.dumps(data.get("settings")) if data.get("settings") else None, + ), + ) + + # 插入消息(如果有) + messages = data.get("messages", []) + for msg in messages: + self._insert_message(cursor, conv_id, msg) + + conn.commit() + return self.get_conversation(conv_id) + + def get_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]: + """获取单个会话(包含消息)""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + "SELECT * FROM conversations WHERE id = ?", (conversation_id,) + ) + row = cursor.fetchone() + + if not row: + return None + + return self._row_to_conversation(row, cursor) + + def list_conversations(self, user_id: str = "default") -> List[Dict[str, Any]]: + """列出用户的所有会话""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + SELECT * FROM conversations + WHERE user_id = ? + ORDER BY updated_at DESC + """, + (user_id,), + ) + + conversations = [] + for row in cursor.fetchall(): + conv = self._row_to_conversation(row, cursor, include_messages=False) + conversations.append(conv) + + return conversations + + def update_conversation( + self, conversation_id: str, data: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """更新会话""" + conn = self._get_connection() + cursor = conn.cursor() + + # 检查会话是否存在 + cursor.execute( + "SELECT id FROM conversations WHERE id = ?", (conversation_id,) + ) + if not cursor.fetchone(): + return None + + now = int(datetime.now(timezone.utc).timestamp() * 1000) + + # 更新会话字段 + update_fields = ["updated_at = ?"] + update_values = [now] + + if "title" in data: + update_fields.append("title = ?") + update_values.append(data["title"]) + if "pinned" in data: + update_fields.append("pinned = ?") + update_values.append(1 if data["pinned"] else 0) + if "archived" in data: + update_fields.append("archived = ?") + update_values.append(1 if data["archived"] else 0) + if "settings" in data: + update_fields.append("settings = ?") + update_values.append(json.dumps(data["settings"])) + + update_values.append(conversation_id) + + cursor.execute( + f"UPDATE conversations SET {', '.join(update_fields)} WHERE id = ?", + update_values, + ) + + # 更新消息(如果提供了 messages 字段) + if "messages" in data: + # 删除旧消息 + cursor.execute( + "DELETE FROM messages WHERE conversation_id = ?", (conversation_id,) + ) + # 插入新消息 + for msg in data["messages"]: + self._insert_message(cursor, conversation_id, msg) + + conn.commit() + return self.get_conversation(conversation_id) + + def delete_conversation(self, conversation_id: str) -> bool: + """删除会话(级联删除消息)""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + "DELETE FROM conversations WHERE id = ?", (conversation_id,) + ) + + conn.commit() + return cursor.rowcount > 0 + + # ── 消息操作 ─────────────────────────────────────────────────────── + + def add_message(self, conversation_id: str, message: Dict[str, Any]) -> Dict[str, Any]: + """添加消息到会话""" + conn = self._get_connection() + cursor = conn.cursor() + + msg = self._insert_message(cursor, conversation_id, message) + + # 更新会话的 updated_at + now = int(datetime.now(timezone.utc).timestamp() * 1000) + cursor.execute( + "UPDATE conversations SET updated_at = ? WHERE id = ?", + (now, conversation_id), + ) + + conn.commit() + return msg + + def update_message(self, message_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """更新消息""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT id FROM messages WHERE id = ?", (message_id,)) + if not cursor.fetchone(): + return None + + update_fields = [] + update_values = [] + + if "content" in data: + update_fields.append("content = ?") + update_values.append(json.dumps(data["content"])) + if "feedback" in data: + update_fields.append("feedback = ?") + update_values.append(json.dumps(data["feedback"])) + + if not update_fields: + return self._get_message_by_id(message_id, cursor) + + update_values.append(message_id) + cursor.execute( + f"UPDATE messages SET {', '.join(update_fields)} WHERE id = ?", + update_values, + ) + + conn.commit() + return self._get_message_by_id(message_id, cursor) + + # ── 内部方法 ─────────────────────────────────────────────────────── + + def _insert_message( + self, cursor: sqlite3.Cursor, conversation_id: str, message: Dict[str, Any] + ) -> Dict[str, Any]: + """插入消息(内部方法)""" + msg_id = message.get("id") or self._generate_id() + timestamp = message.get("timestamp") or int( + datetime.now(timezone.utc).timestamp() * 1000 + ) + + cursor.execute( + """ + INSERT INTO messages (id, conversation_id, role, content, timestamp, feedback) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + msg_id, + conversation_id, + message.get("role", "user"), + json.dumps(message.get("content", "")), + timestamp, + json.dumps(message.get("feedback")) if message.get("feedback") else None, + ), + ) + + return { + "id": msg_id, + "role": message.get("role", "user"), + "content": message.get("content", ""), + "timestamp": timestamp, + "feedback": message.get("feedback"), + } + + def _get_message_by_id( + self, message_id: str, cursor: sqlite3.Cursor + ) -> Optional[Dict[str, Any]]: + """根据 ID 获取消息""" + cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,)) + row = cursor.fetchone() + return self._row_to_message(row) if row else None + + def _row_to_conversation( + self, row: sqlite3.Row, cursor: sqlite3.Cursor, include_messages: bool = True + ) -> Dict[str, Any]: + """将数据库行转换为会话字典""" + conv = { + "id": row["id"], + "userId": row["user_id"], + "title": row["title"], + "createdAt": row["created_at"], + "updatedAt": row["updated_at"], + "pinned": bool(row["pinned"]), + "archived": bool(row["archived"]), + "settings": json.loads(row["settings"]) if row["settings"] else None, + } + + if include_messages: + cursor.execute( + "SELECT * FROM messages WHERE conversation_id = ? ORDER BY timestamp", + (row["id"],), + ) + conv["messages"] = [ + self._row_to_message(msg_row) for msg_row in cursor.fetchall() + ] + + return conv + + def _row_to_message(self, row: sqlite3.Row) -> Dict[str, Any]: + """将数据库行转换为消息字典""" + return { + "id": row["id"], + "role": row["role"], + "content": json.loads(row["content"]), + "timestamp": row["timestamp"], + "feedback": json.loads(row["feedback"]) if row["feedback"] else None, + } + + def _generate_id(self) -> str: + """生成唯一 ID""" + import uuid + return str(uuid.uuid4()) + + +def init_db(): + """初始化数据库(应用启动时调用)""" + global _db_instance + if _db_instance is None: + _db_instance = Database(DB_PATH) + print(f"[数据库] SQLite 初始化完成: {DB_PATH}") + + +def get_db() -> Database: + """获取数据库实例""" + global _db_instance + if _db_instance is None: + init_db() + return _db_instance \ No newline at end of file diff --git a/server/main.py b/server/main.py index f2513ee..39cf28f 100644 --- a/server/main.py +++ b/server/main.py @@ -54,12 +54,15 @@ init_db() load_dotenv() # ── 会话管理路由处理器 ──────────────────────────────────────────────── -from api.conversation_routes import (delete_conversation_handler, +from api.conversation_routes import (add_message_handler, + delete_conversation_handler, get_conversation_handler, get_conversations_handler, save_conversation_handler, serve_upload_handler, stop_generation_handler, + update_conversation_handler, + update_message_handler, upload_file_handler) # ── OpenAI 兼容网关初始化 ─────────────────────────────────────────────── @@ -170,8 +173,8 @@ async def get_models(): @app.get("/api/chat-ui/conversations") -async def get_conversations(): - return await get_conversations_handler() +async def get_conversations(user_id: str = "default"): + return await get_conversations_handler(user_id) @app.get("/api/chat-ui/conversations/{conversation_id}") @@ -189,6 +192,21 @@ async def delete_conversation(conversation_id: str): return await delete_conversation_handler(conversation_id) +@app.put("/api/chat-ui/conversations/{conversation_id}") +async def update_conversation(conversation_id: str, request: Request): + return await update_conversation_handler(conversation_id, await request.json()) + + +@app.post("/api/chat-ui/conversations/{conversation_id}/messages") +async def add_message(conversation_id: str, request: Request): + return await add_message_handler(conversation_id, await request.json()) + + +@app.put("/api/chat-ui/conversations/{conversation_id}/messages/{message_id}") +async def update_message(conversation_id: str, message_id: str, request: Request): + return await update_message_handler(conversation_id, message_id, await request.json()) + + @app.post("/api/chat-ui/upload") async def upload_file(file: UploadFile = File(...)): return await upload_file_handler(file=file) diff --git a/server/middleware/auth.py b/server/middleware/auth.py new file mode 100644 index 0000000..190bdda --- /dev/null +++ b/server/middleware/auth.py @@ -0,0 +1,45 @@ +""" +认证中间件 - 预留接口 + +当前返回默认用户,未来可集成 JWT、OAuth 等认证系统。 +""" + +from typing import Optional + + +def get_current_user_id(request) -> str: + """ + 从请求中获取当前用户 ID(预留) + + 当前返回默认用户 'default' + 未来可集成 JWT、OAuth 等 + + Args: + request: FastAPI Request 对象 + + Returns: + 用户 ID 字符串 + """ + # TODO: 实现 token 验证逻辑 + # 示例: + # auth_header = request.headers.get("Authorization") + # if auth_header and auth_header.startswith("Bearer "): + # token = auth_header[7:] + # user_id = verify_token(token) + # return user_id + + return "default" + + +def get_current_user(request) -> dict: + """ + 获取当前用户完整信息(预留) + + Returns: + 用户信息字典 + """ + return { + "id": get_current_user_id(request), + "name": None, + "email": None + } \ No newline at end of file diff --git a/server/utils/oss_uploader.py b/server/utils/oss_uploader.py index e47706f..e8e552f 100644 --- a/server/utils/oss_uploader.py +++ b/server/utils/oss_uploader.py @@ -57,11 +57,12 @@ def _get_client() -> oss.Client: return oss.Client(cfg) -def _generate_object_key(filename: str, prefix: str = "uploads") -> str: +def _generate_object_key(filename: str, prefix: str = "chat-ui") -> str: """ 根据文件名生成唯一的 OSS 对象 Key 格式: {prefix}/{日期}/{uuid}_{原始文件名} """ + # TODO: 需要按用户ID分目录 date_str = datetime.now().strftime("%Y%m%d") unique_id = uuid.uuid4().hex[:8] safe_name = Path(filename).name # 只取文件名,去掉路径 @@ -80,7 +81,7 @@ def _build_url(object_key: str) -> str: def upload_file( file_path: str, object_key: Optional[str] = None, - prefix: str = "uploads", + prefix: str = "chat-ui", ) -> dict: """ 上传本地文件到 OSS @@ -204,6 +205,99 @@ def upload_fileobj( ) +def delete_file(object_key: str) -> bool: + """ + 删除 OSS 上的单个文件 + + 参数: + object_key: OSS 对象路径(如 "uploads/20240301/abc123_file.jpg") + + 返回: + True 表示删除成功,False 表示失败 + """ + try: + client = _get_client() + result = client.delete_object( + oss.DeleteObjectRequest( + bucket=OSS_BUCKET_NAME, + key=object_key, + ) + ) + return result.status_code == 204 + except Exception as e: + print(f"[OSS] 删除文件失败: {object_key}, 错误: {e}") + return False + + +def delete_files(object_keys: list) -> dict: + """ + 批量删除 OSS 上的文件 + + 参数: + object_keys: OSS 对象路径列表 + + 返回: + { + "deleted": ["成功删除的 object_key 列表"], + "failed": ["删除失败的 object_key 列表"], + } + """ + deleted = [] + failed = [] + + for key in object_keys: + if delete_file(key): + deleted.append(key) + else: + failed.append(key) + + return {"deleted": deleted, "failed": failed} + + +def extract_object_key_from_url(url: str) -> Optional[str]: + """ + 从 OSS URL 中提取 object_key + + 参数: + url: OSS 文件的完整 URL + + 返回: + object_key 或 None(如果不是有效的 OSS URL) + """ + if not url: + return None + + # 支持两种 URL 格式: + # 1. 自定义域名: OSS_URL_PREFIX/object_key + # 2. 默认域名: https://bucket.endpoint/object_key + + try: + # 移除查询参数 + url_path = url.split("?")[0] + + if OSS_URL_PREFIX: + # 自定义域名格式 + prefix = OSS_URL_PREFIX.rstrip("/") + if url_path.startswith(prefix): + return url_path[len(prefix) + 1:] # +1 去掉开头的 / + + # 默认域名格式: https://bucket.endpoint/object_key + endpoint = OSS_ENDPOINT.replace("https://", "").replace("http://", "") + default_prefix = f"https://{OSS_BUCKET_NAME}.{endpoint}/" + + if url_path.startswith(default_prefix): + return url_path[len(default_prefix):] + + # 也尝试匹配 http 版本 + http_prefix = f"http://{OSS_BUCKET_NAME}.{endpoint}/" + if url_path.startswith(http_prefix): + return url_path[len(http_prefix):] + + return None + except Exception: + return None + + # ──────────────────────────────────────────────────────────────── # 命令行入口:python -m utils.oss_uploader --file <路径> # ──────────────────────────────────────────────────────────────── diff --git a/src/App.vue b/src/App.vue index c33840c..f0605b3 100644 --- a/src/App.vue +++ b/src/App.vue @@ -44,7 +44,8 @@ import ShortcutsModal from "@/components/modals/ShortcutsModal.vue"; import SettingsModal from "@/components/modals/SettingsModal.vue"; import ConversationSettingsModal from "@/components/modals/ConversationSettingsModal.vue"; import { Check, AlertCircle, Info } from "@/components/icons"; - +import { useAuthStore } from "./stores/auth"; +const authStore = useAuthStore(); // Stores const chatStore = useChatStore(); const settingsStore = useSettingsStore(); @@ -126,10 +127,13 @@ useKeyboard( // 初始化 onMounted(() => { - // 如果没有对话,创建一个 - if (chatStore.conversations.length === 0) { - chatStore.createConversation(); - } + authStore.init(); + console.log(authStore.token); + + // // 如果没有对话,创建一个 + // if (chatStore.conversations.length === 0) { + // chatStore.createConversation(); + // } }); // 暴露给全局使用 diff --git a/src/components/chat/ChatMain.vue b/src/components/chat/ChatMain.vue index 8d6f5f7..a48f475 100644 --- a/src/components/chat/ChatMain.vue +++ b/src/components/chat/ChatMain.vue @@ -52,6 +52,7 @@ import { ref, computed, watch, nextTick, onMounted } from "vue"; import { storeToRefs } from "pinia"; import { useChatStore } from "@/stores/chat"; import { useSettingsStore } from "@/stores/settings"; +import { useAuthStore } from "@/stores/auth"; import ChatHeader from "./ChatHeader.vue"; import MessageList from "./MessageList.vue"; import ChatInput from "@/components/input/ChatInput.vue"; @@ -65,6 +66,7 @@ defineEmits<{ const chatStore = useChatStore(); const settingsStore = useSettingsStore(); +const authStore = useAuthStore(); const { currentConversation, isStreaming } = storeToRefs(chatStore); const { settings, sidebarCollapsed } = storeToRefs(settingsStore); @@ -164,6 +166,12 @@ async function handleSend( systemPrompt?: string; }, ) { + // 检查认证状态 + if (!authStore.isAuthenticated) { + window.$toast?.('请先登录', 'error'); + return; + } + console.log("handleSend", text, attachments, options); // 检查是否还有正在上传的附件 const uploadingAttachments = attachments.filter((a) => a.uploading); @@ -196,7 +204,7 @@ async function handleSend( // 如果没有当前对话,创建新对话 if (!currentConversation.value) { - chatStore.createConversation(); + await chatStore.createConversation(); } // 从当前会话中提取历史消息(用于上下文记忆),在添加新消息之前提取 @@ -212,7 +220,7 @@ async function handleSend( .map((m: any) => ({ role: m.role, content: m.content.text })); // 添加用户消息 - chatStore.addMessage(MessageRole.USER, { + await chatStore.addMessage(MessageRole.USER, { type: MessageType.TEXT, text, images: attachments.filter((a) => a.type === "image"), @@ -220,7 +228,7 @@ async function handleSend( }); // 添加 AI 消息占位符 - const aiMessage = chatStore.addMessage(MessageRole.ASSISTANT, { + const aiMessage = await chatStore.addMessage(MessageRole.ASSISTANT, { type: MessageType.TEXT, text: "", }); @@ -337,6 +345,12 @@ function handleStop() { // 重试 async function handleRetry(messageId: string) { + // 检查认证状态 + if (!authStore.isAuthenticated) { + window.$toast?.('请先登录', 'error'); + return; + } + const message = messages.value.find((m: any) => m.id === messageId); if (!message || message.role !== MessageRole.ASSISTANT) return; diff --git a/src/components/input/ChatInput.vue b/src/components/input/ChatInput.vue index 9418a52..cac180f 100644 --- a/src/components/input/ChatInput.vue +++ b/src/components/input/ChatInput.vue @@ -1,734 +1,740 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 深度思考 - - - - - - 深度搜索 - - - - - - 联网搜索 - - - - - - - - - - - - {{ charCount }} / {{ maxChars }} - - - {{ sendOnEnter ? "Enter 发送, Shift+Enter 换行" : "Ctrl+Enter 发送" }} - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 深度思考 + + + + + + 深度搜索 + + + + + + 联网搜索 + + + + + + + + + + + + {{ charCount }} / {{ maxChars }} + + + {{ sendOnEnter ? "Enter 发送, Shift+Enter 换行" : "Ctrl+Enter 发送" }} + + + + + + + + + diff --git a/src/components/modals/ConversationSettingsModal.vue b/src/components/modals/ConversationSettingsModal.vue index 5d7bd31..4b17db0 100644 --- a/src/components/modals/ConversationSettingsModal.vue +++ b/src/components/modals/ConversationSettingsModal.vue @@ -215,20 +215,20 @@ const modelSelect = ref(localStorage.getItem("modelSelect") || ""); const currentModelId = ref(settingsStore.getSelectedModelId()); onMounted(() => { - chatApi.getModels().then((res: any) => { - availableModels.value = res; - // 初始化模型显示名称 - const model = availableModels.value?.find( - (m: any) => m.id === currentModelId.value, - ); - if (model) { - modelSelect.value = model.name; - } else if (availableModels.value.length > 0) { - modelSelect.value = availableModels.value[0].name; - currentModelId.value = availableModels.value[0].id; - } - localStorage.setItem("modelSelect", modelSelect.value); - }); + // chatApi.getModels().then((res: any) => { + // availableModels.value = res; + // // 初始化模型显示名称 + // const model = availableModels.value?.find( + // (m: any) => m.id === currentModelId.value, + // ); + // if (model) { + // modelSelect.value = model.name; + // } else if (availableModels.value.length > 0) { + // modelSelect.value = availableModels.value[0].name; + // currentModelId.value = availableModels.value[0].id; + // } + // localStorage.setItem("modelSelect", modelSelect.value); + // }); }); // 本地设置副本 diff --git a/src/components/modals/SettingsModal.vue b/src/components/modals/SettingsModal.vue index 07e3868..a44c063 100644 --- a/src/components/modals/SettingsModal.vue +++ b/src/components/modals/SettingsModal.vue @@ -409,10 +409,10 @@ const availableModels: any = ref([]); const defaultModel: any = ref(localStorage.getItem("defaultModel")); onMounted(() => { - chatApi.getModels().then((res: any) => { - availableModels.value = res; - if (!defaultModel.value) defaultModel.value = res[0].name; - }); + // chatApi.getModels().then((res: any) => { + // availableModels.value = res; + // if (!defaultModel.value) defaultModel.value = res[0].name; + // }); }); const activeTab = ref("appearance"); diff --git a/src/services/api.ts b/src/services/api.ts index 2742c0a..09827c2 100644 --- a/src/services/api.ts +++ b/src/services/api.ts @@ -2,6 +2,8 @@ * Chat UI API 服务 * 所有端点都是固定的,后端需要实现这些端点 */ +import { getAuthHeaders } from './request'; + // API 端点定义(固定) const API_ENDPOINTS = { // 发送消息(流式) @@ -153,7 +155,7 @@ class ChatApi { { method: "POST", headers: { - "Content-Type": "application/json", + ...getAuthHeaders(), Accept: "text/event-stream", }, body: JSON.stringify(openAiRequest), @@ -244,9 +246,7 @@ class ChatApi { const response = await fetch(`${this.baseUrl}${API_ENDPOINTS.CHAT}`, { method: "POST", - headers: { - "Content-Type": "application/json", - }, + headers: getAuthHeaders(), body: JSON.stringify(requestBody), }); @@ -264,9 +264,7 @@ class ChatApi { async stopChat(messageId?: string) { await fetch(`${this.baseUrl}${API_ENDPOINTS.STOP}/${messageId}`, { method: "POST", - headers: { - "Content-Type": "application/json", - }, + headers: getAuthHeaders(), }); } @@ -326,8 +324,13 @@ class ChatApi { const formData = new FormData(); formData.append("file", file); + // 获取认证 headers,但不包含 Content-Type(让浏览器为 FormData 自动设置) + const authHeaders = getAuthHeaders(); + const { 'Content-Type': _, ...headersWithoutContentType } = authHeaders; + const response = await fetch(`${this.baseUrl}${API_ENDPOINTS.UPLOAD}`, { method: "POST", + headers: headersWithoutContentType, body: formData, }); diff --git a/src/services/authService.ts b/src/services/authService.ts new file mode 100644 index 0000000..dcc6518 --- /dev/null +++ b/src/services/authService.ts @@ -0,0 +1,64 @@ +/** + * 认证服务模块 - 预留接口 + * + * 当前返回默认用户,未来可集成 JWT、OAuth 等认证系统 + */ + +export interface AuthUser { + id: string; + name?: string; + email?: string; +} + +// Token 存储 key +const AUTH_TOKEN_KEY = 'auth_token'; + +export const authService = { + /** + * 获取当前用户(预留,目前返回默认用户) + */ + getCurrentUser(): AuthUser | null { + // TODO: 从 token 解析用户信息 + return { id: 'default' }; + }, + + /** + * 获取认证 token(预留) + */ + getToken(): string | null { + return localStorage.getItem(AUTH_TOKEN_KEY); + }, + + /** + * 设置 token(预留) + */ + setToken(token: string): void { + localStorage.setItem(AUTH_TOKEN_KEY, token); + }, + + /** + * 清除认证信息(预留) + */ + clearAuth(): void { + localStorage.removeItem(AUTH_TOKEN_KEY); + }, + + /** + * 检查是否已认证(预留,目前始终返回 true) + */ + isAuthenticated(): boolean { + // TODO: 实现真实的认证检查 + return true; + }, + + /** + * 获取 Authorization header 值 + */ + getAuthHeader(): Record { + const token = this.getToken(); + if (token) { + return { Authorization: `Bearer ${token}` }; + } + return {}; + } +}; \ No newline at end of file diff --git a/src/services/conversationApi.ts b/src/services/conversationApi.ts new file mode 100644 index 0000000..9ac3aca --- /dev/null +++ b/src/services/conversationApi.ts @@ -0,0 +1,302 @@ +/** + * 对话 API 服务层 + * + * 封装所有对话相关的后端 API 调用 + */ + +import { getAuthHeaders } from './request'; +import { useAuthStore } from '@/stores/auth'; +import type { Conversation, Message, MessageContent, ConversationSettings } from '@/types/chat'; + +// API 端点 +const API_BASE = '/api/chat-ui'; +const ENDPOINTS = { + CONVERSATIONS: `${API_BASE}/conversations`, + CONVERSATION: (id: string) => `${API_BASE}/conversations/${id}`, + CONVERSATION_MESSAGES: (id: string) => `${API_BASE}/conversations/${id}/messages`, +}; + +// 后端返回的对话数据格式 +interface BackendConversation { + id: string; + userId?: string; + title: string; + createdAt: number; + updatedAt: number; + pinned: boolean; + archived: boolean; + settings?: ConversationSettings; + messages?: BackendMessage[]; +} + +// 后端返回的消息数据格式 +interface BackendMessage { + id: string; + role: string; + content: MessageContent; + timestamp: number; + feedback?: { + liked?: boolean; + disliked?: boolean; + copied?: boolean; + }; +} + +/** + * 获取请求头(包含认证信息) + */ +function getHeaders(): Record { + return getAuthHeaders(); +} + +/** + * 将后端对话格式转换为前端格式 + */ +function transformConversation(backendConv: BackendConversation): Conversation { + return { + id: backendConv.id, + title: backendConv.title, + createdAt: backendConv.createdAt, + updatedAt: backendConv.updatedAt, + pinned: backendConv.pinned, + archived: backendConv.archived, + settings: backendConv.settings, + messages: (backendConv.messages || []).map(transformMessage), + }; +} + +/** + * 将后端消息格式转换为前端格式 + */ +function transformMessage(backendMsg: BackendMessage): Message { + return { + id: backendMsg.id, + role: backendMsg.role as 'user' | 'assistant' | 'system', + content: backendMsg.content, + timestamp: backendMsg.timestamp, + feedback: backendMsg.feedback, + isStreaming: false, + } as Message; +} + +/** + * 将前端对话格式转换为后端格式 + */ +function toBackendFormat(conversation: Partial, userId?: string): Record { + const data: Record = {}; + + if (conversation.id !== undefined) data.id = conversation.id; + if (userId !== undefined) data.user_id = userId; // 后端使用下划线命名 + if (conversation.title !== undefined) data.title = conversation.title; + if (conversation.createdAt !== undefined) data.createdAt = conversation.createdAt; + if (conversation.updatedAt !== undefined) data.updatedAt = conversation.updatedAt; + if (conversation.pinned !== undefined) data.pinned = conversation.pinned; + if (conversation.archived !== undefined) data.archived = conversation.archived; + if (conversation.settings !== undefined) data.settings = conversation.settings; + if (conversation.messages !== undefined) { + data.messages = conversation.messages.map(msg => ({ + id: msg.id, + role: msg.role, + content: msg.content, + timestamp: msg.timestamp, + feedback: msg.feedback, + })); + } + + return data; +} + +/** + * 对话 API 服务 + */ +export const conversationApi = { + /** + * 获取所有对话列表(不含消息内容) + */ + async fetchConversations(): Promise { + const authStore = useAuthStore(); + const userId = authStore.userId; + + // 构建 URL,添加 user_id 查询参数 + const url = userId + ? `${ENDPOINTS.CONVERSATIONS}?user_id=${encodeURIComponent(userId)}` + : ENDPOINTS.CONVERSATIONS; + + const response = await fetch(url, { + method: 'GET', + headers: getHeaders(), + }); + + if (!response.ok) { + throw new Error(`获取对话列表失败: HTTP ${response.status}`); + } + + const data: BackendConversation[] = await response.json(); + return data.map(transformConversation); + }, + + /** + * 获取单个对话(含消息内容) + */ + async fetchConversation(id: string): Promise { + const response = await fetch(ENDPOINTS.CONVERSATION(id), { + method: 'GET', + headers: getHeaders(), + }); + + if (!response.ok) { + if (response.status === 404) { + throw new Error('对话不存在'); + } + throw new Error(`获取对话失败: HTTP ${response.status}`); + } + + const data: BackendConversation = await response.json(); + return transformConversation(data); + }, + + /** + * 创建新对话 + */ + async createConversation(data: Partial): Promise { + const authStore = useAuthStore(); + const userId = authStore.userId || undefined; + + const response = await fetch(ENDPOINTS.CONVERSATIONS, { + method: 'POST', + headers: getHeaders(), + body: JSON.stringify(toBackendFormat(data, userId)), + }); + + if (!response.ok) { + throw new Error(`创建对话失败: HTTP ${response.status}`); + } + + const result: BackendConversation = await response.json(); + return transformConversation(result); + }, + + /** + * 更新对话(部分更新) + */ + async updateConversation(id: string, data: Partial): Promise { + const response = await fetch(ENDPOINTS.CONVERSATION(id), { + method: 'PUT', + headers: getHeaders(), + body: JSON.stringify(toBackendFormat(data)), + }); + + if (!response.ok) { + if (response.status === 404) { + throw new Error('对话不存在'); + } + throw new Error(`更新对话失败: HTTP ${response.status}`); + } + + const result: BackendConversation = await response.json(); + return transformConversation(result); + }, + + /** + * 保存对话(创建或更新) + */ + async saveConversation(conversation: Conversation): Promise { + const data = toBackendFormat(conversation); + + const response = await fetch(ENDPOINTS.CONVERSATIONS, { + method: 'POST', + headers: getHeaders(), + body: JSON.stringify(data), + }); + + if (!response.ok) { + throw new Error(`保存对话失败: HTTP ${response.status}`); + } + + const result: BackendConversation = await response.json(); + return transformConversation(result); + }, + + /** + * 删除对话 + */ + async deleteConversation(id: string): Promise { + const response = await fetch(ENDPOINTS.CONVERSATION(id), { + method: 'DELETE', + headers: getHeaders(), + }); + + if (!response.ok) { + if (response.status === 404) { + // 对话已不存在,视为成功 + return; + } + throw new Error(`删除对话失败: HTTP ${response.status}`); + } + }, + + /** + * 添加消息到对话(增量更新) + */ + async addMessage(conversationId: string, message: Partial): Promise { + const response = await fetch(ENDPOINTS.CONVERSATION_MESSAGES(conversationId), { + method: 'POST', + headers: getHeaders(), + body: JSON.stringify({ + id: message.id, + role: message.role, + content: message.content, + timestamp: message.timestamp, + feedback: message.feedback, + }), + }); + + if (!response.ok) { + throw new Error(`添加消息失败: HTTP ${response.status}`); + } + + const result: BackendMessage = await response.json(); + return transformMessage(result); + }, + + /** + * 更新消息 + */ + async updateMessage(conversationId: string, messageId: string, data: Partial): Promise { + const response = await fetch(`${ENDPOINTS.CONVERSATION(conversationId)}/messages/${messageId}`, { + method: 'PUT', + headers: getHeaders(), + body: JSON.stringify({ + content: data.content, + feedback: data.feedback, + }), + }); + + if (!response.ok) { + throw new Error(`更新消息失败: HTTP ${response.status}`); + } + + const result: BackendMessage = await response.json(); + return transformMessage(result); + }, + + /** + * 批量迁移对话数据 + */ + async migrateConversations(conversations: Conversation[]): Promise<{ success: number; failed: number }> { + let success = 0; + let failed = 0; + + for (const conversation of conversations) { + try { + await this.saveConversation(conversation); + success++; + } catch (e) { + console.error(`迁移对话失败 [${conversation.id}]:`, e); + failed++; + } + } + + return { success, failed }; + }, +}; \ No newline at end of file diff --git a/src/services/request.ts b/src/services/request.ts new file mode 100644 index 0000000..06658ee --- /dev/null +++ b/src/services/request.ts @@ -0,0 +1,100 @@ +/** + * 统一请求封装 + * + * 自动从 Pinia store 获取认证 token + */ +import { useAuthStore } from '@/stores/auth'; + +/** + * 获取认证 token(从 Pinia store) + */ +function getToken(): string | null { + const authStore = useAuthStore(); + return authStore.token; +} + +/** + * 统一的请求封装函数 + * + * @param url - 请求地址 + * @param options - fetch 选项 + * @returns Response 对象 + * + * @example + * // GET 请求 + * const response = await apiRequest('/api/users'); + * const data = await response.json(); + * + * // POST 请求 + * const response = await apiRequest('/api/users', { + * method: 'POST', + * body: JSON.stringify({ name: 'John' }) + * }); + */ +export async function apiRequest( + url: string, + options: RequestInit = {} +): Promise { + const token = getToken(); + + // 判断是否为 FormData,不设置 Content-Type 让浏览器自动处理 + const isFormData = options.body instanceof FormData; + + // 合并默认配置 + const config: RequestInit = { + ...options, + headers: { + ...(isFormData ? {} : { 'Content-Type': 'application/json' }), + ...(token ? { 'Authorization': `Bearer ${token}` } : {}), + ...options.headers, + }, + }; + + const response = await fetch(url, config); + + // 401 认证失败提示 + if (response.status === 401) { + window.$toast?.('认证失败,请重新登录', 'error'); + } + + return response; +} + +/** + * JSON 请求封装(自动解析响应) + * + * @param url - 请求地址 + * @param options - fetch 选项 + * @returns 解析后的 JSON 数据 + * + * @example + * const users = await apiRequestJson('/api/users'); + */ +export async function apiRequestJson( + url: string, + options: RequestInit = {} +): Promise { + const response = await apiRequest(url, options); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(errorText || `HTTP ${response.status}`); + } + + return response.json(); +} + +/** + * 获取带认证的 headers + * 用于需要手动构建 headers 的场景 + */ +export function getAuthHeaders(): Record { + const token = getToken(); + const headers: Record = { + 'Content-Type': 'application/json', + }; + if (token) { + headers['Authorization'] = `Bearer ${token}`; + } + return headers; +} \ No newline at end of file diff --git a/src/stores/auth.ts b/src/stores/auth.ts new file mode 100644 index 0000000..07af4b4 --- /dev/null +++ b/src/stores/auth.ts @@ -0,0 +1,127 @@ +/** + * 用户认证状态管理 + */ +import { defineStore } from 'pinia'; +import { ref, computed } from 'vue'; +import type { UserInfo } from '@/types/chat'; + +// MARK: dev 默认 token(当 URL 无 token 参数时使用) +const DEV_DEFAULT_TOKEN = ''; + +// 认证接口返回格式 +interface AuthResponse { + code: string; + msg: string; + success: boolean; + timestamp: number; + data: UserInfo | null; +} + +// 认证接口 +const AUTH_CHECK_URL = '/api/auth/check/checkTokenRn'; + +export const useAuthStore = defineStore('auth', () => { + // 状态 + const token = ref(null); + const user = ref(null); + const isInitialized = ref(false); + + // 计算属性 + const isAuthenticated = computed(() => !!token.value); + const userId = computed(() => user.value?.username || null); // username 用于 OSS 路径和数据库 user_id + + /** + * 验证 token 并获取用户信息 + */ + async function checkToken(tokenToCheck: string): Promise { + try { + const response = await fetch(`${AUTH_CHECK_URL}/${tokenToCheck}`); + if (!response.ok) { + return null; + } + const data: AuthResponse = await response.json(); + + + if (data.success && data.data) { + window.$toast?.(`登录成功, 欢迎 ${data.data.nickname || data.data.username}`, 'success'); + + return data.data; + }else{ + window.$toast?.('[Auth] Token 验证失败:Token无效'); + } + return null; + } catch (error) { + + console.error('[Auth] Token 验证失败:', error); + return null; + } + } + + /** + * 初始化 - 从 URL 参数获取 token,验证后设置用户信息 + */ + async function init() { + const searchParams = new URLSearchParams(window.location.search); + const urlToken = searchParams.get('token'); + + // 获取 token:URL > localStorage > 默认值 + const tokenValue = urlToken + || localStorage.getItem('DEV_DEFAULT_TOKEN') + || DEV_DEFAULT_TOKEN; + + if (!tokenValue) { + isInitialized.value = true; + window.$toast?.('未登录,请先登录', 'error'); + return; + } + + // 验证 token + const userInfo = await checkToken(tokenValue); + + if (userInfo) { + + token.value = tokenValue; + user.value = userInfo; + } else { + // 验证失败,清空 + token.value = null; + user.value = null; + } + + isInitialized.value = true; + } + + /** + * 设置用户信息 + */ + function setUser(userInfo: UserInfo) { + user.value = userInfo; + } + + /** + * 获取认证 header + */ + function getAuthHeader(): Record { + if (token.value) { + return { Authorization: `Bearer ${token.value}` }; + } + return {}; + } + + // 初始化(不等待,让调用方通过 isInitialized 判断) + init(); + + return { + // 状态 + token, + user, + isAuthenticated, + userId, + isInitialized, + + // 方法 + setUser, + getAuthHeader, + init, + }; +}); \ No newline at end of file diff --git a/src/stores/chat.ts b/src/stores/chat.ts index 598d45a..c189be1 100644 --- a/src/stores/chat.ts +++ b/src/stores/chat.ts @@ -8,6 +8,7 @@ import type { } from "@/types/chat"; import { MessageRole } from "@/types/chat"; import { generateId, extractTitleFromMessage } from "@/utils/helpers"; +import { conversationApi } from "@/services/conversationApi"; export const useChatStore = defineStore("chat", () => { // 状态 @@ -15,6 +16,8 @@ export const useChatStore = defineStore("chat", () => { const currentConversationId = ref(null); const isStreaming = ref(false); const streamController = ref(null); + const isInitialized = ref(false); + const isLoading = ref(false); // 计算属性 const currentConversation = computed(() => { @@ -40,8 +43,43 @@ export const useChatStore = defineStore("chat", () => { return sortedConversations.value.filter((c) => !c.pinned && !c.archived); }); - // 方法 - function createConversation(): string { + // 初始化方法 - 从后端 API 加载数据 + async function initializeFromApi() { + if (isInitialized.value || isLoading.value) return; + + isLoading.value = true; + try { + const loadedConversations = await conversationApi.fetchConversations(); + conversations.value = loadedConversations; + + // 恢复当前对话 ID(从 localStorage 或选择第一个) + const storedId = localStorage.getItem("chat-current-id"); + if (storedId && conversations.value.find((c) => c.id === storedId)) { + currentConversationId.value = storedId; + } else if (conversations.value.length > 0) { + currentConversationId.value = conversations.value[0].id; + } + + isInitialized.value = true; + } catch (error) { + console.error("Failed to initialize from API:", error); + // 如果 API 失败,尝试从 localStorage 加载(降级处理) + loadFromStorage(); + } finally { + isLoading.value = false; + } + } + + // 保存当前对话 ID 到 localStorage + function saveCurrentId() { + localStorage.setItem( + "chat-current-id", + currentConversationId.value || "" + ); + } + + // 创建对话 + async function createConversation(): Promise { const newConversation: Conversation = { id: generateId(), title: "新对话", @@ -53,89 +91,171 @@ export const useChatStore = defineStore("chat", () => { settings: undefined, }; + // 乐观更新 conversations.value.unshift(newConversation); currentConversationId.value = newConversation.id; - saveToStorage(); + saveCurrentId(); + + // 异步保存到后端 + try { + const saved = await conversationApi.createConversation(newConversation); + // 更新本地数据(以防后端修改了某些字段) + const index = conversations.value.findIndex((c) => c.id === newConversation.id); + if (index !== -1) { + conversations.value[index] = saved; + } + } catch (error) { + console.error("Failed to create conversation:", error); + // 回滚乐观更新 + const index = conversations.value.findIndex((c) => c.id === newConversation.id); + if (index !== -1) { + conversations.value.splice(index, 1); + } + throw error; + } return newConversation.id; } - function deleteConversation(id: string) { + // 删除对话 + async function deleteConversation(id: string) { const index = conversations.value.findIndex((c) => c.id === id); - if (index !== -1) { - conversations.value.splice(index, 1); + if (index === -1) return; - if (currentConversationId.value === id) { - currentConversationId.value = conversations.value[0]?.id || null; - } + // 保存引用以便回滚 + const deletedConversation = conversations.value[index]; - saveToStorage(); + // 乐观更新 + conversations.value.splice(index, 1); + if (currentConversationId.value === id) { + currentConversationId.value = conversations.value[0]?.id || null; + saveCurrentId(); + } + + // 异步删除 + try { + await conversationApi.deleteConversation(id); + } catch (error) { + console.error("Failed to delete conversation:", error); + // 回滚 + conversations.value.splice(index, 0, deletedConversation); + throw error; } } - function selectConversation(id: string) { + // 选择对话 + async function selectConversation(id: string) { currentConversationId.value = id; + saveCurrentId(); + + // 如果对话没有加载消息,从后端加载 + const conversation = conversations.value.find((c) => c.id === id); + if (conversation && (!conversation.messages || conversation.messages.length === 0)) { + try { + const loaded = await conversationApi.fetchConversation(id); + const index = conversations.value.findIndex((c) => c.id === id); + if (index !== -1) { + conversations.value[index] = loaded; + } + } catch (error) { + console.error("Failed to load conversation:", error); + } + } } - function togglePinConversation(id: string) { + // 置顶对话 + async function togglePinConversation(id: string) { const conversation = conversations.value.find((c) => c.id === id); - if (conversation) { + if (!conversation) return; + + // 乐观更新 + conversation.pinned = !conversation.pinned; + + // 异步保存 + try { + await conversationApi.updateConversation(id, { pinned: conversation.pinned }); + } catch (error) { + console.error("Failed to toggle pin:", error); + // 回滚 conversation.pinned = !conversation.pinned; - saveToStorage(); + throw error; } } - function renameConversation(id: string, newTitle: string) { + // 重命名对话 + async function renameConversation(id: string, newTitle: string) { const conversation = conversations.value.find((c) => c.id === id); - if (conversation) { - conversation.title = newTitle; - conversation.updatedAt = Date.now(); - saveToStorage(); + if (!conversation) return; + + const oldTitle = conversation.title; + conversation.title = newTitle; + conversation.updatedAt = Date.now(); + + // 异步保存 + try { + await conversationApi.updateConversation(id, { title: newTitle }); + } catch (error) { + console.error("Failed to rename conversation:", error); + // 回滚 + conversation.title = oldTitle; + throw error; } } - function updateConversationSettings( + // 更新对话设置 + async function updateConversationSettings( id: string, - convSettings: ConversationSettings, + convSettings: ConversationSettings ) { const conversation = conversations.value.find((c) => c.id === id); - if (conversation) { - conversation.settings = { ...conversation.settings, ...convSettings }; - conversation.updatedAt = Date.now(); - saveToStorage(); + if (!conversation) return; + + const oldSettings = conversation.settings; + conversation.settings = { ...conversation.settings, ...convSettings }; + conversation.updatedAt = Date.now(); + + // 异步保存 + try { + await conversationApi.updateConversation(id, { settings: conversation.settings }); + } catch (error) { + console.error("Failed to update settings:", error); + // 回滚 + conversation.settings = oldSettings; + throw error; } } - function addMessage( + // 添加消息 + async function addMessage( role: MessageRole, content: MessageContent, - conversationId?: string, - ): Message { - const targetId = conversationId || currentConversationId.value; + conversationId?: string + ): Promise { + let targetId = conversationId || currentConversationId.value; if (!targetId) { - createConversation(); + await createConversation(); + targetId = currentConversationId.value; } - const conversation = conversations.value.find( - (c) => c.id === (targetId || currentConversationId.value), - ); - + const conversation = conversations.value.find((c) => c.id === targetId); if (!conversation) { throw new Error("Conversation not found"); } - const message: any = { + const message: Message = { id: generateId(), role, content, timestamp: Date.now(), isStreaming: false, - }; + } as Message; + // 乐观更新 conversation.messages.push(message); conversation.updatedAt = Date.now(); + // 如果是第一条用户消息,更新标题 if ( role === MessageRole.USER && conversation.messages.length === 1 && @@ -144,21 +264,64 @@ export const useChatStore = defineStore("chat", () => { conversation.title = extractTitleFromMessage(content.text); } - saveToStorage(); + // 异步保存(使用增量更新) + try { + // 确保 targetId 不为空 + if (targetId) { + // 发送消息到后端,不等待完成 + conversationApi.addMessage(targetId, message).catch((error) => { + console.error("Failed to save message:", error); + }); + + // 如果标题更新了,也保存标题 + if ( + role === MessageRole.USER && + conversation.messages.length === 1 + ) { + conversationApi.updateConversation(targetId, { title: conversation.title }).catch((error) => { + console.error("Failed to update title:", error); + }); + } + } + } catch (error) { + console.error("Failed to add message:", error); + } + return message; } - function updateMessage(messageId: string, updates: Partial) { + // 更新消息 + async function updateMessage(messageId: string, updates: Partial) { const conversation = currentConversation.value; - if (!conversation) return; + if (!conversation) { + console.warn("[updateMessage] No current conversation"); + return; + } const message = conversation.messages.find((m) => m.id === messageId); - if (message) { - Object.assign(message, updates); - saveToStorage(); + if (!message) { + console.warn("[updateMessage] Message not found:", messageId); + return; + } + + // 乐观更新 + Object.assign(message, updates); + + // 异步保存 + try { + console.log("[updateMessage] Saving to backend:", { + conversationId: conversation.id, + messageId, + content: updates.content, + }); + await conversationApi.updateMessage(conversation.id, messageId, updates); + console.log("[updateMessage] Save successful"); + } catch (error) { + console.error("Failed to update message:", error); } } + // 更新消息内容(流式更新时使用,不触发 API 调用) function updateMessageContent(messageId: string, text: string) { const conversation = currentConversation.value; if (!conversation) return; @@ -169,24 +332,49 @@ export const useChatStore = defineStore("chat", () => { } } - function setMessageFeedback( + // 保存整个对话(用于流式结束后) + async function saveConversation(conversationId: string) { + const conversation = conversations.value.find((c) => c.id === conversationId); + if (!conversation) return; + + try { + await conversationApi.updateConversation(conversationId, { + messages: conversation.messages, + updatedAt: Date.now() + }); + } catch (error) { + console.error("Failed to save conversation:", error); + } + } + + // 设置消息反馈 + async function setMessageFeedback( messageId: string, - feedback: "like" | "dislike" | null, + feedback: "like" | "dislike" | null ) { const conversation = currentConversation.value; if (!conversation) return; const message = conversation.messages.find((m) => m.id === messageId); - if (message) { - message.feedback = { - liked: feedback === "like", - disliked: feedback === "dislike", - copied: message.feedback?.copied, - }; - saveToStorage(); + if (!message) return; + + message.feedback = { + liked: feedback === "like", + disliked: feedback === "dislike", + copied: message.feedback?.copied, + }; + + // 异步保存 + try { + await conversationApi.updateMessage(conversation.id, messageId, { + feedback: message.feedback + }); + } catch (error) { + console.error("Failed to save feedback:", error); } } + // 设置消息已复制 function setMessageCopied(messageId: string) { const conversation = currentConversation.value; if (!conversation) return; @@ -200,11 +388,13 @@ export const useChatStore = defineStore("chat", () => { } } + // 开始流式输出 function startStreaming() { isStreaming.value = true; streamController.value = new AbortController(); } + // 停止流式输出 function stopStreaming() { isStreaming.value = false; if (streamController.value) { @@ -213,30 +403,23 @@ export const useChatStore = defineStore("chat", () => { } } - function clearConversation(id: string) { + // 清空对话消息 + async function clearConversation(id: string) { const conversation = conversations.value.find((c) => c.id === id); - if (conversation) { - conversation.messages = []; - conversation.updatedAt = Date.now(); - saveToStorage(); - } - } + if (!conversation) return; - function saveToStorage() { + conversation.messages = []; + conversation.updatedAt = Date.now(); + + // 异步保存 try { - localStorage.setItem( - "chat-conversations", - JSON.stringify(conversations.value), - ); - localStorage.setItem( - "chat-current-id", - currentConversationId.value || "", - ); - } catch (e) { - console.error("Failed to save to storage:", e); + await conversationApi.updateConversation(id, { messages: [] }); + } catch (error) { + console.error("Failed to clear conversation:", error); } } + // 降级:从 localStorage 加载(仅在 API 不可用时使用) function loadFromStorage() { try { const stored = localStorage.getItem("chat-conversations"); @@ -255,17 +438,40 @@ export const useChatStore = defineStore("chat", () => { } } - loadFromStorage(); + // 保存到 localStorage(降级模式使用) + function saveToStorage() { + try { + localStorage.setItem( + "chat-conversations", + JSON.stringify(conversations.value) + ); + localStorage.setItem( + "chat-current-id", + currentConversationId.value || "" + ); + } catch (e) { + console.error("Failed to save to storage:", e); + } + } + + // 初始化 + initializeFromApi(); return { + // 状态 conversations, currentConversationId, isStreaming, streamController, + isInitialized, + isLoading, + // 计算属性 currentConversation, sortedConversations, pinnedConversations, recentConversations, + // 方法 + initializeFromApi, createConversation, deleteConversation, selectConversation, @@ -275,11 +481,13 @@ export const useChatStore = defineStore("chat", () => { addMessage, updateMessage, updateMessageContent, + saveConversation, setMessageFeedback, setMessageCopied, startStreaming, stopStreaming, clearConversation, loadFromStorage, + saveToStorage, }; -}); +}); \ No newline at end of file diff --git a/src/stores/settings.ts b/src/stores/settings.ts index 54376b0..7192e93 100644 --- a/src/stores/settings.ts +++ b/src/stores/settings.ts @@ -16,7 +16,7 @@ export const useSettingsStore = defineStore("settings", () => { compactMode: false, // AI 默认设置 - defaultModel: "glm-4.6", + defaultModel: "glm-4.6v", defaultTemperature: 0.7, defaultMaxTokens: 4096, defaultSystemPrompt: "你是一个有帮助的 AI 助手。", diff --git a/src/types/chat.ts b/src/types/chat.ts index a217b8b..a34c850 100644 --- a/src/types/chat.ts +++ b/src/types/chat.ts @@ -146,3 +146,13 @@ export interface AIModel { provider: string; icon?: string; } + +// 用户信息 +export interface UserInfo { + id: string; + username?: string; + nickname?: string; + email?: string; + avatar?: string; + [key: string]: unknown; +} diff --git a/src/utils/migrateData.ts b/src/utils/migrateData.ts new file mode 100644 index 0000000..ef5c235 --- /dev/null +++ b/src/utils/migrateData.ts @@ -0,0 +1,140 @@ +/** + * 数据迁移工具 + * + * 将 localStorage 中的旧对话数据迁移到后端 SQLite + */ + +import { conversationApi } from '@/services/conversationApi'; +import type { Conversation } from '@/types/chat'; + +const OLD_CONVERSATIONS_KEY = 'chat-conversations'; +const MIGRATION_FLAG_KEY = 'chat-migration-completed'; + +export interface MigrationResult { + success: boolean; + total: number; + migrated: number; + failed: number; + message: string; +} + +/** + * 检查是否已完成迁移 + */ +export function isMigrationCompleted(): boolean { + return localStorage.getItem(MIGRATION_FLAG_KEY) === 'true'; +} + +/** + * 标记迁移已完成 + */ +function markMigrationCompleted() { + localStorage.setItem(MIGRATION_FLAG_KEY, 'true'); +} + +/** + * 从 localStorage 读取旧数据 + */ +function getOldConversations(): Conversation[] { + try { + const stored = localStorage.getItem(OLD_CONVERSATIONS_KEY); + if (stored) { + return JSON.parse(stored); + } + } catch (e) { + console.error('Failed to read old conversations:', e); + } + return []; +} + +/** + * 迁移单个对话 + */ +async function migrateConversation(conversation: Conversation): Promise { + try { + await conversationApi.saveConversation(conversation); + return true; + } catch (error) { + console.error(`Failed to migrate conversation ${conversation.id}:`, error); + return false; + } +} + +/** + * 执行迁移 + */ +export async function migrateData(): Promise { + // 检查是否已迁移 + if (isMigrationCompleted()) { + return { + success: true, + total: 0, + migrated: 0, + failed: 0, + message: '迁移已完成,无需重复执行', + }; + } + + // 读取旧数据 + const oldConversations = getOldConversations(); + + if (oldConversations.length === 0) { + markMigrationCompleted(); + return { + success: true, + total: 0, + migrated: 0, + failed: 0, + message: '没有需要迁移的数据', + }; + } + + // 迁移数据 + let migrated = 0; + let failed = 0; + + for (const conversation of oldConversations) { + const success = await migrateConversation(conversation); + if (success) { + migrated++; + } else { + failed++; + } + } + + // 迁移完成后清理 + if (migrated === oldConversations.length) { + // 全部成功,清理旧数据 + localStorage.removeItem(OLD_CONVERSATIONS_KEY); + markMigrationCompleted(); + } + + return { + success: failed === 0, + total: oldConversations.length, + migrated, + failed, + message: failed === 0 + ? `成功迁移 ${migrated} 条对话` + : `迁移完成:成功 ${migrated} 条,失败 ${failed} 条`, + }; +} + +/** + * 清理 localStorage 中的旧数据 + */ +export function cleanupOldData() { + localStorage.removeItem(OLD_CONVERSATIONS_KEY); + // 保留 chat-current-id,因为它仍在使用 +} + +/** + * 导出迁移状态 + */ +export function getMigrationStatus() { + return { + completed: isMigrationCompleted(), + hasOldData: localStorage.getItem(OLD_CONVERSATIONS_KEY) !== null, + oldDataCount: getOldConversations().length, + }; +} \ No newline at end of file diff --git a/tsconfig.tsbuildinfo b/tsconfig.tsbuildinfo index 185dfda..b217ae1 100644 --- a/tsconfig.tsbuildinfo +++ b/tsconfig.tsbuildinfo @@ -1 +1 @@ -{"root":["./src/main.ts","./src/components/icons/index.ts","./src/composables/useKeyboard.ts","./src/services/api.ts","./src/stores/chat.ts","./src/stores/settings.ts","./src/types/chat.ts","./src/utils/helpers.ts","./src/App.vue","./src/components/chat/ChatHeader.vue","./src/components/chat/ChatMain.vue","./src/components/chat/MessageList.vue","./src/components/chat/WelcomeScreen.vue","./src/components/input/AttachmentPreview.vue","./src/components/input/ChatInput.vue","./src/components/message/CodeBlock.vue","./src/components/message/MessageActions.vue","./src/components/message/MessageBubble.vue","./src/components/message/components/EChartsContainerNode.vue","./src/components/message/components/Loading.vue","./src/components/message/components/ThinkingNode.vue","./src/components/modals/ConversationSettingsModal.vue","./src/components/modals/SearchModal.vue","./src/components/modals/SettingsModal.vue","./src/components/modals/ShortcutsModal.vue","./src/components/sidebar/ChatSidebar.vue","./src/components/sidebar/ConversationItem.vue","./src/components/ui/FormSelect.vue","./src/components/ui/FormSlider.vue","./src/components/ui/FormSwitch.vue"],"errors":true,"version":"5.9.3"} \ No newline at end of file +{"root":["./src/main.ts","./src/components/icons/index.ts","./src/composables/useKeyboard.ts","./src/services/api.ts","./src/services/authService.ts","./src/services/conversationApi.ts","./src/services/request.ts","./src/stores/auth.ts","./src/stores/chat.ts","./src/stores/settings.ts","./src/types/chat.ts","./src/utils/helpers.ts","./src/utils/migrateData.ts","./src/App.vue","./src/components/chat/ChatHeader.vue","./src/components/chat/ChatMain.vue","./src/components/chat/MessageList.vue","./src/components/chat/WelcomeScreen.vue","./src/components/input/AttachmentPreview.vue","./src/components/input/ChatInput.vue","./src/components/message/CodeBlock.vue","./src/components/message/MessageActions.vue","./src/components/message/MessageBubble.vue","./src/components/message/components/EChartsContainerNode.vue","./src/components/message/components/Loading.vue","./src/components/message/components/ThinkingNode.vue","./src/components/modals/ConversationSettingsModal.vue","./src/components/modals/SearchModal.vue","./src/components/modals/SettingsModal.vue","./src/components/modals/ShortcutsModal.vue","./src/components/sidebar/ChatSidebar.vue","./src/components/sidebar/ConversationItem.vue","./src/components/ui/FormSelect.vue","./src/components/ui/FormSlider.vue","./src/components/ui/FormSwitch.vue"],"errors":true,"version":"5.9.3"} \ No newline at end of file diff --git a/vite.config.ts b/vite.config.ts index af95ead..ce319d6 100644 --- a/vite.config.ts +++ b/vite.config.ts @@ -21,6 +21,10 @@ export default defineConfig({ target: "http://localhost:8000", // Python服务器端口 changeOrigin: true, }, + "/api/auth": { + target: "https://sxwz.xueai.art", + changeOrigin: true, + }, }, }, build: {