diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py index dac12e78..d7a82dba 100644 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py +++ b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py @@ -5,6 +5,7 @@ import sys from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager +from datetime import datetime from importlib.metadata import version from typing import Any, Optional @@ -81,6 +82,7 @@ def __init__( db_name: str = "checkpointing_db", checkpoint_collection_name: str = "checkpoints_aio", writes_collection_name: str = "checkpoint_writes_aio", + ttl: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__() @@ -90,6 +92,7 @@ def __init__( self.writes_collection = self.db[writes_collection_name] self._setup_future: asyncio.Future | None = None self.loop = asyncio.get_running_loop() + self.ttl = ttl async def _setup(self) -> None: """Create indexes if not present.""" @@ -107,6 +110,11 @@ async def _setup(self) -> None: keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)], unique=True, ) + if self.ttl: + await self.checkpoint_collection.create_index( + keys=[("created_at", 1)], + expireAfterSeconds=self.ttl, + ) if isinstance(self.client, AsyncMongoClient): num_indexes = len( await (await self.writes_collection.list_indexes()).to_list() # type:ignore[misc] @@ -124,6 +132,11 @@ async def _setup(self) -> None: ], unique=True, ) + if self.ttl: + await self.writes_collection.create_index( + keys=[("created_at", 1)], + expireAfterSeconds=self.ttl, + ) self._setup_future.set_result(None) @classmethod @@ -134,6 +147,7 @@ async def from_conn_string( db_name: str = "checkpointing_db", checkpoint_collection_name: str = "checkpoints_aio", writes_collection_name: str = "checkpoint_writes_aio", + ttl: Optional[int] = None, **kwargs: Any, ) -> AsyncIterator[AsyncMongoDBSaver]: """Create asynchronous checkpointer @@ -153,6 +167,7 @@ async def from_conn_string( db_name, checkpoint_collection_name, writes_collection_name, + ttl, **kwargs, ) await saver._setup() @@ -343,6 +358,8 @@ async def aput( "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } + if self.ttl: + doc["created_at"] = datetime.now() # Perform your operations here await self.checkpoint_collection.update_one( upsert_query, {"$set": doc}, upsert=True @@ -389,6 +406,8 @@ async def aput_writes( "task_path": task_path, "idx": WRITES_IDX_MAP.get(channel, idx), } + if self.ttl: + upsert_query["created_at"] = datetime.now() type_, serialized_value = self.serde.dumps_typed(value) operations.append( UpdateOne( diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py index 0a94702d..b4dea225 100644 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py +++ b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py @@ -1,5 +1,6 @@ from collections.abc import Iterator, Sequence from contextlib import contextmanager +from datetime import datetime from importlib.metadata import version from typing import ( Any, @@ -7,7 +8,7 @@ ) from langchain_core.runnables import RunnableConfig -from pymongo import MongoClient, UpdateOne +from pymongo import ASCENDING, MongoClient, UpdateOne from pymongo.database import Database as MongoDatabase from pymongo.driver_info import DriverInfo @@ -39,6 +40,7 @@ class MongoDBSaver(BaseCheckpointSaver): db_name (Optional[str]): Database name checkpoint_collection_name (Optional[str]): Name of Collection of Checkpoints writes_collection_name (Optional[str]): Name of Collection of intermediate writes. + ttl (Optional[int]): Time to live in seconds. See https://www.mongodb.com/docs/manual/core/index-ttl/. Examples: @@ -69,6 +71,7 @@ def __init__( db_name: str = "checkpointing_db", checkpoint_collection_name: str = "checkpoints", writes_collection_name: str = "checkpoint_writes", + ttl: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__() @@ -76,6 +79,7 @@ def __init__( self.db = self.client[db_name] self.checkpoint_collection = self.db[checkpoint_collection_name] self.writes_collection = self.db[writes_collection_name] + self.ttl = ttl # Create indexes if not present if len(self.checkpoint_collection.list_indexes().to_list()) < 2: @@ -83,6 +87,12 @@ def __init__( keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)], unique=True, ) + if self.ttl: + self.checkpoint_collection.create_index( + keys=[("created_at", ASCENDING)], + expireAfterSeconds=self.ttl, + ) + if len(self.writes_collection.list_indexes().to_list()) < 2: self.writes_collection.create_index( keys=[ @@ -94,6 +104,11 @@ def __init__( ], unique=True, ) + if self.ttl: + self.writes_collection.create_index( + keys=[("created_at", ASCENDING)], + expireAfterSeconds=self.ttl, + ) @classmethod @contextmanager @@ -103,6 +118,7 @@ def from_conn_string( db_name: str = "checkpointing_db", checkpoint_collection_name: str = "checkpoints", writes_collection_name: str = "checkpoint_writes", + ttl: Optional[int] = None, **kwargs: Any, ) -> Iterator["MongoDBSaver"]: """Context manager to create a MongoDB checkpoint saver. @@ -119,6 +135,7 @@ def from_conn_string( db_name: Database name. It will be created if it doesn't exist. checkpoint_collection_name: Checkpoint Collection name. Created if it doesn't exist. writes_collection_name: Collection name of intermediate writes. Created if it doesn't exist. + ttl (Optional[int]): Time to live in seconds. Yields: A new MongoDBSaver. """ client: Optional[MongoClient] = None @@ -134,6 +151,7 @@ def from_conn_string( db_name, checkpoint_collection_name, writes_collection_name, + ttl, **kwargs, ) finally: @@ -361,6 +379,9 @@ def put( "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } + if self.ttl: + upsert_query["created_at"] = datetime.now() + self.checkpoint_collection.update_one(upsert_query, {"$set": doc}, upsert=True) return { "configurable": { @@ -403,6 +424,9 @@ def put_writes( "task_path": task_path, "idx": WRITES_IDX_MAP.get(channel, idx), } + if self.ttl: + upsert_query["created_at"] = datetime.now() + type_, serialized_value = self.serde.dumps_typed(value) operations.append( UpdateOne( diff --git a/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_sync.py b/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_sync.py index e958c0f8..f3431386 100644 --- a/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_sync.py +++ b/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_sync.py @@ -1,4 +1,5 @@ import os +from time import sleep from typing import Any import pytest @@ -165,3 +166,39 @@ def test_nested_filter() -> None: # Drop collections saver.checkpoint_collection.drop() saver.writes_collection.drop() + + +def test_ttl(input_data: dict[str, Any]) -> None: + collection_name = "ttl_test" + ttl = 1 + + # Set period between background task runs + monitor_period = 2 + client: MongoClient = MongoClient(MONGODB_URI) + client.admin.command("setParameter", 1, ttlMonitorSleepSecs=monitor_period) + + with MongoDBSaver.from_conn_string( + MONGODB_URI, DB_NAME, collection_name, ttl=ttl + ) as saver: + try: + # save a checkpoint + saver.put( + input_data["config_2"], + input_data["chkpnt_2"], + input_data["metadata_2"], + {}, + ) + + query: dict[str, Any] = {} # search by no keys, return all checkpoints + search_results_2 = list(saver.list(None, filter=query)) + assert len(search_results_2) == 1 + assert search_results_2[0].metadata == input_data["metadata_2"] + + sleep(ttl + monitor_period) + assert len(list(saver.list(None, filter=query))) == 0 + + finally: + saver.checkpoint_collection.delete_many({}) + saver.checkpoint_collection.drop_indexes() + saver.writes_collection.delete_many({}) + saver.writes_collection.drop_indexes()