306 lines
13 KiB
Python
306 lines
13 KiB
Python
"""Tests for shallow SQLite checkpoint savers (single-checkpoint-per-thread mode).
|
|
|
|
Uses in-memory SQLite (``:memory:``) — no filesystem dependency.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AsyncShallowSqliteSaver tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAsyncShallowSqliteSaver:
|
|
"""Tests for ``AsyncShallowSqliteSaver`` — async shallow persistence."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_aput_deletes_old_checkpoints_before_insert(self):
|
|
"""After two aput calls for the same thread, only 1 checkpoint remains."""
|
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
|
|
|
ShallowSaver = _make_async_shallow_saver()
|
|
|
|
async with ShallowSaver.from_conn_string(":memory:") as saver:
|
|
await saver.setup()
|
|
|
|
thread_config = {"configurable": {"thread_id": "test-thread-1", "checkpoint_ns": ""}}
|
|
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-1", "channel_values": {"x": 1}}
|
|
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-2", "channel_values": {"x": 2}}
|
|
|
|
# Write first checkpoint
|
|
await saver.aput(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
|
|
|
|
# Write second checkpoint — should delete the first
|
|
await saver.aput(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
|
|
|
|
# Verify only 1 checkpoint remains
|
|
results = []
|
|
async for ckpt in saver.alist(thread_config):
|
|
results.append(ckpt)
|
|
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
|
|
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-2"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_different_threads_do_not_interfere(self):
|
|
"""Checkpoints for different thread_ids are independent."""
|
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
|
|
|
ShallowSaver = _make_async_shallow_saver()
|
|
|
|
async with ShallowSaver.from_conn_string(":memory:") as saver:
|
|
await saver.setup()
|
|
|
|
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
|
|
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
|
|
|
|
await saver.aput(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
|
await saver.aput(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
|
await saver.aput(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
|
|
|
|
# Thread A: only ckpt a2
|
|
a_results = [c async for c in saver.alist(t1)]
|
|
assert len(a_results) == 1
|
|
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
|
|
|
|
# Thread B: still has b1
|
|
b_results = [c async for c in saver.alist(t2)]
|
|
assert len(b_results) == 1
|
|
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_writes_table_also_cleaned(self):
|
|
"""aput_writes entries from old checkpoints are also deleted."""
|
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
|
|
|
ShallowSaver = _make_async_shallow_saver()
|
|
|
|
async with ShallowSaver.from_conn_string(":memory:") as saver:
|
|
await saver.setup()
|
|
|
|
thread_config = {"configurable": {"thread_id": "test-writes", "checkpoint_ns": ""}}
|
|
|
|
# Write checkpoint 1 with associated writes
|
|
ckpt1_config = await saver.aput(
|
|
thread_config,
|
|
{"ts": "Z", "id": "ckpt-w1", "channel_values": {}},
|
|
{"source": "input", "step": 1, "writes": {}},
|
|
{},
|
|
)
|
|
await saver.aput_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
|
|
|
|
# Write checkpoint 2 — should delete ckpt1 + its writes
|
|
ckpt2_config = await saver.aput(
|
|
thread_config,
|
|
{"ts": "Z", "id": "ckpt-w2", "channel_values": {}},
|
|
{"source": "loop", "step": 2, "writes": {}},
|
|
{},
|
|
)
|
|
await saver.aput_writes(ckpt2_config, [("messages", "world")], "task-2", "")
|
|
|
|
# Verify only 1 checkpoint remains
|
|
results = [c async for c in saver.alist(thread_config)]
|
|
assert len(results) == 1
|
|
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-w2"
|
|
|
|
# Verify only the latest writes exist (get_tuple returns writes)
|
|
latest = await saver.aget_tuple(ckpt2_config)
|
|
assert latest is not None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_full_mode_retains_all_checkpoints(self):
|
|
"""In full mode (default), all checkpoints are preserved."""
|
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
|
|
async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
|
|
await saver.setup()
|
|
|
|
thread_config = {"configurable": {"thread_id": "test-full", "checkpoint_ns": ""}}
|
|
|
|
await saver.aput(thread_config, {"ts": "Z", "id": "f1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
|
await saver.aput(thread_config, {"ts": "Z", "id": "f2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
|
|
|
|
results = [c async for c in saver.alist(thread_config)]
|
|
assert len(results) == 2, "Full mode should retain all checkpoints"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ShallowSqliteSaver (sync) tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestShallowSqliteSaver:
|
|
"""Tests for ``ShallowSqliteSaver`` — sync shallow persistence."""
|
|
|
|
def test_put_deletes_old_checkpoints_before_insert(self):
|
|
"""After two put calls for the same thread, only 1 checkpoint remains."""
|
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
|
|
|
ShallowSaver = _make_sync_shallow_saver()
|
|
|
|
with ShallowSaver.from_conn_string(":memory:") as saver:
|
|
saver.setup()
|
|
|
|
thread_config = {"configurable": {"thread_id": "test-sync-1", "checkpoint_ns": ""}}
|
|
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-s1", "channel_values": {"x": 1}}
|
|
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-s2", "channel_values": {"x": 2}}
|
|
|
|
saver.put(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
|
|
saver.put(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
|
|
|
|
results = list(saver.list(thread_config))
|
|
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
|
|
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-s2"
|
|
|
|
def test_different_threads_do_not_interfere_sync(self):
|
|
"""Checkpoints for different thread_ids are independent (sync)."""
|
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
|
|
|
ShallowSaver = _make_sync_shallow_saver()
|
|
|
|
with ShallowSaver.from_conn_string(":memory:") as saver:
|
|
saver.setup()
|
|
|
|
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
|
|
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
|
|
|
|
saver.put(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
|
saver.put(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
|
saver.put(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
|
|
|
|
a_results = list(saver.list(t1))
|
|
assert len(a_results) == 1
|
|
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
|
|
|
|
b_results = list(saver.list(t2))
|
|
assert len(b_results) == 1
|
|
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
|
|
|
|
def test_writes_table_also_cleaned_sync(self):
|
|
"""put_writes entries from old checkpoints are also deleted (sync)."""
|
|
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
|
|
|
ShallowSaver = _make_sync_shallow_saver()
|
|
|
|
with ShallowSaver.from_conn_string(":memory:") as saver:
|
|
saver.setup()
|
|
|
|
thread_config = {"configurable": {"thread_id": "test-sync-writes", "checkpoint_ns": ""}}
|
|
|
|
ckpt1_config = saver.put(
|
|
thread_config,
|
|
{"ts": "Z", "id": "ckpt-sw1", "channel_values": {}},
|
|
{"source": "input", "step": 1, "writes": {}},
|
|
{},
|
|
)
|
|
saver.put_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
|
|
|
|
ckpt2_config = saver.put(
|
|
thread_config,
|
|
{"ts": "Z", "id": "ckpt-sw2", "channel_values": {}},
|
|
{"source": "loop", "step": 2, "writes": {}},
|
|
{},
|
|
)
|
|
saver.put_writes(ckpt2_config, [("messages", "world")], "task-2", "")
|
|
|
|
results = list(saver.list(thread_config))
|
|
assert len(results) == 1
|
|
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-sw2"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config integration tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestShallowConfig:
|
|
"""Tests for configuration and factory integration."""
|
|
|
|
def test_sqlite_mode_defaults_to_full(self):
|
|
"""sqlite_mode defaults to 'full' when not specified."""
|
|
from deerflow.config.checkpointer_config import CheckpointerConfig
|
|
|
|
config = CheckpointerConfig(type="sqlite", connection_string="test.db")
|
|
assert config.sqlite_mode == "full"
|
|
|
|
def test_sqlite_mode_shallow_accepted(self):
|
|
"""sqlite_mode can be set to 'shallow'."""
|
|
from deerflow.config.checkpointer_config import CheckpointerConfig
|
|
|
|
config = CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="shallow")
|
|
assert config.sqlite_mode == "shallow"
|
|
|
|
def test_load_sqlite_config_with_shallow_mode(self):
|
|
"""load_checkpointer_config_from_dict accepts sqlite_mode."""
|
|
from deerflow.config.checkpointer_config import (
|
|
get_checkpointer_config,
|
|
load_checkpointer_config_from_dict,
|
|
set_checkpointer_config,
|
|
)
|
|
|
|
set_checkpointer_config(None)
|
|
load_checkpointer_config_from_dict({
|
|
"type": "sqlite",
|
|
"connection_string": "/tmp/test.db",
|
|
"sqlite_mode": "shallow",
|
|
})
|
|
config = get_checkpointer_config()
|
|
assert config is not None
|
|
assert config.sqlite_mode == "shallow"
|
|
|
|
def test_load_sqlite_config_defaults_sqlite_mode(self):
|
|
"""load_checkpointer_config_from_dict defaults sqlite_mode to 'full' when omitted."""
|
|
from deerflow.config.checkpointer_config import (
|
|
get_checkpointer_config,
|
|
load_checkpointer_config_from_dict,
|
|
set_checkpointer_config,
|
|
)
|
|
|
|
set_checkpointer_config(None)
|
|
load_checkpointer_config_from_dict({
|
|
"type": "sqlite",
|
|
"connection_string": "/tmp/test.db",
|
|
})
|
|
config = get_checkpointer_config()
|
|
assert config is not None
|
|
assert config.sqlite_mode == "full"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_async_factory_uses_shallow_saver(self):
|
|
"""When sqlite_mode=shallow, async factory returns AsyncShallowSqliteSaver."""
|
|
from deerflow.agents.checkpointer.async_provider import _async_checkpointer
|
|
from deerflow.config.checkpointer_config import CheckpointerConfig, set_checkpointer_config
|
|
|
|
set_checkpointer_config(CheckpointerConfig(
|
|
type="sqlite",
|
|
connection_string=":memory:",
|
|
sqlite_mode="shallow",
|
|
))
|
|
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
|
|
|
|
async with _async_checkpointer(config) as saver:
|
|
# Should be an instance of the shallow saver
|
|
cls_name = type(saver).__name__
|
|
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
|
|
|
|
set_checkpointer_config(None)
|
|
|
|
def test_sync_factory_uses_shallow_saver(self):
|
|
"""When sqlite_mode=shallow, sync factory returns ShallowSqliteSaver."""
|
|
from deerflow.agents.checkpointer.provider import _sync_checkpointer_cm
|
|
from deerflow.config.checkpointer_config import CheckpointerConfig
|
|
|
|
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
|
|
|
|
with _sync_checkpointer_cm(config) as saver:
|
|
cls_name = type(saver).__name__
|
|
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
|
|
|
|
def test_invalid_sqlite_mode_raises(self):
|
|
"""Invalid sqlite_mode value raises validation error."""
|
|
from deerflow.config.checkpointer_config import CheckpointerConfig
|
|
|
|
with pytest.raises(Exception):
|
|
CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="unknown")
|