Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions jupyter_kernel_client/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) 2023-2024 Datalayer, Inc.
# Copyright (c) 2025 Google
#
# BSD 3-Clause License

import json
from jupyter_kernel_client.utils import serialize_msg_to_ws_json, serialize_msg_to_ws_default, deserialize_msg_from_ws_default, serialize_msg_to_ws_v1, deserialize_msg_from_ws_v1

def test_serialize_msg_to_ws_json():
src_msg = {
"header": {
"msg_id": "c0a8012e-1f3b-4c3a-9c77-123456789abc",
"username": "test-user",
"session": "1c2d3e4f-aaaa-bbbb-cccc-0123456789ab",
"date": "2026-01-26T12:34:56.789Z",
"msg_type": "execute_request",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": "print('hello world')",
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": True,
"stop_on_error": True,
},
"buffers": [],
}

expected_output = json.dumps(src_msg)
serialized_msg = serialize_msg_to_ws_json(src_msg)
assert expected_output == serialized_msg

def test_serialize_and_deserialize_msg_to_ws_default():
src_msg = {
"header": {
"msg_id": "7f3b9d8d-2c6c-4b71-9b64-111111111111",
"username": "test-user",
"session": "1c2d3e4f-aaaa-bbbb-cccc-0123456789ab",
"date": "2026-01-26T12:35:10.123Z",
"msg_type": "comm_msg",
"version": "5.3",
},
"parent_header": {},
"metadata": {
"buffer_paths": [
["content", "data", "payload"],
["content", "data", "extra_blob"]
]
},
"content": {
"comm_id": "abc123abc123abc123abc123abc123ab",
"target_name": "my_binary_comm",
"data": {
"dtype": "uint8",
"shape": [4],
"payload": None,
"extra_blob": None,
"note": "payload + extra_blob come from buffers",
},
},
"buffers": [
b"\x01\x02\x03\x04",
b"\xde\xad\xbe\xef\x00\xff",
],
}

serialized_msg = serialize_msg_to_ws_default(src_msg)
bufn = int.from_bytes(serialized_msg[0:4], byteorder="big")
buffers = src_msg['buffers'] or []

for i in range(1, bufn):
# ignore the json message for now, it's tested the deserialized msg
start = (i+1) * 4
offset = int.from_bytes(serialized_msg[start:start+4], byteorder="big")
buf = buffers[i-1]
serialized_buf_val = serialized_msg[offset:offset+len(buf)]
assert serialized_buf_val == buf

deserialized_msg = deserialize_msg_from_ws_default(serialized_msg)
assert deserialized_msg == src_msg

def test_serialize_and_deserialize_msg_to_ws_v1():
def pack(obj) -> bytes:
return json.dumps(obj, separators=(",", ":"), sort_keys=True).encode("utf-8")

src_msg = {
"channel": "shell",
"header": {
"msg_id": "7f3b9d8d-2c6c-4b71-9b64-111111111111",
"username": "test-user",
"session": "1c2d3e4f-aaaa-bbbb-cccc-0123456789ab",
"date": "2026-01-26T12:35:10.123Z",
"msg_type": "comm_msg",
"version": "5.3",
},
"parent_header": {},
"metadata": {
"buffer_paths": [
["content", "data", "payload"],
["content", "data", "extra_blob"]
]
},
"content": {
"comm_id": "abc123abc123abc123abc123abc123ab",
"target_name": "my_binary_comm",
"data": {
"dtype": "uint8",
"shape": [4],
"payload": None,
"extra_blob": None,
"note": "payload + extra_blob come from buffers",
},
},
"buffers": [
b"\x01\x02\x03\x04",
b"\xde\xad\xbe\xef\x00\xff",
],
}

serialized_msg = serialize_msg_to_ws_v1(src_msg, channel="shell", pack=pack)
# construct the msg lists for the serialized msg
offset = int.from_bytes(serialized_msg[:8], byteorder="little")
offsets = [
int.from_bytes(serialized_msg[8 * (i + 1) : 8 * (i + 2)], byteorder="little") for i in range(offset)
]
serialized_list = [serialized_msg[offsets[i]:offsets[i+1]] for i in range(1, offset-1)]

_, deserialized_msg = deserialize_msg_from_ws_v1(serialized_msg)
assert serialized_list == deserialized_msg

78 changes: 72 additions & 6 deletions jupyter_kernel_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,80 @@ def deserialize_msg_from_ws_v1(ws_msg):
return channel, msg_list


def serialize_msg_to_ws_json(msg):
"""Serialize as JSON (for Colab)."""
return json.dumps(msg, default=str)
def serialize_msg_to_ws_default(msg):
"""Serialize a message using the default protocol."""
offsets = []
buffers = []

msg_copy = dict(msg)
msg_copy['header']['date'] = str(msg_copy['header']['date'])
orig_buffers = msg_copy.pop("buffers", [])
json_bytes = json.dumps(msg_copy).encode("utf-8")
buffers.append(json_bytes)

for b in orig_buffers:
buffers.append(b)

nbufs = len(buffers)
offsets.append(4 * (nbufs + 1))

for i in range(0, nbufs - 1):
offsets.append(offsets[-1] + len(buffers[i]))

total_size = offsets[-1] + len(buffers[-1])
msg_buf = bytearray(total_size)

msg_buf[0:4] = nbufs.to_bytes(4, byteorder="big")

for i, off in enumerate(offsets):
start = 4 * (i + 1)
msg_buf[start:start+4] = off.to_bytes(4, byteorder="big")

for i, b in enumerate(buffers):
start = offsets[i]
msg_buf[start:start+len(b)] = b

return bytes(msg_buf)


def deserialize_msg_from_ws_json(ws_msg):
"""Deserialize as JSON (for Colab)."""
return json.loads(ws_msg.encode('utf-8'))
def deserialize_msg_from_ws_default(ws_msg):
"""Deserialize a message using the default protocol."""
if isinstance(ws_msg, str):
return json.loads(ws_msg.encode('utf-8'))
else:
nbufs = int.from_bytes(ws_msg[:4], byteorder="big")
offsets = []
if nbufs < 2:
raise ValueError("unsupported number of buffers")

for i in range(nbufs):
start = 4 * (i + 1)
off = int.from_bytes(ws_msg[start:start+4], byteorder="big")
offsets.append(off)

json_start = offsets[0]
json_stop = offsets[1]

if not (0 <= json_start <= json_stop <= len(ws_msg)):
raise ValueError("Invalid JSON offsets")

json_bytes = ws_msg[json_start:json_stop]
msg = json.loads(json_bytes.decode("utf-8"))
msg["buffers"] = []
for i in range(1, nbufs):
start = offsets[i]
stop = offsets[i+1] if (i+1) < len(offsets) else len(ws_msg)

if not (0 <= start <= stop <= len(ws_msg)):
raise ValueError(f"Invalid buffer offsets for chunk {i}")

msg["buffers"].append(ws_msg[start:stop])

return msg

def serialize_msg_to_ws_json(msg):
"""Serialize a default protocol with no buffers."""
return json.dumps(msg, default=str)

def url_path_join(*pieces: str) -> str:
"""Join components of url into a relative url
Expand Down
32 changes: 19 additions & 13 deletions jupyter_kernel_client/wsclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from jupyter_kernel_client.constants import REQUEST_TIMEOUT
from jupyter_kernel_client.log import get_logger
from jupyter_kernel_client.utils import deserialize_msg_from_ws_v1, serialize_msg_to_ws_v1, deserialize_msg_from_ws_json, serialize_msg_to_ws_json
from jupyter_kernel_client.utils import deserialize_msg_from_ws_v1, serialize_msg_to_ws_v1, deserialize_msg_from_ws_default, serialize_msg_to_ws_json, serialize_msg_to_ws_default


class JupyterSubprotocol(Enum):
Expand All @@ -47,9 +47,6 @@ class JupyterSubprotocol(Enum):
# https://jupyter-server.readthedocs.io/en/latest/developers/websocket-protocols.html#v1-kernel-websocket-jupyter-org-protocol
V1 = 1

# JSON is specifically for Colab (until we modify this library to actually have support for DEFAULT)
JSON = 2


class WSSession(Session):
"""WebSocket session."""
Expand Down Expand Up @@ -252,13 +249,22 @@ def send( # type:ignore[override]
to_send = self.serialize(msg)
to_send.extend(buffers)

if self.subprotocol == JupyterSubprotocol.JSON:
stream.send_text(serialize_msg_to_ws_json(msg))
elif self.subprotocol == JupyterSubprotocol.V1:
if self.subprotocol == JupyterSubprotocol.V1:
stream.send_bytes(serialize_msg_to_ws_v1(to_send, channel))
else:
# V0 / DEFAULT is currently unsupported
raise ValueError("JupyterSubprotocol.DEFAULT is currently unsupported.")
# The Default protocol is a bytearray with a header pointing to
# offsets where buffers are appended.
#
# Buffers are namely added for cases such as comm messages.
# In the case of the common message without a buffers list, the
# headers will always be '\x00\x00\x00\x01\x00\x00\x00\x08'.
# Since this is constant it might as well not be included, which is
# what Jupyter is doing with the default protocol.
# [server code found here](https://github.com/jupyter-server/jupyter_server/blob/main/jupyter_server/services/kernels/connection/channels.py#L445-L464)
if 'buffers' in msg and len(msg['buffers']) > 0:
stream.send_bytes(serialize_msg_to_ws_default(msg))
else:
stream.send_text(serialize_msg_to_ws_json(msg))

self.log.debug("WSSession.send\n%s\n%s\n%s", msg, to_send, buffers)

Expand Down Expand Up @@ -1245,18 +1251,18 @@ def _on_close(self, _: websocket.WebSocket, close_status_code: t.Any, close_msg:
self.connection_ready.clear()

def _on_message(self, s: websocket.WebSocket, message: bytes) -> None:
if self._subprotocol == JupyterSubprotocol.JSON:
deserialize_msg = deserialize_msg_from_ws_json(message)
if self._subprotocol == JupyterSubprotocol.DEFAULT:
deserialize_msg = deserialize_msg_from_ws_default(message)
channel = deserialize_msg['channel']
elif self._subprotocol == JupyterSubprotocol.V1:
channel, msg_list = deserialize_msg_from_ws_v1(message)
deserialize_msg = self.session.deserialize(msg_list)
else:
raise ValueError("JupyterSubprotocol.DEFAULT is unsupported.")
raise ValueError("unsupported protocol.")

self.log.debug(
"Received message on channel: {channel}, msg_id: {msg_id}, msg_type: {msg_type}".format(
channel=channel,
channel,
**(deserialize_msg or {"msg_id": "null", "msg_type": "null"}),
)
)
Expand Down
Loading