From b01e3d898e3455432edf1ebae99bc40248392804 Mon Sep 17 00:00:00 2001 From: Mahad Munir Date: Tue, 4 Nov 2025 12:38:06 +0500 Subject: [PATCH] send all websocket subprotocols if serializer not specified --- tests/interop_test.py | 3 ++- xconn/async_client.py | 2 +- xconn/client.py | 2 +- xconn/helpers.py | 3 +++ xconn/joiner.py | 38 ++++++++++++++++++++++---------------- xconn/transports.py | 6 ++++++ 6 files changed, 35 insertions(+), 19 deletions(-) diff --git a/tests/interop_test.py b/tests/interop_test.py index 065b561..959762a 100644 --- a/tests/interop_test.py +++ b/tests/interop_test.py @@ -7,8 +7,9 @@ PROCEDURE_ADD = "io.xconn.backend.add2" -SERIALIZERS = [serializers.JSONSerializer(), serializers.CBORSerializer(), serializers.MsgPackSerializer()] +SERIALIZERS = [None, serializers.JSONSerializer(), serializers.CBORSerializer(), serializers.MsgPackSerializer()] AUTHENTICATORS = [ + None, auth.AnonymousAuthenticator(""), auth.TicketAuthenticator("ticket-user", "ticket-pass", {}), auth.WAMPCRAAuthenticator("wamp-cra-user", "cra-secret", {}), diff --git a/xconn/async_client.py b/xconn/async_client.py index 8893226..d95aab1 100644 --- a/xconn/async_client.py +++ b/xconn/async_client.py @@ -12,7 +12,7 @@ class AsyncClient: def __init__( self, authenticator: auth.IClientAuthenticator = auth.AnonymousAuthenticator(""), - serializer: serializers.Serializer = serializers.JSONSerializer(), + serializer: serializers.Serializer = None, ws_config: types.WebsocketConfig = types.WebsocketConfig(), ): self._authenticator = authenticator diff --git a/xconn/client.py b/xconn/client.py index d08dc26..ba42a5d 100644 --- a/xconn/client.py +++ b/xconn/client.py @@ -12,7 +12,7 @@ class Client: def __init__( self, authenticator: auth.IClientAuthenticator = auth.AnonymousAuthenticator(""), - serializer: serializers.Serializer = serializers.JSONSerializer(), + serializer: serializers.Serializer = None, config: types.TransportConfig = types.TransportConfig(), ): self._authenticator = authenticator diff --git a/xconn/helpers.py b/xconn/helpers.py index bf9c4b0..4d7eb02 100644 --- a/xconn/helpers.py +++ b/xconn/helpers.py @@ -20,6 +20,9 @@ SERIALIZER_TYPE_CAPNPROTO = 14 +WS_SUBPROTOCOLS = [CBOR_SUBPROTOCOL, MSGPACK_SUBPROTOCOL, JSON_SUBPROTOCOL] +if _CAPNP_AVAILABLE: + WS_SUBPROTOCOLS.append(CAPNPROTO_SUBPROTOCOL) def get_ws_subprotocol(serializer: serializers.Serializer): if isinstance(serializer, serializers.JSONSerializer): diff --git a/xconn/joiner.py b/xconn/joiner.py index fae5ea8..bb97423 100644 --- a/xconn/joiner.py +++ b/xconn/joiner.py @@ -9,7 +9,7 @@ class WebsocketsJoiner: def __init__( self, authenticator: auth.IClientAuthenticator = None, - serializer: serializers.Serializer = serializers.JSONSerializer(), + serializer: serializers.Serializer = None, ws_config: types.WebsocketConfig = types.WebsocketConfig(), ): self._authenticator = authenticator @@ -17,11 +17,14 @@ def __init__( self._ws_config = ws_config def join(self, uri: str, realm: str) -> types.BaseSession: - transport = WebSocketTransport.connect( - uri, - subprotocols=[helpers.get_ws_subprotocol(serializer=self._serializer)], - config=self._ws_config, - ) + if self._serializer is None: + subprotocols = helpers.WS_SUBPROTOCOLS + else: + subprotocols = [helpers.get_ws_subprotocol(serializer=self._serializer)] + + transport = WebSocketTransport.connect(uri, subprotocols=subprotocols, config=self._ws_config) + if self._serializer is None: + self._serializer = helpers.get_serializer(transport.subprotocol()) j: Joiner = joiner.Joiner(realm, serializer=self._serializer, authenticator=self._authenticator) transport.write(j.send_hello()) @@ -39,7 +42,7 @@ class AsyncWebsocketsJoiner: def __init__( self, authenticator: auth.IClientAuthenticator = None, - serializer: serializers.Serializer = serializers.JSONSerializer(), + serializer: serializers.Serializer = None, ws_config: types.WebsocketConfig = types.WebsocketConfig(), ): self._ws_config = ws_config @@ -47,11 +50,14 @@ def __init__( self._serializer = serializer async def join(self, uri: str, realm: str) -> types.AsyncBaseSession: - transport = await AsyncWebSocketTransport.connect( - uri, - subprotocols=[helpers.get_ws_subprotocol(serializer=self._serializer)], - config=self._ws_config, - ) + if self._serializer is None: + subprotocols = helpers.WS_SUBPROTOCOLS + else: + subprotocols = [helpers.get_ws_subprotocol(serializer=self._serializer)] + + transport = await AsyncWebSocketTransport.connect(uri, subprotocols=subprotocols, config=self._ws_config) + if self._serializer is None: + self._serializer = helpers.get_serializer(transport.subprotocol()) j: Joiner = joiner.Joiner(realm, serializer=self._serializer, authenticator=self._authenticator) await transport.write(j.send_hello()) @@ -69,11 +75,11 @@ class RawSocketJoiner: def __init__( self, authenticator: auth.IClientAuthenticator = None, - serializer: serializers.Serializer = serializers.JSONSerializer(), + serializer: serializers.Serializer = serializers.CBORSerializer(), config: types.TransportConfig = types.TransportConfig(), ): self._authenticator = authenticator - self._serializer = serializer + self._serializer = serializer if serializer is not None else serializers.CBORSerializer() self._config = config def join(self, uri: str, realm: str) -> types.BaseSession: @@ -95,11 +101,11 @@ class AsyncRawSocketJoiner: def __init__( self, authenticator: auth.IClientAuthenticator = None, - serializer: serializers.Serializer = serializers.JSONSerializer(), + serializer: serializers.Serializer = serializers.CBORSerializer(), config: types.TransportConfig = types.TransportConfig(), ): self._authenticator = authenticator - self._serializer = serializer + self._serializer = serializer if serializer is not None else serializers.CBORSerializer() self._config = config async def join(self, uri: str, realm: str) -> types.AsyncBaseSession: diff --git a/xconn/transports.py b/xconn/transports.py index c0812ed..61d504c 100644 --- a/xconn/transports.py +++ b/xconn/transports.py @@ -304,6 +304,9 @@ def ping(self, timeout: int = 10) -> float: received_at = time.time() * 1000 return received_at - created_at + def subprotocol(self): + return self._websocket.subprotocol + class AsyncWebSocketTransport(IAsyncTransport): def __init__(self, websocket: ClientConnection): @@ -355,3 +358,6 @@ async def ping(self, timeout: int = 10) -> float: await asyncio.wait_for(awaitable, timeout) received_at = time.time() * 1000 return received_at - created_at + + def subprotocol(self): + return self._websocket.subprotocol