""" SQLite 数据库模块 - 会话持久化存储 提供会话和消息的 CRUD 操作,支持多用户(预留 user_id 字段)。 """ import json import sqlite3 import threading from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional # 数据库文件路径 DB_PATH = Path(__file__).parent.parent / "data" / "chat.db" # 线程本地存储,确保每个线程使用独立的连接 _thread_local = threading.local() # 全局数据库实例 _db_instance: Optional["Database"] = None class Database: """SQLite 数据库管理类""" def __init__(self, db_path: Path): self.db_path = db_path self.db_path.parent.mkdir(parents=True, exist_ok=True) self._init_tables() def _get_connection(self) -> sqlite3.Connection: """获取当前线程的数据库连接""" if not hasattr(_thread_local, "connection"): conn = sqlite3.connect(str(self.db_path), check_same_thread=False) conn.row_factory = sqlite3.Row # 启用外键约束 conn.execute("PRAGMA foreign_keys = ON") _thread_local.connection = conn return _thread_local.connection def _init_tables(self): """初始化数据库表结构""" conn = self._get_connection() cursor = conn.cursor() # 创建会话表 cursor.execute(""" CREATE TABLE IF NOT EXISTS conversations ( id TEXT PRIMARY KEY, user_id TEXT DEFAULT 'default', title TEXT DEFAULT '新对话', created_at INTEGER, updated_at INTEGER, pinned INTEGER DEFAULT 0, archived INTEGER DEFAULT 0, settings TEXT ) """) # 创建消息表 cursor.execute(""" CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY, conversation_id TEXT NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, timestamp INTEGER, feedback TEXT, FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE ) """) # 创建索引 cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id) """) # 检查并添加缺失的列(迁移旧数据库 - conversations 表) cursor.execute("PRAGMA table_info(conversations)") conv_columns = [col[1] for col in cursor.fetchall()] conv_migrations = [ ('user_id', "TEXT DEFAULT 'default'"), ('pinned', "INTEGER DEFAULT 0"), ('archived', "INTEGER DEFAULT 0"), ('settings', "TEXT"), ] for col_name, col_def in conv_migrations: if col_name not in conv_columns: cursor.execute(f"ALTER TABLE conversations ADD COLUMN {col_name} {col_def}") print(f"[数据库] conversations 表已添加 {col_name} 列") # 检查并添加缺失的列(迁移旧数据库 - messages 表) cursor.execute("PRAGMA table_info(messages)") msg_columns = [col[1] for col in cursor.fetchall()] msg_migrations = [ ('timestamp', "INTEGER"), ('feedback', "TEXT"), ] for col_name, col_def in msg_migrations: if col_name not in msg_columns: cursor.execute(f"ALTER TABLE messages ADD COLUMN {col_name} {col_def}") print(f"[数据库] messages 表已添加 {col_name} 列") # 创建 user_id 索引(在确保列存在后) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_conversations_user ON conversations(user_id) """) # 创建分享表 cursor.execute(""" CREATE TABLE IF NOT EXISTS shares ( id TEXT PRIMARY KEY, conversation_ids TEXT NOT NULL, conversations TEXT NOT NULL, password_hash TEXT NOT NULL, created_at INTEGER, expires_at INTEGER, view_count INTEGER DEFAULT 0 ) """) # 创建分享过期时间索引 cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_shares_expires ON shares(expires_at) """) conn.commit() # ── 会话 CRUD ───────────────────────────────────────────────────── def create_conversation(self, data: Dict[str, Any]) -> Dict[str, Any]: """创建新会话""" conn = self._get_connection() cursor = conn.cursor() now = int(datetime.now(timezone.utc).timestamp() * 1000) conv_id = data.get("id") or self._generate_id() cursor.execute( """ INSERT INTO conversations (id, user_id, title, created_at, updated_at, pinned, archived, settings) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( conv_id, data.get("user_id", "default"), data.get("title", "新对话"), data.get("createdAt", now), now, 1 if data.get("pinned") else 0, 1 if data.get("archived") else 0, json.dumps(data.get("settings")) if data.get("settings") else None, ), ) # 插入消息(如果有) messages = data.get("messages", []) for msg in messages: self._insert_message(cursor, conv_id, msg) conn.commit() return self.get_conversation(conv_id) def get_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]: """获取单个会话(包含消息)""" conn = self._get_connection() cursor = conn.cursor() cursor.execute( "SELECT * FROM conversations WHERE id = ?", (conversation_id,) ) row = cursor.fetchone() if not row: return None return self._row_to_conversation(row, cursor) def list_conversations(self, user_id: str = "default") -> List[Dict[str, Any]]: """列出用户的所有会话""" conn = self._get_connection() cursor = conn.cursor() cursor.execute( """ SELECT * FROM conversations WHERE user_id = ? ORDER BY updated_at DESC """, (user_id,), ) conversations = [] for row in cursor.fetchall(): conv = self._row_to_conversation(row, cursor, include_messages=False) conversations.append(conv) return conversations def update_conversation( self, conversation_id: str, data: Dict[str, Any] ) -> Optional[Dict[str, Any]]: """更新会话""" conn = self._get_connection() cursor = conn.cursor() # 检查会话是否存在 cursor.execute( "SELECT id FROM conversations WHERE id = ?", (conversation_id,) ) if not cursor.fetchone(): return None now = int(datetime.now(timezone.utc).timestamp() * 1000) # 更新会话字段 update_fields = ["updated_at = ?"] update_values = [now] if "title" in data: update_fields.append("title = ?") update_values.append(data["title"]) if "pinned" in data: update_fields.append("pinned = ?") update_values.append(1 if data["pinned"] else 0) if "archived" in data: update_fields.append("archived = ?") update_values.append(1 if data["archived"] else 0) if "settings" in data: update_fields.append("settings = ?") update_values.append(json.dumps(data["settings"])) update_values.append(conversation_id) cursor.execute( f"UPDATE conversations SET {', '.join(update_fields)} WHERE id = ?", update_values, ) # 更新消息(如果提供了 messages 字段) if "messages" in data: # 删除旧消息 cursor.execute( "DELETE FROM messages WHERE conversation_id = ?", (conversation_id,) ) # 插入新消息 for msg in data["messages"]: self._insert_message(cursor, conversation_id, msg) conn.commit() return self.get_conversation(conversation_id) def delete_conversation(self, conversation_id: str) -> bool: """删除会话(级联删除消息)""" conn = self._get_connection() cursor = conn.cursor() cursor.execute( "DELETE FROM conversations WHERE id = ?", (conversation_id,) ) conn.commit() return cursor.rowcount > 0 # ── 消息操作 ─────────────────────────────────────────────────────── def add_message(self, conversation_id: str, message: Dict[str, Any]) -> Dict[str, Any]: """添加消息到会话""" conn = self._get_connection() cursor = conn.cursor() msg = self._insert_message(cursor, conversation_id, message) # 更新会话的 updated_at now = int(datetime.now(timezone.utc).timestamp() * 1000) cursor.execute( "UPDATE conversations SET updated_at = ? WHERE id = ?", (now, conversation_id), ) conn.commit() return msg def update_message(self, message_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: """更新消息""" conn = self._get_connection() cursor = conn.cursor() cursor.execute("SELECT id FROM messages WHERE id = ?", (message_id,)) if not cursor.fetchone(): return None update_fields = [] update_values = [] if "content" in data: update_fields.append("content = ?") update_values.append(json.dumps(data["content"])) if "feedback" in data: update_fields.append("feedback = ?") update_values.append(json.dumps(data["feedback"])) if not update_fields: return self._get_message_by_id(message_id, cursor) update_values.append(message_id) cursor.execute( f"UPDATE messages SET {', '.join(update_fields)} WHERE id = ?", update_values, ) conn.commit() return self._get_message_by_id(message_id, cursor) # ── 内部方法 ─────────────────────────────────────────────────────── def _insert_message( self, cursor: sqlite3.Cursor, conversation_id: str, message: Dict[str, Any] ) -> Dict[str, Any]: """插入消息(内部方法)""" msg_id = message.get("id") or self._generate_id() timestamp = message.get("timestamp") or int( datetime.now(timezone.utc).timestamp() * 1000 ) cursor.execute( """ INSERT INTO messages (id, conversation_id, role, content, timestamp, feedback) VALUES (?, ?, ?, ?, ?, ?) """, ( msg_id, conversation_id, message.get("role", "user"), json.dumps(message.get("content", "")), timestamp, json.dumps(message.get("feedback")) if message.get("feedback") else None, ), ) return { "id": msg_id, "role": message.get("role", "user"), "content": message.get("content", ""), "timestamp": timestamp, "feedback": message.get("feedback"), } def _get_message_by_id( self, message_id: str, cursor: sqlite3.Cursor ) -> Optional[Dict[str, Any]]: """根据 ID 获取消息""" cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,)) row = cursor.fetchone() return self._row_to_message(row) if row else None def _row_to_conversation( self, row: sqlite3.Row, cursor: sqlite3.Cursor, include_messages: bool = True ) -> Dict[str, Any]: """将数据库行转换为会话字典""" conv = { "id": row["id"], "userId": row["user_id"], "title": row["title"], "createdAt": row["created_at"], "updatedAt": row["updated_at"], "pinned": bool(row["pinned"]), "archived": bool(row["archived"]), "settings": json.loads(row["settings"]) if row["settings"] else None, } if include_messages: cursor.execute( "SELECT * FROM messages WHERE conversation_id = ? ORDER BY timestamp", (row["id"],), ) conv["messages"] = [ self._row_to_message(msg_row) for msg_row in cursor.fetchall() ] return conv def _row_to_message(self, row: sqlite3.Row) -> Dict[str, Any]: """将数据库行转换为消息字典""" return { "id": row["id"], "role": row["role"], "content": json.loads(row["content"]), "timestamp": row["timestamp"], "feedback": json.loads(row["feedback"]) if row["feedback"] else None, } def _generate_id(self) -> str: """生成唯一 ID""" import uuid return str(uuid.uuid4()) # ── 分享 CRUD ─────────────────────────────────────────────────────── def create_share(self, data: Dict[str, Any]) -> Dict[str, Any]: """创建分享""" conn = self._get_connection() cursor = conn.cursor() now = int(datetime.now(timezone.utc).timestamp() * 1000) share_id = data.get("id") or self._generate_share_id() cursor.execute( """ INSERT INTO shares (id, conversation_ids, conversations, password_hash, created_at, expires_at, view_count) VALUES (?, ?, ?, ?, ?, ?, 0) """, ( share_id, json.dumps(data.get("conversationIds", [])), json.dumps(data.get("conversations", [])), data.get("passwordHash", ""), now, data.get("expiresAt", now + 604800000), # 默认7天 ), ) conn.commit() return self.get_share(share_id) def get_share(self, share_id: str) -> Optional[Dict[str, Any]]: """获取分享""" conn = self._get_connection() cursor = conn.cursor() cursor.execute("SELECT * FROM shares WHERE id = ?", (share_id,)) row = cursor.fetchone() if not row: return None return self._row_to_share(row) def update_share_view_count(self, share_id: str) -> int: """增加分享访问计数""" conn = self._get_connection() cursor = conn.cursor() cursor.execute( "UPDATE shares SET view_count = view_count + 1 WHERE id = ?", (share_id,), ) conn.commit() # 返回更新后的计数 cursor.execute("SELECT view_count FROM shares WHERE id = ?", (share_id,)) row = cursor.fetchone() return row["view_count"] if row else 0 def delete_share(self, share_id: str) -> bool: """删除分享""" conn = self._get_connection() cursor = conn.cursor() cursor.execute("DELETE FROM shares WHERE id = ?", (share_id,)) conn.commit() return cursor.rowcount > 0 def cleanup_expired_shares(self) -> int: """清理过期分享""" conn = self._get_connection() cursor = conn.cursor() now = int(datetime.now(timezone.utc).timestamp() * 1000) cursor.execute("DELETE FROM shares WHERE expires_at < ?", (now,)) conn.commit() deleted_count = cursor.rowcount if deleted_count > 0: print(f"[数据库] 已清理 {deleted_count} 个过期分享") return deleted_count def _row_to_share(self, row: sqlite3.Row) -> Dict[str, Any]: """将数据库行转换为分享字典""" return { "id": row["id"], "conversationIds": json.loads(row["conversation_ids"]), "conversations": json.loads(row["conversations"]), "passwordHash": row["password_hash"], "createdAt": row["created_at"], "expiresAt": row["expires_at"], "viewCount": row["view_count"], } def _generate_share_id(self) -> str: """生成分享 ID(8位短链接)""" import uuid return uuid.uuid4().hex[:8] def init_db(): """初始化数据库(应用启动时调用)""" global _db_instance if _db_instance is None: _db_instance = Database(DB_PATH) print(f"[数据库] SQLite 初始化完成: {DB_PATH}") def get_db() -> Database: """获取数据库实例""" global _db_instance if _db_instance is None: init_db() return _db_instance