feat: 增加MD列

This commit is contained in:
肖应宇 2026-05-08 10:46:43 +08:00
parent d6bba71524
commit ebd22a1a55
2 changed files with 167 additions and 14 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import abc
import json
import logging
import re
import sqlite3
import threading
from datetime import datetime
@ -16,6 +17,7 @@ from deerflow.config.paths import get_paths
from deerflow.config.thread_memory_config import get_thread_memory_config
logger = logging.getLogger(__name__)
_JSON_FENCE_RE = re.compile(r"```json\s*(.*?)\s*```", re.DOTALL | re.IGNORECASE)
class ThreadMemoryStorage(abc.ABC):
@ -32,15 +34,82 @@ class ThreadMemoryStorage(abc.ABC):
pass
def _memory_to_markdown(data: dict[str, Any]) -> str:
owner = data.get("ownerId")
owner_text = "null" if owner is None else str(owner)
return (
"# Thread Memory\n\n"
f"Owner ID: {owner_text}\n\n"
"## Profile\n"
"```json\n"
f"{json.dumps(data.get('profile', {}), ensure_ascii=False, indent=2)}\n"
"```\n\n"
"## Preferences\n"
"```json\n"
f"{json.dumps(data.get('preferences', {}), ensure_ascii=False, indent=2)}\n"
"```\n\n"
"## Facts\n"
"```json\n"
f"{json.dumps(data.get('facts', []), ensure_ascii=False, indent=2)}\n"
"```"
)
def _memory_from_markdown(markdown: str) -> dict[str, Any]:
parsed = create_empty_thread_memory()
owner_id: str | None = None
owner_line = next((line for line in markdown.splitlines() if line.startswith("Owner ID: ")), None)
if owner_line:
owner_raw = owner_line.split("Owner ID: ", 1)[1].strip()
owner_id = None if owner_raw == "null" else owner_raw
blocks = _JSON_FENCE_RE.findall(markdown)
if len(blocks) >= 1:
try:
profile = json.loads(blocks[0])
if isinstance(profile, dict):
parsed["profile"] = profile
except Exception:
pass
if len(blocks) >= 2:
try:
preferences = json.loads(blocks[1])
if isinstance(preferences, dict):
parsed["preferences"] = preferences
except Exception:
pass
if len(blocks) >= 3:
try:
facts = json.loads(blocks[2])
if isinstance(facts, list):
parsed["facts"] = facts
except Exception:
pass
return {"ownerId": owner_id, **parsed}
def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]:
memory_md = row[2]
if isinstance(memory_md, str) and memory_md.strip():
decoded = _memory_from_markdown(memory_md)
profile = decoded.get("profile", {})
preferences = decoded.get("preferences", {})
facts = decoded.get("facts", [])
owner_id = decoded.get("ownerId")
else:
owner_id = row[1]
profile = json.loads(row[3])
preferences = json.loads(row[4])
facts = json.loads(row[5])
return {
"threadId": row[0],
"ownerId": row[1],
"profile": json.loads(row[2]),
"preferences": json.loads(row[3]),
"facts": json.loads(row[4]),
"memoryVersion": int(row[5]),
"lastUpdated": str(row[6]),
"ownerId": owner_id,
"profile": profile,
"preferences": preferences,
"facts": facts,
"memoryVersion": int(row[6]),
"lastUpdated": str(row[7]),
}
@ -58,6 +127,7 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id TEXT PRIMARY KEY,
owner_id TEXT NULL,
memory_md TEXT NOT NULL DEFAULT '',
profile TEXT NOT NULL DEFAULT '{}',
preferences TEXT NOT NULL DEFAULT '{}',
facts TEXT NOT NULL DEFAULT '[]',
@ -66,6 +136,9 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
)
"""
)
columns = {r[1] for r in self._conn.execute("PRAGMA table_info(thread_memory)").fetchall()}
if "memory_md" not in columns:
self._conn.execute("ALTER TABLE thread_memory ADD COLUMN memory_md TEXT NOT NULL DEFAULT ''")
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)")
self._conn.commit()
@ -76,7 +149,30 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
"FROM thread_memory WHERE thread_id = ?",
(thread_id,),
).fetchone()
return _row_to_memory(row) if row else None
if row is None:
return None
row = (
row[0],
row[1],
"",
row[2],
row[3],
row[4],
row[5],
row[6],
)
try:
row2 = self._conn.execute(
"SELECT thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated "
"FROM thread_memory WHERE thread_id = ?",
(thread_id,),
).fetchone()
if row2 is not None:
row = row2
except sqlite3.OperationalError:
# Backward compatibility when running before migration.
pass
return _row_to_memory(row)
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
now = datetime.utcnow().isoformat() + "Z"
@ -86,13 +182,14 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
with self._lock:
cur = self._conn.execute(
"""
INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version, last_updated)
VALUES (?, ?, ?, ?, ?, 0, ?)
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated)
VALUES (?, ?, ?, ?, ?, ?, 0, ?)
ON CONFLICT(thread_id) DO NOTHING
""",
(
thread_id,
owner_id,
_memory_to_markdown(data),
json.dumps(data.get("profile", {}), ensure_ascii=False),
json.dumps(data.get("preferences", {}), ensure_ascii=False),
json.dumps(data.get("facts", []), ensure_ascii=False),
@ -106,11 +203,12 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
cur = self._conn.execute(
"""
UPDATE thread_memory
SET owner_id = ?, profile = ?, preferences = ?, facts = ?, memory_version = memory_version + 1, last_updated = ?
SET owner_id = ?, memory_md = ?, profile = ?, preferences = ?, facts = ?, memory_version = memory_version + 1, last_updated = ?
WHERE thread_id = ? AND memory_version = ?
""",
(
owner_id,
_memory_to_markdown(data),
json.dumps(data.get("profile", {}), ensure_ascii=False),
json.dumps(data.get("preferences", {}), ensure_ascii=False),
json.dumps(data.get("facts", []), ensure_ascii=False),
@ -140,6 +238,7 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id VARCHAR(64) PRIMARY KEY,
owner_id VARCHAR(64) NULL,
memory_md LONGTEXT NOT NULL,
profile JSON NOT NULL,
preferences JSON NOT NULL,
facts JSON NOT NULL,
@ -149,12 +248,22 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
)
"""
)
cur.execute(
"""
SELECT COUNT(*)
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'thread_memory' AND COLUMN_NAME = 'memory_md'
"""
)
has_memory_md = cur.fetchone()[0] > 0
if not has_memory_md:
cur.execute("ALTER TABLE thread_memory ADD COLUMN memory_md LONGTEXT NOT NULL")
self._conn.commit()
def load(self, thread_id: str) -> dict[str, Any] | None:
with self._conn.cursor() as cur:
cur.execute(
"SELECT thread_id, owner_id, profile, preferences, facts, memory_version, last_updated FROM thread_memory WHERE thread_id = %s",
"SELECT thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated FROM thread_memory WHERE thread_id = %s",
(thread_id,),
)
row = cur.fetchone()
@ -167,13 +276,14 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
with self._conn.cursor() as cur:
cur.execute(
"""
INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version)
VALUES (%s, %s, %s, %s, %s, 0)
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version)
VALUES (%s, %s, %s, %s, %s, %s, 0)
ON DUPLICATE KEY UPDATE thread_id = thread_id
""",
(
thread_id,
owner_id,
_memory_to_markdown(data),
json.dumps(data.get("profile", {}), ensure_ascii=False),
json.dumps(data.get("preferences", {}), ensure_ascii=False),
json.dumps(data.get("facts", []), ensure_ascii=False),
@ -185,11 +295,12 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
cur.execute(
"""
UPDATE thread_memory
SET owner_id = %s, profile = %s, preferences = %s, facts = %s, memory_version = memory_version + 1
SET owner_id = %s, memory_md = %s, profile = %s, preferences = %s, facts = %s, memory_version = memory_version + 1
WHERE thread_id = %s AND memory_version = %s
""",
(
owner_id,
_memory_to_markdown(data),
json.dumps(data.get("profile", {}), ensure_ascii=False),
json.dumps(data.get("preferences", {}), ensure_ascii=False),
json.dumps(data.get("facts", []), ensure_ascii=False),

View File

@ -27,3 +27,45 @@ def test_sqlite_thread_memory_compare_and_swap(tmp_path):
assert loaded2 is not None
assert loaded2["memoryVersion"] == 1
def test_sqlite_thread_memory_saves_markdown_payload(tmp_path):
db_path = tmp_path / "thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path))
thread_id = "thread-md"
assert storage.save(thread_id, _payload(), expected_version=0) is True
with storage._lock:
row = storage._conn.execute("SELECT memory_md FROM thread_memory WHERE thread_id = ?", (thread_id,)).fetchone()
assert row is not None
assert isinstance(row[0], str)
assert "## Profile" in row[0]
assert "## Preferences" in row[0]
assert "## Facts" in row[0]
def test_sqlite_thread_memory_loads_legacy_json_row(tmp_path):
db_path = tmp_path / "legacy-thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path))
thread_id = "thread-legacy"
with storage._lock:
storage._conn.execute(
"""
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated)
VALUES (?, ?, '', ?, ?, ?, 0, datetime('now'))
""",
(
thread_id,
"owner-1",
'{"name":"Alice","role":null,"expertise":[],"language":null,"context":null}',
'{"tone":null,"verbosity":null,"codeStyle":null,"other":null}',
"[]",
),
)
storage._conn.commit()
loaded = storage.load(thread_id)
assert loaded is not None
assert loaded["ownerId"] == "owner-1"
assert loaded["profile"]["name"] == "Alice"