Skip to content

Commit

Permalink
Merge pull request #20778 from superbobry:lazy-imports
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625415993
  • Loading branch information
jax authors committed Apr 16, 2024
2 parents 5bd6013 + 3513685 commit adbb11f
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions jax/_src/debugger/web_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,32 @@
# limitations under the License.
from __future__ import annotations

import atexit
import functools
import importlib.util
import os
from typing import Any
import weakref

from jax._src.debugger import cli_debugger
from jax._src.debugger import core as debugger_core

web_pdb_version: tuple[int, ...] | None = None
try:

@functools.cache
def _web_pdb_version() -> tuple[int, ...]:
import web_pdb # pytype: disable=import-error
web_pdb_version = tuple(map(int, web_pdb.__version__.split(".")))
WEB_PDB_ENABLED = True
except:
WEB_PDB_ENABLED = False
return tuple(map(int, web_pdb.__version__.split(".")))


_web_consoles: dict[tuple[str, int], Any] = {}


_web_consoles: dict[tuple[str, int], web_pdb.WebConsole] = {}
@atexit.register
def _close_debuggers():
for console in _web_consoles.values():
console.close()
_web_consoles.clear()


class WebDebugger(cli_debugger.CliDebugger):
"""A web-based debugger."""
Expand All @@ -39,6 +48,7 @@ class WebDebugger(cli_debugger.CliDebugger):
def __init__(self, frames: list[debugger_core.DebuggerFrame], thread_id,
completekey: str = "tab", host: str = "", port: int = 5555):
if (host, port) not in _web_consoles:
import web_pdb # pytype: disable=import-error
_web_consoles[host, port] = web_pdb.WebConsole(host, port, self)
# Clobber the debugger in the web console
_web_console = _web_consoles[host, port]
Expand All @@ -54,7 +64,7 @@ def get_current_frame_data(self):
current_line = None
if current_frame.offset is not None:
current_line = current_frame.offset + 1
if web_pdb_version and web_pdb_version < (1, 4, 4):
if _web_pdb_version() < (1, 4, 4):
return {
'filename': filename,
'listing': '\n'.join(lines),
Expand All @@ -74,15 +84,15 @@ def get_current_frame_data(self):

def get_globals(self):
current_frame = self.current_frame()
globals = "\n".join([f"{key} = {value}" for key, value in
sorted(current_frame.globals.items())])
return globals
return "\n".join(
f"{key} = {value}"
for key, value in sorted(current_frame.globals.items()))

def get_locals(self):
current_frame = self.current_frame()
locals = "\n".join([f"{key} = {value}" for key, value in
sorted(current_frame.locals.items())])
return locals
return "\n".join(
f"{key} = {value}"
for key, value in sorted(current_frame.locals.items()))

def run(self):
return self.cmdloop()
Expand All @@ -91,5 +101,6 @@ def run_debugger(frames: list[debugger_core.DebuggerFrame],
thread_id: int | None, **kwargs: Any):
WebDebugger(frames, thread_id, **kwargs).run()

if WEB_PDB_ENABLED:

if importlib.util.find_spec("web_pdb") is not None:
debugger_core.register_debugger("web", run_debugger, -2)

0 comments on commit adbb11f

Please sign in to comment.