Skip to content

Commit 37cd36e

Browse files
author
Barkin Simsek
committed
Code review changes.
1 parent d431b37 commit 37cd36e

File tree

1 file changed

+60
-43
lines changed

1 file changed

+60
-43
lines changed

python/mujinwebstackclient/controllerwebclientraw.py

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import threading
2020
import traceback
2121
import uuid
22+
import copy
2223
import websockets
2324
from requests import auth as requests_auth
2425
from requests import adapters as requests_adapters
@@ -34,6 +35,7 @@
3435
from .unixsocketadapter import UnixSocketAdapter
3536

3637
import logging
38+
logging.getLogger('websockets').setLevel(logging.WARNING)
3739
log = logging.getLogger(__name__)
3840

3941
class JSONWebTokenAuth(requests_auth.AuthBase):
@@ -203,18 +205,15 @@ def _RunEventLoop(self):
203205
self._eventLoop.run_forever()
204206

205207
def _StopEventLoop(self):
206-
if self._eventLoop is None:
207-
return
208-
self._eventLoop.stop()
208+
if self._eventLoop is not None:
209+
self._eventLoop.stop()
209210

210211
def _CloseEventLoop(self):
211-
if self._eventLoop is None:
212-
return
213-
if self._eventLoop.is_running():
212+
if self._eventLoop is not None and self._eventLoop.is_running():
214213
self._eventLoop.call_soon_threadsafe(self._StopEventLoop)
215-
if self._eventLoopThread.is_alive():
214+
if self._eventLoopThread is not None and self._eventLoopThread.is_alive():
216215
self._eventLoopThread.join()
217-
if not self._eventLoop.is_closed():
216+
if self._eventLoop is not None and not self._eventLoop.is_closed():
218217
self._eventLoop.close()
219218

220219
def Request(self, method, path, timeout=5, headers=None, **kwargs):
@@ -371,11 +370,9 @@ async def _OpenWebSocketConnection(self):
371370
uri = '%s://%s%s' % (webSocketScheme, parsedUrl.netloc, parsedUrl.path)
372371

373372
# prepare the headers
374-
headers = {
375-
'Content-Type': 'application/json',
376-
'Accept': 'application/json',
377-
'X-CSRFToken': 'token',
378-
}
373+
headers = copy.deepcopy(self._headers)
374+
headers['Content-Type'] = 'application/json'
375+
headers['Accept'] = 'application/json'
379376
subprotocols = ['graphql-ws']
380377

381378
# decide on using unix socket or not
@@ -404,38 +401,58 @@ async def _OpenWebSocketConnection(self):
404401
async def _ListenToWebSocket(self):
405402
try:
406403
async for response in self._webSocket:
407-
try:
408-
content = json.loads(response)
409-
except ValueError as e:
410-
log.exception('caught exception parsing json response: %s: %s', e, response)
411-
412-
if content['type'] == 'connection_ack':
413-
log.debug('received connection_ack')
414-
elif content['type'] == 'ka':
415-
# received keep-alive "ka" message
416-
pass
417-
else:
418-
# raise any error returned
419-
if content is not None and 'payload' in content and 'errors' in content['payload'] and len(content['payload']['errors']) > 0:
420-
message = content['payload']['errors'][0].get('message', response)
421-
errorCode = None
422-
if 'extensions' in content['payload']['errors'][0]:
423-
errorCode = content['payload']['errors'][0]['extensions'].get('errorCode', None)
424-
raise ControllerGraphClientException(message, content=content, errorCode=errorCode)
425-
426-
if content is None or 'payload' not in content:
427-
raise ControllerGraphClientException(_('Unexpected server response: %s') % (response))
428-
429-
# parse to get the subscriptionId so that we can call the correct callback function
430-
subscriptionId = content.get('id')
431-
if subscriptionId in self._subscriptions:
432-
subscription = self._subscriptions[subscriptionId]
433-
subscription.GetSubscriptionCallbackFunction()(error=None, response=content.get('payload') or {})
434-
435-
# stop listening if there is no subscriptions and the connection will close automatically after breaking out the loop
404+
# stop listening if there is no subscriptions
405+
# the connection will close automatically after breaking out the loop
436406
if len(self._subscriptions) == 0:
437407
await self._webSocket.close()
438408
break
409+
410+
# parse the result
411+
content = None
412+
if len(response) > 0:
413+
try:
414+
content = json.loads(response)
415+
except ValueError as e:
416+
log.exception('caught exception parsing json response: %s: %s', e, response)
417+
418+
# sanity checks
419+
if content is None or 'type' not in content:
420+
# raise an error, this should never happen
421+
raise ControllerGraphClientException(_('Unexpected server response: %s') % (response))
422+
423+
# handle control messages
424+
contentType = content['type']
425+
if contentType == 'connection_ack':
426+
log.debug('received connection_ack')
427+
continue
428+
if contentType == 'ka':
429+
# received keep-alive "ka" message
430+
continue
431+
432+
# sanity checks
433+
if 'id' not in content:
434+
# raise an error, this should never happen
435+
raise ControllerGraphClientException(_('Unexpected server response, missing id: %s') % (response))
436+
437+
# select the right subscription
438+
subscriptionId = content['id']
439+
subscription = self._subscriptions.get(subscriptionId)
440+
if subscription is None:
441+
# subscriber is gone
442+
continue
443+
444+
# return if there is an error
445+
if 'payload' in content and 'errors' in content['payload'] and len(content['payload']['errors']) > 0:
446+
message = content['payload']['errors'][0].get('message', response)
447+
errorCode = None
448+
if 'extensions' in content['payload']['errors'][0]:
449+
errorCode = content['payload']['errors'][0]['extensions'].get('errorCode', None)
450+
subscription.GetSubscriptionCallbackFunction()(error=ControllerGraphClientException(message, content=content, errorCode=errorCode), response=None)
451+
continue
452+
453+
# return the payload
454+
subscription.GetSubscriptionCallbackFunction()(error=None, response=content.get('payload') or {})
455+
439456
except Exception as e:
440457
log.exception('caught WebSocket exception: %s', e)
441458
self._webSocket = None
@@ -456,8 +473,8 @@ def SubscribeGraphAPI(self, query: str, callbackFunction: Callable[[Optional[str
456473
# generate subscriptionId, an unique id to sent to the server so that we can have multiple subscriptions using the same WebSocket
457474
subscriptionId = str(uuid.uuid4())
458475
subscription = Subscription(subscriptionId, callbackFunction)
459-
self._subscriptions[subscriptionId] = subscription
460476
with self._webSocketLock:
477+
self._subscriptions[subscriptionId] = subscription
461478
if self._eventLoop is None or not self._eventLoop.is_running:
462479
self._InitializeEventLoopThread()
463480

0 commit comments

Comments
 (0)