""" SQLite 数据库模块 - 会话持久化存储 提供会话和消息的 CRUD 操作,支持多用户(预留 user_id 字段)。 """ import json import sqlite3 import threading from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional # 数据库文件路径 DB_PATH = Path(__file__).parent.parent / "data" / "chat.db" # 线程本地存储,确保每个线程使用独立的连接 _thread_local = threading.local() # 全局数据库实例 _db_instance: Optional["Database"] = None class Database: """SQLite 数据库管理类""" def __init__(self, db_path: Path): self.db_path = db_path self.db_path.parent.mkdir(parents=True, exist_ok=True) self._init_tables() def _get_connection(self) -> sqlite3.Connection: """获取当前线程的数据库连接""" if not hasattr(_thread_local, "connection"): conn = sqlite3.connect(str(self.db_path), check_same_thread=False) conn.row_factory = sqlite3.Row # 启用外键约束 conn.execute("PRAGMA foreign_keys = ON") _thread_local.connection = conn return _thread_local.connection def _init_tables(self): """初始化数据库表结构""" conn = self._get_connection() cursor = conn.cursor() # 创建会话表 cursor.execute(""" CREATE TABLE IF NOT EXISTS conversations ( id TEXT PRIMARY KEY, user_id TEXT DEFAULT 'default', title TEXT DEFAULT '新对话', created_at INTEGER, updated_at INTEGER, pinned INTEGER DEFAULT 0, archived INTEGER DEFAULT 0, settings TEXT ) """) # 创建消息表 cursor.execute(""" CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY, conversation_id TEXT NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, timestamp INTEGER, feedback TEXT, FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE ) """) # 创建索引 cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id) """) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_conversations_user ON conversations(user_id) """) conn.commit() # ── 会话 CRUD ───────────────────────────────────────────────────── def create_conversation(self, data: Dict[str, Any]) -> Dict[str, Any]: """创建新会话""" conn = self._get_connection() cursor = conn.cursor() now = int(datetime.now(timezone.utc).timestamp() * 1000) conv_id = data.get("id") or self._generate_id() cursor.execute( """ INSERT INTO conversations (id, user_id, title, created_at, updated_at, pinned, archived, settings) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( conv_id, data.get("user_id", "default"), data.get("title", "新对话"), data.get("createdAt", now), now, 1 if data.get("pinned") else 0, 1 if data.get("archived") else 0, json.dumps(data.get("settings")) if data.get("settings") else None, ), ) # 插入消息(如果有) messages = data.get("messages", []) for msg in messages: self._insert_message(cursor, conv_id, msg) conn.commit() return self.get_conversation(conv_id) def get_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]: """获取单个会话(包含消息)""" conn = self._get_connection() cursor = conn.cursor() cursor.execute( "SELECT * FROM conversations WHERE id = ?", (conversation_id,) ) row = cursor.fetchone() if not row: return None return self._row_to_conversation(row, cursor) def list_conversations(self, user_id: str = "default") -> List[Dict[str, Any]]: """列出用户的所有会话""" conn = self._get_connection() cursor = conn.cursor() cursor.execute( """ SELECT * FROM conversations WHERE user_id = ? ORDER BY updated_at DESC """, (user_id,), ) conversations = [] for row in cursor.fetchall(): conv = self._row_to_conversation(row, cursor, include_messages=False) conversations.append(conv) return conversations def update_conversation( self, conversation_id: str, data: Dict[str, Any] ) -> Optional[Dict[str, Any]]: """更新会话""" conn = self._get_connection() cursor = conn.cursor() # 检查会话是否存在 cursor.execute( "SELECT id FROM conversations WHERE id = ?", (conversation_id,) ) if not cursor.fetchone(): return None now = int(datetime.now(timezone.utc).timestamp() * 1000) # 更新会话字段 update_fields = ["updated_at = ?"] update_values = [now] if "title" in data: update_fields.append("title = ?") update_values.append(data["title"]) if "pinned" in data: update_fields.append("pinned = ?") update_values.append(1 if data["pinned"] else 0) if "archived" in data: update_fields.append("archived = ?") update_values.append(1 if data["archived"] else 0) if "settings" in data: update_fields.append("settings = ?") update_values.append(json.dumps(data["settings"])) update_values.append(conversation_id) cursor.execute( f"UPDATE conversations SET {', '.join(update_fields)} WHERE id = ?", update_values, ) # 更新消息(如果提供了 messages 字段) if "messages" in data: # 删除旧消息 cursor.execute( "DELETE FROM messages WHERE conversation_id = ?", (conversation_id,) ) # 插入新消息 for msg in data["messages"]: self._insert_message(cursor, conversation_id, msg) conn.commit() return self.get_conversation(conversation_id) def delete_conversation(self, conversation_id: str) -> bool: """删除会话(级联删除消息)""" conn = self._get_connection() cursor = conn.cursor() cursor.execute( "DELETE FROM conversations WHERE id = ?", (conversation_id,) ) conn.commit() return cursor.rowcount > 0 # ── 消息操作 ─────────────────────────────────────────────────────── def add_message(self, conversation_id: str, message: Dict[str, Any]) -> Dict[str, Any]: """添加消息到会话""" conn = self._get_connection() cursor = conn.cursor() msg = self._insert_message(cursor, conversation_id, message) # 更新会话的 updated_at now = int(datetime.now(timezone.utc).timestamp() * 1000) cursor.execute( "UPDATE conversations SET updated_at = ? WHERE id = ?", (now, conversation_id), ) conn.commit() return msg def update_message(self, message_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: """更新消息""" conn = self._get_connection() cursor = conn.cursor() cursor.execute("SELECT id FROM messages WHERE id = ?", (message_id,)) if not cursor.fetchone(): return None update_fields = [] update_values = [] if "content" in data: update_fields.append("content = ?") update_values.append(json.dumps(data["content"])) if "feedback" in data: update_fields.append("feedback = ?") update_values.append(json.dumps(data["feedback"])) if not update_fields: return self._get_message_by_id(message_id, cursor) update_values.append(message_id) cursor.execute( f"UPDATE messages SET {', '.join(update_fields)} WHERE id = ?", update_values, ) conn.commit() return self._get_message_by_id(message_id, cursor) # ── 内部方法 ─────────────────────────────────────────────────────── def _insert_message( self, cursor: sqlite3.Cursor, conversation_id: str, message: Dict[str, Any] ) -> Dict[str, Any]: """插入消息(内部方法)""" msg_id = message.get("id") or self._generate_id() timestamp = message.get("timestamp") or int( datetime.now(timezone.utc).timestamp() * 1000 ) cursor.execute( """ INSERT INTO messages (id, conversation_id, role, content, timestamp, feedback) VALUES (?, ?, ?, ?, ?, ?) """, ( msg_id, conversation_id, message.get("role", "user"), json.dumps(message.get("content", "")), timestamp, json.dumps(message.get("feedback")) if message.get("feedback") else None, ), ) return { "id": msg_id, "role": message.get("role", "user"), "content": message.get("content", ""), "timestamp": timestamp, "feedback": message.get("feedback"), } def _get_message_by_id( self, message_id: str, cursor: sqlite3.Cursor ) -> Optional[Dict[str, Any]]: """根据 ID 获取消息""" cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,)) row = cursor.fetchone() return self._row_to_message(row) if row else None def _row_to_conversation( self, row: sqlite3.Row, cursor: sqlite3.Cursor, include_messages: bool = True ) -> Dict[str, Any]: """将数据库行转换为会话字典""" conv = { "id": row["id"], "userId": row["user_id"], "title": row["title"], "createdAt": row["created_at"], "updatedAt": row["updated_at"], "pinned": bool(row["pinned"]), "archived": bool(row["archived"]), "settings": json.loads(row["settings"]) if row["settings"] else None, } if include_messages: cursor.execute( "SELECT * FROM messages WHERE conversation_id = ? ORDER BY timestamp", (row["id"],), ) conv["messages"] = [ self._row_to_message(msg_row) for msg_row in cursor.fetchall() ] return conv def _row_to_message(self, row: sqlite3.Row) -> Dict[str, Any]: """将数据库行转换为消息字典""" return { "id": row["id"], "role": row["role"], "content": json.loads(row["content"]), "timestamp": row["timestamp"], "feedback": json.loads(row["feedback"]) if row["feedback"] else None, } def _generate_id(self) -> str: """生成唯一 ID""" import uuid return str(uuid.uuid4()) def init_db(): """初始化数据库(应用启动时调用)""" global _db_instance if _db_instance is None: _db_instance = Database(DB_PATH) print(f"[数据库] SQLite 初始化完成: {DB_PATH}") def get_db() -> Database: """获取数据库实例""" global _db_instance if _db_instance is None: init_db() return _db_instance