From 78b4bc39c5cb6fca35d6405e9017ad9ac5627bac Mon Sep 17 00:00:00 2001 From: Drake Aiman Date: Fri, 6 Feb 2026 20:07:18 +0000 Subject: [PATCH] add black formatter --- Makefile | 4 ++ jupyter_kernel_client/client.py | 44 ++++++++---- jupyter_kernel_client/konsoleapp.py | 16 +++-- jupyter_kernel_client/manager.py | 77 ++++++++++++++++----- jupyter_kernel_client/shell.py | 11 ++- jupyter_kernel_client/snippets.py | 4 +- jupyter_kernel_client/tests/test_client.py | 57 ++++++++------- jupyter_kernel_client/tests/test_manager.py | 6 +- jupyter_kernel_client/tests/test_utils.py | 33 ++++++--- jupyter_kernel_client/utils.py | 23 +++--- jupyter_kernel_client/wsclient.py | 63 +++++++++-------- 11 files changed, 219 insertions(+), 119 deletions(-) diff --git a/Makefile b/Makefile index 9c94cd9..5baaaf9 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,10 @@ default: all ## Default target is all. help: ## display this help. @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) +format: ## format python + pip install black + python -m black . + all: clean dev ## Clean Install and Build install: diff --git a/jupyter_kernel_client/client.py b/jupyter_kernel_client/client.py index 955f4cc..671fcbc 100644 --- a/jupyter_kernel_client/client.py +++ b/jupyter_kernel_client/client.py @@ -21,7 +21,9 @@ from jupyter_kernel_client.utils import UTC -def output_hook(outputs: list[dict[str, t.Any]], message: dict[str, t.Any]) -> set[int]: # noqa: C901 +def output_hook( + outputs: list[dict[str, t.Any]], message: dict[str, t.Any] +) -> set[int]: # noqa: C901 """Callback on messages captured during a code snippet execution. The return list of updated output will be empty if no output where changed. @@ -163,7 +165,9 @@ def __init__( self, kernel_id: str | None = None, log: logging.Logger | None = None, **kwargs ) -> None: super().__init__(log=log or get_logger()) - self._manager = self.kernel_manager_class(parent=self, kernel_id=kernel_id, **kwargs) + self._manager = self.kernel_manager_class( + parent=self, kernel_id=kernel_id, **kwargs + ) # Set it after the manager as if a kernel_id is provided, # we will try to connect to it. self._own_kernel = self._manager.kernel is None @@ -456,7 +460,8 @@ def set_variable(self, name: str, value: t.Any) -> None: """ kernel_language = (self.kernel_info or {}).get("language_info", {}).get("name") if kernel_language not in SNIPPETS_REGISTRY.available_languages: - raise ValueError(f"""Code snippet for language {kernel_language} are not available. + raise ValueError( + f"""Code snippet for language {kernel_language} are not available. You can set them yourself using: from jupyter_kernel_client import SNIPPETS_REGISTRY, LanguageSnippets @@ -469,10 +474,13 @@ def set_variable(self, name: str, value: t.Any) -> None: get_variable_mimetypes="", ) ) -""") +""" + ) snippet = SNIPPETS_REGISTRY.get_set_variable(kernel_language) data, metadata = serialize_object(value) - results = self.execute(snippet.format(name=name, data=data, metadata=metadata), silent=True) + results = self.execute( + snippet.format(name=name, data=data, metadata=metadata), silent=True + ) self.log.debug("Set variables: %s", results) if results["status"] == "ok": pass @@ -496,7 +504,8 @@ def get_variable(self, name: str) -> tuple[dict[str, t.Any], dict[str, t.Any]]: """ kernel_language = (self.kernel_info or {}).get("language_info", {}).get("name") if kernel_language not in SNIPPETS_REGISTRY.available_languages: - raise ValueError(f"""Code snippet for language {kernel_language} are not available. + raise ValueError( + f"""Code snippet for language {kernel_language} are not available. You can set them yourself using: from jupyter_kernel_client import SNIPPETS_REGISTRY, LanguageSnippets @@ -509,7 +518,8 @@ def get_variable(self, name: str) -> tuple[dict[str, t.Any], dict[str, t.Any]]: get_variable_mimetypes="", ) ) -""") +""" + ) snippet = SNIPPETS_REGISTRY.get_get_variable(kernel_language) results = self.execute(snippet.format(name=name), silent=True) @@ -536,7 +546,8 @@ def list_variables(self) -> list[VariableDescription]: """ kernel_language = (self.kernel_info or {}).get("language_info", {}).get("name") if kernel_language not in SNIPPETS_REGISTRY.available_languages: - raise ValueError(f"""Code snippet for language {kernel_language} are not available. + raise ValueError( + f"""Code snippet for language {kernel_language} are not available. You can set them yourself using: from jupyter_kernel_client import SNIPPETS_REGISTRY, LanguageSnippets @@ -549,7 +560,8 @@ def list_variables(self) -> list[VariableDescription]: get_variable_mimetypes="", ) ) -""") +""" + ) snippet = SNIPPETS_REGISTRY.get_list_variables(kernel_language) results = self.execute(snippet, silent=True) @@ -592,21 +604,27 @@ def get_variable_mimetypes( """ kernel_language = (self.kernel_info or {}).get("language_info", {}).get("name") if kernel_language not in SNIPPETS_REGISTRY.available_languages: - raise ValueError(f"""Code snippet for language {kernel_language} are not available. + raise ValueError( + f"""Code snippet for language {kernel_language} are not available. You can set them yourself using: from jupyter_kernel_client import SNIPPETS_REGISTRY, LanguageSnippets SNIPPETS_REGISTRY.register("my-language", LanguageSnippets(list_variables="", get_variable="")) -""") +""" + ) snippet = SNIPPETS_REGISTRY.get_get_variable_mimetypes(kernel_language) - results = self.execute(snippet.format(name=name, mimetype=mimetype), silent=True) + results = self.execute( + snippet.format(name=name, mimetype=mimetype), silent=True + ) self.log.debug("Kernel variables: %s", results) if results["status"] == "ok" and results["outputs"]: if mimetype is None: - return results["outputs"][0]["data"], results["outputs"][0].get("metadata", {}) + return results["outputs"][0]["data"], results["outputs"][0].get( + "metadata", {} + ) else: def filter_dict(d: dict, mimetype: str) -> dict: diff --git a/jupyter_kernel_client/konsoleapp.py b/jupyter_kernel_client/konsoleapp.py index ac5ccfe..6451a94 100644 --- a/jupyter_kernel_client/konsoleapp.py +++ b/jupyter_kernel_client/konsoleapp.py @@ -107,7 +107,9 @@ class KonsoleApp(JupyterApp): subcommands = Dict() - server_url = Unicode("http://localhost:8888", config=True, help="URL to the Jupyter Server.") + server_url = Unicode( + "http://localhost:8888", config=True, help="URL to the Jupyter Server." + ) # FIXME it does not support password token = Unicode("", config=True, help="Jupyter Server token.") @@ -126,10 +128,14 @@ class KonsoleApp(JupyterApp): existing = CUnicode("", config=True, help="""Existing kernel ID to connect to.""") - kernel_name = Unicode("python3", config=True, help="""The name of the kernel to connect to.""") + kernel_name = Unicode( + "python3", config=True, help="""The name of the kernel to connect to.""" + ) kernel_path = Unicode( - "", config=True, help="API path from server root to the kernel working directory." + "", + config=True, + help="API path from server root to the kernel working directory.", ) confirm_exit = CBool( @@ -198,7 +204,9 @@ def init_kernel_manager(self) -> None: ) if not self.existing: - self.kernel_client.start_kernel(name=self.kernel_name, path=self.kernel_path) + self.kernel_client.start_kernel( + name=self.kernel_name, path=self.kernel_path + ) elif self.kernel_client.kernel is None: msg = f"Unable to connect to kernel with ID {self.existing}." raise RuntimeError(msg) diff --git a/jupyter_kernel_client/manager.py b/jupyter_kernel_client/manager.py index d9827a4..7703052 100644 --- a/jupyter_kernel_client/manager.py +++ b/jupyter_kernel_client/manager.py @@ -34,7 +34,7 @@ def fetch( headers = { "Accept": "application/json", "Content-Type": "application/json", - "User-Agent": "Jupyter Kernel Client" + "User-Agent": "Jupyter Kernel Client", } headers.update(kwargs.pop("headers", {})) if token: @@ -88,15 +88,17 @@ def __init__( else: self.__extra_headers = {} - if 'headers' not in self.__client_kwargs: - self.__client_kwargs['headers'] = {} - self.__client_kwargs['headers'].update(self.__extra_headers) + if "headers" not in self.__client_kwargs: + self.__client_kwargs["headers"] = {} + self.__client_kwargs["headers"].update(self.__extra_headers) if kernel_id: self.__kernel = { "id": kernel_id, "execution_state": "unknown", - "last_activity": datetime.datetime.strftime(utcnow(), "%Y-%m-%dT%H:%M:%S.%fZ"), + "last_activity": datetime.datetime.strftime( + utcnow(), "%Y-%m-%dT%H:%M:%S.%fZ" + ), } self.refresh_model() @@ -118,7 +120,9 @@ def kernel_url(self) -> str | None: else: return None - client_class = DottedObjectName("jupyter_kernel_client.wsclient.KernelWebSocketClient") + client_class = DottedObjectName( + "jupyter_kernel_client.wsclient.KernelWebSocketClient" + ) client_factory = Type(klass="jupyter_client.client.KernelClientABC") @default("client_factory") @@ -137,7 +141,9 @@ def _client_class_changed(self, change: dict[str, DottedObjectName]) -> None: def client(self) -> t.Any: """Create a client configured to connect to our kernel.""" if not self.kernel_url: - raise RuntimeError("You must first start a kernel before requesting a client.") + raise RuntimeError( + "You must first start a kernel before requesting a client." + ) if not self.__client: base_ws_url = HTTP_PROTOCOL_REGEXP.sub("ws", self.kernel_url, 1) @@ -156,7 +162,9 @@ def client(self) -> t.Any: return self.__client - def refresh_model(self, timeout: float = REQUEST_TIMEOUT) -> dict[str, t.Any] | None: + def refresh_model( + self, timeout: float = REQUEST_TIMEOUT + ) -> dict[str, t.Any] | None: """Refresh the kernel model. Returns @@ -173,7 +181,13 @@ def refresh_model(self, timeout: float = REQUEST_TIMEOUT) -> dict[str, t.Any] | self.log.debug("Request kernel at: %s", self.kernel_url) try: - response = fetch(self.kernel_url, token=self.token, method="GET", timeout=timeout, headers=self.__extra_headers) + response = fetch( + self.kernel_url, + token=self.token, + method="GET", + timeout=timeout, + headers=self.__extra_headers, + ) except HTTPError as error: if error.response.status_code == 404: self.log.warning("Kernel not found at: %s", self.kernel_url) @@ -200,7 +214,13 @@ def list_kernels(self, timeout: float = REQUEST_TIMEOUT) -> list[dict[str, t.Any kernels_url = url_path_join(self.server_url, "api/kernels") self.log.debug("Request kernels at: %s", kernels_url) try: - response = fetch(kernels_url, token=self.token, method="GET", timeout=timeout, headers=self.__extra_headers) + response = fetch( + kernels_url, + token=self.token, + method="GET", + timeout=timeout, + headers=self.__extra_headers, + ) except HTTPError as error: self.log.error("Error fetching kernels: %s", error) return [] @@ -209,7 +229,6 @@ def list_kernels(self, timeout: float = REQUEST_TIMEOUT) -> list[dict[str, t.Any self.log.debug("Kernels retrieved: %s", models) return models - # -------------------------------------------------------------------------- # Kernel management # -------------------------------------------------------------------------- @@ -241,7 +260,7 @@ def start_kernel( method="POST", json={"name": name, "path": path}, timeout=timeout, - headers=self.__extra_headers + headers=self.__extra_headers, ) self.__kernel = response.json() @@ -256,7 +275,9 @@ def shutdown_kernel( ): """Attempts to stop the kernel process cleanly via HTTP.""" if not self.kernel_url: - raise RuntimeError("You must first start a kernel before requesting a client.") + raise RuntimeError( + "You must first start a kernel before requesting a client." + ) self.log.debug("Request shutdown kernel at: %s", self.kernel_url) if not now: @@ -271,7 +292,13 @@ def shutdown_kernel( # If not now and refreshing the model still returns it, try the http way try: - response = fetch(self.kernel_url, token=self.token, method="DELETE", timeout=timeout, headers=self.__extra_headers) + response = fetch( + self.kernel_url, + token=self.token, + method="DELETE", + timeout=timeout, + headers=self.__extra_headers, + ) self.log.debug( "Shutdown kernel response: %d %s", response.status_code, @@ -291,17 +318,29 @@ def shutdown_kernel( def restart_kernel(self, timeout: float = REQUEST_TIMEOUT, **kw): """Restarts a kernel via HTTP request.""" if not self.kernel_url: - raise RuntimeError("You must first start a kernel before requesting a client.") + raise RuntimeError( + "You must first start a kernel before requesting a client." + ) kernel_url = self.kernel_url + "/restart" self.log.debug("Request restart kernel at: %s", kernel_url) - response = fetch(kernel_url, token=self.token, method="POST", timeout=timeout, headers=self.__extra_headers) - self.log.debug("Restart kernel response: %d %s", response.status_code, response.reason) + response = fetch( + kernel_url, + token=self.token, + method="POST", + timeout=timeout, + headers=self.__extra_headers, + ) + self.log.debug( + "Restart kernel response: %d %s", response.status_code, response.reason + ) def interrupt_kernel(self, timeout: float = REQUEST_TIMEOUT): """Interrupts the kernel via an HTTP request.""" if not self.kernel_url: - raise RuntimeError("You must first start a kernel before requesting a client.") + raise RuntimeError( + "You must first start a kernel before requesting a client." + ) kernel_url = self.kernel_url + "/interrupt" self.log.debug("Request interrupt kernel at: %s", kernel_url) @@ -310,7 +349,7 @@ def interrupt_kernel(self, timeout: float = REQUEST_TIMEOUT): token=self.token, method="POST", timeout=timeout, - headers=self.__extra_headers + headers=self.__extra_headers, ) self.log.debug( "Interrupt kernel response: %d %s", diff --git a/jupyter_kernel_client/shell.py b/jupyter_kernel_client/shell.py index 3d2ef01..6bd54f6 100644 --- a/jupyter_kernel_client/shell.py +++ b/jupyter_kernel_client/shell.py @@ -18,8 +18,12 @@ from jupyter_console.ptshell import ZMQTerminalInteractiveShell class WSTerminalInteractiveShell(ZMQTerminalInteractiveShell): - manager = Instance("jupyter_kernel_client.manager.KernelHttpManager", allow_none=True) - client = Instance("jupyter_kernel_client.wsclient.KernelWebSocketClient", allow_none=True) + manager = Instance( + "jupyter_kernel_client.manager.KernelHttpManager", allow_none=True + ) + client = Instance( + "jupyter_kernel_client.wsclient.KernelWebSocketClient", allow_none=True + ) @default("banner") def _default_banner(self): @@ -36,7 +40,8 @@ async def handle_external_iopub(self, loop=None): def show_banner(self): print( # noqa T201 self.banner.format( - version=__version__, kernel_banner=self.kernel_info.get("banner", "") + version=__version__, + kernel_banner=self.kernel_info.get("banner", ""), ), end="", flush=True, diff --git a/jupyter_kernel_client/snippets.py b/jupyter_kernel_client/snippets.py index 15df2e9..c5b9947 100644 --- a/jupyter_kernel_client/snippets.py +++ b/jupyter_kernel_client/snippets.py @@ -87,7 +87,9 @@ def register(self, language: str, snippets: LanguageSnippets) -> None: snippets: Language snippets """ if language in self._snippets: - warnings.warn(f"Snippets for language {language} will be overridden.", stacklevel=2) + warnings.warn( + f"Snippets for language {language} will be overridden.", stacklevel=2 + ) self._snippets[language] = snippets def get_list_variables(self, language: str) -> str: diff --git a/jupyter_kernel_client/tests/test_client.py b/jupyter_kernel_client/tests/test_client.py index 964b657..ca8cb96 100644 --- a/jupyter_kernel_client/tests/test_client.py +++ b/jupyter_kernel_client/tests/test_client.py @@ -17,12 +17,10 @@ def test_execution_as_context_manager(jupyter_server): port, token = jupyter_server with KernelClient(server_url=f"http://localhost:{port}", token=token) as kernel: - reply = kernel.execute( - """import os + reply = kernel.execute("""import os from platform import node print(f"Hey {os.environ.get('USER', 'John Smith')} from {node()}.") -""" - ) +""") assert reply["execution_count"] == 1 assert reply["outputs"] == [ @@ -41,12 +39,10 @@ def test_execution_no_context_manager(jupyter_server): kernel = KernelClient(server_url=f"http://localhost:{port}", token=token) kernel.start() try: - reply = kernel.execute( - """import os + reply = kernel.execute("""import os from platform import node print(f"Hey {os.environ.get('USER', 'John Smith')} from {node()}.") -""" - ) +""") finally: kernel.stop() @@ -69,7 +65,9 @@ def test_list_kernels_client(jupyter_server): kernel_id = kernel.id # Use a new client to list the kernels - listing_client = KernelClient(server_url=f"http://localhost:{port}", token=token) + listing_client = KernelClient( + server_url=f"http://localhost:{port}", token=token + ) kernels = listing_client.list_kernels() assert isinstance(kernels, list) @@ -83,34 +81,26 @@ def test_list_kernels_client(jupyter_server): if k["id"] == kernel_id: found = True - assert found, f"Kernel with id {kernel_id} not found in the list of running kernels." + assert ( + found + ), f"Kernel with id {kernel_id} not found in the list of running kernels." def test_list_variables(jupyter_server): port, token = jupyter_server with KernelClient(server_url=f"http://localhost:{port}", token=token) as kernel: - kernel.execute( - """a = 1.0 + kernel.execute("""a = 1.0 b = "hello the world" c = {3, 4, 5} d = {"name": "titi"} -""" - ) +""") variables = kernel.list_variables() assert variables == [ - VariableDescription( - name="a", - type=["builtins", "float"], - size=None - ), - VariableDescription( - name="b", - type=["builtins", "str"], - size=None - ), + VariableDescription(name="a", type=["builtins", "float"], size=None), + VariableDescription(name="b", type=["builtins", "str"], size=None), VariableDescription( name="c", type=["builtins", "set"], @@ -169,7 +159,16 @@ def test_get_textplain_variables(jupyter_server, variable, set_variable, expecte ( ("lst", [1, 2, 3, 4]), ("arr", np.random.randn(100000)), - ("df", pd.DataFrame({'values': np.random.randn(1000), 'categories': np.random.choice(['A', 'B', 'C'], 1000), 'integers': np.random.randint(1, 100, 1000)})), + ( + "df", + pd.DataFrame( + { + "values": np.random.randn(1000), + "categories": np.random.choice(["A", "B", "C"], 1000), + "integers": np.random.randint(1, 100, 1000), + } + ), + ), ("s", pd.Series(np.random.randn(100000))), ), ) @@ -201,7 +200,7 @@ def test_set_variables_on_execute(jupyter_server, variable, value): port, token = jupyter_server variables = {variable: value} with KernelClient(server_url=f"http://localhost:{port}", token=token) as kernel: - reply = kernel.execute(f'print({variable})', variables=variables) + reply = kernel.execute(f"print({variable})", variables=variables) assert reply["execution_count"] == 1 assert reply["outputs"] == [ { @@ -237,7 +236,7 @@ def test_set_variables(jupyter_server, variable, set_variable, expected): async def test_multi_execution_in_event_loop(jupyter_server): port, token = jupyter_server - current_user = os.environ.get('USER', 'John Smith') + current_user = os.environ.get("USER", "John Smith") current_node = node() with KernelClient(server_url=f"http://localhost:{port}", token=token) as kernel: @@ -249,13 +248,13 @@ async def test_multi_execution_in_event_loop(jupyter_server): import time time.sleep(5) print(f"Hey {{os.environ.get('USER', 'John Smith')}} from {{node()}}.") -""" +""", ), asyncio.to_thread( kernel.execute, """import time time.sleep(1) -print("Hello")""" +print("Hello")""", ), ) diff --git a/jupyter_kernel_client/tests/test_manager.py b/jupyter_kernel_client/tests/test_manager.py index ec47792..5389afe 100644 --- a/jupyter_kernel_client/tests/test_manager.py +++ b/jupyter_kernel_client/tests/test_manager.py @@ -1,4 +1,3 @@ - # Copyright (c) 2023-2024 Datalayer, Inc. # Copyright (c) 2025 Google # @@ -28,6 +27,7 @@ def test_list_kernels(jupyter_server): assert "name" in k if k["id"] == kernel_id: found = True - - assert found, f"Kernel with id {kernel_id} not found in the list of running kernels." + assert ( + found + ), f"Kernel with id {kernel_id} not found in the list of running kernels." diff --git a/jupyter_kernel_client/tests/test_utils.py b/jupyter_kernel_client/tests/test_utils.py index 86e3d03..a3f5dbb 100644 --- a/jupyter_kernel_client/tests/test_utils.py +++ b/jupyter_kernel_client/tests/test_utils.py @@ -4,7 +4,14 @@ # BSD 3-Clause License import json -from jupyter_kernel_client.utils import serialize_msg_to_ws_json, serialize_msg_to_ws_default, deserialize_msg_from_ws_default, serialize_msg_to_ws_v1, deserialize_msg_from_ws_v1 +from jupyter_kernel_client.utils import ( + serialize_msg_to_ws_json, + serialize_msg_to_ws_default, + deserialize_msg_from_ws_default, + serialize_msg_to_ws_v1, + deserialize_msg_from_ws_v1, +) + def test_serialize_msg_to_ws_json(): src_msg = { @@ -33,6 +40,7 @@ def test_serialize_msg_to_ws_json(): serialized_msg = serialize_msg_to_ws_json(src_msg) assert expected_output == serialized_msg + def test_serialize_and_deserialize_msg_to_ws_default(): src_msg = { "header": { @@ -47,7 +55,7 @@ def test_serialize_and_deserialize_msg_to_ws_default(): "metadata": { "buffer_paths": [ ["content", "data", "payload"], - ["content", "data", "extra_blob"] + ["content", "data", "extra_blob"], ] }, "content": { @@ -69,19 +77,20 @@ def test_serialize_and_deserialize_msg_to_ws_default(): serialized_msg = serialize_msg_to_ws_default(src_msg) bufn = int.from_bytes(serialized_msg[0:4], byteorder="big") - buffers = src_msg['buffers'] or [] + buffers = src_msg["buffers"] or [] for i in range(1, bufn): # ignore the json message for now, it's tested the deserialized msg - start = (i+1) * 4 - offset = int.from_bytes(serialized_msg[start:start+4], byteorder="big") - buf = buffers[i-1] - serialized_buf_val = serialized_msg[offset:offset+len(buf)] + start = (i + 1) * 4 + offset = int.from_bytes(serialized_msg[start : start + 4], byteorder="big") + buf = buffers[i - 1] + serialized_buf_val = serialized_msg[offset : offset + len(buf)] assert serialized_buf_val == buf deserialized_msg = deserialize_msg_from_ws_default(serialized_msg) assert deserialized_msg == src_msg + def test_serialize_and_deserialize_msg_to_ws_v1(): def pack(obj) -> bytes: return json.dumps(obj, separators=(",", ":"), sort_keys=True).encode("utf-8") @@ -100,7 +109,7 @@ def pack(obj) -> bytes: "metadata": { "buffer_paths": [ ["content", "data", "payload"], - ["content", "data", "extra_blob"] + ["content", "data", "extra_blob"], ] }, "content": { @@ -124,10 +133,12 @@ def pack(obj) -> bytes: # construct the msg lists for the serialized msg offset = int.from_bytes(serialized_msg[:8], byteorder="little") offsets = [ - int.from_bytes(serialized_msg[8 * (i + 1) : 8 * (i + 2)], byteorder="little") for i in range(offset) + int.from_bytes(serialized_msg[8 * (i + 1) : 8 * (i + 2)], byteorder="little") + for i in range(offset) + ] + serialized_list = [ + serialized_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset - 1) ] - serialized_list = [serialized_msg[offsets[i]:offsets[i+1]] for i in range(1, offset-1)] _, deserialized_msg = deserialize_msg_from_ws_v1(serialized_msg) assert serialized_list == deserialized_msg - diff --git a/jupyter_kernel_client/utils.py b/jupyter_kernel_client/utils.py index 34c203e..0ee235b 100644 --- a/jupyter_kernel_client/utils.py +++ b/jupyter_kernel_client/utils.py @@ -46,10 +46,13 @@ def deserialize_msg_from_ws_v1(ws_msg): """Deserialize a message using the v1 protocol.""" offset_number = int.from_bytes(ws_msg[:8], "little") offsets = [ - int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number) + int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") + for i in range(offset_number) ] channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8") - msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)] + msg_list = [ + ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1) + ] return channel, msg_list @@ -59,7 +62,7 @@ def serialize_msg_to_ws_default(msg): buffers = [] msg_copy = dict(msg) - msg_copy['header']['date'] = str(msg_copy['header']['date']) + msg_copy["header"]["date"] = str(msg_copy["header"]["date"]) orig_buffers = msg_copy.pop("buffers", []) json_bytes = json.dumps(msg_copy).encode("utf-8") buffers.append(json_bytes) @@ -80,11 +83,11 @@ def serialize_msg_to_ws_default(msg): for i, off in enumerate(offsets): start = 4 * (i + 1) - msg_buf[start:start+4] = off.to_bytes(4, byteorder="big") + msg_buf[start : start + 4] = off.to_bytes(4, byteorder="big") for i, b in enumerate(buffers): start = offsets[i] - msg_buf[start:start+len(b)] = b + msg_buf[start : start + len(b)] = b return bytes(msg_buf) @@ -92,7 +95,7 @@ def serialize_msg_to_ws_default(msg): def deserialize_msg_from_ws_default(ws_msg): """Deserialize a message using the default protocol.""" if isinstance(ws_msg, str): - return json.loads(ws_msg.encode('utf-8')) + return json.loads(ws_msg.encode("utf-8")) else: nbufs = int.from_bytes(ws_msg[:4], byteorder="big") offsets = [] @@ -101,7 +104,7 @@ def deserialize_msg_from_ws_default(ws_msg): for i in range(nbufs): start = 4 * (i + 1) - off = int.from_bytes(ws_msg[start:start+4], byteorder="big") + off = int.from_bytes(ws_msg[start : start + 4], byteorder="big") offsets.append(off) json_start = offsets[0] @@ -115,7 +118,7 @@ def deserialize_msg_from_ws_default(ws_msg): msg["buffers"] = [] for i in range(1, nbufs): start = offsets[i] - stop = offsets[i+1] if (i+1) < len(offsets) else len(ws_msg) + stop = offsets[i + 1] if (i + 1) < len(offsets) else len(ws_msg) if not (0 <= start <= stop <= len(ws_msg)): raise ValueError(f"Invalid buffer offsets for chunk {i}") @@ -124,10 +127,12 @@ def deserialize_msg_from_ws_default(ws_msg): return msg + def serialize_msg_to_ws_json(msg): """Serialize a default protocol with no buffers.""" return json.dumps(msg, default=str) + def url_path_join(*pieces: str) -> str: """Join components of url into a relative url @@ -168,4 +173,4 @@ def utcnow() -> datetime: return datetime.now(timezone.utc) -UTC = tzUTC() # type:ignore[abstract] +UTC = tzUTC() # type: ignore[abstract] diff --git a/jupyter_kernel_client/wsclient.py b/jupyter_kernel_client/wsclient.py index 8898250..4d16674 100644 --- a/jupyter_kernel_client/wsclient.py +++ b/jupyter_kernel_client/wsclient.py @@ -20,7 +20,7 @@ from threading import Event, Lock, Thread from urllib.parse import urlencode -import websocket # type:ignore[import-untyped] +import websocket # type: ignore[import-untyped] from jupyter_client.adapter import adapt from jupyter_client.channels import major_protocol_version from jupyter_client.channelsabc import ChannelABC, HBChannelABC @@ -31,7 +31,13 @@ from jupyter_kernel_client.constants import REQUEST_TIMEOUT from jupyter_kernel_client.log import get_logger -from jupyter_kernel_client.utils import deserialize_msg_from_ws_v1, serialize_msg_to_ws_v1, deserialize_msg_from_ws_default, serialize_msg_to_ws_json, serialize_msg_to_ws_default +from jupyter_kernel_client.utils import ( + deserialize_msg_from_ws_v1, + serialize_msg_to_ws_v1, + deserialize_msg_from_ws_default, + serialize_msg_to_ws_json, + serialize_msg_to_ws_default, +) class JupyterSubprotocol(Enum): @@ -51,18 +57,19 @@ class JupyterSubprotocol(Enum): class WSSession(Session): """WebSocket session.""" - def __init__(self, - log: logging.Logger | None = None, - subprotocol: JupyterSubprotocol | None = JupyterSubprotocol.V1, - **kwargs): + def __init__( + self, + log: logging.Logger | None = None, + subprotocol: JupyterSubprotocol | None = JupyterSubprotocol.V1, + **kwargs, + ): super().__init__(**kwargs) self.log = log or get_logger() if not self.debug: self.debug = self.log.level == logging.DEBUG self.subprotocol = subprotocol - - def serialize(self, msg: dict[str, t.Any], **kwargs) -> list[bytes]: # type:ignore[override,no-untyped-def] + def serialize(self, msg: dict[str, t.Any], **kwargs) -> list[bytes]: # type: ignore[override,no-untyped-def] """Serialize the message components to bytes. This is roughly the inverse of deserialize. The serialize/deserialize @@ -113,7 +120,7 @@ def serialize(self, msg: dict[str, t.Any], **kwargs) -> list[bytes]: # type:ign return to_send - def deserialize( # type:ignore[override,no-untyped-def] + def deserialize( # type: ignore[override,no-untyped-def] self, msg_list: list[bytes], content: bool = True, **kwargs ) -> dict[str, t.Any]: """Deserialize a msg_list to a nested message dict. @@ -158,7 +165,7 @@ def deserialize( # type:ignore[override,no-untyped-def] # adapt to the current version return adapt(message) - def send( # type:ignore[override] + def send( # type: ignore[override] self, stream: websocket.WebSocketApp, channel: str, @@ -223,7 +230,7 @@ def send( # type:ignore[override] header=header, metadata=metadata, ) - msg['channel'] = channel + msg["channel"] = channel if self.check_pid and os.getpid() != self.pid: get_logger().warning("Attempted to send message from fork\n%s", msg) @@ -252,16 +259,16 @@ def send( # type:ignore[override] if self.subprotocol == JupyterSubprotocol.V1: stream.send_bytes(serialize_msg_to_ws_v1(to_send, channel)) else: - # The Default protocol is a bytearray with a header pointing to + # The Default protocol is a bytearray with a header pointing to # offsets where buffers are appended. # # Buffers are namely added for cases such as comm messages. # In the case of the common message without a buffers list, the # headers will always be '\x00\x00\x00\x01\x00\x00\x00\x08'. - # Since this is constant it might as well not be included, which is + # Since this is constant it might as well not be included, which is # what Jupyter is doing with the default protocol. # [server code found here](https://github.com/jupyter-server/jupyter_server/blob/main/jupyter_server/services/kernels/connection/channels.py#L445-L464) - if 'buffers' in msg and len(msg['buffers']) > 0: + if "buffers" in msg and len(msg["buffers"]) > 0: stream.send_bytes(serialize_msg_to_ws_default(msg)) else: stream.send_text(serialize_msg_to_ws_json(msg)) @@ -483,9 +490,7 @@ class KernelWebSocketClient(KernelClientABC): DEFAULT_INTERRUPT_WAIT = 1 - - - def __init__( # type:ignore[no-untyped-def] + def __init__( # type: ignore[no-untyped-def] self, endpoint: str, token: str | None = None, @@ -576,7 +581,7 @@ def start_channels( url = self.kernel_ws_endpoint params = {"session_id": self.session.session} if self.token is not None: - params['token'] = self.token + params["token"] = self.token url += "?" + urlencode(params) subprotocols = [] if self._subprotocol == JupyterSubprotocol.V1: @@ -787,7 +792,9 @@ def complete(self, code: str, cursor_pos: int | None = None) -> str: self.shell_channel.send(msg) return msg["header"]["msg_id"] - def inspect(self, code: str, cursor_pos: int | None = None, detail_level: int = 0) -> str: + def inspect( + self, code: str, cursor_pos: int | None = None, detail_level: int = 0 + ) -> str: """Get metadata information about an object in the kernel's namespace. It is up to the kernel to determine the appropriate object to inspect. @@ -859,7 +866,9 @@ def history( if hist_access_type == "range": kwargs.setdefault("session", 0) kwargs.setdefault("start", 0) - content = dict(raw=raw, output=output, hist_access_type=hist_access_type, **kwargs) + content = dict( + raw=raw, output=output, hist_access_type=hist_access_type, **kwargs + ) msg = self.session.msg("history_request", content) self.shell_channel.send(msg) return msg["header"]["msg_id"] @@ -893,7 +902,7 @@ def kernel_info_interactive( return self._kernel_info self.wait_for_ready(timeout) - return self._kernel_info # type:ignore[return-value] + return self._kernel_info # type: ignore[return-value] def comm_info(self, target_name: str | None = None) -> str: """Request comm info @@ -1185,9 +1194,7 @@ def wait_for_ready(self, timeout: float | None = None) -> None: # noqa: C901 # before checking for kernel_info reply while not self.is_alive(): if time.time() > abs_timeout: - message = ( - f"Kernel didn't respond to heartbeats in {timeout:d} seconds and timed out" - ) + message = f"Kernel didn't respond to heartbeats in {timeout:d} seconds and timed out" raise RuntimeError(message) time.sleep(0.2) @@ -1242,7 +1249,9 @@ def _on_open(self, _: websocket.WebSocket) -> None: self.log.debug("Websocket connection is ready.") self.connection_ready.set() - def _on_close(self, _: websocket.WebSocket, close_status_code: t.Any, close_msg: t.Any) -> None: + def _on_close( + self, _: websocket.WebSocket, close_status_code: t.Any, close_msg: t.Any + ) -> None: msg = "Websocket connection is closed" if close_status_code or close_msg: self.log.info("%s: %s %s", msg, close_status_code, close_msg) @@ -1253,7 +1262,7 @@ def _on_close(self, _: websocket.WebSocket, close_status_code: t.Any, close_msg: def _on_message(self, s: websocket.WebSocket, message: bytes) -> None: if self._subprotocol == JupyterSubprotocol.DEFAULT: deserialize_msg = deserialize_msg_from_ws_default(message) - channel = deserialize_msg['channel'] + channel = deserialize_msg["channel"] elif self._subprotocol == JupyterSubprotocol.V1: channel, msg_list = deserialize_msg_from_ws_v1(message) deserialize_msg = self.session.deserialize(msg_list) @@ -1315,7 +1324,7 @@ def _stdin_hook_default(self, msg: dict[str, t.Any]) -> None: def double_int(sig, frame): # call real handler (forwards sigint to kernel), # then raise local interrupt, stopping local raw_input - real_handler(sig, frame) # type:ignore[operator,misc] + real_handler(sig, frame) # type: ignore[operator,misc] raise KeyboardInterrupt signal.signal(signal.SIGINT, double_int)