From 4bff571f2b019efd8ef32feba0240f3158201d5a Mon Sep 17 00:00:00 2001 From: SuperManTouX <93423476+SuperManTouX@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:59:03 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9B=E5=BB=BASQLite=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/data/chat.db | Bin 0 -> 28672 bytes server/database/__init__.py | 3 + server/database/db.py | 369 ++++++++++++++++++++++++++++++++++++ 3 files changed, 372 insertions(+) create mode 100644 server/data/chat.db create mode 100644 server/database/__init__.py create mode 100644 server/database/db.py diff --git a/server/data/chat.db b/server/data/chat.db new file mode 100644 index 0000000000000000000000000000000000000000..fe2326518184651d1126edbce24c6dd991aecd00 GIT binary patch literal 28672 zcmeI%F>ljA6bJBgn9(X20wSgk2uyORBGMEbS86BAjW5L_aYB5Fx<*w8 zz5pAaA*>ASe2Y#XBo-C~BWI_!a&=1(qxC<@i5)+m-}^l)(#h?YO;>Q*9gO<6ppVEF z(KPauQbNeO+;Vb@Mn>)gTQsz2=P%#uv00bZa z0SG_<0uX>eY6Ye%*<4|#q)jWX=kQb4Ikh?iZ^B2uE!=_UTVtP(lwWdHgXt!tjaH2v z((5l+rwe1hMU@H4!IYNkY{ya4>{_nyq^M0F$4T_L@7pKbSF(wI%qCKZ&M4vW{6R7@ zizukPn2wUj8ClI0ibd_bE$rhSk1mM5Gne-@^=C~MPe;qjje?uG4mH`KNnaR^J>58> z&)LzvM9`Hi=2x`#P1-tWHk09_L60wsky8rpiDhBJl@$rQKa@>0*)B7Zck6QQ9NV4O z;Z(`f^?idic3VMVG{H#?R%ZrlRayH==VjA_M$}l7$tS3)x2t-MDKOg)GTB14sFCnV z@Za?08ql)GVh|o%Q75cQw6}J4SIZSjCGETu_H0qrR}NN`?kT7Kw}qEa^4z<|>iR*` zq$P)U?Qu_(l5D~iQ4{0nukSv5nSJ~=dw)%QXT)v69ZN2NZ|u}~=q!yHx}L|KXb$N= z&-7t3LwnSDA#rn5d&1px>^00Izz S00bZa0SG_<0uX=zA%6gj3|v(J literal 0 HcmV?d00001 diff --git a/server/database/__init__.py b/server/database/__init__.py new file mode 100644 index 0000000..0b859a3 --- /dev/null +++ b/server/database/__init__.py @@ -0,0 +1,3 @@ +from .db import Database, get_db, init_db + +__all__ = ["Database", "get_db", "init_db"] \ No newline at end of file diff --git a/server/database/db.py b/server/database/db.py new file mode 100644 index 0000000..3886e77 --- /dev/null +++ b/server/database/db.py @@ -0,0 +1,369 @@ +""" +SQLite 数据库模块 - 会话持久化存储 + +提供会话和消息的 CRUD 操作,支持多用户(预留 user_id 字段)。 +""" + +import json +import sqlite3 +import threading +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional + +# 数据库文件路径 +DB_PATH = Path(__file__).parent.parent / "data" / "chat.db" + +# 线程本地存储,确保每个线程使用独立的连接 +_thread_local = threading.local() + +# 全局数据库实例 +_db_instance: Optional["Database"] = None + + +class Database: + """SQLite 数据库管理类""" + + def __init__(self, db_path: Path): + self.db_path = db_path + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_tables() + + def _get_connection(self) -> sqlite3.Connection: + """获取当前线程的数据库连接""" + if not hasattr(_thread_local, "connection"): + conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + conn.row_factory = sqlite3.Row + # 启用外键约束 + conn.execute("PRAGMA foreign_keys = ON") + _thread_local.connection = conn + return _thread_local.connection + + def _init_tables(self): + """初始化数据库表结构""" + conn = self._get_connection() + cursor = conn.cursor() + + # 创建会话表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + user_id TEXT DEFAULT 'default', + title TEXT DEFAULT '新对话', + created_at INTEGER, + updated_at INTEGER, + pinned INTEGER DEFAULT 0, + archived INTEGER DEFAULT 0, + settings TEXT + ) + """) + + # 创建消息表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp INTEGER, + feedback TEXT, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE + ) + """) + + # 创建索引 + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_conversation + ON messages(conversation_id) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_conversations_user + ON conversations(user_id) + """) + + conn.commit() + + # ── 会话 CRUD ───────────────────────────────────────────────────── + + def create_conversation(self, data: Dict[str, Any]) -> Dict[str, Any]: + """创建新会话""" + conn = self._get_connection() + cursor = conn.cursor() + + now = int(datetime.now(timezone.utc).timestamp() * 1000) + conv_id = data.get("id") or self._generate_id() + + cursor.execute( + """ + INSERT INTO conversations (id, user_id, title, created_at, updated_at, pinned, archived, settings) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + conv_id, + data.get("user_id", "default"), + data.get("title", "新对话"), + data.get("createdAt", now), + now, + 1 if data.get("pinned") else 0, + 1 if data.get("archived") else 0, + json.dumps(data.get("settings")) if data.get("settings") else None, + ), + ) + + # 插入消息(如果有) + messages = data.get("messages", []) + for msg in messages: + self._insert_message(cursor, conv_id, msg) + + conn.commit() + return self.get_conversation(conv_id) + + def get_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]: + """获取单个会话(包含消息)""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + "SELECT * FROM conversations WHERE id = ?", (conversation_id,) + ) + row = cursor.fetchone() + + if not row: + return None + + return self._row_to_conversation(row, cursor) + + def list_conversations(self, user_id: str = "default") -> List[Dict[str, Any]]: + """列出用户的所有会话""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + SELECT * FROM conversations + WHERE user_id = ? + ORDER BY updated_at DESC + """, + (user_id,), + ) + + conversations = [] + for row in cursor.fetchall(): + conv = self._row_to_conversation(row, cursor, include_messages=False) + conversations.append(conv) + + return conversations + + def update_conversation( + self, conversation_id: str, data: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """更新会话""" + conn = self._get_connection() + cursor = conn.cursor() + + # 检查会话是否存在 + cursor.execute( + "SELECT id FROM conversations WHERE id = ?", (conversation_id,) + ) + if not cursor.fetchone(): + return None + + now = int(datetime.now(timezone.utc).timestamp() * 1000) + + # 更新会话字段 + update_fields = ["updated_at = ?"] + update_values = [now] + + if "title" in data: + update_fields.append("title = ?") + update_values.append(data["title"]) + if "pinned" in data: + update_fields.append("pinned = ?") + update_values.append(1 if data["pinned"] else 0) + if "archived" in data: + update_fields.append("archived = ?") + update_values.append(1 if data["archived"] else 0) + if "settings" in data: + update_fields.append("settings = ?") + update_values.append(json.dumps(data["settings"])) + + update_values.append(conversation_id) + + cursor.execute( + f"UPDATE conversations SET {', '.join(update_fields)} WHERE id = ?", + update_values, + ) + + # 更新消息(如果提供了 messages 字段) + if "messages" in data: + # 删除旧消息 + cursor.execute( + "DELETE FROM messages WHERE conversation_id = ?", (conversation_id,) + ) + # 插入新消息 + for msg in data["messages"]: + self._insert_message(cursor, conversation_id, msg) + + conn.commit() + return self.get_conversation(conversation_id) + + def delete_conversation(self, conversation_id: str) -> bool: + """删除会话(级联删除消息)""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + "DELETE FROM conversations WHERE id = ?", (conversation_id,) + ) + + conn.commit() + return cursor.rowcount > 0 + + # ── 消息操作 ─────────────────────────────────────────────────────── + + def add_message(self, conversation_id: str, message: Dict[str, Any]) -> Dict[str, Any]: + """添加消息到会话""" + conn = self._get_connection() + cursor = conn.cursor() + + msg = self._insert_message(cursor, conversation_id, message) + + # 更新会话的 updated_at + now = int(datetime.now(timezone.utc).timestamp() * 1000) + cursor.execute( + "UPDATE conversations SET updated_at = ? WHERE id = ?", + (now, conversation_id), + ) + + conn.commit() + return msg + + def update_message(self, message_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """更新消息""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT id FROM messages WHERE id = ?", (message_id,)) + if not cursor.fetchone(): + return None + + update_fields = [] + update_values = [] + + if "content" in data: + update_fields.append("content = ?") + update_values.append(json.dumps(data["content"])) + if "feedback" in data: + update_fields.append("feedback = ?") + update_values.append(json.dumps(data["feedback"])) + + if not update_fields: + return self._get_message_by_id(message_id, cursor) + + update_values.append(message_id) + cursor.execute( + f"UPDATE messages SET {', '.join(update_fields)} WHERE id = ?", + update_values, + ) + + conn.commit() + return self._get_message_by_id(message_id, cursor) + + # ── 内部方法 ─────────────────────────────────────────────────────── + + def _insert_message( + self, cursor: sqlite3.Cursor, conversation_id: str, message: Dict[str, Any] + ) -> Dict[str, Any]: + """插入消息(内部方法)""" + msg_id = message.get("id") or self._generate_id() + timestamp = message.get("timestamp") or int( + datetime.now(timezone.utc).timestamp() * 1000 + ) + + cursor.execute( + """ + INSERT INTO messages (id, conversation_id, role, content, timestamp, feedback) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + msg_id, + conversation_id, + message.get("role", "user"), + json.dumps(message.get("content", "")), + timestamp, + json.dumps(message.get("feedback")) if message.get("feedback") else None, + ), + ) + + return { + "id": msg_id, + "role": message.get("role", "user"), + "content": message.get("content", ""), + "timestamp": timestamp, + "feedback": message.get("feedback"), + } + + def _get_message_by_id( + self, message_id: str, cursor: sqlite3.Cursor + ) -> Optional[Dict[str, Any]]: + """根据 ID 获取消息""" + cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,)) + row = cursor.fetchone() + return self._row_to_message(row) if row else None + + def _row_to_conversation( + self, row: sqlite3.Row, cursor: sqlite3.Cursor, include_messages: bool = True + ) -> Dict[str, Any]: + """将数据库行转换为会话字典""" + conv = { + "id": row["id"], + "userId": row["user_id"], + "title": row["title"], + "createdAt": row["created_at"], + "updatedAt": row["updated_at"], + "pinned": bool(row["pinned"]), + "archived": bool(row["archived"]), + "settings": json.loads(row["settings"]) if row["settings"] else None, + } + + if include_messages: + cursor.execute( + "SELECT * FROM messages WHERE conversation_id = ? ORDER BY timestamp", + (row["id"],), + ) + conv["messages"] = [ + self._row_to_message(msg_row) for msg_row in cursor.fetchall() + ] + + return conv + + def _row_to_message(self, row: sqlite3.Row) -> Dict[str, Any]: + """将数据库行转换为消息字典""" + return { + "id": row["id"], + "role": row["role"], + "content": json.loads(row["content"]), + "timestamp": row["timestamp"], + "feedback": json.loads(row["feedback"]) if row["feedback"] else None, + } + + def _generate_id(self) -> str: + """生成唯一 ID""" + import uuid + return str(uuid.uuid4()) + + +def init_db(): + """初始化数据库(应用启动时调用)""" + global _db_instance + if _db_instance is None: + _db_instance = Database(DB_PATH) + print(f"[数据库] SQLite 初始化完成: {DB_PATH}") + + +def get_db() -> Database: + """获取数据库实例""" + global _db_instance + if _db_instance is None: + init_db() + return _db_instance \ No newline at end of file