From cb158290bf2331791d963863db630957af459f59 Mon Sep 17 00:00:00 2001 From: Theo Date: Sun, 3 Nov 2024 00:55:57 +0000 Subject: [PATCH] Refactor hot code reloading This is a major refactoring of hot code reloading. It now behaves much better, finds the origin of the exceptions much deeper in the traceback, but only reloads within the scope of the module. It handles class instantiations, class methods, lambdas, and soon module-level functions. There's no limit to how deep the exception is in the call stack, as long as it can be traced back to the module's underlying function. --- bootstrap/__init__.py | 36 ++- bootstrap/launch_experiment.py | 5 +- bootstrap/tui/builder_ui.py | 386 ++++++++++++++++----------------- dataset/example.py | 20 +- 4 files changed, 246 insertions(+), 201 deletions(-) diff --git a/bootstrap/__init__.py b/bootstrap/__init__.py index 5a5cc1f..020e629 100644 --- a/bootstrap/__init__.py +++ b/bootstrap/__init__.py @@ -1,6 +1,7 @@ import uuid from dataclasses import dataclass from functools import partial +from types import FrameType from typing import Any, Callable, List, Optional from hydra_zen.typing import Partial @@ -10,11 +11,42 @@ class MatchboxModule: def __init__(self, name: str, fn: Callable | Partial, *args, **kwargs): self._str_rep = name self._uid = uuid.uuid4().hex - self.underlying_fn = fn.func if isinstance(fn, partial) else fn + self.underlying_fn: Callable = fn.func if isinstance(fn, partial) else fn self.partial = partial(fn, *args, **kwargs) - self.first_run = True + self.to_reload = False self.result = None self.is_frozen = False + self.throw_frame: Optional[FrameType] = None + self.throw_lambda_argname: Optional[str] = None + + def reload(self, new_func: Callable) -> None: + self.underlying_fn = new_func + self.partial = partial( + self.underlying_fn, *self.partial.args, **self.partial.keywords + ) + self.to_reload = False + + def reload_surgically(self, method_name: str, method: Callable) -> None: + setattr(self.underlying_fn, method_name, method) + self.partial = partial( + self.underlying_fn, *self.partial.args, **self.partial.keywords + ) + self.to_reload = False + + def reload_surgically_in_lambda( + self, arg_name: str, method_name: str, method: Callable + ) -> None: + if arg_name not in self.partial.keywords.keys(): + raise KeyError( + "Could not find the argument to replace in the partial kwargs!" + ) + for k, v in self.partial.keywords.items(): + setattr(v, method_name, partial(method, v)) + self.partial.keywords[k] = v # Need to update when using dict iterator + self.partial = partial( + self.underlying_fn, *self.partial.args, **self.partial.keywords + ) + self.to_reload = False def __call__(self, module_chain: List) -> Any: """ diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index bd8c3de..eb8e505 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -159,7 +159,10 @@ async def launch_with_async_gui(): model_module, MatchboxModule( "Model forward", - lambda model, dataset: model(dataset[0][0].unsqueeze(0)), + # NOTE: For now we need to call .forward() explicitly, as Matchbox + # doesn't yet handle hot code reloading for model() due to Pytorch + # wrapping + lambda model, dataset: model.forward(dataset[0][0].unsqueeze(0)), model=model_module, dataset=dataset_module, ), diff --git a/bootstrap/tui/builder_ui.py b/bootstrap/tui/builder_ui.py index 4855136..ef9f5d3 100644 --- a/bootstrap/tui/builder_ui.py +++ b/bootstrap/tui/builder_ui.py @@ -3,12 +3,13 @@ import inspect import sys import traceback -from functools import partial +from types import FrameType from typing import ( Any, Callable, Iterable, List, + Optional, Tuple, ) @@ -21,6 +22,7 @@ Footer, Header, Placeholder, + RichLog, ) from bootstrap.tui.widgets.logger import Logger @@ -73,21 +75,18 @@ async def _run_chain(self) -> None: self.log_tracer("Running the chain...") if len(self._module_chain) == 0: self.log_tracer(Text("The chain is empty!", style="bold red")) - # TODO: Should we reset all modules to "first_run"? Because if we restart the - # chain from a previously frozen step, we should run it as a first run, right? - # Not sure about this. - for i, module in enumerate(self._module_chain): - initial_run = module.first_run - module.first_run = False + for module in self._module_chain: if module.is_frozen: self.log_tracer(Text(f"Skipping frozen module {module}", style="green")) continue + if module.to_reload: + self.log_tracer(Text(f"Reloading module: {module}", style="yellow")) + await self._reload_module(module) self.log_tracer(Text(f"Running module: {module}", style="yellow")) - module.result = await self.catch_and_hang( - module, initial_run, self._module_chain - ) + module.result = await self._catch_and_hang(module, self._module_chain) self.log_tracer(Text(f"{module} ran sucessfully!", style="bold green")) self.print_info("Hanged.") + self.query_one("#traceback", RichLog).clear() await self.hang(threw=False) def run_chain(self) -> None: @@ -113,7 +112,12 @@ def compose(self) -> ComposeResult: lcls.border_title = "Frame locals" yield lcls yield Tracer(classes="box") - yield Placeholder(classes="box") + traceback = RichLog( + classes="box", id="traceback", highlight=True, markup=True, wrap=False + ) + traceback.border_title = "Exception traceback" + traceback.styles.border = ("solid", "gray") + yield traceback yield Footer() def action_reload(self) -> None: @@ -130,7 +134,6 @@ def on_checkbox_changed(self, message: Checkbox.Changed): for module in self._module_chain: if module.uid == message.checkbox.id: module.is_frozen = bool(message.value) - message.stop() def set_start_epoch(self, *args, **kwargs): _ = args @@ -192,151 +195,70 @@ def print_pretty(self, msg: Any) -> None: self.log_tracer(Pretty(msg)) @classmethod - def get_function_frame(cls, func, exc_traceback): + def get_class_frame(cls, func: Callable, exc_traceback) -> Optional[FrameType]: """ - Find the frame of the original function in the traceback. + Find the frame of the last callable within the scope of the MatchboxModule in + the traceback. In this instance, the MatchboxModule is a class so we want to find + the frame of the method that either (1) threw the exception, or (2) called a + function that threw (or originated) the exception. """ + last_frame = None for frame, _ in traceback.walk_tb(exc_traceback): - if frame.f_code.co_name == func.__name__: - return frame - return None - - # TODO: Refactor this - async def catch_and_hang( - self, callable: MatchboxModule | Callable, do_reload_code: bool, *args, **kwargs - ): - if not do_reload_code: # Take out this block. This is just code reloading. - _callable = ( - callable.underlying_fn - if isinstance(callable, MatchboxModule) - else callable - ) - # I think it's not a good idea to let the user reload other modules, because it could - # lead to unexpected behavior across the codebase (e.g. if the function called by the - # callable is used elsewhere where the reference to the function is not updated, - # which probably do not want to do). - self.print_info( - f"[*] Reloading callable '{_callable.__name__}'. (Anything outside this scope will not be reloaded)" - ) + for name, val in inspect.getmembers(func): + if ( + name == frame.f_code.co_name + and "self" in inspect.getargs(frame.f_code).args + ): + print(f"Found method {val} in traceback, continuing...") + last_frame = frame + return last_frame - # This is super interesting stuff about Python's inner workings! Look at the - # documentation for more information: - # https://docs.python.org/3/reference/datamodel.html?highlight=__func__#instance-methods - importlib.reload(sys.modules[_callable.__module__]) - reloaded_module = importlib.import_module(_callable.__module__) - rld_callable = None - # First case, it's a class that we're trying to instantiate. We just need to - # reload the class: - if inspect.isclass(_callable) or _callable.__name__.endswith("__init__"): - self.print_info( - f" -> Reloading class '{_callable.__name__}' from module '{_callable.__module__}'", - ) - reloaded_class = getattr(reloaded_module, _callable.__name__) - rld_callable = reloaded_class - elif hasattr(_callable, "__self__"): - # _callable is a *bound* class method, so we can retrieve the class and reload it - self.print_info( - f" -> Reloading class '{_callable.__self__.__class__.__name__}'" - ) - reloaded_class = getattr( - reloaded_module, _callable.__self__.__class__.__name__ - ) - # Now find the method in the reloaded class, and replace the - # with the reloaded one. - for name, val in inspect.getmembers(reloaded_class): - if inspect.isfunction(val) and val.__name__ == _callable.__name__: - self.print_info( - f" -> Reloading method '{name}'", - ) - rld_callable = val - elif _callable.__name__ == "": - self.print_info( - " -> Callable is a lambda function. This is not supported yet.", - ) - else: - # Most likely we end up here because _callable is the function object of the - # called method, not the method itself. Is there even a case where we end up - # with the method object? First we can try to reload it directly if it was a - # module level function: - try: - self.print_info( - f" -> Reloading module level function '{_callable.__name__}'", - ) - rld_callable = getattr(reloaded_module, _callable.__name__) - except AttributeError: - self.print_info( - f" -> Could not find '{_callable.__name__}' in module '{_callable.__module__}'. " - + "Looking for a class method...", - ) - # Ok that failed, so we need to find the class of the method and reload it, - # then find the method in the reloaded class and replace the function with - # the method's function object; this is the same as above. - # TODO: This feels very hacky! Can we find the class in a better way, maybe - # without going through all classes in the module? Because I'm not sure if the - # qualname always contains the class name in this way; like what about - # inheritance? - self.print_info( - f" -> Reloading class {_callable.__qualname__.split('.')[0]}", - ) - reloaded_class = getattr( - reloaded_module, _callable.__qualname__.split(".")[0] - ) - for name, val in inspect.getmembers(reloaded_class): - if inspect.isfunction(val) and name == _callable.__name__: - self.print_info( - f" -> Reloading method {name}", - ) - rld_callable = val - break - if rld_callable is None: - self.print_err( - f"Could not reload callable {_callable}!", - ) - await self.hang(threw=True) - self.print_info( - f":) Reloaded callable {_callable.__name__}! Retrying the call...", - ) - _callable = rld_callable - if isinstance(callable, MatchboxModule): - # callable.underlying_fn = _callable - callable = MatchboxModule( - callable._str_rep, - _callable, - *callable.partial.args, - **callable.partial.keywords, - ) - else: - raise NotImplementedError() - # TODO: What if the user modified other methods/functions called by the _callable? - # Should we find them and recursively reload them? Maybe we can keep track of - # every called *user* function, and if the user modifies any after an exception is - # caught, we can first ask if we should reload it, warning them about the previous - # calls that will be affected by the reload. - # TODO: Check if we changed the function signature and if so, backtrace the call - # and update the arguments by re-running the routine that generated them and made - # the call. - # TODO: Free memory allocation (open file descriptors, etc.) before retrying the - # call. + @classmethod + def get_lambda_child_frame( + cls, func: Callable, exc_traceback + ) -> Tuple[Optional[FrameType], Optional[str]]: + """ + Find the frame of the last callable within the scope of the MatchboxModule in + the traceback. In this instance, the MatchboxModule is a lambda function so we want + to find the frame of the first function called by the lambda. + """ + lambda_args = inspect.getargs(func.__code__) + potential_matches = {} + for frame, _ in traceback.walk_tb(exc_traceback): + assert lambda_args is not None + frame_args = inspect.getargvalues(frame) + for name, val in potential_matches.items(): + if val == frame.f_code.co_qualname: + return frame, name + elif hasattr(val, frame.f_code.co_name): + return frame, name + for name in lambda_args.args: + if name in frame_args.args: + # NOTE: Now we need to find the argument which initiated the call + # that threw! Which is somewhere deeper in the stack, which + # frame.f_code.co_qualname must match one of the frame_args.args! + # NOTE: We know the next frame in the loop WILL match one of + # this frame's arguments, either in the qual_name directly or in + # the qual_name base (the class) + potential_matches[name] = frame_args.locals[name] + return None, None + + @classmethod + def get_function_frame(cls, func: Callable, exc_traceback) -> Optional[FrameType]: + raise NotImplementedError() + + async def _catch_and_hang(self, module: MatchboxModule, *args, **kwargs): try: - if isinstance(callable, partial): - self.print_info(f"Calling {callable.func} with") - self.print_pretty({"args": callable.args, "kwargs": callable.keywords}) - elif isinstance(callable, MatchboxModule): - self.print_info( - f"Calling MatchboxModule({callable.underlying_fn}) with" - ) - self.print_pretty( - { - "args": args, - "kwargs": kwargs, - "partial.args": callable.partial.args, - "partial.kwargs": callable.partial.keywords, - } - ) - else: - self.print_info(f"Calling {callable} with") - self.print_pretty({"args": args, "kwargs": kwargs}) - output = await asyncio.to_thread(callable, *args, **kwargs) + self.print_info(f"Calling MatchboxModule({module.underlying_fn}) with") + self.print_pretty( + { + "args": args, + "kwargs": kwargs, + "partial.args": module.partial.args, + "partial.kwargs": module.partial.keywords, + } + ) + output = await asyncio.to_thread(module, *args, **kwargs) self.print_info("Output:") self.print_pretty(output) return output @@ -344,64 +266,134 @@ async def catch_and_hang( # If the exception came from the wrapper itself, we should not catch it! exc_type, exc_value, exc_traceback = sys.exc_info() if exc_traceback.tb_next is None: - self.print_warn("Could not find the next frame!") - self.print_pretty(traceback.format_exc()) + self.print_err( + "[ERROR] Could not find the next frame in the call stack!" + ) elif exc_traceback.tb_next.tb_frame.f_code.co_name == "catch_and_hang": - self.print_warn( - f"Caught exception in 'debug_trace': {exception}", + self.print_err( + f"[ERROR] Caught exception in the Builder: {exception}", ) - raise exception else: self.print_err( f"Caught exception: {exception}", ) - self.print_err(traceback.format_exc()) - func = ( - callable.underlying_fn - if isinstance(callable, MatchboxModule) - else callable - ) + self.query_one("#traceback", RichLog).write(traceback.format_exc()) + func = module.underlying_fn # NOTE: This frame is for the given function, which is the root of the # call tree (our MatchboxModule's underlying function). What we want is # to go down to the function that threw, and reload that only if it # wasn't called anywhere in the frozen module's call tree. - frame = self.get_function_frame(func, exc_traceback) + frame = None + if inspect.isclass(func): + frame = self.get_class_frame(func, exc_traceback) + elif inspect.isfunction(func) and func.__name__ == "": + frame, lambda_argname = self.get_lambda_child_frame( + func, exc_traceback + ) + module.throw_lambda_argname = lambda_argname + elif inspect.isfunction(func): + frame = self.get_function_frame(func, exc_traceback) + else: + raise NotImplementedError() if not frame: self.print_err( f"Could not find the frame of the original function {func} in the traceback." ) - # self.print_info("Exception thrown in:") - # self.print_pretty(frame) + module.throw_frame = frame + self.print_info("Exception thrown in:") + self.print_pretty(frame) + module.to_reload = True self.print_info("Hanged.") await self.hang(threw=True) - # reload = self.prompt( - # "Take action? ([L]aunch IPython shell and reload the code/[r]eload the code/[a]bort) ", - # ) - # if reload.lower() in ("l", ""): - # # Drop into an IPython shell to inspect the callable and its context. - # # Get the frame of the original callable - # frame = get_function_frame(callable, exc_traceback) - # if not frame: - # raise Exception( - # f"Could not find the frame of the original function {callable} in the traceback." - # ) - # interactive_shell = IPython.terminal.embed.InteractiveShellEmbed( - # cfg=IPython.terminal.embed.load_default_config(), - # banner1=Text( - # f"[*] Dropping into an IPython shell to inspect {callable} " - # + "with the locals as they were at the time of the exception " - # + f"thrown at line {frame.f_lineno} of {frame.f_code.co_filename}." - # + "\n============================== TIPS ==============================" - # + "\n -> Use '%whos' to list variables in the current scope." - # + "\n -> Use '%debug' to launch the debugger." - # + "\n -> Use '' to display the value of a variable. " - # + "Add a '?' to display the type." - # + "\n -> Use '?' to display the function's docstring. " - # + "Add a '?' to display the source code." - # + "\n -> Use 'frame??' to display the source code of the current frame which threw the exception." - # + "\n==================================================================", - # style="green", - # ), - # exit_msg=Text("Leaving IPython shell.", style="yellow"), - # ) - # interactive_shell(local_ns={**frame.f_locals, "frame": frame}) + + async def _reload_module(self, module: MatchboxModule): + if module.throw_frame is None: + self.exit(1) + raise RuntimeError( + f"Module {module} is set to reload but we don't have the frame that threw!" + ) + self.log_tracer( + Text( + f"Reloading code from {module.throw_frame.f_code.co_filename}", + style="purple", + ) + ) + code_obj = module.throw_frame.f_code + code_module = inspect.getmodule(code_obj) + if code_module is None: + self.exit(1) + raise RuntimeError( + f"Could not find the module for the code object {code_obj}." + ) + rld_module = importlib.reload(code_module) + if code_obj.co_qualname.endswith("__init__"): + class_name = code_obj.co_qualname.split(".")[0] + self.log_tracer( + Text( + f"-> Reloading class {class_name} from module {code_module}", + style="purple", + ) + ) + rld_callable = getattr(rld_module, class_name) + if rld_callable is not None: + self.log_tracer( + Text( + f"-> Reloaded class {code_obj.co_qualname} from module {code_module.__name__}", + style="cyan", + ) + ) + module.reload(rld_callable) + return + + else: + if code_obj.co_qualname.find(".") != -1: + class_name, _ = code_obj.co_qualname.split(".") + self.log_tracer( + Text( + f"-> Reloading class {class_name} from module {code_module}", + style="purple", + ) + ) + rld_class = getattr(rld_module, class_name) + rld_callable = None + # Now find the method in the reloaded class, and replace the + # with the reloaded one. + for name, val in inspect.getmembers(rld_class): + if inspect.isfunction(val) and val.__name__ == code_obj.co_name: + self.print_info( + f" -> Reloading method '{name}'", + ) + rld_callable = val + if rld_callable is not None: + self.log_tracer( + Text( + f"-> Reloaded class-level method {code_obj.co_qualname} from module {code_module.__name__}", + style="cyan", + ) + ) + if module.underlying_fn.__name__ == "": + assert module.throw_lambda_argname is not None + module.reload_surgically_in_lambda( + module.throw_lambda_argname, code_obj.co_name, rld_callable + ) + else: + module.reload_surgically(code_obj.co_name, rld_callable) + return + else: + print(code_module, code_obj, code_obj.co_name) + self.log_tracer( + Text( + f"-> Reloading module-level function {code_obj.co_name} from module {code_module.__name__}", + style="purple", + ) + ) + func = getattr(rld_module, code_obj.co_name) + if func is not None: + self.print_info( + f" -> Reloaded module level function {code_obj.co_name}", + ) + print(inspect.getsource(func)) + module.reload(func) + return + while True: + await asyncio.sleep(1) diff --git a/dataset/example.py b/dataset/example.py index 0033b79..439d4b8 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -20,6 +20,21 @@ from dataset.base.image import ImageDataset +def test(): + # return + raise Exception("This is an exception") + + +def test_target(): + return + raise Exception("This is an exception") + + +def test_recursive(): + # raise Exception("This is an exception") + test_target() + + class SingleProcessingExampleDataset(ImageDataset): IMG_SIZE = (32, 32) @@ -50,6 +65,7 @@ def __init__( debug=debug, tiny=tiny, ) + # raise NotImplementedError("This is a dummy error") self._img_dim = self.IMG_SIZE[0] if img_dim is None else img_dim self._samples, self._labels = self._load( progress, @@ -68,8 +84,10 @@ def _load( if progress is not None: assert job_id is not None progress.update(job_id, total=length) + # raise Exception("This is an exception") + # test() + test_recursive() for _ in range(length): - # raise NotImplementedError("This is a dummy error") if progress is not None: assert job_id is not None progress.advance(job_id)