feat: 增加MD列

This commit is contained in:
肖应宇 2026-05-08 10:46:43 +08:00
parent b49e838980
commit 7db468aa6f
2 changed files with 167 additions and 14 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import abc import abc
import json import json
import logging import logging
import re
import sqlite3 import sqlite3
import threading import threading
from datetime import datetime from datetime import datetime
@ -16,6 +17,7 @@ from deerflow.config.paths import get_paths
from deerflow.config.thread_memory_config import get_thread_memory_config from deerflow.config.thread_memory_config import get_thread_memory_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_JSON_FENCE_RE = re.compile(r"```json\s*(.*?)\s*```", re.DOTALL | re.IGNORECASE)
class ThreadMemoryStorage(abc.ABC): class ThreadMemoryStorage(abc.ABC):
@ -32,15 +34,82 @@ class ThreadMemoryStorage(abc.ABC):
pass pass
def _memory_to_markdown(data: dict[str, Any]) -> str:
owner = data.get("ownerId")
owner_text = "null" if owner is None else str(owner)
return (
"# Thread Memory\n\n"
f"Owner ID: {owner_text}\n\n"
"## Profile\n"
"```json\n"
f"{json.dumps(data.get('profile', {}), ensure_ascii=False, indent=2)}\n"
"```\n\n"
"## Preferences\n"
"```json\n"
f"{json.dumps(data.get('preferences', {}), ensure_ascii=False, indent=2)}\n"
"```\n\n"
"## Facts\n"
"```json\n"
f"{json.dumps(data.get('facts', []), ensure_ascii=False, indent=2)}\n"
"```"
)
def _memory_from_markdown(markdown: str) -> dict[str, Any]:
parsed = create_empty_thread_memory()
owner_id: str | None = None
owner_line = next((line for line in markdown.splitlines() if line.startswith("Owner ID: ")), None)
if owner_line:
owner_raw = owner_line.split("Owner ID: ", 1)[1].strip()
owner_id = None if owner_raw == "null" else owner_raw
blocks = _JSON_FENCE_RE.findall(markdown)
if len(blocks) >= 1:
try:
profile = json.loads(blocks[0])
if isinstance(profile, dict):
parsed["profile"] = profile
except Exception:
pass
if len(blocks) >= 2:
try:
preferences = json.loads(blocks[1])
if isinstance(preferences, dict):
parsed["preferences"] = preferences
except Exception:
pass
if len(blocks) >= 3:
try:
facts = json.loads(blocks[2])
if isinstance(facts, list):
parsed["facts"] = facts
except Exception:
pass
return {"ownerId": owner_id, **parsed}
def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]: def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]:
memory_md = row[2]
if isinstance(memory_md, str) and memory_md.strip():
decoded = _memory_from_markdown(memory_md)
profile = decoded.get("profile", {})
preferences = decoded.get("preferences", {})
facts = decoded.get("facts", [])
owner_id = decoded.get("ownerId")
else:
owner_id = row[1]
profile = json.loads(row[3])
preferences = json.loads(row[4])
facts = json.loads(row[5])
return { return {
"threadId": row[0], "threadId": row[0],
"ownerId": row[1], "ownerId": owner_id,
"profile": json.loads(row[2]), "profile": profile,
"preferences": json.loads(row[3]), "preferences": preferences,
"facts": json.loads(row[4]), "facts": facts,
"memoryVersion": int(row[5]), "memoryVersion": int(row[6]),
"lastUpdated": str(row[6]), "lastUpdated": str(row[7]),
} }
@ -58,6 +127,7 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
CREATE TABLE IF NOT EXISTS thread_memory ( CREATE TABLE IF NOT EXISTS thread_memory (
thread_id TEXT PRIMARY KEY, thread_id TEXT PRIMARY KEY,
owner_id TEXT NULL, owner_id TEXT NULL,
memory_md TEXT NOT NULL DEFAULT '',
profile TEXT NOT NULL DEFAULT '{}', profile TEXT NOT NULL DEFAULT '{}',
preferences TEXT NOT NULL DEFAULT '{}', preferences TEXT NOT NULL DEFAULT '{}',
facts TEXT NOT NULL DEFAULT '[]', facts TEXT NOT NULL DEFAULT '[]',
@ -66,6 +136,9 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
) )
""" """
) )
columns = {r[1] for r in self._conn.execute("PRAGMA table_info(thread_memory)").fetchall()}
if "memory_md" not in columns:
self._conn.execute("ALTER TABLE thread_memory ADD COLUMN memory_md TEXT NOT NULL DEFAULT ''")
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)") self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)")
self._conn.commit() self._conn.commit()
@ -76,7 +149,30 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
"FROM thread_memory WHERE thread_id = ?", "FROM thread_memory WHERE thread_id = ?",
(thread_id,), (thread_id,),
).fetchone() ).fetchone()
return _row_to_memory(row) if row else None if row is None:
return None
row = (
row[0],
row[1],
"",
row[2],
row[3],
row[4],
row[5],
row[6],
)
try:
row2 = self._conn.execute(
"SELECT thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated "
"FROM thread_memory WHERE thread_id = ?",
(thread_id,),
).fetchone()
if row2 is not None:
row = row2
except sqlite3.OperationalError:
# Backward compatibility when running before migration.
pass
return _row_to_memory(row)
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool: def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
now = datetime.utcnow().isoformat() + "Z" now = datetime.utcnow().isoformat() + "Z"
@ -86,13 +182,14 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
with self._lock: with self._lock:
cur = self._conn.execute( cur = self._conn.execute(
""" """
INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version, last_updated) INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated)
VALUES (?, ?, ?, ?, ?, 0, ?) VALUES (?, ?, ?, ?, ?, ?, 0, ?)
ON CONFLICT(thread_id) DO NOTHING ON CONFLICT(thread_id) DO NOTHING
""", """,
( (
thread_id, thread_id,
owner_id, owner_id,
_memory_to_markdown(data),
json.dumps(data.get("profile", {}), ensure_ascii=False), json.dumps(data.get("profile", {}), ensure_ascii=False),
json.dumps(data.get("preferences", {}), ensure_ascii=False), json.dumps(data.get("preferences", {}), ensure_ascii=False),
json.dumps(data.get("facts", []), ensure_ascii=False), json.dumps(data.get("facts", []), ensure_ascii=False),
@ -106,11 +203,12 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
cur = self._conn.execute( cur = self._conn.execute(
""" """
UPDATE thread_memory UPDATE thread_memory
SET owner_id = ?, profile = ?, preferences = ?, facts = ?, memory_version = memory_version + 1, last_updated = ? SET owner_id = ?, memory_md = ?, profile = ?, preferences = ?, facts = ?, memory_version = memory_version + 1, last_updated = ?
WHERE thread_id = ? AND memory_version = ? WHERE thread_id = ? AND memory_version = ?
""", """,
( (
owner_id, owner_id,
_memory_to_markdown(data),
json.dumps(data.get("profile", {}), ensure_ascii=False), json.dumps(data.get("profile", {}), ensure_ascii=False),
json.dumps(data.get("preferences", {}), ensure_ascii=False), json.dumps(data.get("preferences", {}), ensure_ascii=False),
json.dumps(data.get("facts", []), ensure_ascii=False), json.dumps(data.get("facts", []), ensure_ascii=False),
@ -140,6 +238,7 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
CREATE TABLE IF NOT EXISTS thread_memory ( CREATE TABLE IF NOT EXISTS thread_memory (
thread_id VARCHAR(64) PRIMARY KEY, thread_id VARCHAR(64) PRIMARY KEY,
owner_id VARCHAR(64) NULL, owner_id VARCHAR(64) NULL,
memory_md LONGTEXT NOT NULL,
profile JSON NOT NULL, profile JSON NOT NULL,
preferences JSON NOT NULL, preferences JSON NOT NULL,
facts JSON NOT NULL, facts JSON NOT NULL,
@ -149,12 +248,22 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
) )
""" """
) )
cur.execute(
"""
SELECT COUNT(*)
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'thread_memory' AND COLUMN_NAME = 'memory_md'
"""
)
has_memory_md = cur.fetchone()[0] > 0
if not has_memory_md:
cur.execute("ALTER TABLE thread_memory ADD COLUMN memory_md LONGTEXT NOT NULL")
self._conn.commit() self._conn.commit()
def load(self, thread_id: str) -> dict[str, Any] | None: def load(self, thread_id: str) -> dict[str, Any] | None:
with self._conn.cursor() as cur: with self._conn.cursor() as cur:
cur.execute( cur.execute(
"SELECT thread_id, owner_id, profile, preferences, facts, memory_version, last_updated FROM thread_memory WHERE thread_id = %s", "SELECT thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated FROM thread_memory WHERE thread_id = %s",
(thread_id,), (thread_id,),
) )
row = cur.fetchone() row = cur.fetchone()
@ -167,13 +276,14 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
with self._conn.cursor() as cur: with self._conn.cursor() as cur:
cur.execute( cur.execute(
""" """
INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version) INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version)
VALUES (%s, %s, %s, %s, %s, 0) VALUES (%s, %s, %s, %s, %s, %s, 0)
ON DUPLICATE KEY UPDATE thread_id = thread_id ON DUPLICATE KEY UPDATE thread_id = thread_id
""", """,
( (
thread_id, thread_id,
owner_id, owner_id,
_memory_to_markdown(data),
json.dumps(data.get("profile", {}), ensure_ascii=False), json.dumps(data.get("profile", {}), ensure_ascii=False),
json.dumps(data.get("preferences", {}), ensure_ascii=False), json.dumps(data.get("preferences", {}), ensure_ascii=False),
json.dumps(data.get("facts", []), ensure_ascii=False), json.dumps(data.get("facts", []), ensure_ascii=False),
@ -185,11 +295,12 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
cur.execute( cur.execute(
""" """
UPDATE thread_memory UPDATE thread_memory
SET owner_id = %s, profile = %s, preferences = %s, facts = %s, memory_version = memory_version + 1 SET owner_id = %s, memory_md = %s, profile = %s, preferences = %s, facts = %s, memory_version = memory_version + 1
WHERE thread_id = %s AND memory_version = %s WHERE thread_id = %s AND memory_version = %s
""", """,
( (
owner_id, owner_id,
_memory_to_markdown(data),
json.dumps(data.get("profile", {}), ensure_ascii=False), json.dumps(data.get("profile", {}), ensure_ascii=False),
json.dumps(data.get("preferences", {}), ensure_ascii=False), json.dumps(data.get("preferences", {}), ensure_ascii=False),
json.dumps(data.get("facts", []), ensure_ascii=False), json.dumps(data.get("facts", []), ensure_ascii=False),

View File

@ -27,3 +27,45 @@ def test_sqlite_thread_memory_compare_and_swap(tmp_path):
assert loaded2 is not None assert loaded2 is not None
assert loaded2["memoryVersion"] == 1 assert loaded2["memoryVersion"] == 1
def test_sqlite_thread_memory_saves_markdown_payload(tmp_path):
db_path = tmp_path / "thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path))
thread_id = "thread-md"
assert storage.save(thread_id, _payload(), expected_version=0) is True
with storage._lock:
row = storage._conn.execute("SELECT memory_md FROM thread_memory WHERE thread_id = ?", (thread_id,)).fetchone()
assert row is not None
assert isinstance(row[0], str)
assert "## Profile" in row[0]
assert "## Preferences" in row[0]
assert "## Facts" in row[0]
def test_sqlite_thread_memory_loads_legacy_json_row(tmp_path):
db_path = tmp_path / "legacy-thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path))
thread_id = "thread-legacy"
with storage._lock:
storage._conn.execute(
"""
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated)
VALUES (?, ?, '', ?, ?, ?, 0, datetime('now'))
""",
(
thread_id,
"owner-1",
'{"name":"Alice","role":null,"expertise":[],"language":null,"context":null}',
'{"tone":null,"verbosity":null,"codeStyle":null,"other":null}',
"[]",
),
)
storage._conn.commit()
loaded = storage.load(thread_id)
assert loaded is not None
assert loaded["ownerId"] == "owner-1"
assert loaded["profile"]["name"] == "Alice"