feat: 增加MD列
This commit is contained in:
parent
b49e838980
commit
7db468aa6f
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -16,6 +17,7 @@ from deerflow.config.paths import get_paths
|
|||||||
from deerflow.config.thread_memory_config import get_thread_memory_config
|
from deerflow.config.thread_memory_config import get_thread_memory_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
_JSON_FENCE_RE = re.compile(r"```json\s*(.*?)\s*```", re.DOTALL | re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
class ThreadMemoryStorage(abc.ABC):
|
class ThreadMemoryStorage(abc.ABC):
|
||||||
@ -32,15 +34,82 @@ class ThreadMemoryStorage(abc.ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_to_markdown(data: dict[str, Any]) -> str:
|
||||||
|
owner = data.get("ownerId")
|
||||||
|
owner_text = "null" if owner is None else str(owner)
|
||||||
|
return (
|
||||||
|
"# Thread Memory\n\n"
|
||||||
|
f"Owner ID: {owner_text}\n\n"
|
||||||
|
"## Profile\n"
|
||||||
|
"```json\n"
|
||||||
|
f"{json.dumps(data.get('profile', {}), ensure_ascii=False, indent=2)}\n"
|
||||||
|
"```\n\n"
|
||||||
|
"## Preferences\n"
|
||||||
|
"```json\n"
|
||||||
|
f"{json.dumps(data.get('preferences', {}), ensure_ascii=False, indent=2)}\n"
|
||||||
|
"```\n\n"
|
||||||
|
"## Facts\n"
|
||||||
|
"```json\n"
|
||||||
|
f"{json.dumps(data.get('facts', []), ensure_ascii=False, indent=2)}\n"
|
||||||
|
"```"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_from_markdown(markdown: str) -> dict[str, Any]:
|
||||||
|
parsed = create_empty_thread_memory()
|
||||||
|
owner_id: str | None = None
|
||||||
|
owner_line = next((line for line in markdown.splitlines() if line.startswith("Owner ID: ")), None)
|
||||||
|
if owner_line:
|
||||||
|
owner_raw = owner_line.split("Owner ID: ", 1)[1].strip()
|
||||||
|
owner_id = None if owner_raw == "null" else owner_raw
|
||||||
|
|
||||||
|
blocks = _JSON_FENCE_RE.findall(markdown)
|
||||||
|
if len(blocks) >= 1:
|
||||||
|
try:
|
||||||
|
profile = json.loads(blocks[0])
|
||||||
|
if isinstance(profile, dict):
|
||||||
|
parsed["profile"] = profile
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if len(blocks) >= 2:
|
||||||
|
try:
|
||||||
|
preferences = json.loads(blocks[1])
|
||||||
|
if isinstance(preferences, dict):
|
||||||
|
parsed["preferences"] = preferences
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if len(blocks) >= 3:
|
||||||
|
try:
|
||||||
|
facts = json.loads(blocks[2])
|
||||||
|
if isinstance(facts, list):
|
||||||
|
parsed["facts"] = facts
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {"ownerId": owner_id, **parsed}
|
||||||
|
|
||||||
|
|
||||||
def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]:
|
def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]:
|
||||||
|
memory_md = row[2]
|
||||||
|
if isinstance(memory_md, str) and memory_md.strip():
|
||||||
|
decoded = _memory_from_markdown(memory_md)
|
||||||
|
profile = decoded.get("profile", {})
|
||||||
|
preferences = decoded.get("preferences", {})
|
||||||
|
facts = decoded.get("facts", [])
|
||||||
|
owner_id = decoded.get("ownerId")
|
||||||
|
else:
|
||||||
|
owner_id = row[1]
|
||||||
|
profile = json.loads(row[3])
|
||||||
|
preferences = json.loads(row[4])
|
||||||
|
facts = json.loads(row[5])
|
||||||
return {
|
return {
|
||||||
"threadId": row[0],
|
"threadId": row[0],
|
||||||
"ownerId": row[1],
|
"ownerId": owner_id,
|
||||||
"profile": json.loads(row[2]),
|
"profile": profile,
|
||||||
"preferences": json.loads(row[3]),
|
"preferences": preferences,
|
||||||
"facts": json.loads(row[4]),
|
"facts": facts,
|
||||||
"memoryVersion": int(row[5]),
|
"memoryVersion": int(row[6]),
|
||||||
"lastUpdated": str(row[6]),
|
"lastUpdated": str(row[7]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -58,6 +127,7 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
|||||||
CREATE TABLE IF NOT EXISTS thread_memory (
|
CREATE TABLE IF NOT EXISTS thread_memory (
|
||||||
thread_id TEXT PRIMARY KEY,
|
thread_id TEXT PRIMARY KEY,
|
||||||
owner_id TEXT NULL,
|
owner_id TEXT NULL,
|
||||||
|
memory_md TEXT NOT NULL DEFAULT '',
|
||||||
profile TEXT NOT NULL DEFAULT '{}',
|
profile TEXT NOT NULL DEFAULT '{}',
|
||||||
preferences TEXT NOT NULL DEFAULT '{}',
|
preferences TEXT NOT NULL DEFAULT '{}',
|
||||||
facts TEXT NOT NULL DEFAULT '[]',
|
facts TEXT NOT NULL DEFAULT '[]',
|
||||||
@ -66,6 +136,9 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
columns = {r[1] for r in self._conn.execute("PRAGMA table_info(thread_memory)").fetchall()}
|
||||||
|
if "memory_md" not in columns:
|
||||||
|
self._conn.execute("ALTER TABLE thread_memory ADD COLUMN memory_md TEXT NOT NULL DEFAULT ''")
|
||||||
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)")
|
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)")
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
@ -76,7 +149,30 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
|||||||
"FROM thread_memory WHERE thread_id = ?",
|
"FROM thread_memory WHERE thread_id = ?",
|
||||||
(thread_id,),
|
(thread_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
return _row_to_memory(row) if row else None
|
if row is None:
|
||||||
|
return None
|
||||||
|
row = (
|
||||||
|
row[0],
|
||||||
|
row[1],
|
||||||
|
"",
|
||||||
|
row[2],
|
||||||
|
row[3],
|
||||||
|
row[4],
|
||||||
|
row[5],
|
||||||
|
row[6],
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
row2 = self._conn.execute(
|
||||||
|
"SELECT thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated "
|
||||||
|
"FROM thread_memory WHERE thread_id = ?",
|
||||||
|
(thread_id,),
|
||||||
|
).fetchone()
|
||||||
|
if row2 is not None:
|
||||||
|
row = row2
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
# Backward compatibility when running before migration.
|
||||||
|
pass
|
||||||
|
return _row_to_memory(row)
|
||||||
|
|
||||||
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
|
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
|
||||||
now = datetime.utcnow().isoformat() + "Z"
|
now = datetime.utcnow().isoformat() + "Z"
|
||||||
@ -86,13 +182,14 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
cur = self._conn.execute(
|
cur = self._conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version, last_updated)
|
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated)
|
||||||
VALUES (?, ?, ?, ?, ?, 0, ?)
|
VALUES (?, ?, ?, ?, ?, ?, 0, ?)
|
||||||
ON CONFLICT(thread_id) DO NOTHING
|
ON CONFLICT(thread_id) DO NOTHING
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
thread_id,
|
thread_id,
|
||||||
owner_id,
|
owner_id,
|
||||||
|
_memory_to_markdown(data),
|
||||||
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
||||||
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
||||||
json.dumps(data.get("facts", []), ensure_ascii=False),
|
json.dumps(data.get("facts", []), ensure_ascii=False),
|
||||||
@ -106,11 +203,12 @@ class SqliteThreadMemoryStorage(ThreadMemoryStorage):
|
|||||||
cur = self._conn.execute(
|
cur = self._conn.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE thread_memory
|
UPDATE thread_memory
|
||||||
SET owner_id = ?, profile = ?, preferences = ?, facts = ?, memory_version = memory_version + 1, last_updated = ?
|
SET owner_id = ?, memory_md = ?, profile = ?, preferences = ?, facts = ?, memory_version = memory_version + 1, last_updated = ?
|
||||||
WHERE thread_id = ? AND memory_version = ?
|
WHERE thread_id = ? AND memory_version = ?
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
owner_id,
|
owner_id,
|
||||||
|
_memory_to_markdown(data),
|
||||||
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
||||||
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
||||||
json.dumps(data.get("facts", []), ensure_ascii=False),
|
json.dumps(data.get("facts", []), ensure_ascii=False),
|
||||||
@ -140,6 +238,7 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
|
|||||||
CREATE TABLE IF NOT EXISTS thread_memory (
|
CREATE TABLE IF NOT EXISTS thread_memory (
|
||||||
thread_id VARCHAR(64) PRIMARY KEY,
|
thread_id VARCHAR(64) PRIMARY KEY,
|
||||||
owner_id VARCHAR(64) NULL,
|
owner_id VARCHAR(64) NULL,
|
||||||
|
memory_md LONGTEXT NOT NULL,
|
||||||
profile JSON NOT NULL,
|
profile JSON NOT NULL,
|
||||||
preferences JSON NOT NULL,
|
preferences JSON NOT NULL,
|
||||||
facts JSON NOT NULL,
|
facts JSON NOT NULL,
|
||||||
@ -149,12 +248,22 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM information_schema.COLUMNS
|
||||||
|
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'thread_memory' AND COLUMN_NAME = 'memory_md'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
has_memory_md = cur.fetchone()[0] > 0
|
||||||
|
if not has_memory_md:
|
||||||
|
cur.execute("ALTER TABLE thread_memory ADD COLUMN memory_md LONGTEXT NOT NULL")
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
def load(self, thread_id: str) -> dict[str, Any] | None:
|
def load(self, thread_id: str) -> dict[str, Any] | None:
|
||||||
with self._conn.cursor() as cur:
|
with self._conn.cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT thread_id, owner_id, profile, preferences, facts, memory_version, last_updated FROM thread_memory WHERE thread_id = %s",
|
"SELECT thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated FROM thread_memory WHERE thread_id = %s",
|
||||||
(thread_id,),
|
(thread_id,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
@ -167,13 +276,14 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
|
|||||||
with self._conn.cursor() as cur:
|
with self._conn.cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO thread_memory (thread_id, owner_id, profile, preferences, facts, memory_version)
|
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version)
|
||||||
VALUES (%s, %s, %s, %s, %s, 0)
|
VALUES (%s, %s, %s, %s, %s, %s, 0)
|
||||||
ON DUPLICATE KEY UPDATE thread_id = thread_id
|
ON DUPLICATE KEY UPDATE thread_id = thread_id
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
thread_id,
|
thread_id,
|
||||||
owner_id,
|
owner_id,
|
||||||
|
_memory_to_markdown(data),
|
||||||
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
||||||
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
||||||
json.dumps(data.get("facts", []), ensure_ascii=False),
|
json.dumps(data.get("facts", []), ensure_ascii=False),
|
||||||
@ -185,11 +295,12 @@ class MysqlThreadMemoryStorage(ThreadMemoryStorage):
|
|||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE thread_memory
|
UPDATE thread_memory
|
||||||
SET owner_id = %s, profile = %s, preferences = %s, facts = %s, memory_version = memory_version + 1
|
SET owner_id = %s, memory_md = %s, profile = %s, preferences = %s, facts = %s, memory_version = memory_version + 1
|
||||||
WHERE thread_id = %s AND memory_version = %s
|
WHERE thread_id = %s AND memory_version = %s
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
owner_id,
|
owner_id,
|
||||||
|
_memory_to_markdown(data),
|
||||||
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
json.dumps(data.get("profile", {}), ensure_ascii=False),
|
||||||
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
json.dumps(data.get("preferences", {}), ensure_ascii=False),
|
||||||
json.dumps(data.get("facts", []), ensure_ascii=False),
|
json.dumps(data.get("facts", []), ensure_ascii=False),
|
||||||
|
|||||||
@ -27,3 +27,45 @@ def test_sqlite_thread_memory_compare_and_swap(tmp_path):
|
|||||||
assert loaded2 is not None
|
assert loaded2 is not None
|
||||||
assert loaded2["memoryVersion"] == 1
|
assert loaded2["memoryVersion"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_thread_memory_saves_markdown_payload(tmp_path):
|
||||||
|
db_path = tmp_path / "thread-memory.db"
|
||||||
|
storage = SqliteThreadMemoryStorage(str(db_path))
|
||||||
|
thread_id = "thread-md"
|
||||||
|
|
||||||
|
assert storage.save(thread_id, _payload(), expected_version=0) is True
|
||||||
|
|
||||||
|
with storage._lock:
|
||||||
|
row = storage._conn.execute("SELECT memory_md FROM thread_memory WHERE thread_id = ?", (thread_id,)).fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert isinstance(row[0], str)
|
||||||
|
assert "## Profile" in row[0]
|
||||||
|
assert "## Preferences" in row[0]
|
||||||
|
assert "## Facts" in row[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_thread_memory_loads_legacy_json_row(tmp_path):
|
||||||
|
db_path = tmp_path / "legacy-thread-memory.db"
|
||||||
|
storage = SqliteThreadMemoryStorage(str(db_path))
|
||||||
|
thread_id = "thread-legacy"
|
||||||
|
|
||||||
|
with storage._lock:
|
||||||
|
storage._conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO thread_memory (thread_id, owner_id, memory_md, profile, preferences, facts, memory_version, last_updated)
|
||||||
|
VALUES (?, ?, '', ?, ?, ?, 0, datetime('now'))
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
thread_id,
|
||||||
|
"owner-1",
|
||||||
|
'{"name":"Alice","role":null,"expertise":[],"language":null,"context":null}',
|
||||||
|
'{"tone":null,"verbosity":null,"codeStyle":null,"other":null}',
|
||||||
|
"[]",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
storage._conn.commit()
|
||||||
|
|
||||||
|
loaded = storage.load(thread_id)
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded["ownerId"] == "owner-1"
|
||||||
|
assert loaded["profile"]["name"] == "Alice"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user