19
19
import threading
20
20
import traceback
21
21
import uuid
22
+ import copy
22
23
import websockets
23
24
from requests import auth as requests_auth
24
25
from requests import adapters as requests_adapters
34
35
from .unixsocketadapter import UnixSocketAdapter
35
36
36
37
import logging
38
+ logging .getLogger ('websockets' ).setLevel (logging .WARNING )
37
39
log = logging .getLogger (__name__ )
38
40
39
41
class JSONWebTokenAuth (requests_auth .AuthBase ):
@@ -203,18 +205,15 @@ def _RunEventLoop(self):
203
205
self ._eventLoop .run_forever ()
204
206
205
207
def _StopEventLoop (self ):
206
- if self ._eventLoop is None :
207
- return
208
- self ._eventLoop .stop ()
208
+ if self ._eventLoop is not None :
209
+ self ._eventLoop .stop ()
209
210
210
211
def _CloseEventLoop (self ):
211
- if self ._eventLoop is None :
212
- return
213
- if self ._eventLoop .is_running ():
212
+ if self ._eventLoop is not None and self ._eventLoop .is_running ():
214
213
self ._eventLoop .call_soon_threadsafe (self ._StopEventLoop )
215
- if self ._eventLoopThread .is_alive ():
214
+ if self ._eventLoopThread is not None and self . _eventLoopThread .is_alive ():
216
215
self ._eventLoopThread .join ()
217
- if not self ._eventLoop .is_closed ():
216
+ if self . _eventLoop is not None and not self ._eventLoop .is_closed ():
218
217
self ._eventLoop .close ()
219
218
220
219
def Request (self , method , path , timeout = 5 , headers = None , ** kwargs ):
@@ -371,11 +370,9 @@ async def _OpenWebSocketConnection(self):
371
370
uri = '%s://%s%s' % (webSocketScheme , parsedUrl .netloc , parsedUrl .path )
372
371
373
372
# prepare the headers
374
- headers = {
375
- 'Content-Type' : 'application/json' ,
376
- 'Accept' : 'application/json' ,
377
- 'X-CSRFToken' : 'token' ,
378
- }
373
+ headers = copy .deepcopy (self ._headers )
374
+ headers ['Content-Type' ] = 'application/json'
375
+ headers ['Accept' ] = 'application/json'
379
376
subprotocols = ['graphql-ws' ]
380
377
381
378
# decide on using unix socket or not
@@ -404,38 +401,58 @@ async def _OpenWebSocketConnection(self):
404
401
async def _ListenToWebSocket (self ):
405
402
try :
406
403
async for response in self ._webSocket :
407
- try :
408
- content = json .loads (response )
409
- except ValueError as e :
410
- log .exception ('caught exception parsing json response: %s: %s' , e , response )
411
-
412
- if content ['type' ] == 'connection_ack' :
413
- log .debug ('received connection_ack' )
414
- elif content ['type' ] == 'ka' :
415
- # received keep-alive "ka" message
416
- pass
417
- else :
418
- # raise any error returned
419
- if content is not None and 'payload' in content and 'errors' in content ['payload' ] and len (content ['payload' ]['errors' ]) > 0 :
420
- message = content ['payload' ]['errors' ][0 ].get ('message' , response )
421
- errorCode = None
422
- if 'extensions' in content ['payload' ]['errors' ][0 ]:
423
- errorCode = content ['payload' ]['errors' ][0 ]['extensions' ].get ('errorCode' , None )
424
- raise ControllerGraphClientException (message , content = content , errorCode = errorCode )
425
-
426
- if content is None or 'payload' not in content :
427
- raise ControllerGraphClientException (_ ('Unexpected server response: %s' ) % (response ))
428
-
429
- # parse to get the subscriptionId so that we can call the correct callback function
430
- subscriptionId = content .get ('id' )
431
- if subscriptionId in self ._subscriptions :
432
- subscription = self ._subscriptions [subscriptionId ]
433
- subscription .GetSubscriptionCallbackFunction ()(error = None , response = content .get ('payload' ) or {})
434
-
435
- # stop listening if there is no subscriptions and the connection will close automatically after breaking out the loop
404
+ # stop listening if there is no subscriptions
405
+ # the connection will close automatically after breaking out the loop
436
406
if len (self ._subscriptions ) == 0 :
437
407
await self ._webSocket .close ()
438
408
break
409
+
410
+ # parse the result
411
+ content = None
412
+ if len (response ) > 0 :
413
+ try :
414
+ content = json .loads (response )
415
+ except ValueError as e :
416
+ log .exception ('caught exception parsing json response: %s: %s' , e , response )
417
+
418
+ # sanity checks
419
+ if content is None or 'type' not in content :
420
+ # raise an error, this should never happen
421
+ raise ControllerGraphClientException (_ ('Unexpected server response: %s' ) % (response ))
422
+
423
+ # handle control messages
424
+ contentType = content ['type' ]
425
+ if contentType == 'connection_ack' :
426
+ log .debug ('received connection_ack' )
427
+ continue
428
+ if contentType == 'ka' :
429
+ # received keep-alive "ka" message
430
+ continue
431
+
432
+ # sanity checks
433
+ if 'id' not in content :
434
+ # raise an error, this should never happen
435
+ raise ControllerGraphClientException (_ ('Unexpected server response, missing id: %s' ) % (response ))
436
+
437
+ # select the right subscription
438
+ subscriptionId = content ['id' ]
439
+ subscription = self ._subscriptions .get (subscriptionId )
440
+ if subscription is None :
441
+ # subscriber is gone
442
+ continue
443
+
444
+ # return if there is an error
445
+ if 'payload' in content and 'errors' in content ['payload' ] and len (content ['payload' ]['errors' ]) > 0 :
446
+ message = content ['payload' ]['errors' ][0 ].get ('message' , response )
447
+ errorCode = None
448
+ if 'extensions' in content ['payload' ]['errors' ][0 ]:
449
+ errorCode = content ['payload' ]['errors' ][0 ]['extensions' ].get ('errorCode' , None )
450
+ subscription .GetSubscriptionCallbackFunction ()(error = ControllerGraphClientException (message , content = content , errorCode = errorCode ), response = None )
451
+ continue
452
+
453
+ # return the payload
454
+ subscription .GetSubscriptionCallbackFunction ()(error = None , response = content .get ('payload' ) or {})
455
+
439
456
except Exception as e :
440
457
log .exception ('caught WebSocket exception: %s' , e )
441
458
self ._webSocket = None
@@ -456,8 +473,8 @@ def SubscribeGraphAPI(self, query: str, callbackFunction: Callable[[Optional[str
456
473
# generate subscriptionId, an unique id to sent to the server so that we can have multiple subscriptions using the same WebSocket
457
474
subscriptionId = str (uuid .uuid4 ())
458
475
subscription = Subscription (subscriptionId , callbackFunction )
459
- self ._subscriptions [subscriptionId ] = subscription
460
476
with self ._webSocketLock :
477
+ self ._subscriptions [subscriptionId ] = subscription
461
478
if self ._eventLoop is None or not self ._eventLoop .is_running :
462
479
self ._InitializeEventLoopThread ()
463
480
0 commit comments