From a39dba72d2abf731bd74c8028fb59cd1cac3615e Mon Sep 17 00:00:00 2001 From: Theo Date: Sun, 5 Jan 2025 16:56:51 +0000 Subject: [PATCH] Add todos and test cases --- bootstrap/hot_reloading/engine.py | 46 +++++++++++++++++++------------ bootstrap/launch_experiment.py | 44 ++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/bootstrap/hot_reloading/engine.py b/bootstrap/hot_reloading/engine.py index 6c25ffd..e8cc7ea 100644 --- a/bootstrap/hot_reloading/engine.py +++ b/bootstrap/hot_reloading/engine.py @@ -234,12 +234,12 @@ async def _reload_code_obj(self, code_obj: CodeType, module: MatchboxModule): rld_module = importlib.reload(code_module) if code_obj.co_qualname.endswith("__init__"): class_name = code_obj.co_qualname.split(".")[0] - self.ui.log_tracer( - Text( - f"-> Reloading class {class_name} from module {code_module}", - style="purple", - ) - ) + # self.ui.log_tracer( + # Text( + # f"-> Reloading class {class_name} from module {code_module}", + # style="purple", + # ) + # ) rld_callable = getattr(rld_module, class_name) if rld_callable is not None: self.ui.log_tracer( @@ -248,19 +248,18 @@ async def _reload_code_obj(self, code_obj: CodeType, module: MatchboxModule): style="cyan", ) ) - print(inspect.getsource(rld_callable)) + # print(inspect.getsource(rld_callable)) module.reload(rld_callable) return - else: if code_obj.co_qualname.find(".") != -1: class_name, _ = code_obj.co_qualname.split(".") - self.ui.log_tracer( - Text( - f"-> Reloading class {class_name} from module {code_module}", - style="purple", - ) - ) + # self.ui.log_tracer( + # Text( + # f"-> Reloading class {class_name} from module {code_module}", + # style="purple", + # ) + # ) rld_class = getattr(rld_module, class_name) rld_callable = None # Now find the method in the reloaded class, and replace the @@ -315,19 +314,23 @@ async def reload_module(self, module: MatchboxModule): ) code_obj = module.root_frame.f_code else: + # TODO: In the future we should simplify all these cases into one general + # case, where we reload: + # 1. The arguments, and reinstantiate them if they're objects + # 2. The callable itself (class.__init__, function, lambda, etc.) + # and then we rehook arguments into the callable. if module.underlying_fn.__name__ == "": self.ui.exit(1) raise NotImplementedError( "Non-throwing Lambda reloading not implemented yet." ) - # TODO: Get the lambda arguments, and for each argument, find the code - # object and the arg name. The run the following. lambda_args = inspect.getargs(module.underlying_fn.__code__).args - # print(lambda_args) - print(module.partial.args, module.partial.keywords) + print(lambda_args) + print("Partial args:", module.partial.args, module.partial.keywords) all_args = list(module.partial.args) + list( module.partial.keywords.values() ) + print("Combined partial args:", all_args) def get_code_obj(a): if inspect.iscode(a): @@ -342,9 +345,16 @@ def get_code_obj(a): assert len(lambda_args) == len(all_args) for argname, argval in zip(lambda_args, all_args): + print(f"reloading argument '{argname}'") code_obj = get_code_obj(argval) module.root_lambda_argname = argname + # FIXME: This won't work because the code object is the argument + # callable of the lambda, which will be reloaded as the module's + # underlying function. The logic for reloading lambdas was written + # too specifically and can't work in this case. We should just + # refactor the whole hot reloading engine. await self._reload_code_obj(code_obj, module) + return elif inspect.isclass(module.underlying_fn): code_obj = module.underlying_fn.__init__.__code__ else: diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index bd4ab6d..aa75424 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -10,7 +10,7 @@ import os from dataclasses import asdict from time import sleep -from typing import Any +from typing import Any, Callable import hydra_zen import torch @@ -107,6 +107,40 @@ def init_wandb( wandb.watch(model, log=log, log_graph=log_graph) # type: ignore +class TestClass: + def __init__(self): + print("init test class") + + def test(self): + print(123) + # raise Exception("Test class method exception") + + +def test_func_depth_with_throwing_arg_2(class_obj): + print("arg4 will throw!!") + class_obj.test() + + +def test_func_depth_with_throwing_arg(arg4: Callable): + print("arg4 will throw (or not)") + arg4("outch") + + +def test_func_depth(arg3): + # print("heyyyyyy") + print(arg3) + # raise Exception("This is a depth-1 module-level function exception") + + +def test_function(arg1, arg2, arg3): + # print("hey bg") + print(arg1) + # raise Exception("This is a depth-0 module-level function exception") + test_func_depth(arg1) + test_func_depth_with_throwing_arg(test_func_depth) + test_func_depth_with_throwing_arg_2(arg3) + + def launch_builder( run, # type: ignore data_loader: Partial[DataLoader[Any]], @@ -136,7 +170,15 @@ def launch_builder( model, encoder_input_dim=hydra_zen.just(dataset).img_dim ** 2, # type: ignore ) + test_module = MatchboxModule( + "Test", + test_function, + arg1="hey man", + arg2=12, + arg3=TestClass(), + ) chain = [ + # test_module, dataset_module, MatchboxModule( "Dataset test",