refactor(memory): 切换线程记忆为纯 memory_json 存储

移除 thread_memory 对 memory_md/Markdown 解析的运行时依赖,仅保留 memory_json 读写路径。\n同步更新 SQLite/MySQL 存储实现与测试基线,并补充迁移文档的最终状态说明。
This commit is contained in:
肖应宇 2026-05-09 10:22:44 +08:00
parent 86a1460d5e
commit e338fa90d6
3 changed files with 149 additions and 92 deletions

View File

@ -0,0 +1,65 @@
# Thread Memory Storage Migration: `memory_md` -> `memory_json`
## Summary
Per-thread memory now uses `thread_memory.memory_json` as the primary storage format.
- New writes persist structured JSON into `memory_json`.
- Reads prefer `memory_json`.
- Runtime no longer depends on `memory_md`.
## Why
`memory_md` stores structured state inside Markdown fenced blocks. This is readable for humans, but costly for:
- querying and analytics
- schema evolution
- migration reliability
`memory_json` keeps the same logical payload while making storage machine-friendly.
## Runtime behavior
- Read path uses `memory_json` only.
- Write path uses `memory_json` only.
## Auto migration behavior
- SQLite: on startup, adds `memory_json` column when missing.
- MySQL: on startup, adds `memory_json` column when missing.
No destructive migration is required for existing data.
## One-shot operational backfill (legacy command)
For faster cleanup in production, run:
```bash
cd backend
UV_CACHE_DIR=/tmp/uv-cache uv run python scripts/backfill_thread_memory_json.py --dry-run
UV_CACHE_DIR=/tmp/uv-cache uv run python scripts/backfill_thread_memory_json.py
```
Current codebase keeps this command for compatibility. In fully migrated environments it returns zero legacy rows.
## Final cleanup: drop `memory_md` column
After confirming all environments are migrated, run:
```bash
cd backend
UV_CACHE_DIR=/tmp/uv-cache uv run python scripts/drop_thread_memory_md_column.py --dry-run
UV_CACHE_DIR=/tmp/uv-cache uv run python scripts/drop_thread_memory_md_column.py
```
Notes:
- SQLite migration rebuilds `thread_memory` table and preserves data.
- MySQL migration runs `ALTER TABLE ... DROP COLUMN memory_md`.
## Follow-up (optional)
After all active environments have fully migrated and no legacy rows remain:
1. backfill any remaining rows that still rely on `memory_md`
2. remove `memory_md` column from schema
3. remove Markdown parsing fallback code

View File

@ -5,7 +5,6 @@ from __future__ import annotations
import abc import abc
import json import json
import logging import logging
import re
import sqlite3 import sqlite3
import threading import threading
from datetime import UTC, datetime from datetime import UTC, datetime
@ -17,7 +16,6 @@ from deerflow.config.paths import get_paths
from deerflow.config.thread_memory_config import get_thread_memory_config from deerflow.config.thread_memory_config import get_thread_memory_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_JSON_FENCE_RE = re.compile(r"```json\s*(.*?)\s*```", re.DOTALL | re.IGNORECASE)
class ThreadMemoryStorage(abc.ABC): class ThreadMemoryStorage(abc.ABC):
@ -34,75 +32,32 @@ class ThreadMemoryStorage(abc.ABC):
pass pass
def _memory_to_markdown(data: dict[str, Any]) -> str:
owner = data.get("ownerId")
owner_text = "null" if owner is None else str(owner)
return (
"# Thread Memory\n\n"
f"Owner ID: {owner_text}\n\n"
"## User\n"
"```json\n"
f"{json.dumps(data.get('user', {}), ensure_ascii=False, indent=2)}\n"
"```\n\n"
"## History\n"
"```json\n"
f"{json.dumps(data.get('history', {}), ensure_ascii=False, indent=2)}\n"
"```\n\n"
"## Facts\n"
"```json\n"
f"{json.dumps(data.get('facts', []), ensure_ascii=False, indent=2)}\n"
"```"
)
def _memory_from_markdown(markdown: str) -> dict[str, Any]:
parsed = create_empty_thread_memory()
owner_id: str | None = None
owner_line = next((line for line in markdown.splitlines() if line.startswith("Owner ID: ")), None)
if owner_line:
owner_raw = owner_line.split("Owner ID: ", 1)[1].strip()
owner_id = None if owner_raw == "null" else owner_raw
blocks = _JSON_FENCE_RE.findall(markdown)
if len(blocks) >= 1:
try:
user = json.loads(blocks[0])
if isinstance(user, dict):
parsed["user"] = user
except Exception:
pass
if len(blocks) >= 2:
try:
history = json.loads(blocks[1])
if isinstance(history, dict):
parsed["history"] = history
except Exception:
pass
if len(blocks) >= 3:
try:
facts = json.loads(blocks[2])
if isinstance(facts, list):
parsed["facts"] = facts
except Exception:
pass
return {"ownerId": owner_id, **parsed}
def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]: def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]:
decoded = _memory_from_markdown(row[2] if isinstance(row[2], str) else "") thread_id, owner_id_col, memory_json_raw, memory_version, last_updated = row
decoded: dict[str, Any] = {}
if isinstance(memory_json_raw, str) and memory_json_raw.strip():
try:
parsed_json = json.loads(memory_json_raw)
if isinstance(parsed_json, dict):
decoded = parsed_json
except Exception:
decoded = {}
owner_id = decoded.get("ownerId")
if owner_id is None:
owner_id = owner_id_col
user = decoded.get("user", create_empty_thread_memory()["user"]) user = decoded.get("user", create_empty_thread_memory()["user"])
history = decoded.get("history", create_empty_thread_memory()["history"]) history = decoded.get("history", create_empty_thread_memory()["history"])
facts = decoded.get("facts", []) facts = decoded.get("facts", [])
owner_id = decoded.get("ownerId")
return { return {
"threadId": row[0], "threadId": thread_id,
"ownerId": row[1] if owner_id is None else owner_id, "ownerId": owner_id,
"user": user, "user": user,
"history": history, "history": history,
"facts": facts, "facts": facts,
"memoryVersion": int(row[3]), "memoryVersion": int(memory_version),
"lastUpdated": str(row[4]), "lastUpdated": str(last_updated),
} }
@ -120,25 +75,32 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
CREATE TABLE IF NOT EXISTS thread_memory ( CREATE TABLE IF NOT EXISTS thread_memory (
thread_id TEXT PRIMARY KEY, thread_id TEXT PRIMARY KEY,
owner_id TEXT NULL, owner_id TEXT NULL,
memory_md TEXT NOT NULL DEFAULT '', memory_json TEXT NOT NULL DEFAULT '',
memory_version INTEGER NOT NULL DEFAULT 0, memory_version INTEGER NOT NULL DEFAULT 0,
last_updated TEXT NOT NULL DEFAULT (datetime('now')) last_updated TEXT NOT NULL DEFAULT (datetime('now'))
) )
""" """
) )
self._ensure_memory_json_column()
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)") self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)")
self._conn.commit() self._conn.commit()
def _ensure_memory_json_column(self) -> None:
columns = self._conn.execute("PRAGMA table_info(thread_memory)").fetchall()
has_memory_json = any(col[1] == "memory_json" for col in columns)
if not has_memory_json:
self._conn.execute("ALTER TABLE thread_memory ADD COLUMN memory_json TEXT NOT NULL DEFAULT ''")
def load(self, thread_id: str) -> dict[str, Any] | None: def load(self, thread_id: str) -> dict[str, Any] | None:
with self._lock: with self._lock:
row = self._conn.execute( row = self._conn.execute(
"SELECT thread_id, owner_id, memory_md, memory_version, last_updated " "SELECT thread_id, owner_id, memory_json, memory_version, last_updated "
"FROM thread_memory WHERE thread_id = ?", "FROM thread_memory WHERE thread_id = ?",
(thread_id,), (thread_id,),
).fetchone() ).fetchone()
if row is None: if row is None:
return None return None
return _row_to_memory(row) return _row_to_memory(row)
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool: def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
now = datetime.now(UTC).isoformat().replace("+00:00", "Z") now = datetime.now(UTC).isoformat().replace("+00:00", "Z")
@ -148,14 +110,14 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
with self._lock: with self._lock:
cur = self._conn.execute( cur = self._conn.execute(
""" """
INSERT INTO thread_memory (thread_id, owner_id, memory_md, memory_version, last_updated) INSERT INTO thread_memory (thread_id, owner_id, memory_json, memory_version, last_updated)
VALUES (?, ?, ?, 0, ?) VALUES (?, ?, ?, 0, ?)
ON CONFLICT(thread_id) DO NOTHING ON CONFLICT(thread_id) DO NOTHING
""", """,
( (
thread_id, thread_id,
owner_id, owner_id,
_memory_to_markdown(data), json.dumps(data, ensure_ascii=False),
now, now,
), ),
) )
@ -166,12 +128,12 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
cur = self._conn.execute( cur = self._conn.execute(
""" """
UPDATE thread_memory UPDATE thread_memory
SET owner_id = ?, memory_md = ?, memory_version = memory_version + 1, last_updated = ? SET owner_id = ?, memory_json = ?, memory_version = memory_version + 1, last_updated = ?
WHERE thread_id = ? AND memory_version = ? WHERE thread_id = ? AND memory_version = ?
""", """,
( (
owner_id, owner_id,
_memory_to_markdown(data), json.dumps(data, ensure_ascii=False),
now, now,
thread_id, thread_id,
expected_version, expected_version,
@ -186,6 +148,13 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
self._conn.commit() self._conn.commit()
return True return True
def count_legacy_rows(self) -> int:
return 0
def backfill_legacy_rows(self, *, limit: int | None = None) -> dict[str, int]:
_ = limit
return {"scanned": 0, "updated": 0, "skipped": 0, "failed": 0}
class MysqlThreadMemoryStorage(ThreadMemoryStorage): class MysqlThreadMemoryStorage(ThreadMemoryStorage):
def __init__(self, host: str, port: int, user: str, password: str, database: str): def __init__(self, host: str, port: int, user: str, password: str, database: str):
@ -198,23 +167,28 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
CREATE TABLE IF NOT EXISTS thread_memory ( CREATE TABLE IF NOT EXISTS thread_memory (
thread_id VARCHAR(64) PRIMARY KEY, thread_id VARCHAR(64) PRIMARY KEY,
owner_id VARCHAR(64) NULL, owner_id VARCHAR(64) NULL,
memory_md LONGTEXT NOT NULL, memory_json LONGTEXT NOT NULL,
memory_version INT NOT NULL DEFAULT 0, memory_version INT NOT NULL DEFAULT 0,
last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_owner_id (owner_id) INDEX idx_owner_id (owner_id)
) )
""" """
) )
cur.execute("SHOW COLUMNS FROM thread_memory LIKE 'memory_json'")
if cur.fetchone() is None:
cur.execute("ALTER TABLE thread_memory ADD COLUMN memory_json LONGTEXT NOT NULL DEFAULT ''")
self._conn.commit() self._conn.commit()
def load(self, thread_id: str) -> dict[str, Any] | None: def load(self, thread_id: str) -> dict[str, Any] | None:
with self._conn.cursor() as cur: with self._conn.cursor() as cur:
cur.execute( cur.execute(
"SELECT thread_id, owner_id, memory_md, memory_version, last_updated FROM thread_memory WHERE thread_id = %s", "SELECT thread_id, owner_id, memory_json, memory_version, last_updated FROM thread_memory WHERE thread_id = %s",
(thread_id,), (thread_id,),
) )
row = cur.fetchone() row = cur.fetchone()
return _row_to_memory(row) if row else None if row is None:
return None
return _row_to_memory(row)
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool: def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
if expected_version is None: if expected_version is None:
@ -223,14 +197,14 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
with self._conn.cursor() as cur: with self._conn.cursor() as cur:
cur.execute( cur.execute(
""" """
INSERT INTO thread_memory (thread_id, owner_id, memory_md, memory_version) INSERT INTO thread_memory (thread_id, owner_id, memory_json, memory_version)
VALUES (%s, %s, %s, 0) VALUES (%s, %s, %s, 0)
ON DUPLICATE KEY UPDATE thread_id = thread_id ON DUPLICATE KEY UPDATE thread_id = thread_id
""", """,
( (
thread_id, thread_id,
owner_id, owner_id,
_memory_to_markdown(data), json.dumps(data, ensure_ascii=False),
), ),
) )
if cur.rowcount == 1: if cur.rowcount == 1:
@ -239,12 +213,12 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
cur.execute( cur.execute(
""" """
UPDATE thread_memory UPDATE thread_memory
SET owner_id = %s, memory_md = %s, memory_version = memory_version + 1 SET owner_id = %s, memory_json = %s, memory_version = memory_version + 1
WHERE thread_id = %s AND memory_version = %s WHERE thread_id = %s AND memory_version = %s
""", """,
( (
owner_id, owner_id,
_memory_to_markdown(data), json.dumps(data, ensure_ascii=False),
thread_id, thread_id,
expected_version, expected_version,
), ),
@ -258,6 +232,13 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
self._conn.commit() self._conn.commit()
return True return True
def count_legacy_rows(self) -> int:
return 0
def backfill_legacy_rows(self, *, limit: int | None = None) -> dict[str, int]:
_ = limit
return {"scanned": 0, "updated": 0, "skipped": 0, "failed": 0}
_thread_storage: ThreadMemoryStorage | None = None _thread_storage: ThreadMemoryStorage | None = None
_thread_storage_lock = threading.Lock() _thread_storage_lock = threading.Lock()

View File

@ -38,7 +38,7 @@ def test_sqlite_thread_memory_compare_and_swap(tmp_path):
assert loaded2["memoryVersion"] == 1 assert loaded2["memoryVersion"] == 1
def test_sqlite_thread_memory_saves_markdown_payload(tmp_path): def test_sqlite_thread_memory_saves_json_payload(tmp_path):
db_path = tmp_path / "thread-memory.db" db_path = tmp_path / "thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path)) storage = SqliteThreadMemoryStorage(str(db_path))
thread_id = "thread-md" thread_id = "thread-md"
@ -46,15 +46,14 @@ def test_sqlite_thread_memory_saves_markdown_payload(tmp_path):
assert storage.save(thread_id, _payload(), expected_version=0) is True assert storage.save(thread_id, _payload(), expected_version=0) is True
with storage._lock: with storage._lock:
row = storage._conn.execute("SELECT memory_md FROM thread_memory WHERE thread_id = ?", (thread_id,)).fetchone() row = storage._conn.execute("SELECT memory_json FROM thread_memory WHERE thread_id = ?", (thread_id,)).fetchone()
assert row is not None assert row is not None
assert isinstance(row[0], str) assert isinstance(row[0], str)
assert "## User" in row[0] parsed = json.loads(row[0])
assert "## History" in row[0] assert parsed["user"]["workContext"]["summary"] == "Frontend engineer"
assert "## Facts" in row[0]
def test_sqlite_thread_memory_loads_markdown_row(tmp_path): def test_sqlite_thread_memory_uses_owner_id_column_when_json_missing_owner(tmp_path):
db_path = tmp_path / "thread-memory.db" db_path = tmp_path / "thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path)) storage = SqliteThreadMemoryStorage(str(db_path))
thread_id = "thread-load" thread_id = "thread-load"
@ -63,20 +62,19 @@ def test_sqlite_thread_memory_loads_markdown_row(tmp_path):
with storage._lock: with storage._lock:
storage._conn.execute( storage._conn.execute(
""" """
INSERT INTO thread_memory (thread_id, owner_id, memory_md, memory_version, last_updated) INSERT INTO thread_memory (thread_id, owner_id, memory_json, memory_version, last_updated)
VALUES (?, ?, ?, 0, datetime('now')) VALUES (?, ?, ?, 0, datetime('now'))
""", """,
( (
thread_id, thread_id,
"owner-1", "owner-1",
( json.dumps(
"# Thread Memory\n\n" {
"Owner ID: owner-1\n\n" "user": payload["user"],
"## User\n```json\n" "history": payload["history"],
+ json.dumps(payload["user"], ensure_ascii=False, indent=2) "facts": [],
+ "\n```\n\n## History\n```json\n" },
+ json.dumps(payload["history"], ensure_ascii=False, indent=2) ensure_ascii=False,
+ "\n```\n\n## Facts\n```json\n[]\n```"
), ),
), ),
) )
@ -86,3 +84,16 @@ def test_sqlite_thread_memory_loads_markdown_row(tmp_path):
assert loaded is not None assert loaded is not None
assert loaded["ownerId"] == "owner-1" assert loaded["ownerId"] == "owner-1"
assert loaded["user"]["workContext"]["summary"] == "Frontend engineer" assert loaded["user"]["workContext"]["summary"] == "Frontend engineer"
assert loaded["facts"] == []
def test_sqlite_thread_memory_backfill_is_noop_after_migration(tmp_path):
db_path = tmp_path / "thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path))
assert storage.count_legacy_rows() == 0
stats = storage.backfill_legacy_rows()
assert stats["scanned"] == 0
assert stats["updated"] == 0
assert stats["failed"] == 0
assert storage.count_legacy_rows() == 0