Skip to content

Commit cc67090

Browse files
author
Barkin Simsek
committed
Simplify event loop and websocket handling.
1 parent 37cd36e commit cc67090

File tree

2 files changed

+144
-86
lines changed

2 files changed

+144
-86
lines changed

bin/mujin_webstackclientpy_runshell.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@ def _Main():
3232

3333
self = WebstackClient(options.url, options.username, options.password)
3434

35+
# launch interactive shell
3536
from IPython.terminal import embed
3637
ipshell = embed.InteractiveShellEmbed(config=embed.load_default_config())(local_ns=locals())
3738

39+
# destroy the client
40+
self.Destroy()
3841

3942
if __name__ == "__main__":
4043
_Main()

python/mujinwebstackclient/controllerwebclientraw.py

Lines changed: 141 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __call__(self, request):
8080
request.register_hook('response', self._SetJSONWebToken)
8181
return request
8282

83-
class Subscription:
83+
class Subscription(object):
8484
"""Subscription that contains the unique subscription id for every subscription.
8585
"""
8686
_subscriptionId: str # subscription id
@@ -96,8 +96,44 @@ def GetSubscriptionID(self) -> str:
9696
def GetSubscriptionCallbackFunction(self) -> Callable[[Optional[ControllerGraphClientException], Optional[dict]], None]:
9797
return self._subscriptionCallbackFunction
9898

99-
class ControllerWebClientRaw(object):
99+
def __repr__(self):
100+
return '<Subscription(%r, %r)>' % (self._subscriptionId, self._subscriptionCallbackFunction)
101+
102+
class BackgroundThread(object):
103+
_thread: threading.Thread # A thread to run the event loop
104+
_eventLoop: asyncio.AbstractEventLoop # Event loop that is running so that client can add coroutine
105+
106+
def __init__(self):
107+
# create a new event loop in a background thread
108+
self._eventLoop = asyncio.new_event_loop()
109+
self._thread = threading.Thread(target=self._RunEventLoop)
110+
self._thread.start()
111+
112+
def _RunEventLoop(self):
113+
# set the created loop as the current event loop for this thread
114+
asyncio.set_event_loop(self._eventLoop)
115+
self._eventLoop.run_forever()
116+
117+
def RunCoroutine(self, coroutine: Callable):
118+
"""Schedule a coroutine to run on the event loop from another thread
119+
"""
120+
return asyncio.run_coroutine_threadsafe(coroutine, self._eventLoop)
121+
122+
def __del__(self):
123+
self.Destroy()
100124

125+
def Destroy(self):
126+
if self._eventLoop.is_closed():
127+
return
128+
# cancel all tasks in the event loop
129+
for task in asyncio.all_tasks(loop=self._eventLoop):
130+
task.cancel()
131+
# run the loop briefly to let cancellations propagate
132+
self._eventLoop.call_soon_threadsafe(self._eventLoop.stop)
133+
self._thread.join()
134+
self._eventLoop.close()
135+
136+
class ControllerWebClientRaw(object):
101137
_baseurl = None # Base URL of the controller
102138
_username = None # Username to login with
103139
_password = None # Password to login with
@@ -106,9 +142,7 @@ class ControllerWebClientRaw(object):
106142
_session = None # Requests session object
107143
_webSocket: websockets.asyncio.client.ClientConnection # WebSocket used to connect to WebStack for subscriptions
108144
_subscriptions: dict[str, Subscription] # Dictionary that stores the subscriptionId(key) and the corresponding subscription(value)
109-
_eventLoopThread: threading.Thread # A thread to run the event loop
110-
_eventLoop: asyncio.AbstractEventLoop # Event loop that is running so that client can add coroutine(a subscription in this case)
111-
_webSocketLock: threading.Lock # Lock protecting initializing the WebSocket
145+
_subscriptionLock: threading.Lock # Lock protecting _webSocket and _subscriptions
112146

113147
def __init__(self, baseurl, username, password, locale=None, author=None, userAgent=None, additionalHeaders=None, unixEndpoint=None):
114148
self._baseurl = baseurl
@@ -117,11 +151,12 @@ def __init__(self, baseurl, username, password, locale=None, author=None, userAg
117151
self._headers = {}
118152
self._isok = True
119153

120-
self._eventLoop = None
121-
self._eventLoopThread = None
122154
self._webSocket = None
123-
self._webSocketLock = threading.Lock()
124155
self._subscriptions = {}
156+
self._subscriptionLock = threading.Lock()
157+
158+
# Create the background thread for async operations
159+
self._backgroundThread = BackgroundThread()
125160

126161
# Create session
127162
self._session = requests.Session()
@@ -162,7 +197,12 @@ def __del__(self):
162197

163198
def Destroy(self):
164199
self.SetDestroy()
165-
self._CloseEventLoop()
200+
if self._backgroundThread is not None:
201+
# make sure to stop subscriptions and close the websocket first
202+
self._backgroundThread.RunCoroutine(self._StopAllSubscriptions()).result()
203+
# next destroy the thread
204+
self._backgroundThread.Destroy()
205+
self._backgroundThread = None
166206

167207
def SetDestroy(self):
168208
self._isok = False
@@ -191,31 +231,6 @@ def SetUserAgent(self, userAgent=None):
191231
else:
192232
self._headers.pop('User-Agent', None)
193233

194-
def _InitializeEventLoopThread(self):
195-
# Create new event loop if _eventLoop is None, otherwise, reuse the existing one
196-
if self._eventLoop is None:
197-
self._eventLoop = asyncio.new_event_loop()
198-
if self._eventLoopThread is not None and self._eventLoopThread.is_alive():
199-
self._eventLoopThread.join()
200-
self._eventLoopThread = threading.Thread(target=self._RunEventLoop)
201-
self._eventLoopThread.start()
202-
203-
def _RunEventLoop(self):
204-
asyncio.set_event_loop(self._eventLoop)
205-
self._eventLoop.run_forever()
206-
207-
def _StopEventLoop(self):
208-
if self._eventLoop is not None:
209-
self._eventLoop.stop()
210-
211-
def _CloseEventLoop(self):
212-
if self._eventLoop is not None and self._eventLoop.is_running():
213-
self._eventLoop.call_soon_threadsafe(self._StopEventLoop)
214-
if self._eventLoopThread is not None and self._eventLoopThread.is_alive():
215-
self._eventLoopThread.join()
216-
if self._eventLoop is not None and not self._eventLoop.is_closed():
217-
self._eventLoop.close()
218-
219234
def Request(self, method, path, timeout=5, headers=None, **kwargs):
220235
if timeout < 1e-6:
221236
raise WebstackClientError(_('Timeout value (%s sec) is too small') % timeout)
@@ -355,6 +370,21 @@ def CallGraphAPI(self, query, variables=None, headers=None, timeout=5.0):
355370

356371
return content['data']
357372

373+
def _EnsureWebSocketConnection(self):
374+
if self._webSocket is None:
375+
# wait until the connection is established
376+
self._backgroundThread.RunCoroutine(self._OpenWebSocketConnection()).result()
377+
# start listening without blocking
378+
self._backgroundThread.RunCoroutine(self._ListenToWebSocket())
379+
380+
def _IsWebSocketConnectionOpen(self):
381+
return self._webSocket is not None
382+
383+
async def _CloseWebSocket(self):
384+
if self._webSocket is not None:
385+
await self._webSocket.close()
386+
self._webSocket = None
387+
358388
async def _OpenWebSocketConnection(self):
359389
authorization = self._session.auth.GetAuthorizationHeader()
360390

@@ -401,10 +431,8 @@ async def _OpenWebSocketConnection(self):
401431
async def _ListenToWebSocket(self):
402432
try:
403433
async for response in self._webSocket:
404-
# stop listening if there is no subscriptions
405-
# the connection will close automatically after breaking out the loop
406-
if len(self._subscriptions) == 0:
407-
await self._webSocket.close()
434+
# stop if stop is requested
435+
if not self._isok:
408436
break
409437

410438
# parse the result
@@ -434,33 +462,40 @@ async def _ListenToWebSocket(self):
434462
# raise an error, this should never happen
435463
raise ControllerGraphClientException(_('Unexpected server response, missing id: %s') % (response))
436464

437-
# select the right subscription
438-
subscriptionId = content['id']
439-
subscription = self._subscriptions.get(subscriptionId)
440-
if subscription is None:
441-
# subscriber is gone
442-
continue
443-
444-
# return if there is an error
445-
if 'payload' in content and 'errors' in content['payload'] and len(content['payload']['errors']) > 0:
446-
message = content['payload']['errors'][0].get('message', response)
447-
errorCode = None
448-
if 'extensions' in content['payload']['errors'][0]:
449-
errorCode = content['payload']['errors'][0]['extensions'].get('errorCode', None)
450-
subscription.GetSubscriptionCallbackFunction()(error=ControllerGraphClientException(message, content=content, errorCode=errorCode), response=None)
451-
continue
452-
453-
# return the payload
454-
subscription.GetSubscriptionCallbackFunction()(error=None, response=content.get('payload') or {})
465+
# reply back to subscribers
466+
with self._subscriptionLock:
467+
# select the right subscription
468+
subscriptionId = content['id']
469+
subscription = self._subscriptions.get(subscriptionId)
470+
if subscription is None:
471+
# subscriber is gone
472+
continue
473+
474+
# return if there is an error
475+
if 'payload' in content and 'errors' in content['payload'] and len(content['payload']['errors']) > 0:
476+
message = content['payload']['errors'][0].get('message', response)
477+
errorCode = None
478+
if 'extensions' in content['payload']['errors'][0]:
479+
errorCode = content['payload']['errors'][0]['extensions'].get('errorCode', None)
480+
subscription.GetSubscriptionCallbackFunction()(error=ControllerGraphClientException(message, content=content, errorCode=errorCode), response=None)
481+
continue
482+
483+
# return the payload
484+
subscription.GetSubscriptionCallbackFunction()(error=None, response=content.get('payload') or {})
455485

456486
except Exception as e:
457487
log.exception('caught WebSocket exception: %s', e)
458-
self._webSocket = None
459-
with self._webSocketLock:
460-
# send a message back to the caller using the callback function and drop all subscriptions
461-
for subscriptionId, subscription in self._subscriptions.items():
462-
subscription.GetSubscriptionCallbackFunction()(error=e, response=None)
463-
self._subscriptions.clear()
488+
with self._subscriptionLock:
489+
await self._StopAllSubscriptions(ControllerGraphClientException(_('Failed to listen to WebSocket: %s') % (e)))
490+
491+
async def _StopAllSubscriptions(self, error: Optional[ControllerGraphClientException] = None):
492+
# close the websocket
493+
await self._CloseWebSocket()
494+
# send a message back to the callers using the callback function and drop all subscriptions
495+
if error is not None:
496+
for subscriptionId, subscription in self._subscriptions.items():
497+
subscription.GetSubscriptionCallbackFunction()(error=error, response=None)
498+
self._subscriptions.clear()
464499

465500
def SubscribeGraphAPI(self, query: str, callbackFunction: Callable[[Optional[str], Optional[dict]], None], variables: Optional[dict] = None) -> Subscription:
466501
""" Subscribes to changes on Mujin controller.
@@ -470,49 +505,69 @@ def SubscribeGraphAPI(self, query: str, callbackFunction: Callable[[Optional[str
470505
variables (dict): variables that should be passed into the query if necessary
471506
callbackFunction (func): a callback function to process the response data that is received from the subscription
472507
"""
473-
# generate subscriptionId, an unique id to sent to the server so that we can have multiple subscriptions using the same WebSocket
508+
# create a new subscription
474509
subscriptionId = str(uuid.uuid4())
475510
subscription = Subscription(subscriptionId, callbackFunction)
476-
with self._webSocketLock:
477-
self._subscriptions[subscriptionId] = subscription
478-
if self._eventLoop is None or not self._eventLoop.is_running:
479-
self._InitializeEventLoopThread()
480511

481512
async def _Subscribe():
482-
# need to ensure only one thread is initializing the WebSocket
483-
with self._webSocketLock:
484-
# check if _webSocket exists
485-
if self._webSocket is None:
486-
await self._OpenWebSocketConnection()
487-
# create a coroutine that is specially used for listening to the WebSocket
488-
asyncio.run_coroutine_threadsafe(self._ListenToWebSocket(), self._eventLoop)
489-
513+
try:
490514
# start a new subscription on the WebSocket connection
491-
await self._webSocket.send(json.dumps({
492-
'id': subscriptionId,
515+
message = {
516+
'id': subscription.GetSubscriptionID(),
493517
'type': 'start',
494-
'payload': {'query': query, 'variables': variables or {}}
495-
}))
518+
'payload': { 'query': query }
519+
}
520+
if variables:
521+
message['payload']['variables'] = variables
522+
await self._webSocket.send(json.dumps(message))
523+
except Exception as e:
524+
log.exception('caught WebSocket exception: %s', e)
525+
await self._StopAllSubscriptions(ControllerGraphClientException(_('Failed to subscribe: %s') % (e)))
526+
527+
with self._subscriptionLock:
528+
# make sure the websocket connection is running
529+
self._EnsureWebSocketConnection()
530+
531+
# wait until the subscription is created
532+
self._backgroundThread.RunCoroutine(_Subscribe()).result()
533+
self._subscriptions[subscriptionId] = subscription
496534

497-
asyncio.run_coroutine_threadsafe(_Subscribe(), self._eventLoop)
498-
return subscription
535+
return subscription
499536

500537
def UnsubscribeGraphAPI(self, subscription: Subscription):
501538
""" Unsubscribes to Mujin controller.
502539
503540
Args:
504541
subscription (Subscription): the subscription that the user wants to unsubscribe
505542
"""
506-
async def _StopSubscription():
507-
subscriptionId = subscription.GetSubscriptionID()
508-
with self._webSocketLock:
543+
subscriptionId = subscription.GetSubscriptionID()
544+
545+
async def _Unsubscribe():
546+
try:
509547
# check if self._subscriptionIds has subscriptionId
510548
if subscriptionId in self._subscriptions:
511549
await self._webSocket.send(json.dumps({
512550
'id': subscriptionId,
513551
'type': 'stop'
514552
}))
515553
# remove subscription
516-
self._subscriptions.pop(subscription.GetSubscriptionID(), None)
554+
self._subscriptions.pop(subscriptionId, None)
517555

518-
asyncio.run_coroutine_threadsafe(_StopSubscription(), self._eventLoop)
556+
# close the websocket connection if no more subscribers are left
557+
if len(self._subscriptions) == 0:
558+
await self._CloseWebSocket()
559+
except Exception as e:
560+
log.exception('caught WebSocket exception: %s', e)
561+
await self._StopAllSubscriptions(ControllerGraphClientException(_('Failed to unsubscribe: %s') % (e)))
562+
563+
with self._subscriptionLock:
564+
# nothing to do if websocket is not established
565+
if not self._IsWebSocketConnectionOpen():
566+
return
567+
568+
# check if the subscription exists
569+
if subscription.GetSubscriptionID() not in self._subscriptions:
570+
raise ControllerGraphClientException(_('Unknown subscription %r') % (subscription))
571+
572+
# actually unsubscribe and wait until there is a result
573+
self._backgroundThread.RunCoroutine(_Unsubscribe()).result()

0 commit comments

Comments
 (0)