feat: 增加MD列
This commit is contained in:
parent
d6bba71524
commit
ebd22a1a55
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
@ -16,6 +17,7 @@ from deerflow.config.paths import get_paths
|
||||
from deerflow.config.thread_memory_config import get_thread_memory_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_JSON_FENCE_RE = re.compile(r"```json\s*(.*?)\s*```", re.DOTALL | re.IGNORECASE)
|
||||
|
||||
|
||||
class ThreadMemoryStorage(abc.ABC):
|
||||
@ -32,15 +34,82 @@ class ThreadMemoryStorage(abc.ABC):
|
||||
pass
|
||||
|
||||
|
||||
def _memory_to_markdown(data: dict[str, Any]) -> str:
|
||||
owner = data.get("ownerId")
|
||||
owner_text = "null" if owner is None else str(owner)
|
||||
return (
|
||||
"# Thread Memory\n\n"
|
||||
f"Owner ID: {owner_text}\n\n"
|
||||
"## Profile\n"
|
||||
"```json\n"
|
||||
f"{json.dumps(data.get('profile', {}), ensure_ascii=False, indent=2)}\n"
|
||||
"```\n\n"
|
||||
"## Preferences\n"
|
||||
"```json\n"
|
||||
f"{json.dumps(data.get('preferences', {}), ensure_ascii=False, indent=2)}\n"
|
||||
"```\n\n"
|
||||
"## Facts\n"
|
||||
"```json\n"
|
||||
f"{json.dumps(data.get('facts', []), ensure_ascii=False, indent=2)}\n"
|
||||
"```"
|
||||
)
|
||||
|
||||
|
||||
def _memory_from_markdown(markdown: str) -> dict[str, Any]:
|
||||
parsed = create_empty_thread_memory()
|
||||
owner_id: str | None = None
|
||||
owner_line = next((line for line in markdown.splitlines() if line.startswith("Owner ID: ")), None)
|
||||
if owner_line:
|
||||
owner_raw = owner_line.split("Owner ID: ", 1)[1].strip()
|
||||
owner_id = None if owner_raw == "null" else owner_raw
|
||||
|
||||
blocks = _JSON_FENCE_RE.findall(markdown)
|
||||
if len(blocks) >= 1:
|
||||
try:
|
||||
profile = json.loads(blocks[0])
|
||||
if isinstance(profile, dict):
|
||||
parsed["profile"] = profile
|
||||
except Exception:
|
||||
pass
|
||||
if len(blocks) >= 2:
|
||||
try:
|
||||
preferences = json.loads(blocks[1])
|
||||
if isinstance(preferences, dict):
|
||||
parsed["preferences"] = preferences
|
||||
except Exception:
|
||||
pass
|
||||
if len(blocks) >= 3:
|
||||
try:
|
||||
facts = json.loads(blocks[2])
|
||||
if isinstance(facts, list):
|
||||
parsed["facts"] = facts
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"ownerId": owner_id, **parsed}
|
||||
|
||||
|
||||
def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]:
|
||||
memory_md = row[2]
|
||||
if isinstance(memory_md, str) and memory_md.strip():
|
||||
decoded = _memory_from_markdown(memory_md)
|
||||
profile = decoded.get("profile", {})
|
||||
preferences = decoded.get("preferences", {})
|
||||
facts = decoded.get("facts", [])
|
||||
owner_id = decoded.get("ownerId")
|
||||
else:
|
||||
owner_id = row[1]
|
||||
profile = json.loads(row[3])
|
||||
preferences = json.loads(row[4])
|
||||
facts = json.loads(row[5])
|
||||
return {
|
||||
"threadId": row[0],
|
||||
"ownerId": row[1],
|
||||
"profile": json.loads(row[2]),
|
||||
"preferences": json.loads(row[3]),
|
||||
"facts": json.loads(row[4]),
|
||||
"memoryVersion": int(row[5]),
|
||||
"lastUpdated": str(row[6]),
|
||||
"ownerId": owner_id,
|
||||
"profile": profile,
|
||||
"preferences": preferences,
|
||||
"facts": facts,
|
||||
"memoryVersion": int(row[6]),
|
||||
"lastUpdated": str(row[7]),
|
||||
}
|
||||
|
||||
|
||||
@ -58,6 +127,7 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
||||
CREATE TABLE IF NOT EXISTS thread_memory (
|
||||
thread_id TEXT PRIMARY KEY,
|
||||
owner_id TEXT NULL,
|
||||
memory_md TEXT NOT NULL DEFAULT '',
|
||||
profile TEXT NOT NULL DEFAULT '{}',
|
||||
preferences TEXT NOT NULL DEFAULT '{}',
|
||||
facts TEXT NOT NULL DEFAULT '[]',
|
||||
@ -66,6 +136,9 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
||||
)
|
||||
"""
|
||||
)
|
||||
columns = {r[1] for r in self._conn.execute("PRAGMA table_info(thread_memory)").fetchall()}
|
||||
if "memory_md" not in columns:
|
||||
self._conn.execute("ALTER TABLE thread_memory ADD COLUMN memory_md TEXT NOT NULL DEFAULT ''")
|
||||
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)")
|
||||
self._conn.commit()
|
||||
|
||||
@ -76,7 +149,30 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
||||
"FROM thread_memory WHERE thread_id = ?",
|
||||
(thread_id,),
|
||||
).fetchone()
|
||||
return _row_to_memory(row) if row else None
|
||||
if row is None:
|
||||
return None
|
||||
row = (
|
||||
row[0],
|
||||
row[1],
|
||||
"",
|
||||
row[2],
|
||||
row[3],
|
||||
row[4],
|
||||
row[5],
|
||||
row[6],
|
||||
)
|
||||
try:
|
||||
row2 = self._conn.execute(
|
||||
"SELECT thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated "
|
||||
"FROM thread_memory WHERE thread_id = ?",
|
||||
(thread_id,),
|
||||
).fetchone()
|
||||
if row2 is not None:
|
||||
row = row2
|
||||
except sqlite3.OperationalError:
|
||||
# Backward compatibility when running before migration.
|
||||
pass
|
||||
return _row_to_memory(row)
|
||||
|
||||
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
|
||||
now = datetime.utcnow().isoformat() + "Z"
|
||||
@ -86,13 +182,14 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
||||
with self._lock:
|
||||
cur = self._conn.execute(
|
||||
"""
|
||||
INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version, last_updated)
|
||||
VALUES (?, ?, ?, ?, ?, 0, ?)
|
||||
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 0, ?)
|
||||
ON CONFLICT(thread_id) DO NOTHING
|
||||
""",
|
||||
(
|
||||
thread_id,
|
||||
owner_id,
|
||||
_memory_to_markdown(data),
|
||||
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
||||
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
||||
json.dumps(data.get("facts", []), ensure_ascii=False),
|
||||
@ -106,11 +203,12 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
||||
cur = self._conn.execute(
|
||||
"""
|
||||
UPDATE thread_memory
|
||||
SET owner_id = ?, profile = ?, preferences = ?, facts = ?, memory_version = memory_version + 1, last_updated = ?
|
||||
SET owner_id = ?, memory_md = ?, profile = ?, preferences = ?, facts = ?, memory_version = memory_version + 1, last_updated = ?
|
||||
WHERE thread_id = ? AND memory_version = ?
|
||||
""",
|
||||
(
|
||||
owner_id,
|
||||
_memory_to_markdown(data),
|
||||
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
||||
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
||||
json.dumps(data.get("facts", []), ensure_ascii=False),
|
||||
@ -140,6 +238,7 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
|
||||
CREATE TABLE IF NOT EXISTS thread_memory (
|
||||
thread_id VARCHAR(64) PRIMARY KEY,
|
||||
owner_id VARCHAR(64) NULL,
|
||||
memory_md LONGTEXT NOT NULL,
|
||||
profile JSON NOT NULL,
|
||||
preferences JSON NOT NULL,
|
||||
facts JSON NOT NULL,
|
||||
@ -149,12 +248,22 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'thread_memory' AND COLUMN_NAME = 'memory_md'
|
||||
"""
|
||||
)
|
||||
has_memory_md = cur.fetchone()[0] > 0
|
||||
if not has_memory_md:
|
||||
cur.execute("ALTER TABLE thread_memory ADD COLUMN memory_md LONGTEXT NOT NULL")
|
||||
self._conn.commit()
|
||||
|
||||
def load(self, thread_id: str) -> dict[str, Any] | None:
|
||||
with self._conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT thread_id, owner_id, profile, preferences, facts, memory_version, last_updated FROM thread_memory WHERE thread_id = %s",
|
||||
"SELECT thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated FROM thread_memory WHERE thread_id = %s",
|
||||
(thread_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
@ -167,13 +276,14 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
|
||||
with self._conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version)
|
||||
VALUES (%s, %s, %s, %s, %s, 0)
|
||||
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, 0)
|
||||
ON DUPLICATE KEY UPDATE thread_id = thread_id
|
||||
""",
|
||||
(
|
||||
thread_id,
|
||||
owner_id,
|
||||
_memory_to_markdown(data),
|
||||
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
||||
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
||||
json.dumps(data.get("facts", []), ensure_ascii=False),
|
||||
@ -185,11 +295,12 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE thread_memory
|
||||
SET owner_id = %s, profile = %s, preferences = %s, facts = %s, memory_version = memory_version + 1
|
||||
SET owner_id = %s, memory_md = %s, profile = %s, preferences = %s, facts = %s, memory_version = memory_version + 1
|
||||
WHERE thread_id = %s AND memory_version = %s
|
||||
""",
|
||||
(
|
||||
owner_id,
|
||||
_memory_to_markdown(data),
|
||||
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
||||
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
||||
json.dumps(data.get("facts", []), ensure_ascii=False),
|
||||
|
||||
@ -27,3 +27,45 @@ def test_sqlite_thread_memory_compare_and_swap(tmp_path):
|
||||
assert loaded2 is not None
|
||||
assert loaded2["memoryVersion"] == 1
|
||||
|
||||
|
||||
def test_sqlite_thread_memory_saves_markdown_payload(tmp_path):
|
||||
db_path = tmp_path / "thread-memory.db"
|
||||
storage = SqliteThreadMemoryStorage(str(db_path))
|
||||
thread_id = "thread-md"
|
||||
|
||||
assert storage.save(thread_id, _payload(), expected_version=0) is True
|
||||
|
||||
with storage._lock:
|
||||
row = storage._conn.execute("SELECT memory_md FROM thread_memory WHERE thread_id = ?", (thread_id,)).fetchone()
|
||||
assert row is not None
|
||||
assert isinstance(row[0], str)
|
||||
assert "## Profile" in row[0]
|
||||
assert "## Preferences" in row[0]
|
||||
assert "## Facts" in row[0]
|
||||
|
||||
|
||||
def test_sqlite_thread_memory_loads_legacy_json_row(tmp_path):
|
||||
db_path = tmp_path / "legacy-thread-memory.db"
|
||||
storage = SqliteThreadMemoryStorage(str(db_path))
|
||||
thread_id = "thread-legacy"
|
||||
|
||||
with storage._lock:
|
||||
storage._conn.execute(
|
||||
"""
|
||||
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated)
|
||||
VALUES (?, ?, '', ?, ?, ?, 0, datetime('now'))
|
||||
""",
|
||||
(
|
||||
thread_id,
|
||||
"owner-1",
|
||||
'{"name":"Alice","role":null,"expertise":[],"language":null,"context":null}',
|
||||
'{"tone":null,"verbosity":null,"codeStyle":null,"other":null}',
|
||||
"[]",
|
||||
),
|
||||
)
|
||||
storage._conn.commit()
|
||||
|
||||
loaded = storage.load(thread_id)
|
||||
assert loaded is not None
|
||||
assert loaded["ownerId"] == "owner-1"
|
||||
assert loaded["profile"]["name"] == "Alice"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user