Skip to content

Commit

Permalink
Add todos and test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed Jan 5, 2025
1 parent ac8335a commit a39dba7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 19 deletions.
46 changes: 28 additions & 18 deletions bootstrap/hot_reloading/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ async def _reload_code_obj(self, code_obj: CodeType, module: MatchboxModule):
rld_module = importlib.reload(code_module)
if code_obj.co_qualname.endswith("__init__"):
class_name = code_obj.co_qualname.split(".")[0]
self.ui.log_tracer(
Text(
f"-> Reloading class {class_name} from module {code_module}",
style="purple",
)
)
# self.ui.log_tracer(
# Text(
# f"-> Reloading class {class_name} from module {code_module}",
# style="purple",
# )
# )
rld_callable = getattr(rld_module, class_name)
if rld_callable is not None:
self.ui.log_tracer(
Expand All @@ -248,19 +248,18 @@ async def _reload_code_obj(self, code_obj: CodeType, module: MatchboxModule):
style="cyan",
)
)
print(inspect.getsource(rld_callable))
# print(inspect.getsource(rld_callable))
module.reload(rld_callable)
return

else:
if code_obj.co_qualname.find(".") != -1:
class_name, _ = code_obj.co_qualname.split(".")
self.ui.log_tracer(
Text(
f"-> Reloading class {class_name} from module {code_module}",
style="purple",
)
)
# self.ui.log_tracer(
# Text(
# f"-> Reloading class {class_name} from module {code_module}",
# style="purple",
# )
# )
rld_class = getattr(rld_module, class_name)
rld_callable = None
# Now find the method in the reloaded class, and replace the
Expand Down Expand Up @@ -315,19 +314,23 @@ async def reload_module(self, module: MatchboxModule):
)
code_obj = module.root_frame.f_code
else:
# TODO: In the future we should simplify all these cases into one general
# case, where we reload:
# 1. The arguments, and reinstantiate them if they're objects
# 2. The callable itself (class.__init__, function, lambda, etc.)
# and then we rehook arguments into the callable.
if module.underlying_fn.__name__ == "<lambda>":
self.ui.exit(1)
raise NotImplementedError(
"Non-throwing Lambda reloading not implemented yet."
)
# TODO: Get the lambda arguments, and for each argument, find the code
# object and the arg name. The run the following.
lambda_args = inspect.getargs(module.underlying_fn.__code__).args
# print(lambda_args)
print(module.partial.args, module.partial.keywords)
print(lambda_args)
print("Partial args:", module.partial.args, module.partial.keywords)
all_args = list(module.partial.args) + list(
module.partial.keywords.values()
)
print("Combined partial args:", all_args)

def get_code_obj(a):
if inspect.iscode(a):
Expand All @@ -342,9 +345,16 @@ def get_code_obj(a):

assert len(lambda_args) == len(all_args)
for argname, argval in zip(lambda_args, all_args):
print(f"reloading argument '{argname}'")
code_obj = get_code_obj(argval)
module.root_lambda_argname = argname
# FIXME: This won't work because the code object is the argument
# callable of the lambda, which will be reloaded as the module's
# underlying function. The logic for reloading lambdas was written
# too specifically and can't work in this case. We should just
# refactor the whole hot reloading engine.
await self._reload_code_obj(code_obj, module)
return
elif inspect.isclass(module.underlying_fn):
code_obj = module.underlying_fn.__init__.__code__
else:
Expand Down
44 changes: 43 additions & 1 deletion bootstrap/launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from dataclasses import asdict
from time import sleep
from typing import Any
from typing import Any, Callable

import hydra_zen
import torch
Expand Down Expand Up @@ -107,6 +107,40 @@ def init_wandb(
wandb.watch(model, log=log, log_graph=log_graph) # type: ignore


class TestClass:
def __init__(self):
print("init test class")

def test(self):
print(123)
# raise Exception("Test class method exception")


def test_func_depth_with_throwing_arg_2(class_obj):
print("arg4 will throw!!")
class_obj.test()


def test_func_depth_with_throwing_arg(arg4: Callable):
print("arg4 will throw (or not)")
arg4("outch")


def test_func_depth(arg3):
# print("heyyyyyy")
print(arg3)
# raise Exception("This is a depth-1 module-level function exception")


def test_function(arg1, arg2, arg3):
# print("hey bg")
print(arg1)
# raise Exception("This is a depth-0 module-level function exception")
test_func_depth(arg1)
test_func_depth_with_throwing_arg(test_func_depth)
test_func_depth_with_throwing_arg_2(arg3)


def launch_builder(
run, # type: ignore
data_loader: Partial[DataLoader[Any]],
Expand Down Expand Up @@ -136,7 +170,15 @@ def launch_builder(
model,
encoder_input_dim=hydra_zen.just(dataset).img_dim ** 2, # type: ignore
)
test_module = MatchboxModule(
"Test",
test_function,
arg1="hey man",
arg2=12,
arg3=TestClass(),
)
chain = [
# test_module,
dataset_module,
MatchboxModule(
"Dataset test",
Expand Down

0 comments on commit a39dba7

Please sign in to comment.