Skip to content

Commit 268c81e

Browse files
committed
Adds delete_thread and adelete_thread methods to checkpointers.
1 parent 32f68b8 commit 268c81e

File tree

3 files changed

+239
-0
lines changed

3 files changed

+239
-0
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,34 @@ async def aput_writes(
424424
)
425425
await self.writes_collection.bulk_write(operations)
426426

427+
async def adelete_thread(
428+
self,
429+
thread_id: str,
430+
) -> None:
431+
"""Delete all checkpoints and writes associated with a specific thread ID asynchronously.
432+
433+
Args:
434+
thread_id (str): The thread ID whose checkpoints should be deleted.
435+
"""
436+
# Delete all checkpoints associated with the thread ID
437+
await self.checkpoint_collection.delete_many({"thread_id": thread_id})
438+
439+
# Delete all writes associated with the thread ID
440+
await self.writes_collection.delete_many({"thread_id": thread_id})
441+
442+
def delete_thread(
443+
self,
444+
thread_id: str,
445+
) -> None:
446+
"""Delete all checkpoints and writes associated with a specific thread ID.
447+
448+
Args:
449+
thread_id (str): The thread ID whose checkpoints should be deleted.
450+
"""
451+
return asyncio.run_coroutine_threadsafe(
452+
self.adelete_thread(thread_id), self.loop
453+
).result()
454+
427455
def list(
428456
self,
429457
config: Optional[RunnableConfig],

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,18 @@ def put_writes(
442442
)
443443
)
444444
self.writes_collection.bulk_write(operations)
445+
446+
def delete_thread(
447+
self,
448+
thread_id: str,
449+
) -> None:
450+
"""Delete all checkpoints and writes associated with a specific thread ID.
451+
452+
Args:
453+
thread_id (str): The thread ID whose checkpoints should be deleted.
454+
"""
455+
# Delete all checkpoints associated with the thread ID
456+
self.checkpoint_collection.delete_many({"thread_id": thread_id})
457+
458+
# Delete all writes associated with the thread ID
459+
self.writes_collection.delete_many({"thread_id": thread_id})
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import os
2+
3+
from langchain_core.runnables import RunnableConfig
4+
from pymongo import MongoClient
5+
6+
from langgraph.checkpoint.base import CheckpointMetadata, empty_checkpoint
7+
from langgraph.checkpoint.mongodb import MongoDBSaver
8+
from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
9+
10+
# Setup:
11+
MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
12+
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
13+
CHKPT_COLLECTION_NAME = "delete_thread_chkpts"
14+
WRITES_COLLECTION_NAME = "delete_thread_writes"
15+
16+
17+
def test_delete_thread() -> None:
18+
# Clear collections if they exist
19+
client: MongoClient = MongoClient(MONGODB_URI)
20+
db = client[DB_NAME]
21+
db[CHKPT_COLLECTION_NAME].delete_many({})
22+
db[WRITES_COLLECTION_NAME].delete_many({})
23+
24+
with MongoDBSaver.from_conn_string(
25+
MONGODB_URI, DB_NAME, CHKPT_COLLECTION_NAME, WRITES_COLLECTION_NAME
26+
) as saver:
27+
# Thread 1 data
28+
chkpnt_1 = empty_checkpoint()
29+
thread_1_id = "thread-1"
30+
config_1 = RunnableConfig(
31+
configurable=dict(
32+
thread_id=thread_1_id, checkpoint_ns="", checkpoint_id=chkpnt_1["id"]
33+
)
34+
)
35+
metadata_1: CheckpointMetadata = {
36+
"source": "input",
37+
"step": 1,
38+
"writes": {"foo": "bar"},
39+
}
40+
41+
# Thread 2 data
42+
chkpnt_2 = empty_checkpoint()
43+
thread_2_id = "thread-2"
44+
config_2 = RunnableConfig(
45+
configurable=dict(
46+
thread_id=thread_2_id, checkpoint_ns="", checkpoint_id=chkpnt_2["id"]
47+
)
48+
)
49+
metadata_2: CheckpointMetadata = {
50+
"source": "output",
51+
"step": 1,
52+
"writes": {"baz": "qux"},
53+
}
54+
55+
# Save checkpoints for both threads
56+
saver.put(config_1, chkpnt_1, metadata_1, {})
57+
saver.put(config_2, chkpnt_2, metadata_2, {})
58+
59+
# Add some writes
60+
saver.put_writes(config_1, [("channel1", "value1")], "task1")
61+
saver.put_writes(config_2, [("channel2", "value2")], "task2")
62+
63+
# Verify we have data for both threads
64+
assert saver.get_tuple(config_1) is not None
65+
assert saver.get_tuple(config_2) is not None
66+
67+
# Verify we have write data
68+
assert (
69+
saver.checkpoint_collection.count_documents({"thread_id": thread_1_id}) > 0
70+
)
71+
assert saver.writes_collection.count_documents({"thread_id": thread_1_id}) > 0
72+
assert (
73+
saver.checkpoint_collection.count_documents({"thread_id": thread_2_id}) > 0
74+
)
75+
assert saver.writes_collection.count_documents({"thread_id": thread_2_id}) > 0
76+
77+
# Delete thread 1
78+
saver.delete_thread(thread_1_id)
79+
80+
# Verify thread 1 data is gone
81+
assert saver.get_tuple(config_1) is None
82+
assert (
83+
saver.checkpoint_collection.count_documents({"thread_id": thread_1_id}) == 0
84+
)
85+
assert saver.writes_collection.count_documents({"thread_id": thread_1_id}) == 0
86+
87+
# Verify thread 2 data still exists
88+
assert saver.get_tuple(config_2) is not None
89+
assert (
90+
saver.checkpoint_collection.count_documents({"thread_id": thread_2_id}) > 0
91+
)
92+
assert saver.writes_collection.count_documents({"thread_id": thread_2_id}) > 0
93+
94+
95+
async def test_adelete_thread() -> None:
96+
# Clear collections if they exist
97+
client: MongoClient = MongoClient(MONGODB_URI)
98+
db = client[DB_NAME]
99+
db[CHKPT_COLLECTION_NAME].delete_many({})
100+
db[WRITES_COLLECTION_NAME].delete_many({})
101+
102+
async with AsyncMongoDBSaver.from_conn_string(
103+
MONGODB_URI, DB_NAME, CHKPT_COLLECTION_NAME, WRITES_COLLECTION_NAME
104+
) as saver:
105+
# Thread 1 data
106+
chkpnt_1 = empty_checkpoint()
107+
thread_1_id = "thread-1"
108+
config_1 = RunnableConfig(
109+
configurable=dict(
110+
thread_id=thread_1_id, checkpoint_ns="", checkpoint_id=chkpnt_1["id"]
111+
)
112+
)
113+
metadata_1: CheckpointMetadata = {
114+
"source": "input",
115+
"step": 1,
116+
"writes": {"foo": "bar"},
117+
}
118+
119+
# Thread 2 data
120+
chkpnt_2 = empty_checkpoint()
121+
thread_2_id = "thread-2"
122+
config_2 = RunnableConfig(
123+
configurable=dict(
124+
thread_id=thread_2_id, checkpoint_ns="", checkpoint_id=chkpnt_2["id"]
125+
)
126+
)
127+
metadata_2: CheckpointMetadata = {
128+
"source": "output",
129+
"step": 1,
130+
"writes": {"baz": "qux"},
131+
}
132+
133+
assert await saver.checkpoint_collection.count_documents({}) == 0
134+
135+
# Save checkpoints for both threads
136+
await saver.aput(config_1, chkpnt_1, metadata_1, {})
137+
await saver.aput(config_2, chkpnt_2, metadata_2, {})
138+
139+
# Add some writes
140+
await saver.aput_writes(config_1, [("channel1", "value1")], "task1")
141+
await saver.aput_writes(config_2, [("channel2", "value2")], "task2")
142+
143+
# Verify we have data for both threads
144+
assert await saver.aget_tuple(config_1) is not None
145+
assert await saver.aget_tuple(config_2) is not None
146+
147+
# Verify we have write data
148+
assert (
149+
await saver.checkpoint_collection.count_documents(
150+
{"thread_id": thread_1_id}
151+
)
152+
> 0
153+
)
154+
assert (
155+
await saver.writes_collection.count_documents({"thread_id": thread_1_id})
156+
> 0
157+
)
158+
assert (
159+
await saver.checkpoint_collection.count_documents(
160+
{"thread_id": thread_2_id}
161+
)
162+
> 0
163+
)
164+
assert (
165+
await saver.writes_collection.count_documents({"thread_id": thread_2_id})
166+
> 0
167+
)
168+
169+
# Delete thread 1
170+
await saver.adelete_thread(thread_1_id)
171+
172+
# Verify thread 1 data is gone
173+
assert await saver.aget_tuple(config_1) is None
174+
assert (
175+
await saver.checkpoint_collection.count_documents(
176+
{"thread_id": thread_1_id}
177+
)
178+
== 0
179+
)
180+
assert (
181+
await saver.writes_collection.count_documents({"thread_id": thread_1_id})
182+
== 0
183+
)
184+
185+
# Verify thread 2 data still exists
186+
assert await saver.aget_tuple(config_2) is not None
187+
assert (
188+
await saver.checkpoint_collection.count_documents(
189+
{"thread_id": thread_2_id}
190+
)
191+
> 0
192+
)
193+
assert (
194+
await saver.writes_collection.count_documents({"thread_id": thread_2_id})
195+
> 0
196+
)

0 commit comments

Comments
 (0)