deerflow2/backend/tests/test_checkpointer_shallow.py

306 lines
13 KiB
Python

"""Tests for shallow SQLite checkpoint savers (single-checkpoint-per-thread mode).
Uses in-memory SQLite (``:memory:``) — no filesystem dependency.
"""
from __future__ import annotations
import pytest
# ---------------------------------------------------------------------------
# AsyncShallowSqliteSaver tests
# ---------------------------------------------------------------------------
class TestAsyncShallowSqliteSaver:
"""Tests for ``AsyncShallowSqliteSaver`` — async shallow persistence."""
@pytest.mark.anyio
async def test_aput_deletes_old_checkpoints_before_insert(self):
"""After two aput calls for the same thread, only 1 checkpoint remains."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(":memory:") as saver:
await saver.setup()
thread_config = {"configurable": {"thread_id": "test-thread-1", "checkpoint_ns": ""}}
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-1", "channel_values": {"x": 1}}
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-2", "channel_values": {"x": 2}}
# Write first checkpoint
await saver.aput(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
# Write second checkpoint — should delete the first
await saver.aput(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
# Verify only 1 checkpoint remains
results = []
async for ckpt in saver.alist(thread_config):
results.append(ckpt)
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-2"
@pytest.mark.anyio
async def test_different_threads_do_not_interfere(self):
"""Checkpoints for different thread_ids are independent."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(":memory:") as saver:
await saver.setup()
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
await saver.aput(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
await saver.aput(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
await saver.aput(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
# Thread A: only ckpt a2
a_results = [c async for c in saver.alist(t1)]
assert len(a_results) == 1
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
# Thread B: still has b1
b_results = [c async for c in saver.alist(t2)]
assert len(b_results) == 1
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
@pytest.mark.anyio
async def test_writes_table_also_cleaned(self):
"""aput_writes entries from old checkpoints are also deleted."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(":memory:") as saver:
await saver.setup()
thread_config = {"configurable": {"thread_id": "test-writes", "checkpoint_ns": ""}}
# Write checkpoint 1 with associated writes
ckpt1_config = await saver.aput(
thread_config,
{"ts": "Z", "id": "ckpt-w1", "channel_values": {}},
{"source": "input", "step": 1, "writes": {}},
{},
)
await saver.aput_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
# Write checkpoint 2 — should delete ckpt1 + its writes
ckpt2_config = await saver.aput(
thread_config,
{"ts": "Z", "id": "ckpt-w2", "channel_values": {}},
{"source": "loop", "step": 2, "writes": {}},
{},
)
await saver.aput_writes(ckpt2_config, [("messages", "world")], "task-2", "")
# Verify only 1 checkpoint remains
results = [c async for c in saver.alist(thread_config)]
assert len(results) == 1
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-w2"
# Verify only the latest writes exist (get_tuple returns writes)
latest = await saver.aget_tuple(ckpt2_config)
assert latest is not None
@pytest.mark.anyio
async def test_full_mode_retains_all_checkpoints(self):
"""In full mode (default), all checkpoints are preserved."""
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
await saver.setup()
thread_config = {"configurable": {"thread_id": "test-full", "checkpoint_ns": ""}}
await saver.aput(thread_config, {"ts": "Z", "id": "f1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
await saver.aput(thread_config, {"ts": "Z", "id": "f2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
results = [c async for c in saver.alist(thread_config)]
assert len(results) == 2, "Full mode should retain all checkpoints"
# ---------------------------------------------------------------------------
# ShallowSqliteSaver (sync) tests
# ---------------------------------------------------------------------------
class TestShallowSqliteSaver:
"""Tests for ``ShallowSqliteSaver`` — sync shallow persistence."""
def test_put_deletes_old_checkpoints_before_insert(self):
"""After two put calls for the same thread, only 1 checkpoint remains."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(":memory:") as saver:
saver.setup()
thread_config = {"configurable": {"thread_id": "test-sync-1", "checkpoint_ns": ""}}
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-s1", "channel_values": {"x": 1}}
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-s2", "channel_values": {"x": 2}}
saver.put(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
saver.put(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
results = list(saver.list(thread_config))
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-s2"
def test_different_threads_do_not_interfere_sync(self):
"""Checkpoints for different thread_ids are independent (sync)."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(":memory:") as saver:
saver.setup()
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
saver.put(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
saver.put(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
saver.put(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
a_results = list(saver.list(t1))
assert len(a_results) == 1
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
b_results = list(saver.list(t2))
assert len(b_results) == 1
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
def test_writes_table_also_cleaned_sync(self):
"""put_writes entries from old checkpoints are also deleted (sync)."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(":memory:") as saver:
saver.setup()
thread_config = {"configurable": {"thread_id": "test-sync-writes", "checkpoint_ns": ""}}
ckpt1_config = saver.put(
thread_config,
{"ts": "Z", "id": "ckpt-sw1", "channel_values": {}},
{"source": "input", "step": 1, "writes": {}},
{},
)
saver.put_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
ckpt2_config = saver.put(
thread_config,
{"ts": "Z", "id": "ckpt-sw2", "channel_values": {}},
{"source": "loop", "step": 2, "writes": {}},
{},
)
saver.put_writes(ckpt2_config, [("messages", "world")], "task-2", "")
results = list(saver.list(thread_config))
assert len(results) == 1
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-sw2"
# ---------------------------------------------------------------------------
# Config integration tests
# ---------------------------------------------------------------------------
class TestShallowConfig:
"""Tests for configuration and factory integration."""
def test_sqlite_mode_defaults_to_full(self):
"""sqlite_mode defaults to 'full' when not specified."""
from deerflow.config.checkpointer_config import CheckpointerConfig
config = CheckpointerConfig(type="sqlite", connection_string="test.db")
assert config.sqlite_mode == "full"
def test_sqlite_mode_shallow_accepted(self):
"""sqlite_mode can be set to 'shallow'."""
from deerflow.config.checkpointer_config import CheckpointerConfig
config = CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="shallow")
assert config.sqlite_mode == "shallow"
def test_load_sqlite_config_with_shallow_mode(self):
"""load_checkpointer_config_from_dict accepts sqlite_mode."""
from deerflow.config.checkpointer_config import (
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
set_checkpointer_config(None)
load_checkpointer_config_from_dict({
"type": "sqlite",
"connection_string": "/tmp/test.db",
"sqlite_mode": "shallow",
})
config = get_checkpointer_config()
assert config is not None
assert config.sqlite_mode == "shallow"
def test_load_sqlite_config_defaults_sqlite_mode(self):
"""load_checkpointer_config_from_dict defaults sqlite_mode to 'full' when omitted."""
from deerflow.config.checkpointer_config import (
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
set_checkpointer_config(None)
load_checkpointer_config_from_dict({
"type": "sqlite",
"connection_string": "/tmp/test.db",
})
config = get_checkpointer_config()
assert config is not None
assert config.sqlite_mode == "full"
@pytest.mark.anyio
async def test_async_factory_uses_shallow_saver(self):
"""When sqlite_mode=shallow, async factory returns AsyncShallowSqliteSaver."""
from deerflow.agents.checkpointer.async_provider import _async_checkpointer
from deerflow.config.checkpointer_config import CheckpointerConfig, set_checkpointer_config
set_checkpointer_config(CheckpointerConfig(
type="sqlite",
connection_string=":memory:",
sqlite_mode="shallow",
))
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
async with _async_checkpointer(config) as saver:
# Should be an instance of the shallow saver
cls_name = type(saver).__name__
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
set_checkpointer_config(None)
def test_sync_factory_uses_shallow_saver(self):
"""When sqlite_mode=shallow, sync factory returns ShallowSqliteSaver."""
from deerflow.agents.checkpointer.provider import _sync_checkpointer_cm
from deerflow.config.checkpointer_config import CheckpointerConfig
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
with _sync_checkpointer_cm(config) as saver:
cls_name = type(saver).__name__
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
def test_invalid_sqlite_mode_raises(self):
"""Invalid sqlite_mode value raises validation error."""
from deerflow.config.checkpointer_config import CheckpointerConfig
with pytest.raises(Exception):
CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="unknown")