129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
"""Shallow persistence savers for LangGraph SQLite checkpointing.
|
|
|
|
Provides shallow (single-checkpoint-per-thread) variants of the LangGraph
|
|
SQLite savers that automatically delete old checkpoints and writes for the
|
|
same thread before each write, keeping only the latest state.
|
|
|
|
This prevents unbounded growth of ``checkpoints.db`` while preserving
|
|
multi-turn conversation continuity.
|
|
|
|
Implements:
|
|
- ``AsyncShallowSqliteSaver`` — async shallow variant
|
|
- ``ShallowSqliteSaver`` — sync shallow variant
|
|
|
|
Usage is transparent through the existing checkpointer factory when
|
|
``sqlite_mode: shallow`` is set in ``config.yaml``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from langchain_core.runnables import RunnableConfig
|
|
from langgraph.checkpoint.base import ChannelVersions, Checkpoint, CheckpointMetadata
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Async shallow saver
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class AsyncShallowSqliteSaver:
|
|
"""Async SQLite checkpointer that keeps only the latest checkpoint per thread.
|
|
|
|
Extends :class:`langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver` and
|
|
overrides :meth:`aput` to delete all existing checkpoints and writes for
|
|
the same ``thread_id`` before inserting the new one.
|
|
|
|
Each conversation thread stores exactly one checkpoint at any time,
|
|
preventing unbounded database growth.
|
|
"""
|
|
|
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
# Allow extension by subclasses without forcing late-bound import here.
|
|
# The concrete class is built below via the _make_async_shallow factory.
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
|
|
def _make_async_shallow_saver() -> type:
|
|
"""Build and return the ``AsyncShallowSqliteSaver`` class.
|
|
|
|
Import is deferred so that the module is importable even when
|
|
``langgraph-checkpoint-sqlite`` is not installed.
|
|
"""
|
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
|
|
class AsyncShallowSqliteSaver(AsyncSqliteSaver):
|
|
async def aput(
|
|
self,
|
|
config: RunnableConfig,
|
|
checkpoint: Checkpoint,
|
|
metadata: CheckpointMetadata,
|
|
new_versions: ChannelVersions,
|
|
) -> RunnableConfig:
|
|
"""Delete old checkpoints/writes for this thread, then save the new one."""
|
|
|
|
thread_id = config["configurable"]["thread_id"]
|
|
await self.setup()
|
|
|
|
# Delete all existing checkpoints and writes for this thread
|
|
# before inserting the new checkpoint — keeps only the latest.
|
|
async with self.lock:
|
|
await self.conn.execute(
|
|
"DELETE FROM checkpoints WHERE thread_id = ?",
|
|
(str(thread_id),),
|
|
)
|
|
await self.conn.execute(
|
|
"DELETE FROM writes WHERE thread_id = ?",
|
|
(str(thread_id),),
|
|
)
|
|
await self.conn.commit()
|
|
|
|
return await super().aput(config, checkpoint, metadata, new_versions)
|
|
|
|
return AsyncShallowSqliteSaver
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sync shallow saver
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_sync_shallow_saver() -> type:
|
|
"""Build and return the ``ShallowSqliteSaver`` class.
|
|
|
|
Import is deferred so that the module is importable even when
|
|
``langgraph-checkpoint-sqlite`` is not installed.
|
|
"""
|
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
|
|
class ShallowSqliteSaver(SqliteSaver):
|
|
def put(
|
|
self,
|
|
config: RunnableConfig,
|
|
checkpoint: Checkpoint,
|
|
metadata: CheckpointMetadata,
|
|
new_versions: ChannelVersions,
|
|
) -> RunnableConfig:
|
|
"""Delete old checkpoints/writes for this thread, then save the new one."""
|
|
|
|
thread_id = config["configurable"]["thread_id"]
|
|
|
|
# Delete all existing checkpoints and writes for this thread
|
|
# before inserting the new checkpoint — keeps only the latest.
|
|
with self.cursor() as cur:
|
|
cur.execute(
|
|
"DELETE FROM checkpoints WHERE thread_id = ?",
|
|
(str(thread_id),),
|
|
)
|
|
cur.execute(
|
|
"DELETE FROM writes WHERE thread_id = ?",
|
|
(str(thread_id),),
|
|
)
|
|
|
|
return super().put(config, checkpoint, metadata, new_versions)
|
|
|
|
return ShallowSqliteSaver
|