diff --git a/backend/packages/harness/deerflow/agents/checkpointer/__init__.py b/backend/packages/harness/deerflow/agents/checkpointer/__init__.py index 7bb0019a..1a2f8391 100644 --- a/backend/packages/harness/deerflow/agents/checkpointer/__init__.py +++ b/backend/packages/harness/deerflow/agents/checkpointer/__init__.py @@ -7,3 +7,14 @@ __all__ = [ "checkpointer_context", "make_checkpointer", ] + +# Lazy-import shallow savers so the module is still importable without +# langgraph-checkpoint-sqlite installed. +def __getattr__(name: str): + if name == "AsyncShallowSqliteSaver": + from .shallow_sqlite import _make_async_shallow_saver + return _make_async_shallow_saver() + if name == "ShallowSqliteSaver": + from .shallow_sqlite import _make_sync_shallow_saver + return _make_sync_shallow_saver() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/backend/packages/harness/deerflow/agents/checkpointer/async_provider.py b/backend/packages/harness/deerflow/agents/checkpointer/async_provider.py index 9380d781..88e65f75 100644 --- a/backend/packages/harness/deerflow/agents/checkpointer/async_provider.py +++ b/backend/packages/harness/deerflow/agents/checkpointer/async_provider.py @@ -55,6 +55,18 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]: conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") ensure_sqlite_parent_dir(conn_str) + + # Shallow mode: use custom saver that keeps only the latest checkpoint per thread + if getattr(config, "sqlite_mode", "full") == "shallow": + from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver + + ShallowSaver = _make_async_shallow_saver() + async with ShallowSaver.from_conn_string(conn_str) as saver: + await saver.setup() + logger.info("Checkpointer: using AsyncShallowSqliteSaver (%s)", conn_str) + yield saver + return + async with AsyncSqliteSaver.from_conn_string(conn_str) as saver: await saver.setup() yield saver diff --git a/backend/packages/harness/deerflow/agents/checkpointer/provider.py b/backend/packages/harness/deerflow/agents/checkpointer/provider.py index 6f09aac9..8f0ce58e 100644 --- a/backend/packages/harness/deerflow/agents/checkpointer/provider.py +++ b/backend/packages/harness/deerflow/agents/checkpointer/provider.py @@ -67,6 +67,18 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]: raise ImportError(SQLITE_INSTALL) from exc conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") + + # Shallow mode: use custom saver that keeps only the latest checkpoint per thread + if getattr(config, "sqlite_mode", "full") == "shallow": + from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver + + ShallowSaver = _make_sync_shallow_saver() + with ShallowSaver.from_conn_string(conn_str) as saver: + saver.setup() + logger.info("Checkpointer: using ShallowSqliteSaver (%s)", conn_str) + yield saver + return + with SqliteSaver.from_conn_string(conn_str) as saver: saver.setup() logger.info("Checkpointer: using SqliteSaver (%s)", conn_str) diff --git a/backend/packages/harness/deerflow/agents/checkpointer/shallow_sqlite.py b/backend/packages/harness/deerflow/agents/checkpointer/shallow_sqlite.py new file mode 100644 index 00000000..1c07eb4c --- /dev/null +++ b/backend/packages/harness/deerflow/agents/checkpointer/shallow_sqlite.py @@ -0,0 +1,128 @@ +"""Shallow persistence savers for LangGraph SQLite checkpointing. + +Provides shallow (single-checkpoint-per-thread) variants of the LangGraph +SQLite savers that automatically delete old checkpoints and writes for the +same thread before each write, keeping only the latest state. + +This prevents unbounded growth of ``checkpoints.db`` while preserving +multi-turn conversation continuity. + +Implements: +- ``AsyncShallowSqliteSaver`` — async shallow variant +- ``ShallowSqliteSaver`` — sync shallow variant + +Usage is transparent through the existing checkpointer factory when +``sqlite_mode: shallow`` is set in ``config.yaml``. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ChannelVersions, Checkpoint, CheckpointMetadata + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Async shallow saver +# --------------------------------------------------------------------------- + + +class AsyncShallowSqliteSaver: + """Async SQLite checkpointer that keeps only the latest checkpoint per thread. + + Extends :class:`langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver` and + overrides :meth:`aput` to delete all existing checkpoints and writes for + the same ``thread_id`` before inserting the new one. + + Each conversation thread stores exactly one checkpoint at any time, + preventing unbounded database growth. + """ + + def __init_subclass__(cls, **kwargs: Any) -> None: + # Allow extension by subclasses without forcing late-bound import here. + # The concrete class is built below via the _make_async_shallow factory. + super().__init_subclass__(**kwargs) + + +def _make_async_shallow_saver() -> type: + """Build and return the ``AsyncShallowSqliteSaver`` class. + + Import is deferred so that the module is importable even when + ``langgraph-checkpoint-sqlite`` is not installed. + """ + from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver + + class AsyncShallowSqliteSaver(AsyncSqliteSaver): + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Delete old checkpoints/writes for this thread, then save the new one.""" + + thread_id = config["configurable"]["thread_id"] + await self.setup() + + # Delete all existing checkpoints and writes for this thread + # before inserting the new checkpoint — keeps only the latest. + async with self.lock: + await self.conn.execute( + "DELETE FROM checkpoints WHERE thread_id = ?", + (str(thread_id),), + ) + await self.conn.execute( + "DELETE FROM writes WHERE thread_id = ?", + (str(thread_id),), + ) + await self.conn.commit() + + return await super().aput(config, checkpoint, metadata, new_versions) + + return AsyncShallowSqliteSaver + + +# --------------------------------------------------------------------------- +# Sync shallow saver +# --------------------------------------------------------------------------- + + +def _make_sync_shallow_saver() -> type: + """Build and return the ``ShallowSqliteSaver`` class. + + Import is deferred so that the module is importable even when + ``langgraph-checkpoint-sqlite`` is not installed. + """ + from langgraph.checkpoint.sqlite import SqliteSaver + + class ShallowSqliteSaver(SqliteSaver): + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Delete old checkpoints/writes for this thread, then save the new one.""" + + thread_id = config["configurable"]["thread_id"] + + # Delete all existing checkpoints and writes for this thread + # before inserting the new checkpoint — keeps only the latest. + with self.cursor() as cur: + cur.execute( + "DELETE FROM checkpoints WHERE thread_id = ?", + (str(thread_id),), + ) + cur.execute( + "DELETE FROM writes WHERE thread_id = ?", + (str(thread_id),), + ) + + return super().put(config, checkpoint, metadata, new_versions) + + return ShallowSqliteSaver diff --git a/backend/packages/harness/deerflow/config/checkpointer_config.py b/backend/packages/harness/deerflow/config/checkpointer_config.py index 6947cefb..d61b477e 100644 --- a/backend/packages/harness/deerflow/config/checkpointer_config.py +++ b/backend/packages/harness/deerflow/config/checkpointer_config.py @@ -6,6 +6,14 @@ from pydantic import BaseModel, Field CheckpointerType = Literal["memory", "sqlite", "postgres"] +SqliteMode = Literal["full", "shallow"] +"""Persistence mode for the SQLite checkpointer. + +- ``full`` — retain all checkpoint history (default, backward-compatible). +- ``shallow`` — keep only the latest checkpoint per thread, deleting old + records before each write to prevent unbounded database growth. +""" + class CheckpointerConfig(BaseModel): """Configuration for LangGraph state persistence checkpointer.""" @@ -23,6 +31,13 @@ class CheckpointerConfig(BaseModel): "For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. " "For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.", ) + sqlite_mode: SqliteMode = Field( + default="full", + description="SQLite persistence mode. " + "'full' retains all checkpoint history (default). " + "'shallow' keeps only the latest checkpoint per thread, " + "deleting old records before each write.", + ) # Global configuration instance — None means no checkpointer is configured. diff --git a/backend/tests/test_billing_middleware.py b/backend/tests/test_billing_middleware.py index 553ea2b6..9b644556 100644 --- a/backend/tests/test_billing_middleware.py +++ b/backend/tests/test_billing_middleware.py @@ -23,7 +23,7 @@ def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True): default_estimated_output_tokens=None, ) - model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096}) + model_cfg = SimpleNamespace(display_name="GPT-4", model="gpt-4", model_extra={"max_tokens": 4096}) return SimpleNamespace( billing=billing, get_model_config=lambda name: model_cfg if name == "gpt-4" else None, diff --git a/backend/tests/test_checkpointer_shallow.py b/backend/tests/test_checkpointer_shallow.py new file mode 100644 index 00000000..b32940ce --- /dev/null +++ b/backend/tests/test_checkpointer_shallow.py @@ -0,0 +1,305 @@ +"""Tests for shallow SQLite checkpoint savers (single-checkpoint-per-thread mode). + +Uses in-memory SQLite (``:memory:``) — no filesystem dependency. +""" + +from __future__ import annotations + +import pytest + +# --------------------------------------------------------------------------- +# AsyncShallowSqliteSaver tests +# --------------------------------------------------------------------------- + + +class TestAsyncShallowSqliteSaver: + """Tests for ``AsyncShallowSqliteSaver`` — async shallow persistence.""" + + @pytest.mark.anyio + async def test_aput_deletes_old_checkpoints_before_insert(self): + """After two aput calls for the same thread, only 1 checkpoint remains.""" + from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver + + ShallowSaver = _make_async_shallow_saver() + + async with ShallowSaver.from_conn_string(":memory:") as saver: + await saver.setup() + + thread_config = {"configurable": {"thread_id": "test-thread-1", "checkpoint_ns": ""}} + checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-1", "channel_values": {"x": 1}} + checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-2", "channel_values": {"x": 2}} + + # Write first checkpoint + await saver.aput(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {}) + + # Write second checkpoint — should delete the first + await saver.aput(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {}) + + # Verify only 1 checkpoint remains + results = [] + async for ckpt in saver.alist(thread_config): + results.append(ckpt) + assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}" + assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-2" + + @pytest.mark.anyio + async def test_different_threads_do_not_interfere(self): + """Checkpoints for different thread_ids are independent.""" + from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver + + ShallowSaver = _make_async_shallow_saver() + + async with ShallowSaver.from_conn_string(":memory:") as saver: + await saver.setup() + + t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}} + t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}} + + await saver.aput(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {}) + await saver.aput(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {}) + await saver.aput(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {}) + + # Thread A: only ckpt a2 + a_results = [c async for c in saver.alist(t1)] + assert len(a_results) == 1 + assert a_results[0].config["configurable"]["checkpoint_id"] == "a2" + + # Thread B: still has b1 + b_results = [c async for c in saver.alist(t2)] + assert len(b_results) == 1 + assert b_results[0].config["configurable"]["checkpoint_id"] == "b1" + + @pytest.mark.anyio + async def test_writes_table_also_cleaned(self): + """aput_writes entries from old checkpoints are also deleted.""" + from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver + + ShallowSaver = _make_async_shallow_saver() + + async with ShallowSaver.from_conn_string(":memory:") as saver: + await saver.setup() + + thread_config = {"configurable": {"thread_id": "test-writes", "checkpoint_ns": ""}} + + # Write checkpoint 1 with associated writes + ckpt1_config = await saver.aput( + thread_config, + {"ts": "Z", "id": "ckpt-w1", "channel_values": {}}, + {"source": "input", "step": 1, "writes": {}}, + {}, + ) + await saver.aput_writes(ckpt1_config, [("messages", "hello")], "task-1", "") + + # Write checkpoint 2 — should delete ckpt1 + its writes + ckpt2_config = await saver.aput( + thread_config, + {"ts": "Z", "id": "ckpt-w2", "channel_values": {}}, + {"source": "loop", "step": 2, "writes": {}}, + {}, + ) + await saver.aput_writes(ckpt2_config, [("messages", "world")], "task-2", "") + + # Verify only 1 checkpoint remains + results = [c async for c in saver.alist(thread_config)] + assert len(results) == 1 + assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-w2" + + # Verify only the latest writes exist (get_tuple returns writes) + latest = await saver.aget_tuple(ckpt2_config) + assert latest is not None + + @pytest.mark.anyio + async def test_full_mode_retains_all_checkpoints(self): + """In full mode (default), all checkpoints are preserved.""" + from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver + + async with AsyncSqliteSaver.from_conn_string(":memory:") as saver: + await saver.setup() + + thread_config = {"configurable": {"thread_id": "test-full", "checkpoint_ns": ""}} + + await saver.aput(thread_config, {"ts": "Z", "id": "f1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {}) + await saver.aput(thread_config, {"ts": "Z", "id": "f2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {}) + + results = [c async for c in saver.alist(thread_config)] + assert len(results) == 2, "Full mode should retain all checkpoints" + + +# --------------------------------------------------------------------------- +# ShallowSqliteSaver (sync) tests +# --------------------------------------------------------------------------- + + +class TestShallowSqliteSaver: + """Tests for ``ShallowSqliteSaver`` — sync shallow persistence.""" + + def test_put_deletes_old_checkpoints_before_insert(self): + """After two put calls for the same thread, only 1 checkpoint remains.""" + from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver + + ShallowSaver = _make_sync_shallow_saver() + + with ShallowSaver.from_conn_string(":memory:") as saver: + saver.setup() + + thread_config = {"configurable": {"thread_id": "test-sync-1", "checkpoint_ns": ""}} + checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-s1", "channel_values": {"x": 1}} + checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-s2", "channel_values": {"x": 2}} + + saver.put(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {}) + saver.put(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {}) + + results = list(saver.list(thread_config)) + assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}" + assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-s2" + + def test_different_threads_do_not_interfere_sync(self): + """Checkpoints for different thread_ids are independent (sync).""" + from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver + + ShallowSaver = _make_sync_shallow_saver() + + with ShallowSaver.from_conn_string(":memory:") as saver: + saver.setup() + + t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}} + t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}} + + saver.put(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {}) + saver.put(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {}) + saver.put(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {}) + + a_results = list(saver.list(t1)) + assert len(a_results) == 1 + assert a_results[0].config["configurable"]["checkpoint_id"] == "a2" + + b_results = list(saver.list(t2)) + assert len(b_results) == 1 + assert b_results[0].config["configurable"]["checkpoint_id"] == "b1" + + def test_writes_table_also_cleaned_sync(self): + """put_writes entries from old checkpoints are also deleted (sync).""" + from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver + + ShallowSaver = _make_sync_shallow_saver() + + with ShallowSaver.from_conn_string(":memory:") as saver: + saver.setup() + + thread_config = {"configurable": {"thread_id": "test-sync-writes", "checkpoint_ns": ""}} + + ckpt1_config = saver.put( + thread_config, + {"ts": "Z", "id": "ckpt-sw1", "channel_values": {}}, + {"source": "input", "step": 1, "writes": {}}, + {}, + ) + saver.put_writes(ckpt1_config, [("messages", "hello")], "task-1", "") + + ckpt2_config = saver.put( + thread_config, + {"ts": "Z", "id": "ckpt-sw2", "channel_values": {}}, + {"source": "loop", "step": 2, "writes": {}}, + {}, + ) + saver.put_writes(ckpt2_config, [("messages", "world")], "task-2", "") + + results = list(saver.list(thread_config)) + assert len(results) == 1 + assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-sw2" + + +# --------------------------------------------------------------------------- +# Config integration tests +# --------------------------------------------------------------------------- + + +class TestShallowConfig: + """Tests for configuration and factory integration.""" + + def test_sqlite_mode_defaults_to_full(self): + """sqlite_mode defaults to 'full' when not specified.""" + from deerflow.config.checkpointer_config import CheckpointerConfig + + config = CheckpointerConfig(type="sqlite", connection_string="test.db") + assert config.sqlite_mode == "full" + + def test_sqlite_mode_shallow_accepted(self): + """sqlite_mode can be set to 'shallow'.""" + from deerflow.config.checkpointer_config import CheckpointerConfig + + config = CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="shallow") + assert config.sqlite_mode == "shallow" + + def test_load_sqlite_config_with_shallow_mode(self): + """load_checkpointer_config_from_dict accepts sqlite_mode.""" + from deerflow.config.checkpointer_config import ( + get_checkpointer_config, + load_checkpointer_config_from_dict, + set_checkpointer_config, + ) + + set_checkpointer_config(None) + load_checkpointer_config_from_dict({ + "type": "sqlite", + "connection_string": "/tmp/test.db", + "sqlite_mode": "shallow", + }) + config = get_checkpointer_config() + assert config is not None + assert config.sqlite_mode == "shallow" + + def test_load_sqlite_config_defaults_sqlite_mode(self): + """load_checkpointer_config_from_dict defaults sqlite_mode to 'full' when omitted.""" + from deerflow.config.checkpointer_config import ( + get_checkpointer_config, + load_checkpointer_config_from_dict, + set_checkpointer_config, + ) + + set_checkpointer_config(None) + load_checkpointer_config_from_dict({ + "type": "sqlite", + "connection_string": "/tmp/test.db", + }) + config = get_checkpointer_config() + assert config is not None + assert config.sqlite_mode == "full" + + @pytest.mark.anyio + async def test_async_factory_uses_shallow_saver(self): + """When sqlite_mode=shallow, async factory returns AsyncShallowSqliteSaver.""" + from deerflow.agents.checkpointer.async_provider import _async_checkpointer + from deerflow.config.checkpointer_config import CheckpointerConfig, set_checkpointer_config + + set_checkpointer_config(CheckpointerConfig( + type="sqlite", + connection_string=":memory:", + sqlite_mode="shallow", + )) + config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow") + + async with _async_checkpointer(config) as saver: + # Should be an instance of the shallow saver + cls_name = type(saver).__name__ + assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}" + + set_checkpointer_config(None) + + def test_sync_factory_uses_shallow_saver(self): + """When sqlite_mode=shallow, sync factory returns ShallowSqliteSaver.""" + from deerflow.agents.checkpointer.provider import _sync_checkpointer_cm + from deerflow.config.checkpointer_config import CheckpointerConfig + + config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow") + + with _sync_checkpointer_cm(config) as saver: + cls_name = type(saver).__name__ + assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}" + + def test_invalid_sqlite_mode_raises(self): + """Invalid sqlite_mode value raises validation error.""" + from deerflow.config.checkpointer_config import CheckpointerConfig + + with pytest.raises(Exception): + CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="unknown") diff --git a/config.example.yaml b/config.example.yaml index 3512962f..0f8fc410 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -709,6 +709,10 @@ memory: # memory - In-process only. State is lost when the process exits. (default) # sqlite - File-based SQLite persistence. Survives restarts. # Requires: uv add langgraph-checkpoint-sqlite +# sqlite_mode: full (default) retains all history. +# sqlite_mode: shallow keeps only the latest checkpoint per +# thread, deleting old records before each write to prevent +# unbounded database growth. # postgres - PostgreSQL persistence. Suitable for multi-process deployments. # Requires: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool # @@ -722,6 +726,8 @@ memory: checkpointer: type: sqlite connection_string: checkpoints.db + # sqlite_mode: full # default — keep all checkpoint history + # sqlite_mode: shallow # keep only latest checkpoint per thread (prevents DB bloat) # # PostgreSQL (multi-process, production): # checkpointer: