Skip to content

Commit

Permalink
Refactor hot code reloading
Browse files Browse the repository at this point in the history
This is a major refactoring of hot code reloading. It now behaves much
better, finds the origin of the exceptions much deeper in the traceback,
but only reloads within the scope of the module. It handles class
instantiations, class methods, lambdas, and soon module-level functions.
There's no limit to how deep the exception is in the call stack, as long
as it can be traced back to the module's underlying function.
  • Loading branch information
DubiousCactus committed Nov 3, 2024
1 parent 9e7bce8 commit cb15829
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 201 deletions.
36 changes: 34 additions & 2 deletions bootstrap/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from dataclasses import dataclass
from functools import partial
from types import FrameType
from typing import Any, Callable, List, Optional

from hydra_zen.typing import Partial
Expand All @@ -10,11 +11,42 @@ class MatchboxModule:
def __init__(self, name: str, fn: Callable | Partial, *args, **kwargs):
self._str_rep = name
self._uid = uuid.uuid4().hex
self.underlying_fn = fn.func if isinstance(fn, partial) else fn
self.underlying_fn: Callable = fn.func if isinstance(fn, partial) else fn
self.partial = partial(fn, *args, **kwargs)
self.first_run = True
self.to_reload = False
self.result = None
self.is_frozen = False
self.throw_frame: Optional[FrameType] = None
self.throw_lambda_argname: Optional[str] = None

def reload(self, new_func: Callable) -> None:
self.underlying_fn = new_func
self.partial = partial(
self.underlying_fn, *self.partial.args, **self.partial.keywords
)
self.to_reload = False

def reload_surgically(self, method_name: str, method: Callable) -> None:
setattr(self.underlying_fn, method_name, method)
self.partial = partial(
self.underlying_fn, *self.partial.args, **self.partial.keywords
)
self.to_reload = False

def reload_surgically_in_lambda(
self, arg_name: str, method_name: str, method: Callable
) -> None:
if arg_name not in self.partial.keywords.keys():
raise KeyError(
"Could not find the argument to replace in the partial kwargs!"
)
for k, v in self.partial.keywords.items():
setattr(v, method_name, partial(method, v))
self.partial.keywords[k] = v # Need to update when using dict iterator
self.partial = partial(
self.underlying_fn, *self.partial.args, **self.partial.keywords
)
self.to_reload = False

def __call__(self, module_chain: List) -> Any:
"""
Expand Down
5 changes: 4 additions & 1 deletion bootstrap/launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ async def launch_with_async_gui():
model_module,
MatchboxModule(
"Model forward",
lambda model, dataset: model(dataset[0][0].unsqueeze(0)),
# NOTE: For now we need to call .forward() explicitly, as Matchbox
# doesn't yet handle hot code reloading for model() due to Pytorch
# wrapping
lambda model, dataset: model.forward(dataset[0][0].unsqueeze(0)),
model=model_module,
dataset=dataset_module,
),
Expand Down
Loading

0 comments on commit cb15829

Please sign in to comment.