401 lines
14 KiB
Python
401 lines
14 KiB
Python
"""
|
|
SQLite 数据库模块 - 会话持久化存储
|
|
|
|
提供会话和消息的 CRUD 操作,支持多用户(预留 user_id 字段)。
|
|
"""
|
|
|
|
import json
|
|
import sqlite3
|
|
import threading
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
# 数据库文件路径
|
|
DB_PATH = Path(__file__).parent.parent / "data" / "chat.db"
|
|
|
|
# 线程本地存储,确保每个线程使用独立的连接
|
|
_thread_local = threading.local()
|
|
|
|
# 全局数据库实例
|
|
_db_instance: Optional["Database"] = None
|
|
|
|
|
|
class Database:
|
|
"""SQLite 数据库管理类"""
|
|
|
|
def __init__(self, db_path: Path):
|
|
self.db_path = db_path
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._init_tables()
|
|
|
|
def _get_connection(self) -> sqlite3.Connection:
|
|
"""获取当前线程的数据库连接"""
|
|
if not hasattr(_thread_local, "connection"):
|
|
conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
|
conn.row_factory = sqlite3.Row
|
|
# 启用外键约束
|
|
conn.execute("PRAGMA foreign_keys = ON")
|
|
_thread_local.connection = conn
|
|
return _thread_local.connection
|
|
|
|
def _init_tables(self):
|
|
"""初始化数据库表结构"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# 创建会话表
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS conversations (
|
|
id TEXT PRIMARY KEY,
|
|
user_id TEXT DEFAULT 'default',
|
|
title TEXT DEFAULT '新对话',
|
|
created_at INTEGER,
|
|
updated_at INTEGER,
|
|
pinned INTEGER DEFAULT 0,
|
|
archived INTEGER DEFAULT 0,
|
|
settings TEXT
|
|
)
|
|
""")
|
|
|
|
# 创建消息表
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id TEXT PRIMARY KEY,
|
|
conversation_id TEXT NOT NULL,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
timestamp INTEGER,
|
|
feedback TEXT,
|
|
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
|
)
|
|
""")
|
|
|
|
# 创建索引
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_messages_conversation
|
|
ON messages(conversation_id)
|
|
""")
|
|
|
|
# 检查并添加缺失的列(迁移旧数据库 - conversations 表)
|
|
cursor.execute("PRAGMA table_info(conversations)")
|
|
conv_columns = [col[1] for col in cursor.fetchall()]
|
|
|
|
conv_migrations = [
|
|
('user_id', "TEXT DEFAULT 'default'"),
|
|
('pinned', "INTEGER DEFAULT 0"),
|
|
('archived', "INTEGER DEFAULT 0"),
|
|
('settings', "TEXT"),
|
|
]
|
|
|
|
for col_name, col_def in conv_migrations:
|
|
if col_name not in conv_columns:
|
|
cursor.execute(f"ALTER TABLE conversations ADD COLUMN {col_name} {col_def}")
|
|
print(f"[数据库] conversations 表已添加 {col_name} 列")
|
|
|
|
# 检查并添加缺失的列(迁移旧数据库 - messages 表)
|
|
cursor.execute("PRAGMA table_info(messages)")
|
|
msg_columns = [col[1] for col in cursor.fetchall()]
|
|
|
|
msg_migrations = [
|
|
('timestamp', "INTEGER"),
|
|
('feedback', "TEXT"),
|
|
]
|
|
|
|
for col_name, col_def in msg_migrations:
|
|
if col_name not in msg_columns:
|
|
cursor.execute(f"ALTER TABLE messages ADD COLUMN {col_name} {col_def}")
|
|
print(f"[数据库] messages 表已添加 {col_name} 列")
|
|
|
|
# 创建 user_id 索引(在确保列存在后)
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_conversations_user
|
|
ON conversations(user_id)
|
|
""")
|
|
|
|
conn.commit()
|
|
|
|
# ── 会话 CRUD ─────────────────────────────────────────────────────
|
|
|
|
def create_conversation(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""创建新会话"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
now = int(datetime.now(timezone.utc).timestamp() * 1000)
|
|
conv_id = data.get("id") or self._generate_id()
|
|
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO conversations (id, user_id, title, created_at, updated_at, pinned, archived, settings)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
conv_id,
|
|
data.get("user_id", "default"),
|
|
data.get("title", "新对话"),
|
|
data.get("createdAt", now),
|
|
now,
|
|
1 if data.get("pinned") else 0,
|
|
1 if data.get("archived") else 0,
|
|
json.dumps(data.get("settings")) if data.get("settings") else None,
|
|
),
|
|
)
|
|
|
|
# 插入消息(如果有)
|
|
messages = data.get("messages", [])
|
|
for msg in messages:
|
|
self._insert_message(cursor, conv_id, msg)
|
|
|
|
conn.commit()
|
|
return self.get_conversation(conv_id)
|
|
|
|
def get_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]:
|
|
"""获取单个会话(包含消息)"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"SELECT * FROM conversations WHERE id = ?", (conversation_id,)
|
|
)
|
|
row = cursor.fetchone()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return self._row_to_conversation(row, cursor)
|
|
|
|
def list_conversations(self, user_id: str = "default") -> List[Dict[str, Any]]:
|
|
"""列出用户的所有会话"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"""
|
|
SELECT * FROM conversations
|
|
WHERE user_id = ?
|
|
ORDER BY updated_at DESC
|
|
""",
|
|
(user_id,),
|
|
)
|
|
|
|
conversations = []
|
|
for row in cursor.fetchall():
|
|
conv = self._row_to_conversation(row, cursor, include_messages=False)
|
|
conversations.append(conv)
|
|
|
|
return conversations
|
|
|
|
def update_conversation(
|
|
self, conversation_id: str, data: Dict[str, Any]
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""更新会话"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# 检查会话是否存在
|
|
cursor.execute(
|
|
"SELECT id FROM conversations WHERE id = ?", (conversation_id,)
|
|
)
|
|
if not cursor.fetchone():
|
|
return None
|
|
|
|
now = int(datetime.now(timezone.utc).timestamp() * 1000)
|
|
|
|
# 更新会话字段
|
|
update_fields = ["updated_at = ?"]
|
|
update_values = [now]
|
|
|
|
if "title" in data:
|
|
update_fields.append("title = ?")
|
|
update_values.append(data["title"])
|
|
if "pinned" in data:
|
|
update_fields.append("pinned = ?")
|
|
update_values.append(1 if data["pinned"] else 0)
|
|
if "archived" in data:
|
|
update_fields.append("archived = ?")
|
|
update_values.append(1 if data["archived"] else 0)
|
|
if "settings" in data:
|
|
update_fields.append("settings = ?")
|
|
update_values.append(json.dumps(data["settings"]))
|
|
|
|
update_values.append(conversation_id)
|
|
|
|
cursor.execute(
|
|
f"UPDATE conversations SET {', '.join(update_fields)} WHERE id = ?",
|
|
update_values,
|
|
)
|
|
|
|
# 更新消息(如果提供了 messages 字段)
|
|
if "messages" in data:
|
|
# 删除旧消息
|
|
cursor.execute(
|
|
"DELETE FROM messages WHERE conversation_id = ?", (conversation_id,)
|
|
)
|
|
# 插入新消息
|
|
for msg in data["messages"]:
|
|
self._insert_message(cursor, conversation_id, msg)
|
|
|
|
conn.commit()
|
|
return self.get_conversation(conversation_id)
|
|
|
|
def delete_conversation(self, conversation_id: str) -> bool:
|
|
"""删除会话(级联删除消息)"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"DELETE FROM conversations WHERE id = ?", (conversation_id,)
|
|
)
|
|
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
|
|
# ── 消息操作 ───────────────────────────────────────────────────────
|
|
|
|
def add_message(self, conversation_id: str, message: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""添加消息到会话"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
msg = self._insert_message(cursor, conversation_id, message)
|
|
|
|
# 更新会话的 updated_at
|
|
now = int(datetime.now(timezone.utc).timestamp() * 1000)
|
|
cursor.execute(
|
|
"UPDATE conversations SET updated_at = ? WHERE id = ?",
|
|
(now, conversation_id),
|
|
)
|
|
|
|
conn.commit()
|
|
return msg
|
|
|
|
def update_message(self, message_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
"""更新消息"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT id FROM messages WHERE id = ?", (message_id,))
|
|
if not cursor.fetchone():
|
|
return None
|
|
|
|
update_fields = []
|
|
update_values = []
|
|
|
|
if "content" in data:
|
|
update_fields.append("content = ?")
|
|
update_values.append(json.dumps(data["content"]))
|
|
if "feedback" in data:
|
|
update_fields.append("feedback = ?")
|
|
update_values.append(json.dumps(data["feedback"]))
|
|
|
|
if not update_fields:
|
|
return self._get_message_by_id(message_id, cursor)
|
|
|
|
update_values.append(message_id)
|
|
cursor.execute(
|
|
f"UPDATE messages SET {', '.join(update_fields)} WHERE id = ?",
|
|
update_values,
|
|
)
|
|
|
|
conn.commit()
|
|
return self._get_message_by_id(message_id, cursor)
|
|
|
|
# ── 内部方法 ───────────────────────────────────────────────────────
|
|
|
|
def _insert_message(
|
|
self, cursor: sqlite3.Cursor, conversation_id: str, message: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""插入消息(内部方法)"""
|
|
msg_id = message.get("id") or self._generate_id()
|
|
timestamp = message.get("timestamp") or int(
|
|
datetime.now(timezone.utc).timestamp() * 1000
|
|
)
|
|
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO messages (id, conversation_id, role, content, timestamp, feedback)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
msg_id,
|
|
conversation_id,
|
|
message.get("role", "user"),
|
|
json.dumps(message.get("content", "")),
|
|
timestamp,
|
|
json.dumps(message.get("feedback")) if message.get("feedback") else None,
|
|
),
|
|
)
|
|
|
|
return {
|
|
"id": msg_id,
|
|
"role": message.get("role", "user"),
|
|
"content": message.get("content", ""),
|
|
"timestamp": timestamp,
|
|
"feedback": message.get("feedback"),
|
|
}
|
|
|
|
def _get_message_by_id(
|
|
self, message_id: str, cursor: sqlite3.Cursor
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""根据 ID 获取消息"""
|
|
cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,))
|
|
row = cursor.fetchone()
|
|
return self._row_to_message(row) if row else None
|
|
|
|
def _row_to_conversation(
|
|
self, row: sqlite3.Row, cursor: sqlite3.Cursor, include_messages: bool = True
|
|
) -> Dict[str, Any]:
|
|
"""将数据库行转换为会话字典"""
|
|
conv = {
|
|
"id": row["id"],
|
|
"userId": row["user_id"],
|
|
"title": row["title"],
|
|
"createdAt": row["created_at"],
|
|
"updatedAt": row["updated_at"],
|
|
"pinned": bool(row["pinned"]),
|
|
"archived": bool(row["archived"]),
|
|
"settings": json.loads(row["settings"]) if row["settings"] else None,
|
|
}
|
|
|
|
if include_messages:
|
|
cursor.execute(
|
|
"SELECT * FROM messages WHERE conversation_id = ? ORDER BY timestamp",
|
|
(row["id"],),
|
|
)
|
|
conv["messages"] = [
|
|
self._row_to_message(msg_row) for msg_row in cursor.fetchall()
|
|
]
|
|
|
|
return conv
|
|
|
|
def _row_to_message(self, row: sqlite3.Row) -> Dict[str, Any]:
|
|
"""将数据库行转换为消息字典"""
|
|
return {
|
|
"id": row["id"],
|
|
"role": row["role"],
|
|
"content": json.loads(row["content"]),
|
|
"timestamp": row["timestamp"],
|
|
"feedback": json.loads(row["feedback"]) if row["feedback"] else None,
|
|
}
|
|
|
|
def _generate_id(self) -> str:
|
|
"""生成唯一 ID"""
|
|
import uuid
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
def init_db():
|
|
"""初始化数据库(应用启动时调用)"""
|
|
global _db_instance
|
|
if _db_instance is None:
|
|
_db_instance = Database(DB_PATH)
|
|
print(f"[数据库] SQLite 初始化完成: {DB_PATH}")
|
|
|
|
|
|
def get_db() -> Database:
|
|
"""获取数据库实例"""
|
|
global _db_instance
|
|
if _db_instance is None:
|
|
init_db()
|
|
return _db_instance |