Skip to content

Fixes #81 - Add TTL sessions #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import asynccontextmanager
from datetime import datetime
from importlib.metadata import version
from typing import Any, Optional

Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
db_name: str = "checkpointing_db",
checkpoint_collection_name: str = "checkpoints_aio",
writes_collection_name: str = "checkpoint_writes_aio",
ttl: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__()
Expand All @@ -90,6 +92,7 @@ def __init__(
self.writes_collection = self.db[writes_collection_name]
self._setup_future: asyncio.Future | None = None
self.loop = asyncio.get_running_loop()
self.ttl = ttl

async def _setup(self) -> None:
"""Create indexes if not present."""
Expand All @@ -107,6 +110,11 @@ async def _setup(self) -> None:
keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
unique=True,
)
if self.ttl:
await self.checkpoint_collection.create_index(
keys=[("created_at", 1)],
expireAfterSeconds=self.ttl,
)
if isinstance(self.client, AsyncMongoClient):
num_indexes = len(
await (await self.writes_collection.list_indexes()).to_list() # type:ignore[misc]
Expand All @@ -124,6 +132,11 @@ async def _setup(self) -> None:
],
unique=True,
)
if self.ttl:
await self.writes_collection.create_index(
keys=[("created_at", 1)],
expireAfterSeconds=self.ttl,
)
self._setup_future.set_result(None)

@classmethod
Expand All @@ -134,6 +147,7 @@ async def from_conn_string(
db_name: str = "checkpointing_db",
checkpoint_collection_name: str = "checkpoints_aio",
writes_collection_name: str = "checkpoint_writes_aio",
ttl: Optional[int] = None,
**kwargs: Any,
) -> AsyncIterator[AsyncMongoDBSaver]:
"""Create asynchronous checkpointer
Expand All @@ -153,6 +167,7 @@ async def from_conn_string(
db_name,
checkpoint_collection_name,
writes_collection_name,
ttl,
**kwargs,
)
await saver._setup()
Expand Down Expand Up @@ -343,6 +358,8 @@ async def aput(
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
if self.ttl:
doc["created_at"] = datetime.now()
# Perform your operations here
await self.checkpoint_collection.update_one(
upsert_query, {"$set": doc}, upsert=True
Expand Down Expand Up @@ -389,6 +406,8 @@ async def aput_writes(
"task_path": task_path,
"idx": WRITES_IDX_MAP.get(channel, idx),
}
if self.ttl:
upsert_query["created_at"] = datetime.now()
type_, serialized_value = self.serde.dumps_typed(value)
operations.append(
UpdateOne(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from datetime import datetime
from importlib.metadata import version
from typing import (
Any,
Optional,
)

from langchain_core.runnables import RunnableConfig
from pymongo import MongoClient, UpdateOne
from pymongo import ASCENDING, MongoClient, UpdateOne
from pymongo.database import Database as MongoDatabase
from pymongo.driver_info import DriverInfo

Expand Down Expand Up @@ -39,6 +40,7 @@ class MongoDBSaver(BaseCheckpointSaver):
db_name (Optional[str]): Database name
checkpoint_collection_name (Optional[str]): Name of Collection of Checkpoints
writes_collection_name (Optional[str]): Name of Collection of intermediate writes.
ttl (Optional[int]): Time to live in seconds. See https://www.mongodb.com/docs/manual/core/index-ttl/.

Examples:

Expand Down Expand Up @@ -69,20 +71,28 @@ def __init__(
db_name: str = "checkpointing_db",
checkpoint_collection_name: str = "checkpoints",
writes_collection_name: str = "checkpoint_writes",
ttl: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__()
self.client = client
self.db = self.client[db_name]
self.checkpoint_collection = self.db[checkpoint_collection_name]
self.writes_collection = self.db[writes_collection_name]
self.ttl = ttl

# Create indexes if not present
if len(self.checkpoint_collection.list_indexes().to_list()) < 2:
self.checkpoint_collection.create_index(
keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
unique=True,
)
if self.ttl:
self.checkpoint_collection.create_index(
keys=[("created_at", ASCENDING)],
expireAfterSeconds=self.ttl,
)

if len(self.writes_collection.list_indexes().to_list()) < 2:
self.writes_collection.create_index(
keys=[
Expand All @@ -94,6 +104,11 @@ def __init__(
],
unique=True,
)
if self.ttl:
self.writes_collection.create_index(
keys=[("created_at", ASCENDING)],
expireAfterSeconds=self.ttl,
)

@classmethod
@contextmanager
Expand All @@ -103,6 +118,7 @@ def from_conn_string(
db_name: str = "checkpointing_db",
checkpoint_collection_name: str = "checkpoints",
writes_collection_name: str = "checkpoint_writes",
ttl: Optional[int] = None,
**kwargs: Any,
) -> Iterator["MongoDBSaver"]:
"""Context manager to create a MongoDB checkpoint saver.
Expand All @@ -119,6 +135,7 @@ def from_conn_string(
db_name: Database name. It will be created if it doesn't exist.
checkpoint_collection_name: Checkpoint Collection name. Created if it doesn't exist.
writes_collection_name: Collection name of intermediate writes. Created if it doesn't exist.
ttl (Optional[int]): Time to live in seconds.
Yields: A new MongoDBSaver.
"""
client: Optional[MongoClient] = None
Expand All @@ -134,6 +151,7 @@ def from_conn_string(
db_name,
checkpoint_collection_name,
writes_collection_name,
ttl,
**kwargs,
)
finally:
Expand Down Expand Up @@ -361,6 +379,9 @@ def put(
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
if self.ttl:
upsert_query["created_at"] = datetime.now()

self.checkpoint_collection.update_one(upsert_query, {"$set": doc}, upsert=True)
return {
"configurable": {
Expand Down Expand Up @@ -403,6 +424,9 @@ def put_writes(
"task_path": task_path,
"idx": WRITES_IDX_MAP.get(channel, idx),
}
if self.ttl:
upsert_query["created_at"] = datetime.now()

type_, serialized_value = self.serde.dumps_typed(value)
operations.append(
UpdateOne(
Expand Down
37 changes: 37 additions & 0 deletions libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_sync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from time import sleep
from typing import Any

import pytest
Expand Down Expand Up @@ -165,3 +166,39 @@ def test_nested_filter() -> None:
# Drop collections
saver.checkpoint_collection.drop()
saver.writes_collection.drop()


def test_ttl(input_data: dict[str, Any]) -> None:
collection_name = "ttl_test"
ttl = 1

# Set period between background task runs
monitor_period = 2
client: MongoClient = MongoClient(MONGODB_URI)
client.admin.command("setParameter", 1, ttlMonitorSleepSecs=monitor_period)

with MongoDBSaver.from_conn_string(
MONGODB_URI, DB_NAME, collection_name, ttl=ttl
) as saver:
try:
# save a checkpoint
saver.put(
input_data["config_2"],
input_data["chkpnt_2"],
input_data["metadata_2"],
{},
)

query: dict[str, Any] = {} # search by no keys, return all checkpoints
search_results_2 = list(saver.list(None, filter=query))
assert len(search_results_2) == 1
assert search_results_2[0].metadata == input_data["metadata_2"]

sleep(ttl + monitor_period)
assert len(list(saver.list(None, filter=query))) == 0

finally:
saver.checkpoint_collection.delete_many({})
saver.checkpoint_collection.drop_indexes()
saver.writes_collection.delete_many({})
saver.writes_collection.drop_indexes()
Loading