1
1
import asyncio
2
2
import builtins
3
+ import logging
3
4
import sys
5
+ import threading
6
+ import time
4
7
from collections .abc import AsyncIterator , Iterator , Sequence
5
8
from contextlib import asynccontextmanager
6
9
from typing import Any , Optional
@@ -75,6 +78,10 @@ def __init__(
75
78
db_name : str = "checkpointing_db" ,
76
79
checkpoint_collection_name : str = "checkpoints_aio" ,
77
80
writes_collection_name : str = "checkpoint_writes_aio" ,
81
+ auto_delete_expired_threads : bool = False ,
82
+ thread_status_collection_name : str = "checkpoint_thread_status_aio" ,
83
+ thread_expire_time_second : int = 2592000 , #30 days
84
+ thread_expire_check_time_second : int = 86400 , # 1 day
78
85
** kwargs : Any ,
79
86
) -> None :
80
87
super ().__init__ ()
@@ -83,6 +90,14 @@ def __init__(
83
90
self .checkpoint_collection = self .db [checkpoint_collection_name ]
84
91
self .writes_collection = self .db [writes_collection_name ]
85
92
self .loop = asyncio .get_running_loop ()
93
+ self .auto_delete_expired_threads = auto_delete_expired_threads
94
+ if self .auto_delete_expired_threads :
95
+ self .thread_expire_time_second = thread_expire_time_second
96
+ self .thread_expire_check_time_second = thread_expire_check_time_second
97
+ self .thread_status_collection = self .db [thread_status_collection_name ]
98
+ self .last_thread_expire_check_time_second = 0
99
+ self .thread_expire_check_lock = threading .RLock ()
100
+ self .try_delete_expired_threads_from_checkpoints ()
86
101
87
102
@classmethod
88
103
@asynccontextmanager
@@ -290,6 +305,18 @@ async def aput(
290
305
await self .checkpoint_collection .update_one (
291
306
upsert_query , {"$set" : doc }, upsert = True
292
307
)
308
+ if self .auto_delete_expired_threads :
309
+ thread_status_collection_doc = {
310
+ "thread_id" : thread_id ,
311
+ "last_update" : int (time .time ())
312
+ }
313
+ thread_status_collection_upsert_query = {
314
+ "thread_id" : thread_id
315
+ }
316
+ await self .thread_status_collection .update_one (
317
+ thread_status_collection_upsert_query , {"$set" : thread_status_collection_doc }, upsert = True
318
+ )
319
+ await self .try_delete_expired_threads_from_checkpoints ()
293
320
return {
294
321
"configurable" : {
295
322
"thread_id" : thread_id ,
@@ -450,3 +477,34 @@ def put_writes(
450
477
return asyncio .run_coroutine_threadsafe (
451
478
self .aput_writes (config , writes , task_id ), self .loop
452
479
).result ()
480
+
481
+ async def try_delete_expired_threads_from_checkpoints (self ):
482
+ self .thread_expire_check_lock .acquire ()
483
+ if self .last_thread_expire_check_time_second + self .thread_expire_check_time_second < int (time .time ()):
484
+ try :
485
+ thread_status_collection_result = self .thread_status_collection .find (
486
+ {
487
+ "last_update" : {
488
+ "$lt" : int (time .time ()) - self .thread_expire_time_second
489
+ }
490
+ }
491
+ )
492
+ expired_thread_ids = []
493
+ async for doc in thread_status_collection_result :
494
+ if doc ["last_update" ] + self .thread_expire_time_second < int (time .time ()):
495
+ thread_id = doc ["thread_id" ]
496
+ expired_thread_ids .append (thread_id )
497
+ await self .checkpoint_collection .delete_many ({
498
+ "thread_id" : {
499
+ "$in" : expired_thread_ids
500
+ }
501
+ })
502
+ await self .thread_status_collection .delete_many ({
503
+ "thread_id" : {
504
+ "$in" : expired_thread_ids
505
+ }
506
+ })
507
+ except Exception as e :
508
+ logging .error ("delete_expired_threads_from_checkpoints error:" , e )
509
+ self .last_thread_expire_check_time_second = int (time .time ())
510
+ self .thread_expire_check_lock .release ()
0 commit comments