Compare commits
2 Commits
main
...
backup-bef
| Author | SHA1 | Date |
|---|---|---|
|
|
06ebc8cdb2 | |
|
|
4bff571f2b |
Binary file not shown.
|
|
@ -0,0 +1,3 @@
|
|||
from .db import Database, get_db, init_db
|
||||
|
||||
__all__ = ["Database", "get_db", "init_db"]
|
||||
|
|
@ -0,0 +1,369 @@
|
|||
"""
|
||||
SQLite 数据库模块 - 会话持久化存储
|
||||
|
||||
提供会话和消息的 CRUD 操作,支持多用户(预留 user_id 字段)。
|
||||
"""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# 数据库文件路径
|
||||
DB_PATH = Path(__file__).parent.parent / "data" / "chat.db"
|
||||
|
||||
# 线程本地存储,确保每个线程使用独立的连接
|
||||
_thread_local = threading.local()
|
||||
|
||||
# 全局数据库实例
|
||||
_db_instance: Optional["Database"] = None
|
||||
|
||||
|
||||
class Database:
|
||||
"""SQLite 数据库管理类"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_tables()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""获取当前线程的数据库连接"""
|
||||
if not hasattr(_thread_local, "connection"):
|
||||
conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
# 启用外键约束
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
_thread_local.connection = conn
|
||||
return _thread_local.connection
|
||||
|
||||
def _init_tables(self):
|
||||
"""初始化数据库表结构"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建会话表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT DEFAULT 'default',
|
||||
title TEXT DEFAULT '新对话',
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
pinned INTEGER DEFAULT 0,
|
||||
archived INTEGER DEFAULT 0,
|
||||
settings TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建消息表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
timestamp INTEGER,
|
||||
feedback TEXT,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建索引
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation
|
||||
ON messages(conversation_id)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_user
|
||||
ON conversations(user_id)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
# ── 会话 CRUD ─────────────────────────────────────────────────────
|
||||
|
||||
def create_conversation(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""创建新会话"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
now = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
conv_id = data.get("id") or self._generate_id()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO conversations (id, user_id, title, created_at, updated_at, pinned, archived, settings)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
conv_id,
|
||||
data.get("user_id", "default"),
|
||||
data.get("title", "新对话"),
|
||||
data.get("createdAt", now),
|
||||
now,
|
||||
1 if data.get("pinned") else 0,
|
||||
1 if data.get("archived") else 0,
|
||||
json.dumps(data.get("settings")) if data.get("settings") else None,
|
||||
),
|
||||
)
|
||||
|
||||
# 插入消息(如果有)
|
||||
messages = data.get("messages", [])
|
||||
for msg in messages:
|
||||
self._insert_message(cursor, conv_id, msg)
|
||||
|
||||
conn.commit()
|
||||
return self.get_conversation(conv_id)
|
||||
|
||||
def get_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取单个会话(包含消息)"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT * FROM conversations WHERE id = ?", (conversation_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return self._row_to_conversation(row, cursor)
|
||||
|
||||
def list_conversations(self, user_id: str = "default") -> List[Dict[str, Any]]:
|
||||
"""列出用户的所有会话"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM conversations
|
||||
WHERE user_id = ?
|
||||
ORDER BY updated_at DESC
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
conversations = []
|
||||
for row in cursor.fetchall():
|
||||
conv = self._row_to_conversation(row, cursor, include_messages=False)
|
||||
conversations.append(conv)
|
||||
|
||||
return conversations
|
||||
|
||||
def update_conversation(
|
||||
self, conversation_id: str, data: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""更新会话"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 检查会话是否存在
|
||||
cursor.execute(
|
||||
"SELECT id FROM conversations WHERE id = ?", (conversation_id,)
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
return None
|
||||
|
||||
now = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
|
||||
# 更新会话字段
|
||||
update_fields = ["updated_at = ?"]
|
||||
update_values = [now]
|
||||
|
||||
if "title" in data:
|
||||
update_fields.append("title = ?")
|
||||
update_values.append(data["title"])
|
||||
if "pinned" in data:
|
||||
update_fields.append("pinned = ?")
|
||||
update_values.append(1 if data["pinned"] else 0)
|
||||
if "archived" in data:
|
||||
update_fields.append("archived = ?")
|
||||
update_values.append(1 if data["archived"] else 0)
|
||||
if "settings" in data:
|
||||
update_fields.append("settings = ?")
|
||||
update_values.append(json.dumps(data["settings"]))
|
||||
|
||||
update_values.append(conversation_id)
|
||||
|
||||
cursor.execute(
|
||||
f"UPDATE conversations SET {', '.join(update_fields)} WHERE id = ?",
|
||||
update_values,
|
||||
)
|
||||
|
||||
# 更新消息(如果提供了 messages 字段)
|
||||
if "messages" in data:
|
||||
# 删除旧消息
|
||||
cursor.execute(
|
||||
"DELETE FROM messages WHERE conversation_id = ?", (conversation_id,)
|
||||
)
|
||||
# 插入新消息
|
||||
for msg in data["messages"]:
|
||||
self._insert_message(cursor, conversation_id, msg)
|
||||
|
||||
conn.commit()
|
||||
return self.get_conversation(conversation_id)
|
||||
|
||||
def delete_conversation(self, conversation_id: str) -> bool:
|
||||
"""删除会话(级联删除消息)"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"DELETE FROM conversations WHERE id = ?", (conversation_id,)
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
# ── 消息操作 ───────────────────────────────────────────────────────
|
||||
|
||||
def add_message(self, conversation_id: str, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""添加消息到会话"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
msg = self._insert_message(cursor, conversation_id, message)
|
||||
|
||||
# 更新会话的 updated_at
|
||||
now = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
cursor.execute(
|
||||
"UPDATE conversations SET updated_at = ? WHERE id = ?",
|
||||
(now, conversation_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return msg
|
||||
|
||||
def update_message(self, message_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""更新消息"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT id FROM messages WHERE id = ?", (message_id,))
|
||||
if not cursor.fetchone():
|
||||
return None
|
||||
|
||||
update_fields = []
|
||||
update_values = []
|
||||
|
||||
if "content" in data:
|
||||
update_fields.append("content = ?")
|
||||
update_values.append(json.dumps(data["content"]))
|
||||
if "feedback" in data:
|
||||
update_fields.append("feedback = ?")
|
||||
update_values.append(json.dumps(data["feedback"]))
|
||||
|
||||
if not update_fields:
|
||||
return self._get_message_by_id(message_id, cursor)
|
||||
|
||||
update_values.append(message_id)
|
||||
cursor.execute(
|
||||
f"UPDATE messages SET {', '.join(update_fields)} WHERE id = ?",
|
||||
update_values,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return self._get_message_by_id(message_id, cursor)
|
||||
|
||||
# ── 内部方法 ───────────────────────────────────────────────────────
|
||||
|
||||
def _insert_message(
|
||||
self, cursor: sqlite3.Cursor, conversation_id: str, message: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""插入消息(内部方法)"""
|
||||
msg_id = message.get("id") or self._generate_id()
|
||||
timestamp = message.get("timestamp") or int(
|
||||
datetime.now(timezone.utc).timestamp() * 1000
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO messages (id, conversation_id, role, content, timestamp, feedback)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
msg_id,
|
||||
conversation_id,
|
||||
message.get("role", "user"),
|
||||
json.dumps(message.get("content", "")),
|
||||
timestamp,
|
||||
json.dumps(message.get("feedback")) if message.get("feedback") else None,
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"id": msg_id,
|
||||
"role": message.get("role", "user"),
|
||||
"content": message.get("content", ""),
|
||||
"timestamp": timestamp,
|
||||
"feedback": message.get("feedback"),
|
||||
}
|
||||
|
||||
def _get_message_by_id(
|
||||
self, message_id: str, cursor: sqlite3.Cursor
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""根据 ID 获取消息"""
|
||||
cursor.execute("SELECT * FROM messages WHERE id = ?", (message_id,))
|
||||
row = cursor.fetchone()
|
||||
return self._row_to_message(row) if row else None
|
||||
|
||||
def _row_to_conversation(
|
||||
self, row: sqlite3.Row, cursor: sqlite3.Cursor, include_messages: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""将数据库行转换为会话字典"""
|
||||
conv = {
|
||||
"id": row["id"],
|
||||
"userId": row["user_id"],
|
||||
"title": row["title"],
|
||||
"createdAt": row["created_at"],
|
||||
"updatedAt": row["updated_at"],
|
||||
"pinned": bool(row["pinned"]),
|
||||
"archived": bool(row["archived"]),
|
||||
"settings": json.loads(row["settings"]) if row["settings"] else None,
|
||||
}
|
||||
|
||||
if include_messages:
|
||||
cursor.execute(
|
||||
"SELECT * FROM messages WHERE conversation_id = ? ORDER BY timestamp",
|
||||
(row["id"],),
|
||||
)
|
||||
conv["messages"] = [
|
||||
self._row_to_message(msg_row) for msg_row in cursor.fetchall()
|
||||
]
|
||||
|
||||
return conv
|
||||
|
||||
def _row_to_message(self, row: sqlite3.Row) -> Dict[str, Any]:
|
||||
"""将数据库行转换为消息字典"""
|
||||
return {
|
||||
"id": row["id"],
|
||||
"role": row["role"],
|
||||
"content": json.loads(row["content"]),
|
||||
"timestamp": row["timestamp"],
|
||||
"feedback": json.loads(row["feedback"]) if row["feedback"] else None,
|
||||
}
|
||||
|
||||
def _generate_id(self) -> str:
|
||||
"""生成唯一 ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库(应用启动时调用)"""
|
||||
global _db_instance
|
||||
if _db_instance is None:
|
||||
_db_instance = Database(DB_PATH)
|
||||
print(f"[数据库] SQLite 初始化完成: {DB_PATH}")
|
||||
|
||||
|
||||
def get_db() -> Database:
|
||||
"""获取数据库实例"""
|
||||
global _db_instance
|
||||
if _db_instance is None:
|
||||
init_db()
|
||||
return _db_instance
|
||||
Loading…
Reference in New Issue