feat: implement shallow SQLite checkpoint savers and update configuration for persistence modes

This commit is contained in:
Titan 2026-05-08 10:58:09 +08:00
parent f209057b18
commit f3558d6bb2
8 changed files with 490 additions and 1 deletions

View File

@ -7,3 +7,14 @@ __all__ = [
"checkpointer_context",
"make_checkpointer",
]
# Lazy-import shallow savers so the module is still importable without
# langgraph-checkpoint-sqlite installed.
def __getattr__(name: str):
if name == "AsyncShallowSqliteSaver":
from .shallow_sqlite import _make_async_shallow_saver
return _make_async_shallow_saver()
if name == "ShallowSqliteSaver":
from .shallow_sqlite import _make_sync_shallow_saver
return _make_sync_shallow_saver()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -55,6 +55,18 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
ensure_sqlite_parent_dir(conn_str)
# Shallow mode: use custom saver that keeps only the latest checkpoint per thread
if getattr(config, "sqlite_mode", "full") == "shallow":
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(conn_str) as saver:
await saver.setup()
logger.info("Checkpointer: using AsyncShallowSqliteSaver (%s)", conn_str)
yield saver
return
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver

View File

@ -67,6 +67,18 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
# Shallow mode: use custom saver that keeps only the latest checkpoint per thread
if getattr(config, "sqlite_mode", "full") == "shallow":
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using ShallowSqliteSaver (%s)", conn_str)
yield saver
return
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)

View File

@ -0,0 +1,128 @@
"""Shallow persistence savers for LangGraph SQLite checkpointing.
Provides shallow (single-checkpoint-per-thread) variants of the LangGraph
SQLite savers that automatically delete old checkpoints and writes for the
same thread before each write, keeping only the latest state.
This prevents unbounded growth of ``checkpoints.db`` while preserving
multi-turn conversation continuity.
Implements:
- ``AsyncShallowSqliteSaver`` async shallow variant
- ``ShallowSqliteSaver`` sync shallow variant
Usage is transparent through the existing checkpointer factory when
``sqlite_mode: shallow`` is set in ``config.yaml``.
"""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import ChannelVersions, Checkpoint, CheckpointMetadata
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Async shallow saver
# ---------------------------------------------------------------------------
class AsyncShallowSqliteSaver:
"""Async SQLite checkpointer that keeps only the latest checkpoint per thread.
Extends :class:`langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver` and
overrides :meth:`aput` to delete all existing checkpoints and writes for
the same ``thread_id`` before inserting the new one.
Each conversation thread stores exactly one checkpoint at any time,
preventing unbounded database growth.
"""
def __init_subclass__(cls, **kwargs: Any) -> None:
# Allow extension by subclasses without forcing late-bound import here.
# The concrete class is built below via the _make_async_shallow factory.
super().__init_subclass__(**kwargs)
def _make_async_shallow_saver() -> type:
"""Build and return the ``AsyncShallowSqliteSaver`` class.
Import is deferred so that the module is importable even when
``langgraph-checkpoint-sqlite`` is not installed.
"""
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
class AsyncShallowSqliteSaver(AsyncSqliteSaver):
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Delete old checkpoints/writes for this thread, then save the new one."""
thread_id = config["configurable"]["thread_id"]
await self.setup()
# Delete all existing checkpoints and writes for this thread
# before inserting the new checkpoint — keeps only the latest.
async with self.lock:
await self.conn.execute(
"DELETE FROM checkpoints WHERE thread_id = ?",
(str(thread_id),),
)
await self.conn.execute(
"DELETE FROM writes WHERE thread_id = ?",
(str(thread_id),),
)
await self.conn.commit()
return await super().aput(config, checkpoint, metadata, new_versions)
return AsyncShallowSqliteSaver
# ---------------------------------------------------------------------------
# Sync shallow saver
# ---------------------------------------------------------------------------
def _make_sync_shallow_saver() -> type:
"""Build and return the ``ShallowSqliteSaver`` class.
Import is deferred so that the module is importable even when
``langgraph-checkpoint-sqlite`` is not installed.
"""
from langgraph.checkpoint.sqlite import SqliteSaver
class ShallowSqliteSaver(SqliteSaver):
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Delete old checkpoints/writes for this thread, then save the new one."""
thread_id = config["configurable"]["thread_id"]
# Delete all existing checkpoints and writes for this thread
# before inserting the new checkpoint — keeps only the latest.
with self.cursor() as cur:
cur.execute(
"DELETE FROM checkpoints WHERE thread_id = ?",
(str(thread_id),),
)
cur.execute(
"DELETE FROM writes WHERE thread_id = ?",
(str(thread_id),),
)
return super().put(config, checkpoint, metadata, new_versions)
return ShallowSqliteSaver

View File

@ -6,6 +6,14 @@ from pydantic import BaseModel, Field
CheckpointerType = Literal["memory", "sqlite", "postgres"]
SqliteMode = Literal["full", "shallow"]
"""Persistence mode for the SQLite checkpointer.
- ``full`` retain all checkpoint history (default, backward-compatible).
- ``shallow`` keep only the latest checkpoint per thread, deleting old
records before each write to prevent unbounded database growth.
"""
class CheckpointerConfig(BaseModel):
"""Configuration for LangGraph state persistence checkpointer."""
@ -23,6 +31,13 @@ class CheckpointerConfig(BaseModel):
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
)
sqlite_mode: SqliteMode = Field(
default="full",
description="SQLite persistence mode. "
"'full' retains all checkpoint history (default). "
"'shallow' keeps only the latest checkpoint per thread, "
"deleting old records before each write.",
)
# Global configuration instance — None means no checkpointer is configured.

View File

@ -23,7 +23,7 @@ def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True):
default_estimated_output_tokens=None,
)
model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096})
model_cfg = SimpleNamespace(display_name="GPT-4", model="gpt-4", model_extra={"max_tokens": 4096})
return SimpleNamespace(
billing=billing,
get_model_config=lambda name: model_cfg if name == "gpt-4" else None,

View File

@ -0,0 +1,305 @@
"""Tests for shallow SQLite checkpoint savers (single-checkpoint-per-thread mode).
Uses in-memory SQLite (``:memory:``) no filesystem dependency.
"""
from __future__ import annotations
import pytest
# ---------------------------------------------------------------------------
# AsyncShallowSqliteSaver tests
# ---------------------------------------------------------------------------
class TestAsyncShallowSqliteSaver:
"""Tests for ``AsyncShallowSqliteSaver`` — async shallow persistence."""
@pytest.mark.anyio
async def test_aput_deletes_old_checkpoints_before_insert(self):
"""After two aput calls for the same thread, only 1 checkpoint remains."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(":memory:") as saver:
await saver.setup()
thread_config = {"configurable": {"thread_id": "test-thread-1", "checkpoint_ns": ""}}
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-1", "channel_values": {"x": 1}}
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-2", "channel_values": {"x": 2}}
# Write first checkpoint
await saver.aput(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
# Write second checkpoint — should delete the first
await saver.aput(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
# Verify only 1 checkpoint remains
results = []
async for ckpt in saver.alist(thread_config):
results.append(ckpt)
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-2"
@pytest.mark.anyio
async def test_different_threads_do_not_interfere(self):
"""Checkpoints for different thread_ids are independent."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(":memory:") as saver:
await saver.setup()
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
await saver.aput(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
await saver.aput(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
await saver.aput(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
# Thread A: only ckpt a2
a_results = [c async for c in saver.alist(t1)]
assert len(a_results) == 1
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
# Thread B: still has b1
b_results = [c async for c in saver.alist(t2)]
assert len(b_results) == 1
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
@pytest.mark.anyio
async def test_writes_table_also_cleaned(self):
"""aput_writes entries from old checkpoints are also deleted."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(":memory:") as saver:
await saver.setup()
thread_config = {"configurable": {"thread_id": "test-writes", "checkpoint_ns": ""}}
# Write checkpoint 1 with associated writes
ckpt1_config = await saver.aput(
thread_config,
{"ts": "Z", "id": "ckpt-w1", "channel_values": {}},
{"source": "input", "step": 1, "writes": {}},
{},
)
await saver.aput_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
# Write checkpoint 2 — should delete ckpt1 + its writes
ckpt2_config = await saver.aput(
thread_config,
{"ts": "Z", "id": "ckpt-w2", "channel_values": {}},
{"source": "loop", "step": 2, "writes": {}},
{},
)
await saver.aput_writes(ckpt2_config, [("messages", "world")], "task-2", "")
# Verify only 1 checkpoint remains
results = [c async for c in saver.alist(thread_config)]
assert len(results) == 1
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-w2"
# Verify only the latest writes exist (get_tuple returns writes)
latest = await saver.aget_tuple(ckpt2_config)
assert latest is not None
@pytest.mark.anyio
async def test_full_mode_retains_all_checkpoints(self):
"""In full mode (default), all checkpoints are preserved."""
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
await saver.setup()
thread_config = {"configurable": {"thread_id": "test-full", "checkpoint_ns": ""}}
await saver.aput(thread_config, {"ts": "Z", "id": "f1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
await saver.aput(thread_config, {"ts": "Z", "id": "f2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
results = [c async for c in saver.alist(thread_config)]
assert len(results) == 2, "Full mode should retain all checkpoints"
# ---------------------------------------------------------------------------
# ShallowSqliteSaver (sync) tests
# ---------------------------------------------------------------------------
class TestShallowSqliteSaver:
"""Tests for ``ShallowSqliteSaver`` — sync shallow persistence."""
def test_put_deletes_old_checkpoints_before_insert(self):
"""After two put calls for the same thread, only 1 checkpoint remains."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(":memory:") as saver:
saver.setup()
thread_config = {"configurable": {"thread_id": "test-sync-1", "checkpoint_ns": ""}}
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-s1", "channel_values": {"x": 1}}
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-s2", "channel_values": {"x": 2}}
saver.put(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
saver.put(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
results = list(saver.list(thread_config))
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-s2"
def test_different_threads_do_not_interfere_sync(self):
"""Checkpoints for different thread_ids are independent (sync)."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(":memory:") as saver:
saver.setup()
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
saver.put(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
saver.put(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
saver.put(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
a_results = list(saver.list(t1))
assert len(a_results) == 1
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
b_results = list(saver.list(t2))
assert len(b_results) == 1
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
def test_writes_table_also_cleaned_sync(self):
"""put_writes entries from old checkpoints are also deleted (sync)."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(":memory:") as saver:
saver.setup()
thread_config = {"configurable": {"thread_id": "test-sync-writes", "checkpoint_ns": ""}}
ckpt1_config = saver.put(
thread_config,
{"ts": "Z", "id": "ckpt-sw1", "channel_values": {}},
{"source": "input", "step": 1, "writes": {}},
{},
)
saver.put_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
ckpt2_config = saver.put(
thread_config,
{"ts": "Z", "id": "ckpt-sw2", "channel_values": {}},
{"source": "loop", "step": 2, "writes": {}},
{},
)
saver.put_writes(ckpt2_config, [("messages", "world")], "task-2", "")
results = list(saver.list(thread_config))
assert len(results) == 1
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-sw2"
# ---------------------------------------------------------------------------
# Config integration tests
# ---------------------------------------------------------------------------
class TestShallowConfig:
"""Tests for configuration and factory integration."""
def test_sqlite_mode_defaults_to_full(self):
"""sqlite_mode defaults to 'full' when not specified."""
from deerflow.config.checkpointer_config import CheckpointerConfig
config = CheckpointerConfig(type="sqlite", connection_string="test.db")
assert config.sqlite_mode == "full"
def test_sqlite_mode_shallow_accepted(self):
"""sqlite_mode can be set to 'shallow'."""
from deerflow.config.checkpointer_config import CheckpointerConfig
config = CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="shallow")
assert config.sqlite_mode == "shallow"
def test_load_sqlite_config_with_shallow_mode(self):
"""load_checkpointer_config_from_dict accepts sqlite_mode."""
from deerflow.config.checkpointer_config import (
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
set_checkpointer_config(None)
load_checkpointer_config_from_dict({
"type": "sqlite",
"connection_string": "/tmp/test.db",
"sqlite_mode": "shallow",
})
config = get_checkpointer_config()
assert config is not None
assert config.sqlite_mode == "shallow"
def test_load_sqlite_config_defaults_sqlite_mode(self):
"""load_checkpointer_config_from_dict defaults sqlite_mode to 'full' when omitted."""
from deerflow.config.checkpointer_config import (
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
set_checkpointer_config(None)
load_checkpointer_config_from_dict({
"type": "sqlite",
"connection_string": "/tmp/test.db",
})
config = get_checkpointer_config()
assert config is not None
assert config.sqlite_mode == "full"
@pytest.mark.anyio
async def test_async_factory_uses_shallow_saver(self):
"""When sqlite_mode=shallow, async factory returns AsyncShallowSqliteSaver."""
from deerflow.agents.checkpointer.async_provider import _async_checkpointer
from deerflow.config.checkpointer_config import CheckpointerConfig, set_checkpointer_config
set_checkpointer_config(CheckpointerConfig(
type="sqlite",
connection_string=":memory:",
sqlite_mode="shallow",
))
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
async with _async_checkpointer(config) as saver:
# Should be an instance of the shallow saver
cls_name = type(saver).__name__
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
set_checkpointer_config(None)
def test_sync_factory_uses_shallow_saver(self):
"""When sqlite_mode=shallow, sync factory returns ShallowSqliteSaver."""
from deerflow.agents.checkpointer.provider import _sync_checkpointer_cm
from deerflow.config.checkpointer_config import CheckpointerConfig
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
with _sync_checkpointer_cm(config) as saver:
cls_name = type(saver).__name__
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
def test_invalid_sqlite_mode_raises(self):
"""Invalid sqlite_mode value raises validation error."""
from deerflow.config.checkpointer_config import CheckpointerConfig
with pytest.raises(Exception):
CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="unknown")

View File

@ -709,6 +709,10 @@ memory:
# memory - In-process only. State is lost when the process exits. (default)
# sqlite - File-based SQLite persistence. Survives restarts.
# Requires: uv add langgraph-checkpoint-sqlite
# sqlite_mode: full (default) retains all history.
# sqlite_mode: shallow keeps only the latest checkpoint per
# thread, deleting old records before each write to prevent
# unbounded database growth.
# postgres - PostgreSQL persistence. Suitable for multi-process deployments.
# Requires: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool
#
@ -722,6 +726,8 @@ memory:
checkpointer:
type: sqlite
connection_string: checkpoints.db
# sqlite_mode: full # default — keep all checkpoint history
# sqlite_mode: shallow # keep only latest checkpoint per thread (prevents DB bloat)
#
# PostgreSQL (multi-process, production):
# checkpointer: