diff --git a/jupyter_kernel_client/tests/test_utils.py b/jupyter_kernel_client/tests/test_utils.py new file mode 100644 index 0000000..86e3d03 --- /dev/null +++ b/jupyter_kernel_client/tests/test_utils.py @@ -0,0 +1,133 @@ +# Copyright (c) 2023-2024 Datalayer, Inc. +# Copyright (c) 2025 Google +# +# BSD 3-Clause License + +import json +from jupyter_kernel_client.utils import serialize_msg_to_ws_json, serialize_msg_to_ws_default, deserialize_msg_from_ws_default, serialize_msg_to_ws_v1, deserialize_msg_from_ws_v1 + +def test_serialize_msg_to_ws_json(): + src_msg = { + "header": { + "msg_id": "c0a8012e-1f3b-4c3a-9c77-123456789abc", + "username": "test-user", + "session": "1c2d3e4f-aaaa-bbbb-cccc-0123456789ab", + "date": "2026-01-26T12:34:56.789Z", + "msg_type": "execute_request", + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": "print('hello world')", + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": True, + "stop_on_error": True, + }, + "buffers": [], + } + + expected_output = json.dumps(src_msg) + serialized_msg = serialize_msg_to_ws_json(src_msg) + assert expected_output == serialized_msg + +def test_serialize_and_deserialize_msg_to_ws_default(): + src_msg = { + "header": { + "msg_id": "7f3b9d8d-2c6c-4b71-9b64-111111111111", + "username": "test-user", + "session": "1c2d3e4f-aaaa-bbbb-cccc-0123456789ab", + "date": "2026-01-26T12:35:10.123Z", + "msg_type": "comm_msg", + "version": "5.3", + }, + "parent_header": {}, + "metadata": { + "buffer_paths": [ + ["content", "data", "payload"], + ["content", "data", "extra_blob"] + ] + }, + "content": { + "comm_id": "abc123abc123abc123abc123abc123ab", + "target_name": "my_binary_comm", + "data": { + "dtype": "uint8", + "shape": [4], + "payload": None, + "extra_blob": None, + "note": "payload + extra_blob come from buffers", + }, + }, + "buffers": [ + b"\x01\x02\x03\x04", + b"\xde\xad\xbe\xef\x00\xff", + ], + } + + serialized_msg = serialize_msg_to_ws_default(src_msg) + bufn = int.from_bytes(serialized_msg[0:4], byteorder="big") + buffers = src_msg['buffers'] or [] + + for i in range(1, bufn): + # ignore the json message for now, it's tested the deserialized msg + start = (i+1) * 4 + offset = int.from_bytes(serialized_msg[start:start+4], byteorder="big") + buf = buffers[i-1] + serialized_buf_val = serialized_msg[offset:offset+len(buf)] + assert serialized_buf_val == buf + + deserialized_msg = deserialize_msg_from_ws_default(serialized_msg) + assert deserialized_msg == src_msg + +def test_serialize_and_deserialize_msg_to_ws_v1(): + def pack(obj) -> bytes: + return json.dumps(obj, separators=(",", ":"), sort_keys=True).encode("utf-8") + + src_msg = { + "channel": "shell", + "header": { + "msg_id": "7f3b9d8d-2c6c-4b71-9b64-111111111111", + "username": "test-user", + "session": "1c2d3e4f-aaaa-bbbb-cccc-0123456789ab", + "date": "2026-01-26T12:35:10.123Z", + "msg_type": "comm_msg", + "version": "5.3", + }, + "parent_header": {}, + "metadata": { + "buffer_paths": [ + ["content", "data", "payload"], + ["content", "data", "extra_blob"] + ] + }, + "content": { + "comm_id": "abc123abc123abc123abc123abc123ab", + "target_name": "my_binary_comm", + "data": { + "dtype": "uint8", + "shape": [4], + "payload": None, + "extra_blob": None, + "note": "payload + extra_blob come from buffers", + }, + }, + "buffers": [ + b"\x01\x02\x03\x04", + b"\xde\xad\xbe\xef\x00\xff", + ], + } + + serialized_msg = serialize_msg_to_ws_v1(src_msg, channel="shell", pack=pack) + # construct the msg lists for the serialized msg + offset = int.from_bytes(serialized_msg[:8], byteorder="little") + offsets = [ + int.from_bytes(serialized_msg[8 * (i + 1) : 8 * (i + 2)], byteorder="little") for i in range(offset) + ] + serialized_list = [serialized_msg[offsets[i]:offsets[i+1]] for i in range(1, offset-1)] + + _, deserialized_msg = deserialize_msg_from_ws_v1(serialized_msg) + assert serialized_list == deserialized_msg + diff --git a/jupyter_kernel_client/utils.py b/jupyter_kernel_client/utils.py index ba1a365..34c203e 100644 --- a/jupyter_kernel_client/utils.py +++ b/jupyter_kernel_client/utils.py @@ -53,14 +53,80 @@ def deserialize_msg_from_ws_v1(ws_msg): return channel, msg_list -def serialize_msg_to_ws_json(msg): - """Serialize as JSON (for Colab).""" - return json.dumps(msg, default=str) +def serialize_msg_to_ws_default(msg): + """Serialize a message using the default protocol.""" + offsets = [] + buffers = [] + + msg_copy = dict(msg) + msg_copy['header']['date'] = str(msg_copy['header']['date']) + orig_buffers = msg_copy.pop("buffers", []) + json_bytes = json.dumps(msg_copy).encode("utf-8") + buffers.append(json_bytes) + + for b in orig_buffers: + buffers.append(b) + + nbufs = len(buffers) + offsets.append(4 * (nbufs + 1)) + + for i in range(0, nbufs - 1): + offsets.append(offsets[-1] + len(buffers[i])) + + total_size = offsets[-1] + len(buffers[-1]) + msg_buf = bytearray(total_size) + + msg_buf[0:4] = nbufs.to_bytes(4, byteorder="big") + + for i, off in enumerate(offsets): + start = 4 * (i + 1) + msg_buf[start:start+4] = off.to_bytes(4, byteorder="big") + + for i, b in enumerate(buffers): + start = offsets[i] + msg_buf[start:start+len(b)] = b + + return bytes(msg_buf) -def deserialize_msg_from_ws_json(ws_msg): - """Deserialize as JSON (for Colab).""" - return json.loads(ws_msg.encode('utf-8')) +def deserialize_msg_from_ws_default(ws_msg): + """Deserialize a message using the default protocol.""" + if isinstance(ws_msg, str): + return json.loads(ws_msg.encode('utf-8')) + else: + nbufs = int.from_bytes(ws_msg[:4], byteorder="big") + offsets = [] + if nbufs < 2: + raise ValueError("unsupported number of buffers") + + for i in range(nbufs): + start = 4 * (i + 1) + off = int.from_bytes(ws_msg[start:start+4], byteorder="big") + offsets.append(off) + + json_start = offsets[0] + json_stop = offsets[1] + + if not (0 <= json_start <= json_stop <= len(ws_msg)): + raise ValueError("Invalid JSON offsets") + + json_bytes = ws_msg[json_start:json_stop] + msg = json.loads(json_bytes.decode("utf-8")) + msg["buffers"] = [] + for i in range(1, nbufs): + start = offsets[i] + stop = offsets[i+1] if (i+1) < len(offsets) else len(ws_msg) + + if not (0 <= start <= stop <= len(ws_msg)): + raise ValueError(f"Invalid buffer offsets for chunk {i}") + + msg["buffers"].append(ws_msg[start:stop]) + + return msg + +def serialize_msg_to_ws_json(msg): + """Serialize a default protocol with no buffers.""" + return json.dumps(msg, default=str) def url_path_join(*pieces: str) -> str: """Join components of url into a relative url diff --git a/jupyter_kernel_client/wsclient.py b/jupyter_kernel_client/wsclient.py index a561dec..8898250 100644 --- a/jupyter_kernel_client/wsclient.py +++ b/jupyter_kernel_client/wsclient.py @@ -31,7 +31,7 @@ from jupyter_kernel_client.constants import REQUEST_TIMEOUT from jupyter_kernel_client.log import get_logger -from jupyter_kernel_client.utils import deserialize_msg_from_ws_v1, serialize_msg_to_ws_v1, deserialize_msg_from_ws_json, serialize_msg_to_ws_json +from jupyter_kernel_client.utils import deserialize_msg_from_ws_v1, serialize_msg_to_ws_v1, deserialize_msg_from_ws_default, serialize_msg_to_ws_json, serialize_msg_to_ws_default class JupyterSubprotocol(Enum): @@ -47,9 +47,6 @@ class JupyterSubprotocol(Enum): # https://jupyter-server.readthedocs.io/en/latest/developers/websocket-protocols.html#v1-kernel-websocket-jupyter-org-protocol V1 = 1 - # JSON is specifically for Colab (until we modify this library to actually have support for DEFAULT) - JSON = 2 - class WSSession(Session): """WebSocket session.""" @@ -252,13 +249,22 @@ def send( # type:ignore[override] to_send = self.serialize(msg) to_send.extend(buffers) - if self.subprotocol == JupyterSubprotocol.JSON: - stream.send_text(serialize_msg_to_ws_json(msg)) - elif self.subprotocol == JupyterSubprotocol.V1: + if self.subprotocol == JupyterSubprotocol.V1: stream.send_bytes(serialize_msg_to_ws_v1(to_send, channel)) else: - # V0 / DEFAULT is currently unsupported - raise ValueError("JupyterSubprotocol.DEFAULT is currently unsupported.") + # The Default protocol is a bytearray with a header pointing to + # offsets where buffers are appended. + # + # Buffers are namely added for cases such as comm messages. + # In the case of the common message without a buffers list, the + # headers will always be '\x00\x00\x00\x01\x00\x00\x00\x08'. + # Since this is constant it might as well not be included, which is + # what Jupyter is doing with the default protocol. + # [server code found here](https://github.com/jupyter-server/jupyter_server/blob/main/jupyter_server/services/kernels/connection/channels.py#L445-L464) + if 'buffers' in msg and len(msg['buffers']) > 0: + stream.send_bytes(serialize_msg_to_ws_default(msg)) + else: + stream.send_text(serialize_msg_to_ws_json(msg)) self.log.debug("WSSession.send\n%s\n%s\n%s", msg, to_send, buffers) @@ -1245,18 +1251,18 @@ def _on_close(self, _: websocket.WebSocket, close_status_code: t.Any, close_msg: self.connection_ready.clear() def _on_message(self, s: websocket.WebSocket, message: bytes) -> None: - if self._subprotocol == JupyterSubprotocol.JSON: - deserialize_msg = deserialize_msg_from_ws_json(message) + if self._subprotocol == JupyterSubprotocol.DEFAULT: + deserialize_msg = deserialize_msg_from_ws_default(message) channel = deserialize_msg['channel'] elif self._subprotocol == JupyterSubprotocol.V1: channel, msg_list = deserialize_msg_from_ws_v1(message) deserialize_msg = self.session.deserialize(msg_list) else: - raise ValueError("JupyterSubprotocol.DEFAULT is unsupported.") + raise ValueError("unsupported protocol.") self.log.debug( "Received message on channel: {channel}, msg_id: {msg_id}, msg_type: {msg_type}".format( - channel=channel, + channel, **(deserialize_msg or {"msg_id": "null", "msg_type": "null"}), ) )