ai-chat-ui/server/database/db.py

520 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
SQLite 数据库模块 - 会话持久化存储
提供会话和消息的 CRUD 操作,支持多用户(预留 user_id 字段)。
"""
import json
import sqlite3
import threading
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
# 数据库文件路径
DB_PATH = Path(__file__).parent.parent / "data" / "chat.db"
# 线程本地存储,确保每个线程使用独立的连接
_thread_local = threading.local()
# 全局数据库实例
_db_instance: Optional["Database"] = None
class Database:
"""SQLite 数据库管理类"""
def __init__(self, db_path: Path):
self.db_path = db_path
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_tables()
def _get_connection(self) -> sqlite3.Connection:
"""获取当前线程的数据库连接"""
if not hasattr(_thread_local, "connection"):
conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
conn.row_factory = sqlite3.Row
# 启用外键约束
conn.execute("PRAGMA foreign_keys = ON")
_thread_local.connection = conn
return _thread_local.connection
def _init_tables(self):
"""初始化数据库表结构"""
conn = self._get_connection()
cursor = conn.cursor()
# 创建会话表
cursor.execute("""
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
user_id TEXT DEFAULT 'default',
title TEXT DEFAULT '新对话',
created_at INTEGER,
updated_at INTEGER,
pinned INTEGER DEFAULT 0,
archived INTEGER DEFAULT 0,
settings TEXT
)
""")
# 创建消息表
cursor.execute("""
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp INTEGER,
feedback TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
)
""")
# 创建索引
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_conversation
ON messages(conversation_id)
""")
# 检查并添加缺失的列(迁移旧数据库 - conversations 表)
cursor.execute("PRAGMA table_info(conversations)")
conv_columns = [col[1] for col in cursor.fetchall()]
conv_migrations = [
('user_id', "TEXT DEFAULT 'default'"),
('pinned', "INTEGER DEFAULT 0"),
('archived', "INTEGER DEFAULT 0"),
('settings', "TEXT"),
]
for col_name, col_def in conv_migrations:
if col_name not in conv_columns:
cursor.execute(f"ALTER TABLE conversations ADD COLUMN {col_name} {col_def}")
print(f"[数据库] conversations 表已添加 {col_name}")
# 检查并添加缺失的列(迁移旧数据库 - messages 表)
cursor.execute("PRAGMA table_info(messages)")
msg_columns = [col[1] for col in cursor.fetchall()]
msg_migrations = [
('timestamp', "INTEGER"),
('feedback', "TEXT"),
]
for col_name, col_def in msg_migrations:
if col_name not in msg_columns:
cursor.execute(f"ALTER TABLE messages ADD COLUMN {col_name} {col_def}")
print(f"[数据库] messages 表已添加 {col_name}")
# 创建 user_id 索引(在确保列存在后)
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_conversations_user
ON conversations(user_id)
""")
# 创建分享表
cursor.execute("""
CREATE TABLE IF NOT EXISTS shares (
id TEXT PRIMARY KEY,
conversation_ids TEXT NOT NULL,
conversations TEXT NOT NULL,
password_hash TEXT NOT NULL,
created_at INTEGER,
expires_at INTEGER,
view_count INTEGER DEFAULT 0
)
""")
# 创建分享过期时间索引
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_shares_expires
ON shares(expires_at)
""")
conn.commit()
# ── 会话 CRUD ─────────────────────────────────────────────────────
def create_conversation(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""创建新会话"""
conn = self._get_connection()
cursor = conn.cursor()
now = int(datetime.now(timezone.utc).timestamp() * 1000)
conv_id = data.get("id") or self._generate_id()
cursor.execute(
"""
INSERT INTO conversations (id, user_id, title, created_at, updated_at, pinned, archived, settings)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
conv_id,
data.get("user_id", "default"),
data.get("title", "新对话"),
data.get("createdAt", now),
now,
1 if data.get("pinned") else 0,
1 if data.get("archived") else 0,
json.dumps(data.get("settings")) if data.get("settings") else None,
),
)
# 插入消息(如果有)
messages = data.get("messages", [])
for msg in messages:
self._insert_message(cursor, conv_id, msg)
conn.commit()
return self.get_conversation(conv_id)
def get_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]:
"""获取单个会话(包含消息)"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM conversations WHERE id = ?", (conversation_id,)
)
row = cursor.fetchone()
if not row:
return None
return self._row_to_conversation(row, cursor)
def list_conversations(self, user_id: str = "default") -> List[Dict[str, Any]]:
"""列出用户的所有会话"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
SELECT * FROM conversations
WHERE user_id = ?
ORDER BY updated_at DESC
""",
(user_id,),
)
conversations = []
for row in cursor.fetchall():
conv = self._row_to_conversation(row, cursor, include_messages=False)
conversations.append(conv)
return conversations
def update_conversation(
self, conversation_id: str, data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""更新会话"""
conn = self._get_connection()
cursor = conn.cursor()
# 检查会话是否存在
cursor.execute(
"SELECT id FROM conversations WHERE id = ?", (conversation_id,)
)
if not cursor.fetchone():
return None
now = int(datetime.now(timezone.utc).timestamp() * 1000)
# 更新会话字段
update_fields = ["updated_at = ?"]
update_values = [now]
if "title" in data:
update_fields.append("title = ?")
update_values.append(data["title"])
if "pinned" in data:
update_fields.append("pinned = ?")
update_values.append(1 if data["pinned"] else 0)
if "archived" in data:
update_fields.append("archived = ?")
update_values.append(1 if data["archived"] else 0)
if "settings" in data:
update_fields.append("settings = ?")
update_values.append(json.dumps(data["settings"]))
update_values.append(conversation_id)
cursor.execute(
f"UPDATE conversations SET {', '.join(update_fields)} WHERE id = ?",
update_values,
)
# 更新消息(如果提供了 messages 字段)
if "messages" in data:
# 删除旧消息
cursor.execute(
"DELETE FROM messages WHERE conversation_id = ?", (conversation_id,)
)
# 插入新消息
for msg in data["messages"]:
self._insert_message(cursor, conversation_id, msg)
conn.commit()
return self.get_conversation(conversation_id)
def delete_conversation(self, conversation_id: str) -> bool:
"""删除会话(级联删除消息)"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"DELETE FROM conversations WHERE id = ?", (conversation_id,)
)
conn.commit()
return cursor.rowcount > 0
# ── 消息操作 ───────────────────────────────────────────────────────
def add_message(self, conversation_id: str, message: Dict[str, Any]) -> Dict[str, Any]:
"""添加消息到会话"""
conn = self._get_connection()
cursor = conn.cursor()
msg = self._insert_message(cursor, conversation_id, message)
# 更新会话的 updated_at
now = int(datetime.now(timezone.utc).timestamp() * 1000)
cursor.execute(
"UPDATE conversations SET updated_at = ? WHERE id = ?",
(now, conversation_id),
)
conn.commit()
return msg
def update_message(self, message_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""更新消息"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT id FROM messages WHERE id = ?", (message_id,))
if not cursor.fetchone():
return None
update_fields = []
update_values = []
if "content" in data:
update_fields.append("content = ?")
update_values.append(json.dumps(data["content"]))
if "feedback" in data:
update_fields.append("feedback = ?")
update_values.append(json.dumps(data["feedback"]))
if not update_fields:
return self._get_message_by_id(message_id, cursor)
update_values.append(message_id)
cursor.execute(
f"UPDATE messages SET {', '.join(update_fields)} WHERE id = ?",
update_values,
)
conn.commit()
return self._get_message_by_id(message_id, cursor)
# ── 内部方法 ───────────────────────────────────────────────────────
def _insert_message(
self, cursor: sqlite3.Cursor, conversation_id: str, message: Dict[str, Any]
) -> Dict[str, Any]:
"""插入消息(内部方法)"""
msg_id = message.get("id") or self._generate_id()
timestamp = message.get("timestamp") or int(
datetime.now(timezone.utc).timestamp() * 1000
)
cursor.execute(
"""
INSERT INTO messages (id, conversation_id, role, content, timestamp, feedback)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
msg_id,
conversation_id,
message.get("role", "user"),
json.dumps(message.get("content", "")),
timestamp,
json.dumps(message.get("feedback")) if message.get("feedback") else None,
),
)
return {
"id": msg_id,
"role": message.get("role", "user"),
"content": message.get("content", ""),
"timestamp": timestamp,
"feedback": message.get("feedback"),
}
def _get_message_by_id(
self, message_id: str, cursor: sqlite3.Cursor
) -> Optional[Dict[str, Any]]:
"""根据 ID 获取消息"""
cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,))
row = cursor.fetchone()
return self._row_to_message(row) if row else None
def _row_to_conversation(
self, row: sqlite3.Row, cursor: sqlite3.Cursor, include_messages: bool = True
) -> Dict[str, Any]:
"""将数据库行转换为会话字典"""
conv = {
"id": row["id"],
"userId": row["user_id"],
"title": row["title"],
"createdAt": row["created_at"],
"updatedAt": row["updated_at"],
"pinned": bool(row["pinned"]),
"archived": bool(row["archived"]),
"settings": json.loads(row["settings"]) if row["settings"] else None,
}
if include_messages:
cursor.execute(
"SELECT * FROM messages WHERE conversation_id = ? ORDER BY timestamp",
(row["id"],),
)
conv["messages"] = [
self._row_to_message(msg_row) for msg_row in cursor.fetchall()
]
return conv
def _row_to_message(self, row: sqlite3.Row) -> Dict[str, Any]:
"""将数据库行转换为消息字典"""
return {
"id": row["id"],
"role": row["role"],
"content": json.loads(row["content"]),
"timestamp": row["timestamp"],
"feedback": json.loads(row["feedback"]) if row["feedback"] else None,
}
def _generate_id(self) -> str:
"""生成唯一 ID"""
import uuid
return str(uuid.uuid4())
# ── 分享 CRUD ───────────────────────────────────────────────────────
def create_share(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""创建分享"""
conn = self._get_connection()
cursor = conn.cursor()
now = int(datetime.now(timezone.utc).timestamp() * 1000)
share_id = data.get("id") or self._generate_share_id()
cursor.execute(
"""
INSERT INTO shares (id, conversation_ids, conversations, password_hash, created_at, expires_at, view_count)
VALUES (?, ?, ?, ?, ?, ?, 0)
""",
(
share_id,
json.dumps(data.get("conversationIds", [])),
json.dumps(data.get("conversations", [])),
data.get("passwordHash", ""),
now,
data.get("expiresAt", now + 604800000), # 默认7天
),
)
conn.commit()
return self.get_share(share_id)
def get_share(self, share_id: str) -> Optional[Dict[str, Any]]:
"""获取分享"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT * FROM shares WHERE id = ?", (share_id,))
row = cursor.fetchone()
if not row:
return None
return self._row_to_share(row)
def update_share_view_count(self, share_id: str) -> int:
"""增加分享访问计数"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"UPDATE shares SET view_count = view_count + 1 WHERE id = ?",
(share_id,),
)
conn.commit()
# 返回更新后的计数
cursor.execute("SELECT view_count FROM shares WHERE id = ?", (share_id,))
row = cursor.fetchone()
return row["view_count"] if row else 0
def delete_share(self, share_id: str) -> bool:
"""删除分享"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("DELETE FROM shares WHERE id = ?", (share_id,))
conn.commit()
return cursor.rowcount > 0
def cleanup_expired_shares(self) -> int:
"""清理过期分享"""
conn = self._get_connection()
cursor = conn.cursor()
now = int(datetime.now(timezone.utc).timestamp() * 1000)
cursor.execute("DELETE FROM shares WHERE expires_at < ?", (now,))
conn.commit()
deleted_count = cursor.rowcount
if deleted_count > 0:
print(f"[数据库] 已清理 {deleted_count} 个过期分享")
return deleted_count
def _row_to_share(self, row: sqlite3.Row) -> Dict[str, Any]:
"""将数据库行转换为分享字典"""
return {
"id": row["id"],
"conversationIds": json.loads(row["conversation_ids"]),
"conversations": json.loads(row["conversations"]),
"passwordHash": row["password_hash"],
"createdAt": row["created_at"],
"expiresAt": row["expires_at"],
"viewCount": row["view_count"],
}
def _generate_share_id(self) -> str:
"""生成分享 ID8位短链接"""
import uuid
return uuid.uuid4().hex[:8]
def init_db():
"""初始化数据库(应用启动时调用)"""
global _db_instance
if _db_instance is None:
_db_instance = Database(DB_PATH)
print(f"[数据库] SQLite 初始化完成: {DB_PATH}")
def get_db() -> Database:
"""获取数据库实例"""
global _db_instance
if _db_instance is None:
init_db()
return _db_instance