1
+ from __future__ import annotations
2
+
1
3
import asyncio
2
4
import builtins
3
5
import sys
9
11
from langchain_core .runnables import RunnableConfig
10
12
from motor .motor_asyncio import AsyncIOMotorClient , AsyncIOMotorDatabase
11
13
from pymongo import UpdateOne
14
+ from pymongo .asynchronous .database import AsyncDatabase
15
+ from pymongo .asynchronous .mongo_client import AsyncMongoClient
12
16
from pymongo .driver_info import DriverInfo
13
17
14
18
from langgraph .checkpoint .base import (
@@ -68,12 +72,12 @@ class AsyncMongoDBSaver(BaseCheckpointSaver):
68
72
input=3, output=4
69
73
"""
70
74
71
- client : AsyncIOMotorClient
72
- db : AsyncIOMotorDatabase
75
+ client : AsyncIOMotorClient | AsyncMongoClient
76
+ db : AsyncIOMotorDatabase | AsyncDatabase
73
77
74
78
def __init__ (
75
79
self ,
76
- client : AsyncIOMotorClient ,
80
+ client : AsyncIOMotorClient | AsyncMongoClient ,
77
81
db_name : str = "checkpointing_db" ,
78
82
checkpoint_collection_name : str = "checkpoints_aio" ,
79
83
writes_collection_name : str = "checkpoint_writes_aio" ,
@@ -92,12 +96,24 @@ async def _setup(self):
92
96
if self ._setup_future is not None :
93
97
return await self ._setup_future
94
98
self ._setup_future = asyncio .Future ()
95
- if len (await self .checkpoint_collection .list_indexes ().to_list ()) < 2 :
99
+ if isinstance (self .client , AsyncMongoClient ):
100
+ num_indexes = len (
101
+ await (await self .checkpoint_collection .list_indexes ()).to_list ()
102
+ )
103
+ else :
104
+ num_indexes = len (await self .checkpoint_collection .list_indexes ().to_list ())
105
+ if num_indexes < 2 :
96
106
await self .checkpoint_collection .create_index (
97
107
keys = [("thread_id" , 1 ), ("checkpoint_ns" , 1 ), ("checkpoint_id" , - 1 )],
98
108
unique = True ,
99
109
)
100
- if len (await self .writes_collection .list_indexes ().to_list ()) < 2 :
110
+ if isinstance (self .client , AsyncMongoClient ):
111
+ num_indexes = len (
112
+ await (await self .writes_collection .list_indexes ()).to_list ()
113
+ )
114
+ else :
115
+ num_indexes = len (await self .writes_collection .list_indexes ().to_list ())
116
+ if num_indexes < 2 :
101
117
await self .writes_collection .create_index (
102
118
keys = [
103
119
("thread_id" , 1 ),
@@ -119,7 +135,7 @@ async def from_conn_string(
119
135
checkpoint_collection_name : str = "checkpoints_aio" ,
120
136
writes_collection_name : str = "checkpoint_writes_aio" ,
121
137
** kwargs : Any ,
122
- ) -> AsyncIterator [" AsyncMongoDBSaver" ]:
138
+ ) -> AsyncIterator [AsyncMongoDBSaver ]:
123
139
"""Create asynchronous checkpointer
124
140
125
141
This includes creation of collections and indexes if they don't exist
0 commit comments