@@ -80,7 +80,7 @@ def __call__(self, request):
80
80
request .register_hook ('response' , self ._SetJSONWebToken )
81
81
return request
82
82
83
- class Subscription :
83
+ class Subscription ( object ) :
84
84
"""Subscription that contains the unique subscription id for every subscription.
85
85
"""
86
86
_subscriptionId : str # subscription id
@@ -96,8 +96,44 @@ def GetSubscriptionID(self) -> str:
96
96
def GetSubscriptionCallbackFunction (self ) -> Callable [[Optional [ControllerGraphClientException ], Optional [dict ]], None ]:
97
97
return self ._subscriptionCallbackFunction
98
98
99
- class ControllerWebClientRaw (object ):
99
+ def __repr__ (self ):
100
+ return '<Subscription(%r, %r)>' % (self ._subscriptionId , self ._subscriptionCallbackFunction )
101
+
102
+ class BackgroundThread (object ):
103
+ _thread : threading .Thread # A thread to run the event loop
104
+ _eventLoop : asyncio .AbstractEventLoop # Event loop that is running so that client can add coroutine
105
+
106
+ def __init__ (self ):
107
+ # create a new event loop in a background thread
108
+ self ._eventLoop = asyncio .new_event_loop ()
109
+ self ._thread = threading .Thread (target = self ._RunEventLoop )
110
+ self ._thread .start ()
111
+
112
+ def _RunEventLoop (self ):
113
+ # set the created loop as the current event loop for this thread
114
+ asyncio .set_event_loop (self ._eventLoop )
115
+ self ._eventLoop .run_forever ()
116
+
117
+ def RunCoroutine (self , coroutine : Callable ):
118
+ """Schedule a coroutine to run on the event loop from another thread
119
+ """
120
+ return asyncio .run_coroutine_threadsafe (coroutine , self ._eventLoop )
121
+
122
+ def __del__ (self ):
123
+ self .Destroy ()
100
124
125
+ def Destroy (self ):
126
+ if self ._eventLoop .is_closed ():
127
+ return
128
+ # cancel all tasks in the event loop
129
+ for task in asyncio .all_tasks (loop = self ._eventLoop ):
130
+ task .cancel ()
131
+ # run the loop briefly to let cancellations propagate
132
+ self ._eventLoop .call_soon_threadsafe (self ._eventLoop .stop )
133
+ self ._thread .join ()
134
+ self ._eventLoop .close ()
135
+
136
+ class ControllerWebClientRaw (object ):
101
137
_baseurl = None # Base URL of the controller
102
138
_username = None # Username to login with
103
139
_password = None # Password to login with
@@ -106,9 +142,7 @@ class ControllerWebClientRaw(object):
106
142
_session = None # Requests session object
107
143
_webSocket : websockets .asyncio .client .ClientConnection # WebSocket used to connect to WebStack for subscriptions
108
144
_subscriptions : dict [str , Subscription ] # Dictionary that stores the subscriptionId(key) and the corresponding subscription(value)
109
- _eventLoopThread : threading .Thread # A thread to run the event loop
110
- _eventLoop : asyncio .AbstractEventLoop # Event loop that is running so that client can add coroutine(a subscription in this case)
111
- _webSocketLock : threading .Lock # Lock protecting initializing the WebSocket
145
+ _subscriptionLock : threading .Lock # Lock protecting _webSocket and _subscriptions
112
146
113
147
def __init__ (self , baseurl , username , password , locale = None , author = None , userAgent = None , additionalHeaders = None , unixEndpoint = None ):
114
148
self ._baseurl = baseurl
@@ -117,11 +151,12 @@ def __init__(self, baseurl, username, password, locale=None, author=None, userAg
117
151
self ._headers = {}
118
152
self ._isok = True
119
153
120
- self ._eventLoop = None
121
- self ._eventLoopThread = None
122
154
self ._webSocket = None
123
- self ._webSocketLock = threading .Lock ()
124
155
self ._subscriptions = {}
156
+ self ._subscriptionLock = threading .Lock ()
157
+
158
+ # Create the background thread for async operations
159
+ self ._backgroundThread = BackgroundThread ()
125
160
126
161
# Create session
127
162
self ._session = requests .Session ()
@@ -162,7 +197,12 @@ def __del__(self):
162
197
163
198
def Destroy (self ):
164
199
self .SetDestroy ()
165
- self ._CloseEventLoop ()
200
+ if self ._backgroundThread is not None :
201
+ # make sure to stop subscriptions and close the websocket first
202
+ self ._backgroundThread .RunCoroutine (self ._StopAllSubscriptions ()).result ()
203
+ # next destroy the thread
204
+ self ._backgroundThread .Destroy ()
205
+ self ._backgroundThread = None
166
206
167
207
def SetDestroy (self ):
168
208
self ._isok = False
@@ -191,31 +231,6 @@ def SetUserAgent(self, userAgent=None):
191
231
else :
192
232
self ._headers .pop ('User-Agent' , None )
193
233
194
- def _InitializeEventLoopThread (self ):
195
- # Create new event loop if _eventLoop is None, otherwise, reuse the existing one
196
- if self ._eventLoop is None :
197
- self ._eventLoop = asyncio .new_event_loop ()
198
- if self ._eventLoopThread is not None and self ._eventLoopThread .is_alive ():
199
- self ._eventLoopThread .join ()
200
- self ._eventLoopThread = threading .Thread (target = self ._RunEventLoop )
201
- self ._eventLoopThread .start ()
202
-
203
- def _RunEventLoop (self ):
204
- asyncio .set_event_loop (self ._eventLoop )
205
- self ._eventLoop .run_forever ()
206
-
207
- def _StopEventLoop (self ):
208
- if self ._eventLoop is not None :
209
- self ._eventLoop .stop ()
210
-
211
- def _CloseEventLoop (self ):
212
- if self ._eventLoop is not None and self ._eventLoop .is_running ():
213
- self ._eventLoop .call_soon_threadsafe (self ._StopEventLoop )
214
- if self ._eventLoopThread is not None and self ._eventLoopThread .is_alive ():
215
- self ._eventLoopThread .join ()
216
- if self ._eventLoop is not None and not self ._eventLoop .is_closed ():
217
- self ._eventLoop .close ()
218
-
219
234
def Request (self , method , path , timeout = 5 , headers = None , ** kwargs ):
220
235
if timeout < 1e-6 :
221
236
raise WebstackClientError (_ ('Timeout value (%s sec) is too small' ) % timeout )
@@ -355,6 +370,21 @@ def CallGraphAPI(self, query, variables=None, headers=None, timeout=5.0):
355
370
356
371
return content ['data' ]
357
372
373
+ def _EnsureWebSocketConnection (self ):
374
+ if self ._webSocket is None :
375
+ # wait until the connection is established
376
+ self ._backgroundThread .RunCoroutine (self ._OpenWebSocketConnection ()).result ()
377
+ # start listening without blocking
378
+ self ._backgroundThread .RunCoroutine (self ._ListenToWebSocket ())
379
+
380
+ def _IsWebSocketConnectionOpen (self ):
381
+ return self ._webSocket is not None
382
+
383
+ async def _CloseWebSocket (self ):
384
+ if self ._webSocket is not None :
385
+ await self ._webSocket .close ()
386
+ self ._webSocket = None
387
+
358
388
async def _OpenWebSocketConnection (self ):
359
389
authorization = self ._session .auth .GetAuthorizationHeader ()
360
390
@@ -401,10 +431,8 @@ async def _OpenWebSocketConnection(self):
401
431
async def _ListenToWebSocket (self ):
402
432
try :
403
433
async for response in self ._webSocket :
404
- # stop listening if there is no subscriptions
405
- # the connection will close automatically after breaking out the loop
406
- if len (self ._subscriptions ) == 0 :
407
- await self ._webSocket .close ()
434
+ # stop if stop is requested
435
+ if not self ._isok :
408
436
break
409
437
410
438
# parse the result
@@ -434,33 +462,40 @@ async def _ListenToWebSocket(self):
434
462
# raise an error, this should never happen
435
463
raise ControllerGraphClientException (_ ('Unexpected server response, missing id: %s' ) % (response ))
436
464
437
- # select the right subscription
438
- subscriptionId = content ['id' ]
439
- subscription = self ._subscriptions .get (subscriptionId )
440
- if subscription is None :
441
- # subscriber is gone
442
- continue
443
-
444
- # return if there is an error
445
- if 'payload' in content and 'errors' in content ['payload' ] and len (content ['payload' ]['errors' ]) > 0 :
446
- message = content ['payload' ]['errors' ][0 ].get ('message' , response )
447
- errorCode = None
448
- if 'extensions' in content ['payload' ]['errors' ][0 ]:
449
- errorCode = content ['payload' ]['errors' ][0 ]['extensions' ].get ('errorCode' , None )
450
- subscription .GetSubscriptionCallbackFunction ()(error = ControllerGraphClientException (message , content = content , errorCode = errorCode ), response = None )
451
- continue
452
-
453
- # return the payload
454
- subscription .GetSubscriptionCallbackFunction ()(error = None , response = content .get ('payload' ) or {})
465
+ # reply back to subscribers
466
+ with self ._subscriptionLock :
467
+ # select the right subscription
468
+ subscriptionId = content ['id' ]
469
+ subscription = self ._subscriptions .get (subscriptionId )
470
+ if subscription is None :
471
+ # subscriber is gone
472
+ continue
473
+
474
+ # return if there is an error
475
+ if 'payload' in content and 'errors' in content ['payload' ] and len (content ['payload' ]['errors' ]) > 0 :
476
+ message = content ['payload' ]['errors' ][0 ].get ('message' , response )
477
+ errorCode = None
478
+ if 'extensions' in content ['payload' ]['errors' ][0 ]:
479
+ errorCode = content ['payload' ]['errors' ][0 ]['extensions' ].get ('errorCode' , None )
480
+ subscription .GetSubscriptionCallbackFunction ()(error = ControllerGraphClientException (message , content = content , errorCode = errorCode ), response = None )
481
+ continue
482
+
483
+ # return the payload
484
+ subscription .GetSubscriptionCallbackFunction ()(error = None , response = content .get ('payload' ) or {})
455
485
456
486
except Exception as e :
457
487
log .exception ('caught WebSocket exception: %s' , e )
458
- self ._webSocket = None
459
- with self ._webSocketLock :
460
- # send a message back to the caller using the callback function and drop all subscriptions
461
- for subscriptionId , subscription in self ._subscriptions .items ():
462
- subscription .GetSubscriptionCallbackFunction ()(error = e , response = None )
463
- self ._subscriptions .clear ()
488
+ with self ._subscriptionLock :
489
+ await self ._StopAllSubscriptions (ControllerGraphClientException (_ ('Failed to listen to WebSocket: %s' ) % (e )))
490
+
491
+ async def _StopAllSubscriptions (self , error : Optional [ControllerGraphClientException ] = None ):
492
+ # close the websocket
493
+ await self ._CloseWebSocket ()
494
+ # send a message back to the callers using the callback function and drop all subscriptions
495
+ if error is not None :
496
+ for subscriptionId , subscription in self ._subscriptions .items ():
497
+ subscription .GetSubscriptionCallbackFunction ()(error = error , response = None )
498
+ self ._subscriptions .clear ()
464
499
465
500
def SubscribeGraphAPI (self , query : str , callbackFunction : Callable [[Optional [str ], Optional [dict ]], None ], variables : Optional [dict ] = None ) -> Subscription :
466
501
""" Subscribes to changes on Mujin controller.
@@ -470,49 +505,69 @@ def SubscribeGraphAPI(self, query: str, callbackFunction: Callable[[Optional[str
470
505
variables (dict): variables that should be passed into the query if necessary
471
506
callbackFunction (func): a callback function to process the response data that is received from the subscription
472
507
"""
473
- # generate subscriptionId, an unique id to sent to the server so that we can have multiple subscriptions using the same WebSocket
508
+ # create a new subscription
474
509
subscriptionId = str (uuid .uuid4 ())
475
510
subscription = Subscription (subscriptionId , callbackFunction )
476
- with self ._webSocketLock :
477
- self ._subscriptions [subscriptionId ] = subscription
478
- if self ._eventLoop is None or not self ._eventLoop .is_running :
479
- self ._InitializeEventLoopThread ()
480
511
481
512
async def _Subscribe ():
482
- # need to ensure only one thread is initializing the WebSocket
483
- with self ._webSocketLock :
484
- # check if _webSocket exists
485
- if self ._webSocket is None :
486
- await self ._OpenWebSocketConnection ()
487
- # create a coroutine that is specially used for listening to the WebSocket
488
- asyncio .run_coroutine_threadsafe (self ._ListenToWebSocket (), self ._eventLoop )
489
-
513
+ try :
490
514
# start a new subscription on the WebSocket connection
491
- await self . _webSocket . send ( json . dumps ( {
492
- 'id' : subscriptionId ,
515
+ message = {
516
+ 'id' : subscription . GetSubscriptionID () ,
493
517
'type' : 'start' ,
494
- 'payload' : {'query' : query , 'variables' : variables or {}}
495
- }))
518
+ 'payload' : { 'query' : query }
519
+ }
520
+ if variables :
521
+ message ['payload' ]['variables' ] = variables
522
+ await self ._webSocket .send (json .dumps (message ))
523
+ except Exception as e :
524
+ log .exception ('caught WebSocket exception: %s' , e )
525
+ await self ._StopAllSubscriptions (ControllerGraphClientException (_ ('Failed to subscribe: %s' ) % (e )))
526
+
527
+ with self ._subscriptionLock :
528
+ # make sure the websocket connection is running
529
+ self ._EnsureWebSocketConnection ()
530
+
531
+ # wait until the subscription is created
532
+ self ._backgroundThread .RunCoroutine (_Subscribe ()).result ()
533
+ self ._subscriptions [subscriptionId ] = subscription
496
534
497
- asyncio .run_coroutine_threadsafe (_Subscribe (), self ._eventLoop )
498
- return subscription
535
+ return subscription
499
536
500
537
def UnsubscribeGraphAPI (self , subscription : Subscription ):
501
538
""" Unsubscribes to Mujin controller.
502
539
503
540
Args:
504
541
subscription (Subscription): the subscription that the user wants to unsubscribe
505
542
"""
506
- async def _StopSubscription ():
507
- subscriptionId = subscription .GetSubscriptionID ()
508
- with self ._webSocketLock :
543
+ subscriptionId = subscription .GetSubscriptionID ()
544
+
545
+ async def _Unsubscribe ():
546
+ try :
509
547
# check if self._subscriptionIds has subscriptionId
510
548
if subscriptionId in self ._subscriptions :
511
549
await self ._webSocket .send (json .dumps ({
512
550
'id' : subscriptionId ,
513
551
'type' : 'stop'
514
552
}))
515
553
# remove subscription
516
- self ._subscriptions .pop (subscription . GetSubscriptionID () , None )
554
+ self ._subscriptions .pop (subscriptionId , None )
517
555
518
- asyncio .run_coroutine_threadsafe (_StopSubscription (), self ._eventLoop )
556
+ # close the websocket connection if no more subscribers are left
557
+ if len (self ._subscriptions ) == 0 :
558
+ await self ._CloseWebSocket ()
559
+ except Exception as e :
560
+ log .exception ('caught WebSocket exception: %s' , e )
561
+ await self ._StopAllSubscriptions (ControllerGraphClientException (_ ('Failed to unsubscribe: %s' ) % (e )))
562
+
563
+ with self ._subscriptionLock :
564
+ # nothing to do if websocket is not established
565
+ if not self ._IsWebSocketConnectionOpen ():
566
+ return
567
+
568
+ # check if the subscription exists
569
+ if subscription .GetSubscriptionID () not in self ._subscriptions :
570
+ raise ControllerGraphClientException (_ ('Unknown subscription %r' ) % (subscription ))
571
+
572
+ # actually unsubscribe and wait until there is a result
573
+ self ._backgroundThread .RunCoroutine (_Unsubscribe ()).result ()
0 commit comments