ai-chat-ui/server/database/db.py

401 lines
14 KiB
Python

"""
SQLite 数据库模块 - 会话持久化存储
提供会话和消息的 CRUD 操作,支持多用户(预留 user_id 字段)。
"""
import json
import sqlite3
import threading
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
# 数据库文件路径
DB_PATH = Path(__file__).parent.parent / "data" / "chat.db"
# 线程本地存储,确保每个线程使用独立的连接
_thread_local = threading.local()
# 全局数据库实例
_db_instance: Optional["Database"] = None
class Database:
"""SQLite 数据库管理类"""
def __init__(self, db_path: Path):
self.db_path = db_path
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_tables()
def _get_connection(self) -> sqlite3.Connection:
"""获取当前线程的数据库连接"""
if not hasattr(_thread_local, "connection"):
conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
conn.row_factory = sqlite3.Row
# 启用外键约束
conn.execute("PRAGMA foreign_keys = ON")
_thread_local.connection = conn
return _thread_local.connection
def _init_tables(self):
"""初始化数据库表结构"""
conn = self._get_connection()
cursor = conn.cursor()
# 创建会话表
cursor.execute("""
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
user_id TEXT DEFAULT 'default',
title TEXT DEFAULT '新对话',
created_at INTEGER,
updated_at INTEGER,
pinned INTEGER DEFAULT 0,
archived INTEGER DEFAULT 0,
settings TEXT
)
""")
# 创建消息表
cursor.execute("""
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp INTEGER,
feedback TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
)
""")
# 创建索引
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_conversation
ON messages(conversation_id)
""")
# 检查并添加缺失的列(迁移旧数据库 - conversations 表)
cursor.execute("PRAGMA table_info(conversations)")
conv_columns = [col[1] for col in cursor.fetchall()]
conv_migrations = [
('user_id', "TEXT DEFAULT 'default'"),
('pinned', "INTEGER DEFAULT 0"),
('archived', "INTEGER DEFAULT 0"),
('settings', "TEXT"),
]
for col_name, col_def in conv_migrations:
if col_name not in conv_columns:
cursor.execute(f"ALTER TABLE conversations ADD COLUMN {col_name} {col_def}")
print(f"[数据库] conversations 表已添加 {col_name}")
# 检查并添加缺失的列(迁移旧数据库 - messages 表)
cursor.execute("PRAGMA table_info(messages)")
msg_columns = [col[1] for col in cursor.fetchall()]
msg_migrations = [
('timestamp', "INTEGER"),
('feedback', "TEXT"),
]
for col_name, col_def in msg_migrations:
if col_name not in msg_columns:
cursor.execute(f"ALTER TABLE messages ADD COLUMN {col_name} {col_def}")
print(f"[数据库] messages 表已添加 {col_name}")
# 创建 user_id 索引(在确保列存在后)
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_conversations_user
ON conversations(user_id)
""")
conn.commit()
# ── 会话 CRUD ─────────────────────────────────────────────────────
def create_conversation(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""创建新会话"""
conn = self._get_connection()
cursor = conn.cursor()
now = int(datetime.now(timezone.utc).timestamp() * 1000)
conv_id = data.get("id") or self._generate_id()
cursor.execute(
"""
INSERT INTO conversations (id, user_id, title, created_at, updated_at, pinned, archived, settings)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
conv_id,
data.get("user_id", "default"),
data.get("title", "新对话"),
data.get("createdAt", now),
now,
1 if data.get("pinned") else 0,
1 if data.get("archived") else 0,
json.dumps(data.get("settings")) if data.get("settings") else None,
),
)
# 插入消息(如果有)
messages = data.get("messages", [])
for msg in messages:
self._insert_message(cursor, conv_id, msg)
conn.commit()
return self.get_conversation(conv_id)
def get_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]:
"""获取单个会话(包含消息)"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM conversations WHERE id = ?", (conversation_id,)
)
row = cursor.fetchone()
if not row:
return None
return self._row_to_conversation(row, cursor)
def list_conversations(self, user_id: str = "default") -> List[Dict[str, Any]]:
"""列出用户的所有会话"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
SELECT * FROM conversations
WHERE user_id = ?
ORDER BY updated_at DESC
""",
(user_id,),
)
conversations = []
for row in cursor.fetchall():
conv = self._row_to_conversation(row, cursor, include_messages=False)
conversations.append(conv)
return conversations
def update_conversation(
self, conversation_id: str, data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""更新会话"""
conn = self._get_connection()
cursor = conn.cursor()
# 检查会话是否存在
cursor.execute(
"SELECT id FROM conversations WHERE id = ?", (conversation_id,)
)
if not cursor.fetchone():
return None
now = int(datetime.now(timezone.utc).timestamp() * 1000)
# 更新会话字段
update_fields = ["updated_at = ?"]
update_values = [now]
if "title" in data:
update_fields.append("title = ?")
update_values.append(data["title"])
if "pinned" in data:
update_fields.append("pinned = ?")
update_values.append(1 if data["pinned"] else 0)
if "archived" in data:
update_fields.append("archived = ?")
update_values.append(1 if data["archived"] else 0)
if "settings" in data:
update_fields.append("settings = ?")
update_values.append(json.dumps(data["settings"]))
update_values.append(conversation_id)
cursor.execute(
f"UPDATE conversations SET {', '.join(update_fields)} WHERE id = ?",
update_values,
)
# 更新消息(如果提供了 messages 字段)
if "messages" in data:
# 删除旧消息
cursor.execute(
"DELETE FROM messages WHERE conversation_id = ?", (conversation_id,)
)
# 插入新消息
for msg in data["messages"]:
self._insert_message(cursor, conversation_id, msg)
conn.commit()
return self.get_conversation(conversation_id)
def delete_conversation(self, conversation_id: str) -> bool:
"""删除会话(级联删除消息)"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"DELETE FROM conversations WHERE id = ?", (conversation_id,)
)
conn.commit()
return cursor.rowcount > 0
# ── 消息操作 ───────────────────────────────────────────────────────
def add_message(self, conversation_id: str, message: Dict[str, Any]) -> Dict[str, Any]:
"""添加消息到会话"""
conn = self._get_connection()
cursor = conn.cursor()
msg = self._insert_message(cursor, conversation_id, message)
# 更新会话的 updated_at
now = int(datetime.now(timezone.utc).timestamp() * 1000)
cursor.execute(
"UPDATE conversations SET updated_at = ? WHERE id = ?",
(now, conversation_id),
)
conn.commit()
return msg
def update_message(self, message_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""更新消息"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT id FROM messages WHERE id = ?", (message_id,))
if not cursor.fetchone():
return None
update_fields = []
update_values = []
if "content" in data:
update_fields.append("content = ?")
update_values.append(json.dumps(data["content"]))
if "feedback" in data:
update_fields.append("feedback = ?")
update_values.append(json.dumps(data["feedback"]))
if not update_fields:
return self._get_message_by_id(message_id, cursor)
update_values.append(message_id)
cursor.execute(
f"UPDATE messages SET {', '.join(update_fields)} WHERE id = ?",
update_values,
)
conn.commit()
return self._get_message_by_id(message_id, cursor)
# ── 内部方法 ───────────────────────────────────────────────────────
def _insert_message(
self, cursor: sqlite3.Cursor, conversation_id: str, message: Dict[str, Any]
) -> Dict[str, Any]:
"""插入消息(内部方法)"""
msg_id = message.get("id") or self._generate_id()
timestamp = message.get("timestamp") or int(
datetime.now(timezone.utc).timestamp() * 1000
)
cursor.execute(
"""
INSERT INTO messages (id, conversation_id, role, content, timestamp, feedback)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
msg_id,
conversation_id,
message.get("role", "user"),
json.dumps(message.get("content", "")),
timestamp,
json.dumps(message.get("feedback")) if message.get("feedback") else None,
),
)
return {
"id": msg_id,
"role": message.get("role", "user"),
"content": message.get("content", ""),
"timestamp": timestamp,
"feedback": message.get("feedback"),
}
def _get_message_by_id(
self, message_id: str, cursor: sqlite3.Cursor
) -> Optional[Dict[str, Any]]:
"""根据 ID 获取消息"""
cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,))
row = cursor.fetchone()
return self._row_to_message(row) if row else None
def _row_to_conversation(
self, row: sqlite3.Row, cursor: sqlite3.Cursor, include_messages: bool = True
) -> Dict[str, Any]:
"""将数据库行转换为会话字典"""
conv = {
"id": row["id"],
"userId": row["user_id"],
"title": row["title"],
"createdAt": row["created_at"],
"updatedAt": row["updated_at"],
"pinned": bool(row["pinned"]),
"archived": bool(row["archived"]),
"settings": json.loads(row["settings"]) if row["settings"] else None,
}
if include_messages:
cursor.execute(
"SELECT * FROM messages WHERE conversation_id = ? ORDER BY timestamp",
(row["id"],),
)
conv["messages"] = [
self._row_to_message(msg_row) for msg_row in cursor.fetchall()
]
return conv
def _row_to_message(self, row: sqlite3.Row) -> Dict[str, Any]:
"""将数据库行转换为消息字典"""
return {
"id": row["id"],
"role": row["role"],
"content": json.loads(row["content"]),
"timestamp": row["timestamp"],
"feedback": json.loads(row["feedback"]) if row["feedback"] else None,
}
def _generate_id(self) -> str:
"""生成唯一 ID"""
import uuid
return str(uuid.uuid4())
def init_db():
"""初始化数据库(应用启动时调用)"""
global _db_instance
if _db_instance is None:
_db_instance = Database(DB_PATH)
print(f"[数据库] SQLite 初始化完成: {DB_PATH}")
def get_db() -> Database:
"""获取数据库实例"""
global _db_instance
if _db_instance is None:
init_db()
return _db_instance