Skip to content

Commit 170264e

Browse files
Fixes #81 - Add TTL sessions (#108)
Co-authored-by: Casey Clements <[email protected]>
1 parent ff93552 commit 170264e

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
from collections.abc import AsyncIterator, Iterator, Sequence
77
from contextlib import asynccontextmanager
8+
from datetime import datetime
89
from importlib.metadata import version
910
from typing import Any, Optional
1011

@@ -81,6 +82,7 @@ def __init__(
8182
db_name: str = "checkpointing_db",
8283
checkpoint_collection_name: str = "checkpoints_aio",
8384
writes_collection_name: str = "checkpoint_writes_aio",
85+
ttl: Optional[int] = None,
8486
**kwargs: Any,
8587
) -> None:
8688
super().__init__()
@@ -90,6 +92,7 @@ def __init__(
9092
self.writes_collection = self.db[writes_collection_name]
9193
self._setup_future: asyncio.Future | None = None
9294
self.loop = asyncio.get_running_loop()
95+
self.ttl = ttl
9396

9497
async def _setup(self) -> None:
9598
"""Create indexes if not present."""
@@ -107,6 +110,11 @@ async def _setup(self) -> None:
107110
keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
108111
unique=True,
109112
)
113+
if self.ttl:
114+
await self.checkpoint_collection.create_index(
115+
keys=[("created_at", 1)],
116+
expireAfterSeconds=self.ttl,
117+
)
110118
if isinstance(self.client, AsyncMongoClient):
111119
num_indexes = len(
112120
await (await self.writes_collection.list_indexes()).to_list() # type:ignore[misc]
@@ -124,6 +132,11 @@ async def _setup(self) -> None:
124132
],
125133
unique=True,
126134
)
135+
if self.ttl:
136+
await self.writes_collection.create_index(
137+
keys=[("created_at", 1)],
138+
expireAfterSeconds=self.ttl,
139+
)
127140
self._setup_future.set_result(None)
128141

129142
@classmethod
@@ -134,6 +147,7 @@ async def from_conn_string(
134147
db_name: str = "checkpointing_db",
135148
checkpoint_collection_name: str = "checkpoints_aio",
136149
writes_collection_name: str = "checkpoint_writes_aio",
150+
ttl: Optional[int] = None,
137151
**kwargs: Any,
138152
) -> AsyncIterator[AsyncMongoDBSaver]:
139153
"""Create asynchronous checkpointer
@@ -153,6 +167,7 @@ async def from_conn_string(
153167
db_name,
154168
checkpoint_collection_name,
155169
writes_collection_name,
170+
ttl,
156171
**kwargs,
157172
)
158173
await saver._setup()
@@ -343,6 +358,8 @@ async def aput(
343358
"checkpoint_ns": checkpoint_ns,
344359
"checkpoint_id": checkpoint_id,
345360
}
361+
if self.ttl:
362+
doc["created_at"] = datetime.now()
346363
# Perform your operations here
347364
await self.checkpoint_collection.update_one(
348365
upsert_query, {"$set": doc}, upsert=True
@@ -389,6 +406,8 @@ async def aput_writes(
389406
"task_path": task_path,
390407
"idx": WRITES_IDX_MAP.get(channel, idx),
391408
}
409+
if self.ttl:
410+
upsert_query["created_at"] = datetime.now()
392411
type_, serialized_value = self.serde.dumps_typed(value)
393412
operations.append(
394413
UpdateOne(

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from collections.abc import Iterator, Sequence
22
from contextlib import contextmanager
3+
from datetime import datetime
34
from importlib.metadata import version
45
from typing import (
56
Any,
67
Optional,
78
)
89

910
from langchain_core.runnables import RunnableConfig
10-
from pymongo import MongoClient, UpdateOne
11+
from pymongo import ASCENDING, MongoClient, UpdateOne
1112
from pymongo.database import Database as MongoDatabase
1213
from pymongo.driver_info import DriverInfo
1314

@@ -39,6 +40,7 @@ class MongoDBSaver(BaseCheckpointSaver):
3940
db_name (Optional[str]): Database name
4041
checkpoint_collection_name (Optional[str]): Name of Collection of Checkpoints
4142
writes_collection_name (Optional[str]): Name of Collection of intermediate writes.
43+
ttl (Optional[int]): Time to live in seconds. See https://www.mongodb.com/docs/manual/core/index-ttl/.
4244
4345
Examples:
4446
@@ -69,20 +71,28 @@ def __init__(
6971
db_name: str = "checkpointing_db",
7072
checkpoint_collection_name: str = "checkpoints",
7173
writes_collection_name: str = "checkpoint_writes",
74+
ttl: Optional[int] = None,
7275
**kwargs: Any,
7376
) -> None:
7477
super().__init__()
7578
self.client = client
7679
self.db = self.client[db_name]
7780
self.checkpoint_collection = self.db[checkpoint_collection_name]
7881
self.writes_collection = self.db[writes_collection_name]
82+
self.ttl = ttl
7983

8084
# Create indexes if not present
8185
if len(self.checkpoint_collection.list_indexes().to_list()) < 2:
8286
self.checkpoint_collection.create_index(
8387
keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
8488
unique=True,
8589
)
90+
if self.ttl:
91+
self.checkpoint_collection.create_index(
92+
keys=[("created_at", ASCENDING)],
93+
expireAfterSeconds=self.ttl,
94+
)
95+
8696
if len(self.writes_collection.list_indexes().to_list()) < 2:
8797
self.writes_collection.create_index(
8898
keys=[
@@ -94,6 +104,11 @@ def __init__(
94104
],
95105
unique=True,
96106
)
107+
if self.ttl:
108+
self.writes_collection.create_index(
109+
keys=[("created_at", ASCENDING)],
110+
expireAfterSeconds=self.ttl,
111+
)
97112

98113
@classmethod
99114
@contextmanager
@@ -103,6 +118,7 @@ def from_conn_string(
103118
db_name: str = "checkpointing_db",
104119
checkpoint_collection_name: str = "checkpoints",
105120
writes_collection_name: str = "checkpoint_writes",
121+
ttl: Optional[int] = None,
106122
**kwargs: Any,
107123
) -> Iterator["MongoDBSaver"]:
108124
"""Context manager to create a MongoDB checkpoint saver.
@@ -119,6 +135,7 @@ def from_conn_string(
119135
db_name: Database name. It will be created if it doesn't exist.
120136
checkpoint_collection_name: Checkpoint Collection name. Created if it doesn't exist.
121137
writes_collection_name: Collection name of intermediate writes. Created if it doesn't exist.
138+
ttl (Optional[int]): Time to live in seconds.
122139
Yields: A new MongoDBSaver.
123140
"""
124141
client: Optional[MongoClient] = None
@@ -134,6 +151,7 @@ def from_conn_string(
134151
db_name,
135152
checkpoint_collection_name,
136153
writes_collection_name,
154+
ttl,
137155
**kwargs,
138156
)
139157
finally:
@@ -361,6 +379,9 @@ def put(
361379
"checkpoint_ns": checkpoint_ns,
362380
"checkpoint_id": checkpoint_id,
363381
}
382+
if self.ttl:
383+
upsert_query["created_at"] = datetime.now()
384+
364385
self.checkpoint_collection.update_one(upsert_query, {"$set": doc}, upsert=True)
365386
return {
366387
"configurable": {
@@ -403,6 +424,9 @@ def put_writes(
403424
"task_path": task_path,
404425
"idx": WRITES_IDX_MAP.get(channel, idx),
405426
}
427+
if self.ttl:
428+
upsert_query["created_at"] = datetime.now()
429+
406430
type_, serialized_value = self.serde.dumps_typed(value)
407431
operations.append(
408432
UpdateOne(

libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_sync.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from time import sleep
23
from typing import Any
34

45
import pytest
@@ -165,3 +166,39 @@ def test_nested_filter() -> None:
165166
# Drop collections
166167
saver.checkpoint_collection.drop()
167168
saver.writes_collection.drop()
169+
170+
171+
def test_ttl(input_data: dict[str, Any]) -> None:
172+
collection_name = "ttl_test"
173+
ttl = 1
174+
175+
# Set period between background task runs
176+
monitor_period = 2
177+
client: MongoClient = MongoClient(MONGODB_URI)
178+
client.admin.command("setParameter", 1, ttlMonitorSleepSecs=monitor_period)
179+
180+
with MongoDBSaver.from_conn_string(
181+
MONGODB_URI, DB_NAME, collection_name, ttl=ttl
182+
) as saver:
183+
try:
184+
# save a checkpoint
185+
saver.put(
186+
input_data["config_2"],
187+
input_data["chkpnt_2"],
188+
input_data["metadata_2"],
189+
{},
190+
)
191+
192+
query: dict[str, Any] = {} # search by no keys, return all checkpoints
193+
search_results_2 = list(saver.list(None, filter=query))
194+
assert len(search_results_2) == 1
195+
assert search_results_2[0].metadata == input_data["metadata_2"]
196+
197+
sleep(ttl + monitor_period)
198+
assert len(list(saver.list(None, filter=query))) == 0
199+
200+
finally:
201+
saver.checkpoint_collection.delete_many({})
202+
saver.checkpoint_collection.drop_indexes()
203+
saver.writes_collection.delete_many({})
204+
saver.writes_collection.drop_indexes()

0 commit comments

Comments
 (0)