diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..183ab29 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,31 @@ +{ + "configurations": [ + { + "name": "Launch Train [exp_a]", + "type": "python", + "request": "launch", + "python": "/Users/cactus/miniforge3/envs/bellsnw/bin/python", + "autoReload": { "enable": true }, + "program": "${workspaceFolder}/train.py", + "args": ["+experiment=exp_a", "dataset.tiny=1"] + }, + { + "name": "Launch Build [exp_a]", + "type": "python", + "request": "launch", + "python": "/Users/cactus/miniforge3/envs/bellsnw/bin/python", + "autoReload": { "enable": true }, + "program": "${workspaceFolder}/build.py", + "args": ["+experiment=exp_a", "dataset.tiny=1"] + }, + { + "name": "Attach Build [exp_a]", + "type": "python", + "request": "attach", + "connect": { + "host": "localhost", + "port": 5555 + } + } + ] +} diff --git a/bootstrap/__init__.py b/bootstrap/__init__.py index e69de29..8c4b543 100644 --- a/bootstrap/__init__.py +++ b/bootstrap/__init__.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Optional + +from hydra_zen.typing import Partial + + +class MatchboxModule: + PREV = "MatchboxModule.PREV" # TODO: This is used as an enum value. Should figure it out + + def __init__(self, name: str, fn: Callable | Partial, *args, **kwargs): + # TODO: Figure out this entire class. It's a hack, I'm still figuring things + # out as I go. + self._str_rep = name + self.underlying_fn = fn.func if isinstance(fn, partial) else fn + self.partial = partial(fn, *args, **kwargs) + + def __call__(self, prev_result: Any) -> Any: + # TODO: Replace .PREV in any of the function's args/kwargs with prev_result + for i, arg in enumerate(self.partial.args): + if arg == self.PREV: + assert prev_result is not None + self.partial.args[i] = prev_result + for key, value in self.partial.keywords.items(): + if value == self.PREV: + assert prev_result is not None + self.partial.keywords[key] = prev_result + return self.partial() + + def __str__(self) -> str: + return self._str_rep + + +@dataclass +class MatchboxModuleState: + first_run: bool + result: Any + is_frozen: bool diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index 820813f..710675d 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -25,6 +25,7 @@ from rich.syntax import Syntax from torch.utils.data import DataLoader, Dataset +from bootstrap import MatchboxModule from bootstrap.factories import ( make_dataloaders, make_datasets, @@ -34,6 +35,7 @@ make_training_loss, parallelize_model, ) +from bootstrap.tui.builder_ui import BuilderUI from bootstrap.tui.training_ui import TrainingUI from conf import project as project_conf from src.base_tester import BaseTester @@ -105,6 +107,145 @@ def init_wandb( wandb.watch(model, log=log, log_graph=log_graph) # type: ignore +def launch_builder( + run, # type: ignore + data_loader: Partial[DataLoader[Any]], + optimizer: Partial[torch.optim.Optimizer], # pyright: ignore + scheduler: Partial[torch.optim.lr_scheduler.LRScheduler], + trainer: Partial[BaseTrainer], + tester: Partial[BaseTester], + dataset: Partial[Dataset[Any]], + model: Partial[torch.nn.Module], + training_loss: Partial[torch.nn.Module], +): + exp_conf = hydra_zen.to_yaml( + dict( + run_conf=run, + dataset=dataset, + model=model, + optimizer=optimizer, + scheduler=scheduler, + training_loss=training_loss, + ) + ) + # TODO: Overwrite data_loader.num_workers=0 + # data_loader.num_workers = 0 + + async def launch_with_async_gui(): + tui = BuilderUI() + task = asyncio.create_task(tui.run_async()) + await asyncio.sleep(0.5) # Wait for the app to start up + while not tui.is_running: + await asyncio.sleep(0.01) # Wait for the app to start up + # trace_catcher = TraceCatcher(tui) + + # ============ Partials instantiation ============ + # NOTE: We're gonna need a lot of thinking and right now I'm just too tired. We + # basically need to have a complex mechanism that does conditional hot code + # reloading in the following places. Of course, we'll never re-run the entire + # program while in the builder. We'll just reload pieces of code and restart the + # execution at some specific places. + + # train_dataset = await trace_catcher.catch_and_hang( + # dataset, split="train", seed=run.seed, progress=None, job_id=None + # ) + # model_inst = await trace_catcher.catch_and_hang( + # make_model, model, train_dataset + # ) + # opt_inst = await trace_catcher.catch_and_hang( + # make_optimizer, optimizer, model_inst + # ) + # scheduler_inst = await trace_catcher.catch_and_hang( + # make_scheduler, scheduler, opt_inst, run.epochs + # ) + # training_loss_inst = await trace_catcher.catch_and_hang( + # make_training_loss, run.training_mode, training_loss + # ) + # if model_inst is not None: + # model_inst = to_cuda_(parallelize_model(model_inst)) + # if training_loss_inst is not None: + # training_loss_inst = to_cuda_(training_loss_inst) + tui.chain_up( + [ + MatchboxModule( + "Dataset", + dataset, # TODO: Fix the code reloading, then revert to using the dataset factory + split="train", + seed=run.seed, + progress=None, + job_id=None, + ), + MatchboxModule( + "Model", + make_model, + model, + dataset=dataset, + ), + MatchboxModule( + "Optimizer", make_optimizer, optimizer, model=MatchboxModule.PREV + ), + MatchboxModule( + "Scheduler", + make_scheduler, + scheduler, + optimizer=MatchboxModule.PREV, + epochs=run.epochs, + ), + MatchboxModule( + "Loss", make_training_loss, run.training_mode, training_loss + ), + ] + ) + tui.run_chain() + # all_success = False # TODO: + # if all_success: + # # TODO: idk how to handle this YET + # # Somehow, the dataloader will crash if it's not forked when using multiprocessing + # # along with Textual. + # mp.set_start_method("fork") + # train_loader_inst, val_loader_inst, test_loader_inst = make_dataloaders( + # data_loader, + # train_dataset, + # val_dataset, + # test_dataset, + # run.training_mode, + # run.seed, + # ) + # init_wandb("test-run", model_inst, exp_conf) + # + # model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode) + # common_args = dict( + # run_name="build-run", + # model=model_inst, + # model_ckpt_path=model_ckpt_path, + # training_loss=training_loss_inst, + # tui=tui, + # ) + # if training_loss_inst is None: + # raise ValueError("training_loss must be defined in training mode!") + # if val_loader_inst is None or train_loader_inst is None: + # raise ValueError( + # "val_loader and train_loader must be defined in training mode!" + # ) + # await trainer( + # train_loader=train_loader_inst, + # val_loader=val_loader_inst, + # opt=opt_inst, + # scheduler=scheduler_inst, + # **common_args, + # **asdict(run), + # ).train( + # epochs=run.epochs, + # val_every=run.val_every, + # visualize_every=run.viz_every, + # visualize_train_every=run.viz_train_every, + # visualize_n_samples=run.viz_num_samples, + # ) + _ = await task + + asyncio.run(launch_with_async_gui()) + + def launch_experiment( run, # type: ignore data_loader: Partial[DataLoader[Any]], diff --git a/bootstrap/tui/builder_ui.py b/bootstrap/tui/builder_ui.py new file mode 100644 index 0000000..cef7def --- /dev/null +++ b/bootstrap/tui/builder_ui.py @@ -0,0 +1,425 @@ +import asyncio +import importlib +import inspect +import sys +import traceback +from functools import partial +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Tuple, +) + +from rich.console import RenderableType +from rich.pretty import Pretty +from rich.text import Text +from textual import log +from textual.app import App, ComposeResult +from textual.widgets import ( + Checkbox, + Footer, + Header, + Placeholder, + RichLog, +) + +from bootstrap.tui.widgets.tracer import Tracer + +if __name__ == "__main__": + import os + import sys + + sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + ) +from bootstrap import MatchboxModule, MatchboxModuleState +from bootstrap.tui.widgets.checkbox_panel import CheckboxPanel +from bootstrap.tui.widgets.editor import CodeEditor +from bootstrap.tui.widgets.files_tree import FilesTree + + +class BuilderUI(App): + """ + A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework. + """ + + TITLE = "Matchbox Builder TUI" + CSS_PATH = "styles/builder_ui.css" + + BINDINGS = [ + ("q", "quit", "Quit"), + ("d", "toggle_dark", "Toggle dark mode"), + ("r", "reload", "Reload hot code"), + ] + + def __init__(self): + super().__init__() + # TODO: Unify the module chain and the module states! + self._module_chain: List[MatchboxModule] = [] + self._runner_task = None + self._module_states: Dict[str, MatchboxModuleState] = {} + + def chain_up(self, modules_seq: List[MatchboxModule]) -> None: + """Add a module (callable to interactively implement and debug) to the + run-reload chain.""" + self._module_chain = modules_seq + for module in modules_seq: + self._module_states[str(module)] = MatchboxModuleState( + first_run=True, result=None, is_frozen=False + ) + self.query_one(CheckboxPanel).add_checkbox(str(module)) + + async def _run_chain(self) -> None: + log("_run_chain()") + self.log_tracer("Running the chain...") + if len(self._module_chain) == 0: + self.log_tracer(Text("The chain is empty!", style="bold red")) + # TODO: Should we reset all modules to "first_run"? Because if we restart the + # chain from a previously frozen step, we should run it as a first run, right? + # Not sure about this. + for module_idx, module in enumerate(self._module_chain): + log(f"Running module: {module}...") + initial_run = self._module_states[str(module)].first_run + self._module_states[str(module)].first_run = False + if self._module_states[str(module)].is_frozen: + self.log_tracer(Text(f"Skipping frozen module {module}", style="green")) + continue + self.log_tracer(Text(f"Running module: {module}", style="yellow")) + prev_result = self._module_states[ + list(self._module_states.keys())[module_idx - 1] + ].result + self._module_states[str(module)].result = await self.catch_and_hang( + module, initial_run, prev_result + ) + self.log_tracer(Text(f"{module} ran sucessfully!", style="bold green")) + self.print_info("Hanged.") + await self.hang(threw=False) + self.query_one(Tracer).clear() + + def run_chain(self) -> None: + if self._runner_task is not None: + log("Cancelling previous chain run...") + self._runner_task.cancel() + self._runner_task = None + log("Starting new chain run...") + self._runner_task = asyncio.create_task(self._run_chain(), name="run_chain") + + def compose(self) -> ComposeResult: + yield Header() + yield CheckboxPanel(classes="box") + yield CodeEditor(classes="box", id="code") + logs = RichLog(classes="box", id="logger") + logs.border_title = "User logs" + logs.styles.border = ("solid", "gray") + yield logs + ftree = FilesTree(classes="box") + ftree.border_title = "Project tree" + ftree.styles.border = ("solid", "gray") + yield ftree + lcls = Placeholder("Locals area", classes="box") + lcls.loading = True + lcls.border_title = "Frame locals" + yield lcls + yield Tracer(classes="box") + yield Placeholder(classes="box") + yield Footer() + + def action_reload(self) -> None: + log("Reloading...") + self.query_one(Tracer).clear() + self.log_tracer("Reloading hot code...") + self.query_one(CheckboxPanel).ready() + self.query_one(Tracer).ready() + # self.query_one(CheckboxPanel).hang(threw) + self.query_one(CodeEditor).ready() + self.run_chain() + + def on_checkbox_changed(self, message: Checkbox.Changed): + self.query_one("#logger", RichLog).write( + f"Checkbox {message.checkbox.id} changed to: {message.value}" + ) + assert message.checkbox.id is not None + self._module_states[message.checkbox.id].is_frozen = bool(message.value) + # setattr(self, f"{message.checkbox.id}_is_frozen", message.value) + + def print_log(self, message: str) -> None: + self.query_one("#logger", RichLog).write(message) + + def print(self, message: str) -> None: + # TODO: Remove this by merging main into this branch + self.print_log(message) + + def set_start_epoch(self, *args, **kwargs): + _ = args + _ = kwargs + pass + + def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: + _ = total + + def noop(*args, **kwargs): + _ = args + _ = kwargs + pass + + return iterable, noop + + def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: + _ = total + + def noop(*args, **kwargs): + _ = args + _ = kwargs + pass + + return iterable, noop + + def log_tracer(self, message: str | RenderableType) -> None: + self.query_one(Tracer).write(message) + + async def hang(self, threw: bool) -> None: + """ + Give visual signal that the builder is hung, either due to an exception or + because the function ran successfully. + """ + self.query_one(Tracer).hang(threw) + self.query_one(CodeEditor).hang(threw) + while self.is_running: + await asyncio.sleep(1) + + def print_err(self, msg: str | Exception) -> None: + self.log_tracer( + Text("[!] " + msg, style="bold red") + if isinstance(msg, str) + else Pretty(msg) + ) + + def print_warn(self, msg: str) -> None: + self.log_tracer(Text("[!] " + msg, style="bold yellow")) + + def prompt(self, msg: str) -> str: + # TODO: We need to use a popup callback + self.log_tracer(Text("[?] " + msg, style="italic pink")) + return "y" + + def print_info(self, msg: str) -> None: + self.log_tracer(Text(msg, style="bold blue")) + + def print_pretty(self, msg: Any) -> None: + self.log_tracer(Pretty(msg)) + + @classmethod + def get_function_frame(cls, func, exc_traceback): + """ + Find the frame of the original function in the traceback. + """ + for frame, _ in traceback.walk_tb(exc_traceback): + if frame.f_code.co_name == func.__name__: + return frame + return None + + # TODO: Refactor this + async def catch_and_hang( + self, callable: MatchboxModule | Callable, reload_code: bool, *args, **kwargs + ): + if not reload_code: # Take out this block. This is just code reloading. + _callable = ( + callable.underlying_fn + if isinstance(callable, MatchboxModule) + else callable + ) + # I think it's not a good idea to let the user reload other modules, because it could + # lead to unexpected behavior across the codebase (e.g. if the function called by the + # callable is used elsewhere where the reference to the function is not updated, + # which probably do not want to do). + self.print_info( + f"[*] Reloading callable '{_callable.__name__}'. (Anything outside this scope will not be reloaded)" + ) + + # This is super interesting stuff about Python's inner workings! Look at the + # documentation for more information: + # https://docs.python.org/3/reference/datamodel.html?highlight=__func__#instance-methods + importlib.reload(sys.modules[_callable.__module__]) + reloaded_module = importlib.import_module(_callable.__module__) + rld_callable = None + # First case, it's a class that we're trying to instantiate. We just need to + # reload the class: + if inspect.isclass(_callable) or _callable.__name__.endswith("__init__"): + self.print_info( + f" -> Reloading class '{_callable.__name__}' from module '{_callable.__module__}'", + ) + reloaded_class = getattr(reloaded_module, _callable.__name__) + self.print_log(reloaded_class) + self.print_log(inspect.getsource(reloaded_class)) + rld_callable = reloaded_class + elif hasattr(_callable, "__self__"): + # _callable is a *bound* class method, so we can retrieve the class and reload it + self.print_info( + f" -> Reloading class '{_callable.__self__.__class__.__name__}'" + ) + reloaded_class = getattr( + reloaded_module, _callable.__self__.__class__.__name__ + ) + # Now find the method in the reloaded class, and replace the + # with the reloaded one. + for name, val in inspect.getmembers(reloaded_class): + if inspect.isfunction(val) and val.__name__ == _callable.__name__: + self.print_info( + f" -> Reloading method '{name}'", + ) + rld_callable = val + else: + # Most likely we end up here because _callable is the function object of the + # called method, not the method itself. Is there even a case where we end up + # with the method object? First we can try to reload it directly if it was a + # module level function: + try: + self.print_info( + f" -> Reloading module level function '{_callable.__name__}'", + ) + rld_callable = getattr(reloaded_module, _callable.__name__) + except AttributeError: + self.print_info( + f" -> Could not find '{_callable.__name__}' in module '{_callable.__module__}'. " + + "Looking for a class method...", + ) + # Ok that failed, so we need to find the class of the method and reload it, + # then find the method in the reloaded class and replace the function with + # the method's function object; this is the same as above. + # TODO: This feels very hacky! Can we find the class in a better way, maybe + # without going through all classes in the module? Because I'm not sure if the + # qualname always contains the class name in this way; like what about + # inheritance? + self.print_info( + f" -> Reloading class {_callable.__qualname__.split('.')[0]}", + ) + reloaded_class = getattr( + reloaded_module, _callable.__qualname__.split(".")[0] + ) + for name, val in inspect.getmembers(reloaded_class): + if inspect.isfunction(val) and name == _callable.__name__: + self.print_info( + f" -> Reloading method {name}", + ) + rld_callable = val + break + if rld_callable is None: + self.print_err( + f"Could not reload callable {_callable}!", + ) + await self.hang(threw=True) + self.print_info( + f":) Reloaded callable {_callable.__name__}! Retrying the call...", + ) + _callable = rld_callable + if isinstance(callable, MatchboxModule): + # callable.underlying_fn = _callable + callable = MatchboxModule( + callable._str_rep, + _callable, + *callable.partial.args, + **callable.partial.keywords, + ) + else: + raise NotImplementedError() + # TODO: What if the user modified other methods/functions called by the _callable? + # Should we find them and recursively reload them? Maybe we can keep track of + # every called *user* function, and if the user modifies any after an exception is + # caught, we can first ask if we should reload it, warning them about the previous + # calls that will be affected by the reload. + # TODO: Check if we changed the function signature and if so, backtrace the call + # and update the arguments by re-running the routine that generated them and made + # the call. + # TODO: Free memory allocation (open file descriptors, etc.) before retrying the + # call. + try: + if isinstance(callable, partial): + self.print_info(f"Calling {callable.func} with") + self.print_pretty({"args": callable.args, "kwargs": callable.keywords}) + elif isinstance(callable, MatchboxModule): + self.print_info( + f"Calling MatchboxModule({callable.underlying_fn}) with" + ) + self.print_pretty( + { + "args": args, + "kwargs": kwargs, + "partial.args": callable.partial.args, + "partial.kwargs": callable.partial.keywords, + } + ) + else: + self.print_info(f"Calling {callable} with") + self.print_pretty({"args": args, "kwargs": kwargs}) + output = await asyncio.to_thread(callable, *args, **kwargs) + self.print_info("Output:") + self.print_pretty(output) + return output + except Exception as exception: + # If the exception came from the wrapper itself, we should not catch it! + exc_type, exc_value, exc_traceback = sys.exc_info() + if exc_traceback.tb_next is None: + self.print_warn("Could not find the next frame!") + self.print_pretty(traceback.format_exc()) + elif exc_traceback.tb_next.tb_frame.f_code.co_name == "catch_and_hang": + self.print_warn( + f"Caught exception in 'debug_trace': {exception}", + ) + raise exception + else: + self.print_err( + f"Caught exception: {exception}", + ) + self.print_err(traceback.format_exc()) + func = ( + callable.underlying_fn + if isinstance(callable, MatchboxModule) + else callable + ) + # NOTE: This frame is for the given function, which is the root of the + # call tree (our MatchboxModule's underlying function). What we want is + # to go down to the function that threw, and reload that only if it + # wasn't called anywhere in the frozen module's call tree. + frame = self.get_function_frame(func, exc_traceback) + if not frame: + self.print_err( + f"Could not find the frame of the original function {func} in the traceback." + ) + # self.print_info("Exception thrown in:") + # self.print_pretty(frame) + self.print_info("Hanged.") + await self.hang(threw=True) + # reload = self.prompt( + # "Take action? ([L]aunch IPython shell and reload the code/[r]eload the code/[a]bort) ", + # ) + # if reload.lower() in ("l", ""): + # # Drop into an IPython shell to inspect the callable and its context. + # # Get the frame of the original callable + # frame = get_function_frame(callable, exc_traceback) + # if not frame: + # raise Exception( + # f"Could not find the frame of the original function {callable} in the traceback." + # ) + # interactive_shell = IPython.terminal.embed.InteractiveShellEmbed( + # cfg=IPython.terminal.embed.load_default_config(), + # banner1=Text( + # f"[*] Dropping into an IPython shell to inspect {callable} " + # + "with the locals as they were at the time of the exception " + # + f"thrown at line {frame.f_lineno} of {frame.f_code.co_filename}." + # + "\n============================== TIPS ==============================" + # + "\n -> Use '%whos' to list variables in the current scope." + # + "\n -> Use '%debug' to launch the debugger." + # + "\n -> Use '' to display the value of a variable. " + # + "Add a '?' to display the type." + # + "\n -> Use '?' to display the function's docstring. " + # + "Add a '?' to display the source code." + # + "\n -> Use 'frame??' to display the source code of the current frame which threw the exception." + # + "\n==================================================================", + # style="green", + # ), + # exit_msg=Text("Leaving IPython shell.", style="yellow"), + # ) + # interactive_shell(local_ns={**frame.f_locals, "frame": frame}) diff --git a/bootstrap/tui/styles/builder_ui.css b/bootstrap/tui/styles/builder_ui.css new file mode 100644 index 0000000..3b094e5 --- /dev/null +++ b/bootstrap/tui/styles/builder_ui.css @@ -0,0 +1,31 @@ +Screen { + layout: grid; + grid-size: 3; + grid-columns: 1fr 3fr 1fr; + /* grid-rows: 65% 35%; */ + grid-rows: 1fr 2fr 2fr; +} + +.box { + height: 100%; + border: solid green; +} + +#code { + row-span: 2; +} + +#logger { + row-span: 2; +} + + +Center { + margin-top: 1; + margin-bottom: 1; + layout: horizontal; +} + +ProgressBar { + padding-left: 3; +} diff --git a/bootstrap/tui/training_ui.py b/bootstrap/tui/training_ui.py index 6070b67..3b76080 100644 --- a/bootstrap/tui/training_ui.py +++ b/bootstrap/tui/training_ui.py @@ -24,6 +24,13 @@ from torchvision.datasets import MNIST from torchvision.transforms.functional import to_tensor +if __name__ == "__main__": + import os + import sys + + sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + ) from bootstrap.tui import Plot_BestModel, Task from bootstrap.tui.logger import Logger from bootstrap.tui.widgets.plotting import PlotterWidget @@ -35,7 +42,7 @@ class TrainingUI(App): A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework. """ - TITLE = "Matchbox TUI" + TITLE = "Matchbox Training TUI" CSS_PATH = "styles/training_ui.css" BINDINGS = [ diff --git a/bootstrap/tui/widgets/__init__.py b/bootstrap/tui/widgets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bootstrap/tui/widgets/checkbox_panel.py b/bootstrap/tui/widgets/checkbox_panel.py new file mode 100644 index 0000000..efaa687 --- /dev/null +++ b/bootstrap/tui/widgets/checkbox_panel.py @@ -0,0 +1,38 @@ +from textual.app import ComposeResult +from textual.containers import VerticalScroll +from textual.widgets import Checkbox, Static + + +class CheckboxPanel(Static): + def compose(self) -> ComposeResult: + yield VerticalScroll(id="tickers") + + def add_checkbox(self, label: str) -> None: + checkbox = Checkbox(label, value=False, id=label) + # yield Switch() # TODO: Use switches!! + self.query_one("#tickers").mount(checkbox) + + def on_mount(self): + self.ready() + + def on_checkbox_changed(self, message: Checkbox.Changed): + _ = message + self.due() + + def due(self) -> None: + # TODO: Blink the border + self.styles.border = ("dashed", "yellow") + self.styles.opacity = 0.8 + self.border_title = "Frozen modules: due for reloading" + + def hang(self, threw: bool) -> None: + if threw: + self.styles.border = ("dashed", "red") + self.border_title = "Frozen modules: exception was thrown" + else: + self.due() + + def ready(self) -> None: + self.styles.border = ("solid", "green") + self.styles.opacity = 1.0 + self.border_title = "Frozen modules: active" diff --git a/bootstrap/tui/widgets/editor.py b/bootstrap/tui/widgets/editor.py new file mode 100644 index 0000000..ee8583b --- /dev/null +++ b/bootstrap/tui/widgets/editor.py @@ -0,0 +1,55 @@ +from textual.app import ComposeResult +from textual.widgets import Static, TextArea + +TEXT = """\ +from developer import Help + +class App: + def __init__(self): + # This code editor is not yet operational, it is a placeholder for a future + # feature. You may use the editor of your choosing to edit your code in the + # meantime. Press 'r' or click on the footer button to reload the training + # program. + + def help_im_stuck(self) -> ?: + self.escape_key = '' + # Press if you are stuck in this text area! + + + def how_does_it_work(self) -> str: + # Using the checkboxes on the top left, you can freeze/unfreeze modules of your + # PyTorch program. + self.instructions() + + def instructions(self) -> Help: + # 1. The frozen modules will remain in memory and the code will + # only be executed once, which will save you precious time in your research. + + # 2. The unfrozen modules will be entirely reloaded from disk and re-run when you + # press 'r' or click on the footer button. This is called 'hot code reloading'. + + # 3. When your program crashes or throws an uncaught exception (i.e. bad tensor + # operation), Matchbox will catch it and display the trace below this text area. + # You will be able to debug it easily with the frame locals at time of death on + # the lower left corner. We will soon introduce a REPL to aid in debugging. + + # 4. You may as well call builder.print() in your code to log anything on the + # right Rich log panel. You can use any Rich renderables in addition to strings, + # tensors and what not. + # Hope this helps you become a more pragmatic and faster deep learning researcher :) +""" + + +class CodeEditor(Static): + def compose(self) -> ComposeResult: + yield TextArea.code_editor(TEXT, language="python") + + def on_mount(self): + self.border_title = "Code editor" + self.ready() + + def hang(self, threw: bool) -> None: + self.styles.border = ("dashed", "red" if threw else "yellow") + + def ready(self) -> None: + self.styles.border = ("solid", "green") diff --git a/bootstrap/tui/widgets/files_tree.py b/bootstrap/tui/widgets/files_tree.py new file mode 100644 index 0000000..4649d86 --- /dev/null +++ b/bootstrap/tui/widgets/files_tree.py @@ -0,0 +1,19 @@ +from textual.app import ComposeResult +from textual.widgets import ( + Static, + Tree, +) + + +class FilesTree(Static): + def compose(self) -> ComposeResult: + tree: Tree[dict] = Tree("root") + tree.root.expand() + src = tree.root.add("src", expand=True) + src.add_leaf("base_trainer.py") + src.add_leaf("base_tester.py") + src.add_leaf("this_is_dummy.py") + other = tree.root.add("it_is_all_dummy", expand=True) + other.add_leaf("even_this.py") + other.add_leaf("the_whole_tree.py") + yield tree diff --git a/bootstrap/tui/widgets/logger.py b/bootstrap/tui/widgets/logger.py new file mode 100644 index 0000000..147e122 --- /dev/null +++ b/bootstrap/tui/widgets/logger.py @@ -0,0 +1,78 @@ +from datetime import datetime +from typing import Any + +from rich.console import Group, RenderableType +from rich.pretty import Pretty +from rich.text import Text +from textual.app import ComposeResult +from textual.events import Print +from textual.widgets import RichLog, Static + + +class Logger(Static): + def compose(self) -> ComposeResult: + yield RichLog(highlight=True, markup=True, wrap=True) + + def on_mount(self): + self.begin_capture_print() + + def on_print(self, event: Print) -> None: + if event.text.strip() != "": + # FIXME: Why do we need this hack?! + self.wite(event.text, event.stderr) + + def wite(self, message: Any, is_stderr: bool): + logger: RichLog = self.query_one(RichLog) + if isinstance(message, (RenderableType, str)): + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan" if not is_stderr else "bold red", + end="", + ), + message, + ), + ) + else: + ppable, pp_msg = True, None + try: + pp_msg = Pretty(message) + except Exception: + ppable = False + if ppable and pp_msg is not None: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text(str(type(message)) + " ", style="italic blue", end=""), + pp_msg, + ) + ) + else: + try: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + message, + ), + ) + except Exception: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text(str(type(message)) + " ", style="italic blue", end=""), + Text(str(message)), + ), + ) diff --git a/bootstrap/tui/widgets/progress.py b/bootstrap/tui/widgets/progress.py index 05ecb4a..35256d4 100644 --- a/bootstrap/tui/widgets/progress.py +++ b/bootstrap/tui/widgets/progress.py @@ -105,14 +105,16 @@ def __next__(self): def update_hook(loss: Optional[float] = None): self.query_one(ProgressBar).advance() if loss is not None: - plabel: Label = self.query_one("#progress_label") # type: ignore - plabel.update(self.DESCRIPTIONS[task] + f"[loss={loss:.4f}]") + self.query_one("#progress_label", Label).update( + self.DESCRIPTIONS[task] + f"[loss={loss:.4f}]" + ) def reset_hook(): sleep(0.5) self.query_one(ProgressBar).update(total=100, progress=0) - plabel: Label = self.query_one("#progress_label") # type: ignore - plabel.update(self.DESCRIPTIONS[Task.IDLE]) + self.query_one("#progress_label", Label).update( + self.DESCRIPTIONS[Task.IDLE] + ) wrapper = None update_p, reset_p = partial(update_hook), partial(reset_hook) @@ -125,6 +127,5 @@ def reset_hook(): f"iterable must be a Sequence or an Iterator, got {type(iterable)}" ) self.query_one(ProgressBar).update(total=total, progress=0) - plabel: Label = self.query_one("#progress_label") # type: ignore - plabel.update(self.DESCRIPTIONS[task]) + self.query_one("#progress_label", Label).update(self.DESCRIPTIONS[task]) return wrapper, wrapper.update_loss_hook diff --git a/bootstrap/tui/widgets/tracer.py b/bootstrap/tui/widgets/tracer.py new file mode 100644 index 0000000..2341c7f --- /dev/null +++ b/bootstrap/tui/widgets/tracer.py @@ -0,0 +1,30 @@ +from rich.console import RenderableType +from textual.app import ComposeResult +from textual.widgets import RichLog, Static + + +class Tracer(Static): + def compose(self) -> ComposeResult: + yield RichLog() + + def on_mount(self): + self.ready() + + def hang(self, threw: bool) -> None: + # TODO: Blink the border + self.styles.border = ("dashed", "red" if threw else "yellow") + self.border_title = "Exception trace: hanged" + ( + "(exception thrown)" if threw else "(no exception thrown)" + ) + + def ready(self) -> None: + self.loading = True + self.styles.border = ("solid", "green") + self.border_title = "Exception trace: running" + + def write(self, message: str | RenderableType) -> None: + self.loading = False + self.query_one(RichLog).write(message) + + def clear(self) -> None: + self.query_one(RichLog).clear() diff --git a/build.py b/build.py new file mode 100755 index 0000000..aaed15d --- /dev/null +++ b/build.py @@ -0,0 +1,42 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 +# +# Copyright © 2023 Théo Morales +# +# Distributed under terms of the MIT license. + + +from rich.console import Console +from rich.live import Live + +if __name__ == "__main__": + console = Console() + status = console.status( + "[bold cyan]Building experiment configurations...", spinner="monkey" + ) + with Live(status, console=console): + from hydra_zen import store, zen + + from bootstrap.launch_experiment import launch_builder + from conf import project as project_conf + from conf.experiment import make_experiment_configs + from utils import seed_everything + + make_experiment_configs() + # ============ Hydra-Zen ============ + store.add_to_hydra_store( + overwrite_ok=True + ) # Overwrite Hydra's default config to update it + zen( + launch_builder, + pre_call=[ + lambda cfg: seed_everything( + cfg.run.seed + ) # training is the config of the training group, part of the base config + if project_conf.REPRODUCIBLE + else lambda: None, + ], + ).hydra_main( + config_name="base_experiment", + version_base="1.3", # Hydra base version + ) diff --git a/conf/experiment.py b/conf/experiment.py index 48f20b0..590c25d 100644 --- a/conf/experiment.py +++ b/conf/experiment.py @@ -211,6 +211,7 @@ class RunConfig: viz_num_samples: int = 5 load_from: Optional[str] = None training_mode: bool = True + build_mode: bool = True run_store = store(group="run") diff --git a/dataset/base/image.py b/dataset/base/image.py index 02aa7c4..0c75dd3 100644 --- a/dataset/base/image.py +++ b/dataset/base/image.py @@ -32,8 +32,8 @@ def __init__( dataset_name: str, split: str, seed: int, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, img_size: Optional[tuple[int, ...]] = None, augment: bool = False, normalize: bool = False, diff --git a/dataset/example.py b/dataset/example.py index e1408ca..228d905 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -29,8 +29,8 @@ def __init__( dataset_name: str, split: str, seed: int, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, img_dim: Optional[int] = None, augment: bool = False, normalize: bool = False, @@ -58,13 +58,18 @@ def __init__( def _load( self, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: length = 3 if self._tiny else 20 - progress.update(job_id, total=length) + if progress is not None: + assert job_id is not None + progress.update(job_id, total=length) for _ in range(length): - progress.advance(job_id) + # raise NotImplementedError("This is a dummy error") + if progress is not None: + assert job_id is not None + progress.advance(job_id) sleep(0.001 if self._tiny else 0.1) return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8) @@ -81,8 +86,8 @@ def __init__( dataset_name: str, split: str, seed: int, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, img_dim: Optional[int] = None, augment: bool = False, normalize: bool = False, @@ -115,8 +120,8 @@ def __init__( dataset_name: str, split: str, seed: int, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, img_dim: Optional[int] = None, augment: bool = False, normalize: bool = False, diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index 98162fd..3b77803 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -40,8 +40,8 @@ def __init__( seed: int, debug: bool, tiny: bool, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, **kwargs, ): _ = dataset_root @@ -68,8 +68,8 @@ def __init__( seed: int, debug: bool, tiny: bool, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, **kwargs, ): self._samples, self._labels = [], [] @@ -111,8 +111,8 @@ def __init__( seed: int, debug: bool, tiny: bool, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, scd_lazy: bool = True, **kwargs, ): @@ -373,8 +373,8 @@ def __init__( seed: int, debug: bool, tiny: bool, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, mpd_lazy: bool = True, mpd_chunk_size: int = 1, mpd_processes: Optional[int] = None, @@ -501,8 +501,8 @@ def __init__( seed: int, debug: bool, tiny: bool, - progress: Progress, - job_id: TaskID, + progress: Optional[Progress] = None, + job_id: Optional[TaskID] = None, **kwargs, ) -> None: super().__init__( diff --git a/model/example.py b/model/example.py index 8895f7b..75c53bc 100644 --- a/model/example.py +++ b/model/example.py @@ -27,6 +27,7 @@ def __init__( nn.ReLU(), nn.Linear(encoder_dim, latent_dim), ) + # raise Exception("This is an exception") self._decoder = nn.Sequential( nn.Linear(latent_dim, decoder_dim), nn.ReLU(), diff --git a/requirements.txt b/requirements.txt index 825ee8f..84f9e5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ neovim rich textual textual_plotext +textual[syntax] diff --git a/utils/__init__.py b/utils/__init__.py index 933b4b4..abf9b81 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -6,16 +6,10 @@ # Distributed under terms of the MIT license. -# import importlib -# import inspect import os import random - -# import sys -import traceback from typing import Any, Callable, Dict, List, Optional -# import IPython import numpy as np import torch from hydra.utils import to_absolute_path @@ -95,269 +89,3 @@ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return func(*args, **kwargs) return wrapper - - -# def blink_pbar(i: int, pbar: tqdm, n: int) -> None: -# """Blink the progress bar every n iterations. -# Args: -# i (int): current iteration -# pbar (tqdm): progress bar -# n (int): blink every n iterations -# """ -# if i % n == 0: -# pbar.colour = ( -# project_conf.Theme.TRAINING.value -# if pbar.colour == project_conf.Theme.VALIDATION.value -# else project_conf.Theme.VALIDATION.value -# ) -# -# -# def update_pbar_str(pbar: tqdm, string: str, color_code: int) -> None: -# """Update the progress bar string. -# Args: -# pbar (tqdm): progress bar -# string (str): string to update the progress bar with -# color_code (int): color code for the string -# """ -# pbar.set_description_str(colorize(string, color_code)) - - -def get_function_frame(func, exc_traceback): - """ - Find the frame of the original function in the traceback. - """ - for frame, _ in traceback.walk_tb(exc_traceback): - if frame.f_code.co_name == func.__name__: - return frame - return None - - -''' -# TODO: Refactor this -def debug_trace(callable): - """ - Decorator to call a callable and launch an IPython shell after an exception is thrown. This - lets the user debug the callable in the context of the exception and fix the function/method. It will - then retry the call until no exception is thrown, after reloading the function/method code. - """ - - def wrapper(*args, **kwargs): - nonlocal callable - while True: - try: - return callable(*args, **kwargs) - except Exception as exception: - # If the exception came from the wrapper itself, we should not catch it! - exc_type, exc_value, exc_traceback = sys.exc_info() - if exc_traceback.tb_next is None: - traceback.print_exc() - sys.exit(1) - elif exc_traceback.tb_next.tb_frame.f_code.co_name == "wrapper": - print( - colorize( - f"[!] Caught exception in 'debug_trace': {exception}", - project_conf.ANSI_COLORS["red"], - ) - ) - sys.exit(1) - print( - colorize( - f"[!] Caught exception: {exception}", - project_conf.ANSI_COLORS["red"], - ) - ) - full_traceback = input( - colorize( - "[?] Display full traceback? (y/[N]) ", - project_conf.ANSI_COLORS["yellow"], - ) - ) - if full_traceback.lower() == "y": - traceback.print_exc() - reload = input( - colorize( - "[?] Take action? ([L]aunch IPython shell and reload the code/[r]eload the code/[a]bort) ", - project_conf.ANSI_COLORS["yellow"], - ) - ) - - if reload.lower() not in ("l", "", "r"): - print("[!] Aborting") - # TODO: Why can't I just raise the exception? It's weird but it gets caught by - # the wrapper a few times until it finally gets raised. - sys.exit(1) - if reload.lower() in ("l", ""): - # Drop into an IPython shell to inspect the callable and its context. - # Get the frame of the original callable - frame = get_function_frame(callable, exc_traceback) - if not frame: - raise Exception( - f"Could not find the frame of the original function {callable} in the traceback." - ) - interactive_shell = IPython.terminal.embed.InteractiveShellEmbed( - cfg=IPython.terminal.embed.load_default_config(), - banner1=colorize( - f"[*] Dropping into an IPython shell to inspect {callable} " - + "with the locals as they were at the time of the exception " - + f"thrown at line {frame.f_lineno} of {frame.f_code.co_filename}." - + "\n============================== TIPS ==============================" - + "\n -> Use '%whos' to list variables in the current scope." - + "\n -> Use '%debug' to launch the debugger." - + "\n -> Use '' to display the value of a variable. " - + "Add a '?' to display the type." - + "\n -> Use '?' to display the function's docstring. " - + "Add a '?' to display the source code." - + "\n -> Use 'frame??' to display the source code of the current frame which threw the exception." - + "\n==================================================================", - project_conf.ANSI_COLORS["green"], - ), - exit_msg=colorize( - "Leaving IPython shell.", project_conf.ANSI_COLORS["yellow"] - ), - ) - interactive_shell(local_ns={**frame.f_locals, "frame": frame}) - # I think it's not a good idea to let the user reload other modules, because it could - # lead to unexpected behavior across the codebase (e.g. if the function called by the - # callable is used elsewhere where the reference to the function is not updated, - # which probably do not want to do). - print( - colorize( - f"[*] Reloading callable {callable.__name__}.", - project_conf.ANSI_COLORS["green"], - ) - + colorize( - " (Anything outside this scope will not be reloaded)", - project_conf.ANSI_COLORS["red"], - ) - ) - - # This is super interesting stuff about Python's inner workings! Look at the - # documentation for more information: - # https://docs.python.org/3/reference/datamodel.html?highlight=__func__#instance-methods - importlib.reload(sys.modules[callable.__module__]) - reloaded_module = importlib.import_module(callable.__module__) - rld_callable = None - if hasattr(callable, "__self__"): - # callable is a *bound* class method, so we can retrieve the class and reload it - print( - colorize( - f"-> Reloading class {callable.__self__.__class__.__name__}", - project_conf.ANSI_COLORS["cyan"], - ) - ) - reloaded_class = getattr( - reloaded_module, callable.__self__.__class__.__name__ - ) - # Now find the method in the reloaded class, and replace the - # with the reloaded one. - for name, val in inspect.getmembers(reloaded_class): - if ( - inspect.isfunction(val) - and val.__name__ == callable.__name__ - ): - print( - colorize( - f"-> Reloading method {name}", - project_conf.ANSI_COLORS["cyan"], - ) - ) - rld_callable = val - else: - # Most likely we end up here because callable is the function object of the - # called method, not the method itself. Is there even a case where we end up - # with the method object? First we can try to reload it directly if it was a - # module level function: - try: - print( - colorize( - f"-> Reloading module level function {callable.__name__}", - project_conf.ANSI_COLORS["cyan"], - ) - ) - callable = getattr(reloaded_module, callable.__name__) - except AttributeError: - print( - colorize( - f"-> Could not find {callable.__name__} in module {callable.__module__}. " - + "Looking for a class method...", - project_conf.ANSI_COLORS["magenta"], - ) - ) - # Ok that failed, so we need to find the class of the method and reload it, - # then find the method in the reloaded class and replace the function with - # the method's function object; this is the same as above. - # TODO: This feels very hacky! Can we find the class in a better way, maybe - # without going through all classes in the module? Because I'm not sure if the - # qualname always contains the class name in this way; like what about - # inheritance? - print( - colorize( - f"-> Reloading class {callable.__qualname__.split('.')[0]}", - project_conf.ANSI_COLORS["cyan"], - ) - ) - reloaded_class = getattr( - reloaded_module, callable.__qualname__.split(".")[0] - ) - for name, val in inspect.getmembers(reloaded_class): - if inspect.isfunction(val) and name == callable.__name__: - print( - colorize( - f"-> Reloading method {name}", - project_conf.ANSI_COLORS["cyan"], - ) - ) - rld_callable = val - break - if rld_callable is None: - print( - colorize( - f"[!] Could not reload callable {callable}!", - project_conf.ANSI_COLORS["red"], - ) - ) - sys.exit(1) - print( - colorize( - f"[*] Reloaded callable {callable.__name__}! Retrying the call...", - project_conf.ANSI_COLORS["green"], - ) - ) - callable = rld_callable - # TODO: What if the user modified other methods/functions called by the callable? - # Should we find them and recursively reload them? Maybe we can keep track of - # every called *user* function, and if the user modifies any after an exception is - # caught, we can first ask if we should reload it, warning them about the previous - # calls that will be affected by the reload. - # TODO: Check if we changed the function signature and if so, backtrace the call - # and update the arguments by re-running the routine that generated them and made - # the call. - # TODO: Free memory allocation (open file descriptors, etc.) before retrying the - # call. - - return wrapper - - -def debug_methods(cls): - """ - Decorator to debug all methods of a class using the debug_trace decorator if the DEBUG environment variable is set. - """ - if not project_conf.DEBUG: - return cls - for key, val in vars(cls).items(): - if callable(val): - setattr(cls, key, debug_trace(val)) - return cls - - -class DebugMetaclass(type): - """ - We can use this metaclass to automatically decorate all methods of a class with the debug_trace - decorator, making it simpler with inheritance. - """ - - def __new__(cls, name, bases, dct): - obj = super().__new__(cls, name, bases, dct) - obj = debug_methods(obj) - return obj -'''