Skip to content

Commit f74fb1d

Browse files
authored
INTPYTHON-567 Fix AsyncMongoDBSaver setup (#109)
1 parent 8d3c8fd commit f74fb1d

File tree

3 files changed

+88
-63
lines changed

3 files changed

+88
-63
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import builtins
35
import sys
@@ -9,6 +11,8 @@
911
from langchain_core.runnables import RunnableConfig
1012
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
1113
from pymongo import UpdateOne
14+
from pymongo.asynchronous.database import AsyncDatabase
15+
from pymongo.asynchronous.mongo_client import AsyncMongoClient
1216
from pymongo.driver_info import DriverInfo
1317

1418
from langgraph.checkpoint.base import (
@@ -68,12 +72,12 @@ class AsyncMongoDBSaver(BaseCheckpointSaver):
6872
input=3, output=4
6973
"""
7074

71-
client: AsyncIOMotorClient
72-
db: AsyncIOMotorDatabase
75+
client: AsyncIOMotorClient | AsyncMongoClient
76+
db: AsyncIOMotorDatabase | AsyncDatabase
7377

7478
def __init__(
7579
self,
76-
client: AsyncIOMotorClient,
80+
client: AsyncIOMotorClient | AsyncMongoClient,
7781
db_name: str = "checkpointing_db",
7882
checkpoint_collection_name: str = "checkpoints_aio",
7983
writes_collection_name: str = "checkpoint_writes_aio",
@@ -92,12 +96,24 @@ async def _setup(self):
9296
if self._setup_future is not None:
9397
return await self._setup_future
9498
self._setup_future = asyncio.Future()
95-
if len(await self.checkpoint_collection.list_indexes().to_list()) < 2:
99+
if isinstance(self.client, AsyncMongoClient):
100+
num_indexes = len(
101+
await (await self.checkpoint_collection.list_indexes()).to_list()
102+
)
103+
else:
104+
num_indexes = len(await self.checkpoint_collection.list_indexes().to_list())
105+
if num_indexes < 2:
96106
await self.checkpoint_collection.create_index(
97107
keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
98108
unique=True,
99109
)
100-
if len(await self.writes_collection.list_indexes().to_list()) < 2:
110+
if isinstance(self.client, AsyncMongoClient):
111+
num_indexes = len(
112+
await (await self.writes_collection.list_indexes()).to_list()
113+
)
114+
else:
115+
num_indexes = len(await self.writes_collection.list_indexes().to_list())
116+
if num_indexes < 2:
101117
await self.writes_collection.create_index(
102118
keys=[
103119
("thread_id", 1),
@@ -119,7 +135,7 @@ async def from_conn_string(
119135
checkpoint_collection_name: str = "checkpoints_aio",
120136
writes_collection_name: str = "checkpoint_writes_aio",
121137
**kwargs: Any,
122-
) -> AsyncIterator["AsyncMongoDBSaver"]:
138+
) -> AsyncIterator[AsyncMongoDBSaver]:
123139
"""Create asynchronous checkpointer
124140
125141
This includes creation of collections and indexes if they don't exist

libs/langgraph-checkpoint-mongodb/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ readme = "README.md"
1010
requires-python = ">=3.9"
1111
dependencies = [
1212
"langgraph-checkpoint>=2.0.23,<3.0.0",
13-
"pymongo>=4.9,<4.12",
14-
"motor>3.5.0",
13+
"pymongo>=4.10,<4.12",
14+
"motor>3.6.0",
1515
]
1616

1717
[dependency-groups]

0 commit comments

Comments
 (0)