From 35136851f20397cbac50cbd98ca7ac1663f2aed8 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 16 Apr 2024 17:22:11 +0100 Subject: [PATCH] Import web_pdb lazily This ensures that `import jax` is not affected by `web_pdb` being installed. Note that I also added an atexit handler closing the active `web_pdb` consoles. This is strictly speaking not necessary, as the server powering the console is running in a daemon thread, but nice-to-have anyway. --- jax/_src/debugger/web_debugger.py | 41 ++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/jax/_src/debugger/web_debugger.py b/jax/_src/debugger/web_debugger.py index d5c97d1903ad..443bfa676715 100644 --- a/jax/_src/debugger/web_debugger.py +++ b/jax/_src/debugger/web_debugger.py @@ -13,6 +13,9 @@ # limitations under the License. from __future__ import annotations +import atexit +import functools +import importlib.util import os from typing import Any import weakref @@ -20,16 +23,22 @@ from jax._src.debugger import cli_debugger from jax._src.debugger import core as debugger_core -web_pdb_version: tuple[int, ...] | None = None -try: + +@functools.cache +def _web_pdb_version() -> tuple[int, ...]: import web_pdb # pytype: disable=import-error - web_pdb_version = tuple(map(int, web_pdb.__version__.split("."))) - WEB_PDB_ENABLED = True -except: - WEB_PDB_ENABLED = False + return tuple(map(int, web_pdb.__version__.split("."))) + + +_web_consoles: dict[tuple[str, int], Any] = {} -_web_consoles: dict[tuple[str, int], web_pdb.WebConsole] = {} +@atexit.register +def _close_debuggers(): + for console in _web_consoles.values(): + console.close() + _web_consoles.clear() + class WebDebugger(cli_debugger.CliDebugger): """A web-based debugger.""" @@ -39,6 +48,7 @@ class WebDebugger(cli_debugger.CliDebugger): def __init__(self, frames: list[debugger_core.DebuggerFrame], thread_id, completekey: str = "tab", host: str = "", port: int = 5555): if (host, port) not in _web_consoles: + import web_pdb # pytype: disable=import-error _web_consoles[host, port] = web_pdb.WebConsole(host, port, self) # Clobber the debugger in the web console _web_console = _web_consoles[host, port] @@ -54,7 +64,7 @@ def get_current_frame_data(self): current_line = None if current_frame.offset is not None: current_line = current_frame.offset + 1 - if web_pdb_version and web_pdb_version < (1, 4, 4): + if _web_pdb_version() < (1, 4, 4): return { 'filename': filename, 'listing': '\n'.join(lines), @@ -74,15 +84,15 @@ def get_current_frame_data(self): def get_globals(self): current_frame = self.current_frame() - globals = "\n".join([f"{key} = {value}" for key, value in - sorted(current_frame.globals.items())]) - return globals + return "\n".join( + f"{key} = {value}" + for key, value in sorted(current_frame.globals.items())) def get_locals(self): current_frame = self.current_frame() - locals = "\n".join([f"{key} = {value}" for key, value in - sorted(current_frame.locals.items())]) - return locals + return "\n".join( + f"{key} = {value}" + for key, value in sorted(current_frame.locals.items())) def run(self): return self.cmdloop() @@ -91,5 +101,6 @@ def run_debugger(frames: list[debugger_core.DebuggerFrame], thread_id: int | None, **kwargs: Any): WebDebugger(frames, thread_id, **kwargs).run() -if WEB_PDB_ENABLED: + +if importlib.util.find_spec("web_pdb") is not None: debugger_core.register_debugger("web", run_debugger, -2)