From 8c01ae7d6dcd19e4aff4dd21d0d09be91ea451e1 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 17:27:56 +0100 Subject: [PATCH] Format code and comments --- conf/experiment.py | 1 - dataset/example.py | 6 ++-- dataset/mixins/__init__.py | 1 + pyproject.toml | 6 ++-- src/base_tester.py | 15 +++++---- src/base_trainer.py | 40 ++++++++++++----------- utils/__init__.py | 2 +- utils/gui.py | 65 ++++++++++++++++++-------------------- 8 files changed, 69 insertions(+), 67 deletions(-) diff --git a/conf/experiment.py b/conf/experiment.py index b74c02b..b294cbb 100644 --- a/conf/experiment.py +++ b/conf/experiment.py @@ -1,4 +1,3 @@ - """ Configurations for the experiments and config groups, using hydra-zen. """ diff --git a/dataset/example.py b/dataset/example.py index 1c8e622..b67eba5 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -61,9 +61,9 @@ def _load( progress: Progress, job_id: TaskID, ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: - len = 3 if self._tiny else 20 - progress.update(job_id, total=len) - for _ in range(len): + length = 3 if self._tiny else 20 + progress.update(job_id, total=length) + for _ in range(length): progress.advance(job_id) sleep(0.001 if self._tiny else 0.1) return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8) diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index 97f064a..ff92bba 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -123,6 +123,7 @@ def __init__( ) self._split = split self._lazy = scd_lazy # TODO: Implement eager caching (rn the default is lazy) + # TODO: Refactor and reduce cyclomatic complexity argnames = inspect.getfullargspec(self.__class__.__init__).args found = False frame: FrameType | None = inspect.currentframe() diff --git a/pyproject.toml b/pyproject.toml index 615a051..0ad2b90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ exclude = [ line-length = 88 indent-width = 4 -# Assume Python 3.8 -target-version = "py38" +# Assume Python 3.11 +target-version = "py311" [tool.ruff.lint] # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. @@ -69,7 +69,7 @@ line-ending = "auto" # # This is currently disabled by default, but it is planned for this # to be opt-out in the future. -docstring-code-format = false +docstring-code-format = true # Set the line length limit used when formatting code snippets in # docstrings. diff --git a/src/base_tester.py b/src/base_tester.py index 2b0fb61..9e582c2 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -30,7 +30,7 @@ console = Console() -print = console.print +print = console.print # skipcq: PYL-W0603 class BaseTester(BaseTrainer): @@ -55,7 +55,7 @@ def __init__( _loss = training_loss self._gui = gui global print # skipcq: PYL-W0603 - print = self._gui.print + print = self._gui.print # skipcq: PYL-W0603 self._run_name = run_name self._model = model if model_ckpt_path is None: @@ -83,7 +83,8 @@ def _test_iteration( batch: Union[Tuple, List, Tensor], ) -> Tuple[Tensor, Dict[str, Tensor]]: """Evaluation procedure for one batch. We want to keep the code DRY and avoid - making mistakes, so this code calls the BaseTrainer._train_val_iteration() method. + making mistakes, so this code calls the BaseTrainer._train_val_iteration() + method. Args: batch: The batch to process. Returns: @@ -99,7 +100,8 @@ async def test( ) -> None: """Computes the average loss on the test set. Args: - visualize_every (int, optional): Visualize the model predictions every n batches. + visualize_every (int, optional): Visualize the model predictions every n + batches. Defaults to 0 (no visualization). """ _ = kwargs @@ -107,7 +109,7 @@ async def test( test_metrics: Dict[str, MeanMetric] = defaultdict(MeanMetric) self._model.eval() print(Text(f"[*] Testing {self._run_name}", style="bold green")) - """ ==================== Training loop for one epoch ==================== """ + # ==================== Training loop for one epoch ==================== pbar, update_loss_hook = self._gui.track_testing( self._data_loader, total=len(self._data_loader) ) @@ -124,7 +126,8 @@ async def test( if visualize_every > 0 and (i + 1) % visualize_every == 0: self._visualize(batch, i) - # TODO: Report metrics in a special panel? Then hang the GUI until the user is done. + # TODO: Report metrics in a special panel? Then hang the GUI until the user is + # done. print("=" * 81) print("==" + " " * 31 + " Test results " + " " * 31 + "==") print("=" * 81) diff --git a/src/base_trainer.py b/src/base_trainer.py index 4117be2..2288b5e 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -34,7 +34,7 @@ from utils.training import visualize_model_predictions console = Console() -print = console.print +print = console.print # skipcq: PYL-W0603 class BaseTrainer: @@ -77,7 +77,7 @@ def __init__( self._n_ctrl_c = 0 self._gui = gui global print # skipcq: PYL-W0603 - print = self._gui.print + print = self._gui.print # skipcq: PYL-W0603 if model_ckpt_path is not None: self._load_checkpoint(model_ckpt_path) signal.signal(signal.SIGINT, self._terminator) @@ -104,8 +104,9 @@ def _train_val_iteration( epoch: int, validation: bool = False, ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Training or validation procedure for one batch. We want to keep the code DRY and avoid - making mistakes, so write this code only once at the cost of many function calls! + """Training or validation procedure for one batch. We want to keep the code DRY + and avoid making mistakes, so write this code only once at the cost of many + function calls! Args: batch: The batch to process. Returns: @@ -136,7 +137,7 @@ def _train_epoch( epoch_loss: MeanMetric = MeanMetric() epoch_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) has_visualized = 0 - """ ==================== Training loop for one epoch ==================== """ + # ==================== Training loop for one epoch ==================== pbar, update_loss_hook = self._gui.track_training( self._train_loader, total=len(self._train_loader), @@ -188,7 +189,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: float: Average validation loss for the epoch. """ has_visualized = 0 - """ ==================== Validation loop for one epoch ==================== """ + # ==================== Validation loop for one epoch ==================== with torch.no_grad(): val_loss: MeanMetric = MeanMetric() val_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) @@ -212,7 +213,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: for k, v in loss_components.items(): val_loss_components[k].update(v.item()) update_loss_hook(val_loss.compute()) - """ ==================== Visualization ==================== """ + # ==================== Visualization ==================== if ( visualize and has_visualized < self._viz_n_samples @@ -235,8 +236,8 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: }, step=epoch, ) - # Set minimize_metric to a key in val_loss_components if you wish to minimize - # a specific metric instead of the validation loss: + # Set minimize_metric to a key in val_loss_components if you wish to + # minimize a specific metric instead of the validation loss: self._model_saver( epoch, mean_val_loss, @@ -268,7 +269,7 @@ async def train( ) self._viz_n_samples = visualize_n_samples self._gui.set_start_epoch(self._epoch) - """ ==================== Training loop ==================== """ + # ==================== Training loop ==================== last_val_loss = float("inf") for epoch in range(self._epoch, epochs): print(f"Epoch: {epoch}") @@ -294,7 +295,7 @@ async def train( last_val_loss = val_loss if self._scheduler is not None: await asyncio.to_thread(self._scheduler.step) - """ ==================== Plotting ==================== """ + # ==================== Plotting ==================== self._gui.plot(epoch, train_loss, last_val_loss) # , self._model_saver) await asyncio.to_thread( self._save_checkpoint, @@ -333,18 +334,19 @@ def _save_checkpoint(self, val_loss: float, ckpt_path: str, **kwargs) -> None: ) def _load_checkpoint(self, ckpt_path: str, model_only: bool = False) -> None: - """Loads the model and optimizer state from a checkpoint file. This method should remain in - this class because it should be extendable in classes inheriting from this class, instead - of being overwritten/modified. That would be a source of bugs and a bad practice. + """Loads the model and optimizer state from a checkpoint file. This method + should remain in this class because it should be extendable in classes + inheriting from this class, instead of being overwritten/modified. That would be + a source of bugs and a bad practice. Args: ckpt_path (str): The path to the checkpoint file. - model_only (bool): If True, only the model is loaded (useful for BaseTester). + model_only (bool): If True, only the model is loaded (useful for + BaseTester). Returns: None """ print(f"[*] Restoring from checkpoint: {ckpt_path}") ckpt = torch.load(ckpt_path) - # If the model was optimized with torch.optimize() we need to remove the "_orig_mod" # prefix: if "_orig_mod" in list(ckpt["model_ckpt"].keys())[0]: ckpt["model_ckpt"] = { @@ -355,7 +357,8 @@ def _load_checkpoint(self, ckpt_path: str, model_only: bool = False) -> None: except Exception: if project_conf.PARTIALLY_LOAD_MODEL_IF_NO_FULL_MATCH: print( - "[!] Partially loading model weights (no full match between model and checkpoint)" + "[!] Partially loading model weights " + + "(no full match between model and checkpoint)" ) self._model.load_state_dict(ckpt["model_ckpt"], strict=False) if not model_only: @@ -378,7 +381,8 @@ def _terminator(self, sig, frame): and self._n_ctrl_c == 0 ): print( - f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}. Press Ctrl+C again to abort." + f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}." + + " Press Ctrl+C again to abort." ) self._n_ctrl_c += 1 elif ( diff --git a/utils/__init__.py b/utils/__init__.py index 5564d43..933b4b4 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -13,7 +13,7 @@ # import sys import traceback -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional # import IPython import numpy as np diff --git a/utils/gui.py b/utils/gui.py index 39a6a83..0daceea 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -63,7 +63,12 @@ def __init__( classes: The CSS classes of the plotter widget. disabled: Whether the plotter widget is disabled or not. """ - super().__init__(name=name, id=id, classes=classes, disabled=disabled) + super().__init__( + name=name, + id=id, + classes=classes, + disabled=disabled, + ) self._title = title self._log_scale = use_log_scale self._train_losses: list[float] = [] @@ -171,8 +176,9 @@ def __init__(self): def update_loss_hook(self, loss: float) -> None: """Update the loss value in the progress bar.""" - # TODO: min_val_loss during validation, val_loss during training. Ideally the - # second parameter would be super flexible (use a dict then). + # TODO: min_val_loss during validation, val_loss during training. + # Ideally the second parameter would be super flexible (use a dict + # then). self._loss = loss class SeqWrapper(abc.Iterator, LossHook): @@ -235,24 +241,11 @@ def reset_hook(): plabel.update(self.DESCRIPTIONS[Task.IDLE]) wrapper = None - update_p, reset_p = ( - partial(update_hook), - partial(reset_hook), - ) + update_p, reset_p = partial(update_hook), partial(reset_hook) if isinstance(iterable, abc.Sequence): - wrapper = SeqWrapper( - iterable, - total, - update_p, - reset_p, - ) + wrapper = SeqWrapper(iterable, total, update_p, reset_p) elif isinstance(iterable, (abc.Iterator, DataLoader)): - wrapper = IteratorWrapper( - iterable, - total, - update_p, - reset_p, - ) + wrapper = IteratorWrapper(iterable, total, update_p, reset_p) else: raise ValueError( f"iterable must be a Sequence or an Iterator, got {type(iterable)}" @@ -264,7 +257,9 @@ def reset_hook(): class GUI(App): - """A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework.""" + """ + A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework. + """ TITLE = "Matchbox TUI" CSS_PATH = "style.css" @@ -327,11 +322,7 @@ def print(self, message: Any): if isinstance(message, (RenderableType, str)): logger.write( Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), + Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""), message, ), ) @@ -379,28 +370,31 @@ def print(self, message: Any): ) def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: - """Return an iterable that tracks the progress of the training process, and a progress bar - hook to update the loss value at each iteration.""" + """Return an iterable that tracks the progress of the training process, and a + progress bar hook to update the loss value at each iteration.""" return self.query_one(DatasetProgressBar).track_iterable( iterable, Task.TRAINING, total ) def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: - """Return an iterable that tracks the progress of the validation process, and a progress bar - hook to update the loss value at each iteration.""" + """Return an iterable that tracks the progress of the validation process, and a + progress bar hook to update the loss value at each iteration.""" return self.query_one(DatasetProgressBar).track_iterable( iterable, Task.VALIDATION, total ) def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: - """Return an iterable that tracks the progress of the testing process, and a progress bar - hook to update the loss value at each iteration.""" + """Return an iterable that tracks the progress of the testing process, and a + progress bar hook to update the loss value at each iteration.""" return self.query_one(DatasetProgressBar).track_iterable( iterable, Task.TESTING, total ) def plot( - self, epoch: int, train_loss: float, val_loss: Optional[float] = None + self, + epoch: int, + train_loss: float, + val_loss: Optional[float] = None, ) -> None: """Plot the training and validation losses for the current epoch.""" self.query_one(PlotterWidget).loading = False @@ -432,8 +426,8 @@ async def run_my_app(): await asyncio.sleep(0.1) await asyncio.sleep(2) mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) - # Somehow, the dataloader will crash if it's not forked when using multiprocessing along with - # Textual. + # Somehow, the dataloader will crash if it's not forked when using multiprocessing + # along with Textual. mp.set_start_method("fork") dataloader = DataLoader(mnist, 32, shuffle=True, num_workers=2) pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader)) @@ -444,7 +438,8 @@ async def run_my_app(): update_progress_loss(random()) gui.plot(epoch=i, train_loss=random(), val_loss=random()) gui.print( - f"[{i+1}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!" + f"[{i+1}/{len(dataloader)}]: " + + "We can also iterate over PyTorch dataloaders!" ) if i == 0: gui.print(batch)