feat: implement shallow SQLite checkpoint savers and update configuration for persistence modes
This commit is contained in:
parent
f209057b18
commit
f3558d6bb2
|
|
@ -7,3 +7,14 @@ __all__ = [
|
||||||
"checkpointer_context",
|
"checkpointer_context",
|
||||||
"make_checkpointer",
|
"make_checkpointer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Lazy-import shallow savers so the module is still importable without
|
||||||
|
# langgraph-checkpoint-sqlite installed.
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name == "AsyncShallowSqliteSaver":
|
||||||
|
from .shallow_sqlite import _make_async_shallow_saver
|
||||||
|
return _make_async_shallow_saver()
|
||||||
|
if name == "ShallowSqliteSaver":
|
||||||
|
from .shallow_sqlite import _make_sync_shallow_saver
|
||||||
|
return _make_sync_shallow_saver()
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,18 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||||
|
|
||||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||||
ensure_sqlite_parent_dir(conn_str)
|
ensure_sqlite_parent_dir(conn_str)
|
||||||
|
|
||||||
|
# Shallow mode: use custom saver that keeps only the latest checkpoint per thread
|
||||||
|
if getattr(config, "sqlite_mode", "full") == "shallow":
|
||||||
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
||||||
|
|
||||||
|
ShallowSaver = _make_async_shallow_saver()
|
||||||
|
async with ShallowSaver.from_conn_string(conn_str) as saver:
|
||||||
|
await saver.setup()
|
||||||
|
logger.info("Checkpointer: using AsyncShallowSqliteSaver (%s)", conn_str)
|
||||||
|
yield saver
|
||||||
|
return
|
||||||
|
|
||||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||||
await saver.setup()
|
await saver.setup()
|
||||||
yield saver
|
yield saver
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,18 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
||||||
raise ImportError(SQLITE_INSTALL) from exc
|
raise ImportError(SQLITE_INSTALL) from exc
|
||||||
|
|
||||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||||
|
|
||||||
|
# Shallow mode: use custom saver that keeps only the latest checkpoint per thread
|
||||||
|
if getattr(config, "sqlite_mode", "full") == "shallow":
|
||||||
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
||||||
|
|
||||||
|
ShallowSaver = _make_sync_shallow_saver()
|
||||||
|
with ShallowSaver.from_conn_string(conn_str) as saver:
|
||||||
|
saver.setup()
|
||||||
|
logger.info("Checkpointer: using ShallowSqliteSaver (%s)", conn_str)
|
||||||
|
yield saver
|
||||||
|
return
|
||||||
|
|
||||||
with SqliteSaver.from_conn_string(conn_str) as saver:
|
with SqliteSaver.from_conn_string(conn_str) as saver:
|
||||||
saver.setup()
|
saver.setup()
|
||||||
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
|
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,128 @@
|
||||||
|
"""Shallow persistence savers for LangGraph SQLite checkpointing.
|
||||||
|
|
||||||
|
Provides shallow (single-checkpoint-per-thread) variants of the LangGraph
|
||||||
|
SQLite savers that automatically delete old checkpoints and writes for the
|
||||||
|
same thread before each write, keeping only the latest state.
|
||||||
|
|
||||||
|
This prevents unbounded growth of ``checkpoints.db`` while preserving
|
||||||
|
multi-turn conversation continuity.
|
||||||
|
|
||||||
|
Implements:
|
||||||
|
- ``AsyncShallowSqliteSaver`` — async shallow variant
|
||||||
|
- ``ShallowSqliteSaver`` — sync shallow variant
|
||||||
|
|
||||||
|
Usage is transparent through the existing checkpointer factory when
|
||||||
|
``sqlite_mode: shallow`` is set in ``config.yaml``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.checkpoint.base import ChannelVersions, Checkpoint, CheckpointMetadata
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Async shallow saver
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncShallowSqliteSaver:
|
||||||
|
"""Async SQLite checkpointer that keeps only the latest checkpoint per thread.
|
||||||
|
|
||||||
|
Extends :class:`langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver` and
|
||||||
|
overrides :meth:`aput` to delete all existing checkpoints and writes for
|
||||||
|
the same ``thread_id`` before inserting the new one.
|
||||||
|
|
||||||
|
Each conversation thread stores exactly one checkpoint at any time,
|
||||||
|
preventing unbounded database growth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
# Allow extension by subclasses without forcing late-bound import here.
|
||||||
|
# The concrete class is built below via the _make_async_shallow factory.
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_async_shallow_saver() -> type:
|
||||||
|
"""Build and return the ``AsyncShallowSqliteSaver`` class.
|
||||||
|
|
||||||
|
Import is deferred so that the module is importable even when
|
||||||
|
``langgraph-checkpoint-sqlite`` is not installed.
|
||||||
|
"""
|
||||||
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||||
|
|
||||||
|
class AsyncShallowSqliteSaver(AsyncSqliteSaver):
|
||||||
|
async def aput(
|
||||||
|
self,
|
||||||
|
config: RunnableConfig,
|
||||||
|
checkpoint: Checkpoint,
|
||||||
|
metadata: CheckpointMetadata,
|
||||||
|
new_versions: ChannelVersions,
|
||||||
|
) -> RunnableConfig:
|
||||||
|
"""Delete old checkpoints/writes for this thread, then save the new one."""
|
||||||
|
|
||||||
|
thread_id = config["configurable"]["thread_id"]
|
||||||
|
await self.setup()
|
||||||
|
|
||||||
|
# Delete all existing checkpoints and writes for this thread
|
||||||
|
# before inserting the new checkpoint — keeps only the latest.
|
||||||
|
async with self.lock:
|
||||||
|
await self.conn.execute(
|
||||||
|
"DELETE FROM checkpoints WHERE thread_id = ?",
|
||||||
|
(str(thread_id),),
|
||||||
|
)
|
||||||
|
await self.conn.execute(
|
||||||
|
"DELETE FROM writes WHERE thread_id = ?",
|
||||||
|
(str(thread_id),),
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
|
||||||
|
return await super().aput(config, checkpoint, metadata, new_versions)
|
||||||
|
|
||||||
|
return AsyncShallowSqliteSaver
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sync shallow saver
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sync_shallow_saver() -> type:
|
||||||
|
"""Build and return the ``ShallowSqliteSaver`` class.
|
||||||
|
|
||||||
|
Import is deferred so that the module is importable even when
|
||||||
|
``langgraph-checkpoint-sqlite`` is not installed.
|
||||||
|
"""
|
||||||
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
|
||||||
|
class ShallowSqliteSaver(SqliteSaver):
|
||||||
|
def put(
|
||||||
|
self,
|
||||||
|
config: RunnableConfig,
|
||||||
|
checkpoint: Checkpoint,
|
||||||
|
metadata: CheckpointMetadata,
|
||||||
|
new_versions: ChannelVersions,
|
||||||
|
) -> RunnableConfig:
|
||||||
|
"""Delete old checkpoints/writes for this thread, then save the new one."""
|
||||||
|
|
||||||
|
thread_id = config["configurable"]["thread_id"]
|
||||||
|
|
||||||
|
# Delete all existing checkpoints and writes for this thread
|
||||||
|
# before inserting the new checkpoint — keeps only the latest.
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"DELETE FROM checkpoints WHERE thread_id = ?",
|
||||||
|
(str(thread_id),),
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"DELETE FROM writes WHERE thread_id = ?",
|
||||||
|
(str(thread_id),),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().put(config, checkpoint, metadata, new_versions)
|
||||||
|
|
||||||
|
return ShallowSqliteSaver
|
||||||
|
|
@ -6,6 +6,14 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
CheckpointerType = Literal["memory", "sqlite", "postgres"]
|
CheckpointerType = Literal["memory", "sqlite", "postgres"]
|
||||||
|
|
||||||
|
SqliteMode = Literal["full", "shallow"]
|
||||||
|
"""Persistence mode for the SQLite checkpointer.
|
||||||
|
|
||||||
|
- ``full`` — retain all checkpoint history (default, backward-compatible).
|
||||||
|
- ``shallow`` — keep only the latest checkpoint per thread, deleting old
|
||||||
|
records before each write to prevent unbounded database growth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class CheckpointerConfig(BaseModel):
|
class CheckpointerConfig(BaseModel):
|
||||||
"""Configuration for LangGraph state persistence checkpointer."""
|
"""Configuration for LangGraph state persistence checkpointer."""
|
||||||
|
|
@ -23,6 +31,13 @@ class CheckpointerConfig(BaseModel):
|
||||||
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
|
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
|
||||||
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
|
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
|
||||||
)
|
)
|
||||||
|
sqlite_mode: SqliteMode = Field(
|
||||||
|
default="full",
|
||||||
|
description="SQLite persistence mode. "
|
||||||
|
"'full' retains all checkpoint history (default). "
|
||||||
|
"'shallow' keeps only the latest checkpoint per thread, "
|
||||||
|
"deleting old records before each write.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Global configuration instance — None means no checkpointer is configured.
|
# Global configuration instance — None means no checkpointer is configured.
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True):
|
||||||
default_estimated_output_tokens=None,
|
default_estimated_output_tokens=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096})
|
model_cfg = SimpleNamespace(display_name="GPT-4", model="gpt-4", model_extra={"max_tokens": 4096})
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
billing=billing,
|
billing=billing,
|
||||||
get_model_config=lambda name: model_cfg if name == "gpt-4" else None,
|
get_model_config=lambda name: model_cfg if name == "gpt-4" else None,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,305 @@
|
||||||
|
"""Tests for shallow SQLite checkpoint savers (single-checkpoint-per-thread mode).
|
||||||
|
|
||||||
|
Uses in-memory SQLite (``:memory:``) — no filesystem dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AsyncShallowSqliteSaver tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncShallowSqliteSaver:
|
||||||
|
"""Tests for ``AsyncShallowSqliteSaver`` — async shallow persistence."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_aput_deletes_old_checkpoints_before_insert(self):
|
||||||
|
"""After two aput calls for the same thread, only 1 checkpoint remains."""
|
||||||
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
||||||
|
|
||||||
|
ShallowSaver = _make_async_shallow_saver()
|
||||||
|
|
||||||
|
async with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||||
|
await saver.setup()
|
||||||
|
|
||||||
|
thread_config = {"configurable": {"thread_id": "test-thread-1", "checkpoint_ns": ""}}
|
||||||
|
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-1", "channel_values": {"x": 1}}
|
||||||
|
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-2", "channel_values": {"x": 2}}
|
||||||
|
|
||||||
|
# Write first checkpoint
|
||||||
|
await saver.aput(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
|
||||||
|
|
||||||
|
# Write second checkpoint — should delete the first
|
||||||
|
await saver.aput(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||||
|
|
||||||
|
# Verify only 1 checkpoint remains
|
||||||
|
results = []
|
||||||
|
async for ckpt in saver.alist(thread_config):
|
||||||
|
results.append(ckpt)
|
||||||
|
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
|
||||||
|
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-2"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_different_threads_do_not_interfere(self):
|
||||||
|
"""Checkpoints for different thread_ids are independent."""
|
||||||
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
||||||
|
|
||||||
|
ShallowSaver = _make_async_shallow_saver()
|
||||||
|
|
||||||
|
async with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||||
|
await saver.setup()
|
||||||
|
|
||||||
|
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
|
||||||
|
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
|
||||||
|
|
||||||
|
await saver.aput(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||||
|
await saver.aput(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||||
|
await saver.aput(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||||
|
|
||||||
|
# Thread A: only ckpt a2
|
||||||
|
a_results = [c async for c in saver.alist(t1)]
|
||||||
|
assert len(a_results) == 1
|
||||||
|
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
|
||||||
|
|
||||||
|
# Thread B: still has b1
|
||||||
|
b_results = [c async for c in saver.alist(t2)]
|
||||||
|
assert len(b_results) == 1
|
||||||
|
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_writes_table_also_cleaned(self):
|
||||||
|
"""aput_writes entries from old checkpoints are also deleted."""
|
||||||
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
||||||
|
|
||||||
|
ShallowSaver = _make_async_shallow_saver()
|
||||||
|
|
||||||
|
async with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||||
|
await saver.setup()
|
||||||
|
|
||||||
|
thread_config = {"configurable": {"thread_id": "test-writes", "checkpoint_ns": ""}}
|
||||||
|
|
||||||
|
# Write checkpoint 1 with associated writes
|
||||||
|
ckpt1_config = await saver.aput(
|
||||||
|
thread_config,
|
||||||
|
{"ts": "Z", "id": "ckpt-w1", "channel_values": {}},
|
||||||
|
{"source": "input", "step": 1, "writes": {}},
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
await saver.aput_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
|
||||||
|
|
||||||
|
# Write checkpoint 2 — should delete ckpt1 + its writes
|
||||||
|
ckpt2_config = await saver.aput(
|
||||||
|
thread_config,
|
||||||
|
{"ts": "Z", "id": "ckpt-w2", "channel_values": {}},
|
||||||
|
{"source": "loop", "step": 2, "writes": {}},
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
await saver.aput_writes(ckpt2_config, [("messages", "world")], "task-2", "")
|
||||||
|
|
||||||
|
# Verify only 1 checkpoint remains
|
||||||
|
results = [c async for c in saver.alist(thread_config)]
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-w2"
|
||||||
|
|
||||||
|
# Verify only the latest writes exist (get_tuple returns writes)
|
||||||
|
latest = await saver.aget_tuple(ckpt2_config)
|
||||||
|
assert latest is not None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_full_mode_retains_all_checkpoints(self):
|
||||||
|
"""In full mode (default), all checkpoints are preserved."""
|
||||||
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||||
|
|
||||||
|
async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
|
||||||
|
await saver.setup()
|
||||||
|
|
||||||
|
thread_config = {"configurable": {"thread_id": "test-full", "checkpoint_ns": ""}}
|
||||||
|
|
||||||
|
await saver.aput(thread_config, {"ts": "Z", "id": "f1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||||
|
await saver.aput(thread_config, {"ts": "Z", "id": "f2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||||
|
|
||||||
|
results = [c async for c in saver.alist(thread_config)]
|
||||||
|
assert len(results) == 2, "Full mode should retain all checkpoints"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ShallowSqliteSaver (sync) tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestShallowSqliteSaver:
|
||||||
|
"""Tests for ``ShallowSqliteSaver`` — sync shallow persistence."""
|
||||||
|
|
||||||
|
def test_put_deletes_old_checkpoints_before_insert(self):
|
||||||
|
"""After two put calls for the same thread, only 1 checkpoint remains."""
|
||||||
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
||||||
|
|
||||||
|
ShallowSaver = _make_sync_shallow_saver()
|
||||||
|
|
||||||
|
with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||||
|
saver.setup()
|
||||||
|
|
||||||
|
thread_config = {"configurable": {"thread_id": "test-sync-1", "checkpoint_ns": ""}}
|
||||||
|
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-s1", "channel_values": {"x": 1}}
|
||||||
|
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-s2", "channel_values": {"x": 2}}
|
||||||
|
|
||||||
|
saver.put(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
|
||||||
|
saver.put(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||||
|
|
||||||
|
results = list(saver.list(thread_config))
|
||||||
|
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
|
||||||
|
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-s2"
|
||||||
|
|
||||||
|
def test_different_threads_do_not_interfere_sync(self):
|
||||||
|
"""Checkpoints for different thread_ids are independent (sync)."""
|
||||||
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
||||||
|
|
||||||
|
ShallowSaver = _make_sync_shallow_saver()
|
||||||
|
|
||||||
|
with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||||
|
saver.setup()
|
||||||
|
|
||||||
|
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
|
||||||
|
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
|
||||||
|
|
||||||
|
saver.put(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||||
|
saver.put(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||||
|
saver.put(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||||
|
|
||||||
|
a_results = list(saver.list(t1))
|
||||||
|
assert len(a_results) == 1
|
||||||
|
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
|
||||||
|
|
||||||
|
b_results = list(saver.list(t2))
|
||||||
|
assert len(b_results) == 1
|
||||||
|
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
|
||||||
|
|
||||||
|
def test_writes_table_also_cleaned_sync(self):
|
||||||
|
"""put_writes entries from old checkpoints are also deleted (sync)."""
|
||||||
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
||||||
|
|
||||||
|
ShallowSaver = _make_sync_shallow_saver()
|
||||||
|
|
||||||
|
with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||||
|
saver.setup()
|
||||||
|
|
||||||
|
thread_config = {"configurable": {"thread_id": "test-sync-writes", "checkpoint_ns": ""}}
|
||||||
|
|
||||||
|
ckpt1_config = saver.put(
|
||||||
|
thread_config,
|
||||||
|
{"ts": "Z", "id": "ckpt-sw1", "channel_values": {}},
|
||||||
|
{"source": "input", "step": 1, "writes": {}},
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
saver.put_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
|
||||||
|
|
||||||
|
ckpt2_config = saver.put(
|
||||||
|
thread_config,
|
||||||
|
{"ts": "Z", "id": "ckpt-sw2", "channel_values": {}},
|
||||||
|
{"source": "loop", "step": 2, "writes": {}},
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
saver.put_writes(ckpt2_config, [("messages", "world")], "task-2", "")
|
||||||
|
|
||||||
|
results = list(saver.list(thread_config))
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-sw2"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config integration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestShallowConfig:
|
||||||
|
"""Tests for configuration and factory integration."""
|
||||||
|
|
||||||
|
def test_sqlite_mode_defaults_to_full(self):
|
||||||
|
"""sqlite_mode defaults to 'full' when not specified."""
|
||||||
|
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||||
|
|
||||||
|
config = CheckpointerConfig(type="sqlite", connection_string="test.db")
|
||||||
|
assert config.sqlite_mode == "full"
|
||||||
|
|
||||||
|
def test_sqlite_mode_shallow_accepted(self):
|
||||||
|
"""sqlite_mode can be set to 'shallow'."""
|
||||||
|
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||||
|
|
||||||
|
config = CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="shallow")
|
||||||
|
assert config.sqlite_mode == "shallow"
|
||||||
|
|
||||||
|
def test_load_sqlite_config_with_shallow_mode(self):
|
||||||
|
"""load_checkpointer_config_from_dict accepts sqlite_mode."""
|
||||||
|
from deerflow.config.checkpointer_config import (
|
||||||
|
get_checkpointer_config,
|
||||||
|
load_checkpointer_config_from_dict,
|
||||||
|
set_checkpointer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
set_checkpointer_config(None)
|
||||||
|
load_checkpointer_config_from_dict({
|
||||||
|
"type": "sqlite",
|
||||||
|
"connection_string": "/tmp/test.db",
|
||||||
|
"sqlite_mode": "shallow",
|
||||||
|
})
|
||||||
|
config = get_checkpointer_config()
|
||||||
|
assert config is not None
|
||||||
|
assert config.sqlite_mode == "shallow"
|
||||||
|
|
||||||
|
def test_load_sqlite_config_defaults_sqlite_mode(self):
|
||||||
|
"""load_checkpointer_config_from_dict defaults sqlite_mode to 'full' when omitted."""
|
||||||
|
from deerflow.config.checkpointer_config import (
|
||||||
|
get_checkpointer_config,
|
||||||
|
load_checkpointer_config_from_dict,
|
||||||
|
set_checkpointer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
set_checkpointer_config(None)
|
||||||
|
load_checkpointer_config_from_dict({
|
||||||
|
"type": "sqlite",
|
||||||
|
"connection_string": "/tmp/test.db",
|
||||||
|
})
|
||||||
|
config = get_checkpointer_config()
|
||||||
|
assert config is not None
|
||||||
|
assert config.sqlite_mode == "full"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_async_factory_uses_shallow_saver(self):
|
||||||
|
"""When sqlite_mode=shallow, async factory returns AsyncShallowSqliteSaver."""
|
||||||
|
from deerflow.agents.checkpointer.async_provider import _async_checkpointer
|
||||||
|
from deerflow.config.checkpointer_config import CheckpointerConfig, set_checkpointer_config
|
||||||
|
|
||||||
|
set_checkpointer_config(CheckpointerConfig(
|
||||||
|
type="sqlite",
|
||||||
|
connection_string=":memory:",
|
||||||
|
sqlite_mode="shallow",
|
||||||
|
))
|
||||||
|
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
|
||||||
|
|
||||||
|
async with _async_checkpointer(config) as saver:
|
||||||
|
# Should be an instance of the shallow saver
|
||||||
|
cls_name = type(saver).__name__
|
||||||
|
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
|
||||||
|
|
||||||
|
set_checkpointer_config(None)
|
||||||
|
|
||||||
|
def test_sync_factory_uses_shallow_saver(self):
|
||||||
|
"""When sqlite_mode=shallow, sync factory returns ShallowSqliteSaver."""
|
||||||
|
from deerflow.agents.checkpointer.provider import _sync_checkpointer_cm
|
||||||
|
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||||
|
|
||||||
|
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
|
||||||
|
|
||||||
|
with _sync_checkpointer_cm(config) as saver:
|
||||||
|
cls_name = type(saver).__name__
|
||||||
|
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
|
||||||
|
|
||||||
|
def test_invalid_sqlite_mode_raises(self):
|
||||||
|
"""Invalid sqlite_mode value raises validation error."""
|
||||||
|
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="unknown")
|
||||||
|
|
@ -709,6 +709,10 @@ memory:
|
||||||
# memory - In-process only. State is lost when the process exits. (default)
|
# memory - In-process only. State is lost when the process exits. (default)
|
||||||
# sqlite - File-based SQLite persistence. Survives restarts.
|
# sqlite - File-based SQLite persistence. Survives restarts.
|
||||||
# Requires: uv add langgraph-checkpoint-sqlite
|
# Requires: uv add langgraph-checkpoint-sqlite
|
||||||
|
# sqlite_mode: full (default) retains all history.
|
||||||
|
# sqlite_mode: shallow keeps only the latest checkpoint per
|
||||||
|
# thread, deleting old records before each write to prevent
|
||||||
|
# unbounded database growth.
|
||||||
# postgres - PostgreSQL persistence. Suitable for multi-process deployments.
|
# postgres - PostgreSQL persistence. Suitable for multi-process deployments.
|
||||||
# Requires: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool
|
# Requires: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool
|
||||||
#
|
#
|
||||||
|
|
@ -722,6 +726,8 @@ memory:
|
||||||
checkpointer:
|
checkpointer:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
connection_string: checkpoints.db
|
connection_string: checkpoints.db
|
||||||
|
# sqlite_mode: full # default — keep all checkpoint history
|
||||||
|
# sqlite_mode: shallow # keep only latest checkpoint per thread (prevents DB bloat)
|
||||||
#
|
#
|
||||||
# PostgreSQL (multi-process, production):
|
# PostgreSQL (multi-process, production):
|
||||||
# checkpointer:
|
# checkpointer:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue