From a6181c80b5f57b37875aa946e8a847d17ee7b282 Mon Sep 17 00:00:00 2001 From: Theo Date: Wed, 5 Jun 2024 22:39:35 +0100 Subject: [PATCH 01/38] launch_experiment.py: refactor and use Rich --- dataset/example.py | 4 + launch_experiment.py | 316 ++++++++++++++++++++++++++----------------- requirements.txt | 1 + src/base_trainer.py | 61 +++++---- utils/__init__.py | 47 ++++--- 5 files changed, 259 insertions(+), 170 deletions(-) diff --git a/dataset/example.py b/dataset/example.py index a526e40..c0390f8 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -10,9 +10,11 @@ This is mostly used to test the framework. """ +from time import sleep from typing import Optional, Tuple, Union import torch +from rich.progress import track from torch import Tensor from dataset.base.image import ImageDataset @@ -49,6 +51,8 @@ def __init__( def _load( self, dataset_root: str, tiny: bool, split: str, seed: int ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: + for _ in track(range(10), description=f"Loading dataset splt '{split}'"): + sleep(0.1) return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8) def __getitem__(self, index: int): diff --git a/launch_experiment.py b/launch_experiment.py index e328b03..842e0e4 100644 --- a/launch_experiment.py +++ b/launch_experiment.py @@ -8,173 +8,239 @@ import os from dataclasses import asdict -from typing import Any, Optional +from typing import Any, Dict, Optional, Tuple import hydra_zen import torch import wandb import yaml from hydra.core.hydra_config import HydraConfig -from hydra.utils import to_absolute_path from hydra_zen import just from hydra_zen.typing import Partial +from rich.console import Console, Group +from rich.panel import Panel +from rich.pretty import Pretty +from rich.syntax import Syntax from torch.utils.data import DataLoader, Dataset -import conf.experiment as exp_conf from conf import project as project_conf from model import TransparentDataParallel from src.base_tester import BaseTester from src.base_trainer import BaseTrainer -from utils import colorize, to_cuda_ +from utils import load_model_ckpt, to_cuda_ +console = Console() -def launch_experiment( - run: exp_conf.RunConfig, - data_loader: Partial[DataLoader[Any]], - optimizer: Partial[torch.optim.Optimizer], - scheduler: Partial[torch.optim.lr_scheduler.LRScheduler], - trainer: Partial[BaseTrainer], - tester: Partial[BaseTester], - dataset: Partial[Dataset[Any]], - model: Partial[torch.nn.Module], - training_loss: Partial[torch.nn.Module], -): - run_name = os.path.basename(HydraConfig.get().runtime.output_dir) + +def print_config(run_name: str, exp_conf: str) -> None: # Generate a random ANSI code: - color_code = f"38;5;{hash(run_name) % 255}" - print( - colorize( - f"========================= Running {run_name} =========================", - color_code, - ) - ) - exp_conf = hydra_zen.to_yaml( - dict( - run_name=run_name, - run_conf=run, - dataset=dataset, - model=model, - optimizer=optimizer, - scheduler=scheduler, - training_loss=training_loss, - ) + run_color = f"color({hash(run_name) % 255})" + background_color = f"color({(hash(run_name) + 128) % 255})" + console.print( + f"Running {run_name}", + style=f"bold {run_color} on {background_color}", + justify="center", ) - print( - colorize( - "Experiment config:\n" + "_" * 18 + "\n" + exp_conf + "_" * 18, color_code - ) + console.rule() + console.print( + Panel( + Syntax( + exp_conf, lexer="yaml", dedent=True, word_wrap=False, theme="dracula" + ), + title="Experiment configuration", + expand=False, + ), + overflow="ellipsis", ) - "============ Partials instantiation ============" - model_inst = model( - encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore - ) # Use just() to get the config out of the Zen-Partial - print(model_inst) - print(f"Number of parameters: {sum(p.numel() for p in model_inst.parameters())}") - print( - f"Number of trainable parameters: {sum(p.numel() for p in model_inst.parameters() if p.requires_grad)}" - ) - train_dataset: Optional[Dataset[Any]] = None - val_dataset: Optional[Dataset[Any]] = None - test_dataset: Optional[Dataset[Any]] = None - if run.training_mode: - train_dataset = dataset(split="train", seed=run.seed) - val_dataset = dataset(split="val", seed=run.seed) - else: - test_dataset = dataset(split="test", augment=False, seed=run.seed) - opt_inst = optimizer(model_inst.parameters()) - scheduler_inst = scheduler( - opt_inst - ) # TODO: less hacky way to set T_max for CosineAnnealingLR? - if isinstance(scheduler_inst, torch.optim.lr_scheduler.CosineAnnealingLR): - scheduler_inst.T_max = run.epochs - - "======== Multi GPUs ==========" - print( - colorize( - f"[*] Number of GPUs: {torch.cuda.device_count()}", - project_conf.ANSI_COLORS["cyan"], - ) +def print_model(model: torch.nn.Module) -> None: + console.print( + Panel( + Group( + Pretty(model), + f"Number of parameters: {sum(p.numel() for p in model.parameters())}", + f"Number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}", + ), + title="Model architecture", + expand=False, + ), + overflow="ellipsis", ) - if torch.cuda.device_count() > 1: - print( - colorize( - f"-> Using {torch.cuda.device_count()} GPUs!", - project_conf.ANSI_COLORS["cyan"], + console.rule() + + +def init_wandb( + run_name: str, + model: torch.nn.Module, + exp_conf: str, + log="gradients", + log_graph=False, +) -> None: + if project_conf.USE_WANDB: + with console.status("Initializing Weights & Biases...", spinner="moon"): + # exp_conf is a string, so we need to load it back to a dict: + exp_conf = yaml.safe_load(exp_conf) + wandb.init( # type: ignore + project=project_conf.PROJECT_NAME, + name=run_name, + config=exp_conf, ) - ) - model_inst = TransparentDataParallel(model_inst) + wandb.watch(model, log=log, log_graph=log_graph) # type: ignore - training_loss_inst: Optional[torch.nn.Module] = None - if run.training_mode: - training_loss_inst = training_loss() - "============ CUDA ============" - model_inst: torch.nn.Module = to_cuda_(model_inst) # type: ignore - training_loss_inst = to_cuda_(training_loss_inst) # type: ignore +def make_datasets( + training_mode: bool, seed: int, dataset_partial: Partial[Dataset[Any]] +) -> Tuple[Optional[Dataset[Any]], Optional[Dataset[Any]], Optional[Dataset[Any]]]: + train_dataset: Optional[Dataset[Any]] = None + val_dataset: Optional[Dataset[Any]] = None + test_dataset: Optional[Dataset[Any]] = None + with console.status("Loading datasets...", spinner="monkey"): + if training_mode: + train_dataset = dataset_partial(split="train", seed=seed) + val_dataset = dataset_partial(split="val", seed=seed) + else: + test_dataset = dataset_partial(split="test", augment=False, seed=seed) + return train_dataset, val_dataset, test_dataset - "============ Weights & Biases ============" - if project_conf.USE_WANDB: - # exp_conf is a string, so we need to load it back to a dict: - exp_conf = yaml.safe_load(exp_conf) - wandb.init( # type: ignore - project=project_conf.PROJECT_NAME, - name=run_name, - config=exp_conf, - ) - wandb.watch(model_inst, log="all", log_graph=True) # type: ignore - " ============ Reproducibility of data loaders ============ " - g = None + +def make_dataloaders( + data_loader_partial: Partial[DataLoader[Dataset[Any]]], + train_dataset: Optional[Dataset[Any]], + val_dataset: Optional[Dataset[Any]], + test_dataset: Optional[Dataset[Any]], + training_mode: bool, + seed: int, +) -> Tuple[ + Optional[DataLoader[Dataset[Any]]], + Optional[DataLoader[Dataset[Any]]], + Optional[DataLoader[Dataset[Any]]], +]: + generator = None if project_conf.REPRODUCIBLE: - g = torch.Generator() - g.manual_seed(run.seed) + generator = torch.Generator() + generator.manual_seed(seed) train_loader_inst: Optional[DataLoader[Any]] = None val_loader_inst: Optional[DataLoader[Dataset[Any]]] = None test_loader_inst: Optional[DataLoader[Any]] = None - if run.training_mode: + if training_mode: if train_dataset is None or val_dataset is None: raise ValueError( "train_dataset and val_dataset must be defined in training mode!" ) - train_loader_inst = data_loader(train_dataset, generator=g) - val_loader_inst = data_loader( - val_dataset, generator=g, shuffle=False, drop_last=False + train_loader_inst = data_loader_partial(train_dataset, generator=generator) + val_loader_inst = data_loader_partial( + val_dataset, generator=generator, shuffle=False, drop_last=False ) else: if test_dataset is None: raise ValueError("test_dataset must be defined in testing mode!") - test_loader_inst = data_loader( - test_dataset, generator=g, shuffle=False, drop_last=False + test_loader_inst = data_loader_partial( + test_dataset, generator=generator, shuffle=False, drop_last=False ) + return train_loader_inst, val_loader_inst, test_loader_inst - " ============ Training ============ " - model_ckpt_path = None - if run.load_from is not None: - if run.load_from.endswith(".ckpt"): - model_ckpt_path = to_absolute_path(run.load_from) - if not os.path.exists(model_ckpt_path): - raise ValueError(f"File {model_ckpt_path} does not exist!") - else: - run_models = sorted( - [ - f - for f in os.listdir(to_absolute_path(f"runs/{run.load_from}/")) - if f.endswith(".ckpt") - and (not f.startswith("last") if not run.training_mode else True) - ] - ) - if len(run_models) < 1: - raise ValueError(f"No model found in runs/{run.load_from}/") - model_ckpt_path = to_absolute_path( - os.path.join( - "runs", - run.load_from, - run_models[-1], - ) - ) +def make_model( + model_partial: Partial[torch.nn.Module], dataset: Partial[Dataset[Any]] +) -> torch.nn.Module: + with console.status("Loading model...", spinner="runner"): + model_inst = model_partial( + encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore + ) # Use just() to get the config out of the Zen-Partial + + return model_inst + + +def parallelize_model(model: torch.nn.Module) -> torch.nn.Module: + console.print( + f"[*] Number of GPUs: {torch.cuda.device_count()}", + style="bold cyan", + ) + if torch.cuda.device_count() > 1: + console.print( + f"-> Using {torch.cuda.device_count()} GPUs!", + style="bold cyan", + ) + model = TransparentDataParallel(model) + return model + + +def make_optimizer( + optimizer_partial: Partial[torch.optim.Optimizer], model: torch.nn.Module +) -> torch.optim.Optimizer: + return optimizer_partial(model.parameters()) + + +def make_scheduler( + scheduler_partial: Partial[torch.optim.lr_scheduler.LRScheduler], + optimizer: torch.optim.Optimizer, + epochs: int, +) -> torch.optim.lr_scheduler.LRScheduler: + scheduler = scheduler_partial( + optimizer + ) # TODO: less hacky way to set T_max for CosineAnnealingLR? + if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR): + scheduler.T_max = epochs + return scheduler + + +def make_training_loss( + training_mode: bool, training_loss_partial: Partial[torch.nn.Module] +): + training_loss: Optional[torch.nn.Module] = None + if training_mode: + training_loss = training_loss_partial() + return training_loss + + +def launch_experiment( + run, # type: ignore + data_loader: Partial[DataLoader[Any]], + optimizer: Partial[torch.optim.Optimizer], + scheduler: Partial[torch.optim.lr_scheduler.LRScheduler], + trainer: Partial[BaseTrainer], + tester: Partial[BaseTester], + dataset: Partial[Dataset[Any]], + model: Partial[torch.nn.Module], + training_loss: Partial[torch.nn.Module], +): + run_name = os.path.basename(HydraConfig.get().runtime.output_dir) + exp_conf = hydra_zen.to_yaml( + dict( + run_conf=run, + dataset=dataset, + model=model, + optimizer=optimizer, + scheduler=scheduler, + training_loss=training_loss, + ) + ) + print_config(run_name, exp_conf) + + """ ============ Partials instantiation ============ """ + model_inst = make_model(model, dataset) + print_model(model_inst) + train_dataset, val_dataset, test_dataset = make_datasets( + run.training_mode, run.seed, dataset + ) + opt_inst = make_optimizer(optimizer, model_inst) + scheduler_inst = make_scheduler(scheduler, opt_inst, run.epochs) + model_inst = to_cuda_(parallelize_model(model_inst)) + training_loss_inst = to_cuda_(make_training_loss(run.training_mode, training_loss)) + train_loader_inst, val_loader_inst, test_loader_inst = make_dataloaders( + data_loader, + train_dataset, + val_dataset, + test_dataset, + run.training_mode, + run.seed, + ) + init_wandb(run_name, model_inst, exp_conf) + + """ ============ Training ============ """ + model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode) if run.training_mode: if training_loss_inst is None: raise ValueError("training_loss must be defined in training mode!") diff --git a/requirements.txt b/requirements.txt index ddfcdc4..b16827f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ pre-commit blosc2 ipython neovim +rich diff --git a/src/base_trainer.py b/src/base_trainer.py index be05375..c1ead4c 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -19,18 +19,21 @@ import torch import wandb from hydra.core.hydra_config import HydraConfig +from rich.console import Console +from rich.progress import track from torch import Tensor from torch.nn import Module from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import MeanMetric -from tqdm import tqdm from conf import project as project_conf -from utils import blink_pbar, colorize, to_cuda, update_pbar_str +from utils import to_cuda from utils.helpers import BestNModelSaver from utils.training import visualize_model_predictions +console = Console() + class BaseTrainer: def __init__( @@ -65,7 +68,6 @@ def __init__( project_conf.BEST_N_MODELS_TO_KEEP, self._save_checkpoint ) self._minimize_metric = "val_loss" - self._pbar = tqdm(total=len(self._train_loader), desc="Training") self._training_loss = training_loss self._viz_n_samples = 1 self._n_ctrl_c = 0 @@ -124,12 +126,14 @@ def _train_epoch( """ epoch_loss: MeanMetric = MeanMetric() epoch_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) - self._pbar.reset() - self._pbar.set_description(description) color_code = project_conf.ANSI_COLORS[project_conf.Theme.TRAINING.value] has_visualized = 0 """ ==================== Training loop for one epoch ==================== """ - for i, batch in enumerate(self._train_loader): + for i, batch in track( + enumerate(self._train_loader), + description=description, + total=len(self._train_loader), + ): if ( not self._running and project_conf.SIGINT_BEHAVIOR @@ -146,12 +150,12 @@ def _train_epoch( epoch_loss.update(loss.item()) for k, v in loss_components.items(): epoch_loss_components[k].update(v.item()) - update_pbar_str( - self._pbar, - f"{description} [loss={epoch_loss.compute():.4f} /" - + f" val_loss={last_val_loss:.4f}]", - color_code, - ) + # update_pbar_str( + # self._pbar, + # f"{description} [loss={epoch_loss.compute():.4f} /" + # + f" val_loss={last_val_loss:.4f}]", + # color_code, + # ) if ( visualize and has_visualized < self._viz_n_samples @@ -160,7 +164,7 @@ def _train_epoch( with torch.no_grad(): self._visualize(batch, epoch) has_visualized += 1 - self._pbar.update() + # self._pbar.update() mean_epoch_loss: float = epoch_loss.compute().item() if project_conf.USE_WANDB: wandb.log({"train_loss": mean_epoch_loss}, step=epoch) @@ -187,7 +191,11 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: with torch.no_grad(): val_loss: MeanMetric = MeanMetric() val_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) - for i, batch in enumerate(self._val_loader): + for i, batch in track( + enumerate(self._val_loader), + description=description, + total=len(self._val_loader), + ): if ( not self._running and project_conf.SIGINT_BEHAVIOR @@ -196,7 +204,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: print("[!] Training aborted.") break # Blink the progress bar to indicate that the validation loop is running - blink_pbar(i, self._pbar, 4) + # blink_pbar(i, self._pbar, 4) loss, loss_components = self._train_val_iteration( batch, epoch, @@ -204,12 +212,12 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: val_loss.update(loss.item()) for k, v in loss_components.items(): val_loss_components[k].update(v.item()) - update_pbar_str( - self._pbar, - f"{description} [loss={val_loss.compute():.4f} /" - + f" min_val_loss={self._model_saver.min_val_loss:.4f}]", - color_code, - ) + # update_pbar_str( + # self._pbar, + # f"{description} [loss={val_loss.compute():.4f} /" + # + f" min_val_loss={self._model_saver.min_val_loss:.4f}]", + # color_code, + # ) """ ==================== Visualization ==================== """ if ( visualize @@ -262,12 +270,7 @@ def train( """ if model_ckpt_path is not None: self._load_checkpoint(model_ckpt_path) - print( - colorize( - f"[*] Training {self._run_name} for {epochs} epochs", - project_conf.ANSI_COLORS["green"], - ) - ) + # console.print(f"[*] Training {self._run_name} for {epochs} epochs", style="bold green") self._viz_n_samples = visualize_n_samples train_losses: List[float] = [] val_losses: List[float] = [] @@ -277,7 +280,7 @@ def train( if not self._running: break self._model.train() - self._pbar.colour = project_conf.Theme.TRAINING.value + # self._pbar.colour = project_conf.Theme.TRAINING.value train_losses.append( self._train_epoch( f"Epoch {epoch}/{epochs}: Training", @@ -291,7 +294,7 @@ def train( ) if epoch % val_every == 0: self._model.eval() - self._pbar.colour = project_conf.Theme.VALIDATION.value + # self._pbar.colour = project_conf.Theme.VALIDATION.value val_losses.append( self._val_epoch( f"Epoch {epoch}/{epochs}: Validation", diff --git a/utils/__init__.py b/utils/__init__.py index 4832e55..680ef44 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -8,22 +8,52 @@ # import importlib # import inspect +import os import random # import sys import traceback from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union # import IPython import numpy as np import torch +from hydra.utils import to_absolute_path from torch import Tensor, nn from tqdm import tqdm from conf import project as project_conf +def load_model_ckpt(load_from: Optional[str], training_mode: bool) -> Optional[str]: + model_ckpt_path = None + if load_from is not None: + if load_from.endswith(".ckpt"): + model_ckpt_path = to_absolute_path(load_from) + if not os.path.exists(model_ckpt_path): + raise ValueError(f"File {model_ckpt_path} does not exist!") + else: + run_models = sorted( + [ + f + for f in os.listdir(to_absolute_path(f"runs/{load_from}/")) + if f.endswith(".ckpt") + and (not f.startswith("last") if not training_mode else True) + ] + ) + if len(run_models) < 1: + raise ValueError(f"No model found in runs/{load_from}/") + model_ckpt_path = to_absolute_path( + os.path.join( + "runs", + load_from, + run_models[-1], + ) + ) + return model_ckpt_path + + def seed_everything(seed: int): torch.manual_seed(seed) # type: ignore np.random.seed(seed) @@ -69,10 +99,6 @@ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return wrapper -def colorize(string: str, ansii_code: Union[int, str]) -> str: - return f"\033[{ansii_code}m{string}\033[0m" - - def blink_pbar(i: int, pbar: tqdm, n: int) -> None: """Blink the progress bar every n iterations. Args: @@ -88,17 +114,6 @@ def blink_pbar(i: int, pbar: tqdm, n: int) -> None: ) -@contextmanager -def colorize_prints(ansii_code: Union[int, str]): - if isinstance(ansii_code, str): - ansii_code = project_conf.ANSI_COLORS[ansii_code] - print(f"\033[{ansii_code}m", end="") - try: - yield - finally: - print("\033[0m", end="") - - def update_pbar_str(pbar: tqdm, string: str, color_code: int) -> None: """Update the progress bar string. Args: From 174bdcbc2ab7a7ddea5581b4651e6431c1450f9f Mon Sep 17 00:00:00 2001 From: Theo Date: Wed, 5 Jun 2024 22:46:00 +0100 Subject: [PATCH 02/38] Refactor: move factories and launcher in bootstrap/ --- bootstrap/factories.py | 141 ++++++++++++++++++ .../launch_experiment.py | 120 ++------------- conf/experiment.py | 2 +- test.py | 2 +- train.py | 2 +- 5 files changed, 157 insertions(+), 110 deletions(-) create mode 100644 bootstrap/factories.py rename launch_experiment.py => bootstrap/launch_experiment.py (60%) diff --git a/bootstrap/factories.py b/bootstrap/factories.py new file mode 100644 index 0000000..ecbb6de --- /dev/null +++ b/bootstrap/factories.py @@ -0,0 +1,141 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 +# +# Copyright © 2024 Théo Morales +# +# Distributed under terms of the MIT license. + +""" +All factories. +""" + +import os +from dataclasses import asdict +from typing import Any, Optional, Tuple + +import hydra_zen +import torch +import wandb +import yaml +from hydra.core.hydra_config import HydraConfig +from hydra_zen import just +from hydra_zen.typing import Partial +from rich.console import Console, Group +from rich.panel import Panel +from rich.pretty import Pretty +from rich.syntax import Syntax +from torch.utils.data import DataLoader, Dataset + +from conf import project as project_conf +from model import TransparentDataParallel +from src.base_tester import BaseTester +from src.base_trainer import BaseTrainer +from utils import load_model_ckpt, to_cuda_ + +console = Console() + + +def make_datasets( + training_mode: bool, seed: int, dataset_partial: Partial[Dataset[Any]] +) -> Tuple[Optional[Dataset[Any]], Optional[Dataset[Any]], Optional[Dataset[Any]]]: + train_dataset: Optional[Dataset[Any]] = None + val_dataset: Optional[Dataset[Any]] = None + test_dataset: Optional[Dataset[Any]] = None + with console.status("Loading datasets...", spinner="monkey"): + if training_mode: + train_dataset = dataset_partial(split="train", seed=seed) + val_dataset = dataset_partial(split="val", seed=seed) + else: + test_dataset = dataset_partial(split="test", augment=False, seed=seed) + return train_dataset, val_dataset, test_dataset + + +def make_dataloaders( + data_loader_partial: Partial[DataLoader[Dataset[Any]]], + train_dataset: Optional[Dataset[Any]], + val_dataset: Optional[Dataset[Any]], + test_dataset: Optional[Dataset[Any]], + training_mode: bool, + seed: int, +) -> Tuple[ + Optional[DataLoader[Dataset[Any]]], + Optional[DataLoader[Dataset[Any]]], + Optional[DataLoader[Dataset[Any]]], +]: + generator = None + if project_conf.REPRODUCIBLE: + generator = torch.Generator() + generator.manual_seed(seed) + + train_loader_inst: Optional[DataLoader[Any]] = None + val_loader_inst: Optional[DataLoader[Dataset[Any]]] = None + test_loader_inst: Optional[DataLoader[Any]] = None + if training_mode: + if train_dataset is None or val_dataset is None: + raise ValueError( + "train_dataset and val_dataset must be defined in training mode!" + ) + train_loader_inst = data_loader_partial(train_dataset, generator=generator) + val_loader_inst = data_loader_partial( + val_dataset, generator=generator, shuffle=False, drop_last=False + ) + else: + if test_dataset is None: + raise ValueError("test_dataset must be defined in testing mode!") + test_loader_inst = data_loader_partial( + test_dataset, generator=generator, shuffle=False, drop_last=False + ) + return train_loader_inst, val_loader_inst, test_loader_inst + + +def make_model( + model_partial: Partial[torch.nn.Module], dataset: Partial[Dataset[Any]] +) -> torch.nn.Module: + with console.status("Loading model...", spinner="runner"): + model_inst = model_partial( + encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore + ) # Use just() to get the config out of the Zen-Partial + + return model_inst + + +def parallelize_model(model: torch.nn.Module) -> torch.nn.Module: + console.print( + f"[*] Number of GPUs: {torch.cuda.device_count()}", + style="bold cyan", + ) + if torch.cuda.device_count() > 1: + console.print( + f"-> Using {torch.cuda.device_count()} GPUs!", + style="bold cyan", + ) + model = TransparentDataParallel(model) + return model + + +def make_optimizer( + optimizer_partial: Partial[torch.optim.Optimizer], model: torch.nn.Module +) -> torch.optim.Optimizer: + return optimizer_partial(model.parameters()) + + +def make_scheduler( + scheduler_partial: Partial[torch.optim.lr_scheduler.LRScheduler], + optimizer: torch.optim.Optimizer, + epochs: int, +) -> torch.optim.lr_scheduler.LRScheduler: + scheduler = scheduler_partial( + optimizer + ) # TODO: less hacky way to set T_max for CosineAnnealingLR? + if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR): + scheduler.T_max = epochs + return scheduler + + +def make_training_loss( + training_mode: bool, training_loss_partial: Partial[torch.nn.Module] +): + training_loss: Optional[torch.nn.Module] = None + if training_mode: + training_loss = training_loss_partial() + return training_loss diff --git a/launch_experiment.py b/bootstrap/launch_experiment.py similarity index 60% rename from launch_experiment.py rename to bootstrap/launch_experiment.py index 842e0e4..a67e8eb 100644 --- a/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -8,7 +8,7 @@ import os from dataclasses import asdict -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional, Tuple import hydra_zen import torch @@ -23,6 +23,14 @@ from rich.syntax import Syntax from torch.utils.data import DataLoader, Dataset +from bootstrap.factories import ( + make_dataloaders, + make_datasets, + make_model, + make_optimizer, + make_scheduler, + make_training_loss, +) from conf import project as project_conf from model import TransparentDataParallel from src.base_tester import BaseTester @@ -32,6 +40,7 @@ console = Console() +# =========================================== Printing =========================================== def print_config(run_name: str, exp_conf: str) -> None: # Generate a random ANSI code: run_color = f"color({hash(run_name) % 255})" @@ -70,6 +79,9 @@ def print_model(model: torch.nn.Module) -> None: console.rule() +# ================================================================================================= + + def init_wandb( run_name: str, model: torch.nn.Module, @@ -89,112 +101,6 @@ def init_wandb( wandb.watch(model, log=log, log_graph=log_graph) # type: ignore -def make_datasets( - training_mode: bool, seed: int, dataset_partial: Partial[Dataset[Any]] -) -> Tuple[Optional[Dataset[Any]], Optional[Dataset[Any]], Optional[Dataset[Any]]]: - train_dataset: Optional[Dataset[Any]] = None - val_dataset: Optional[Dataset[Any]] = None - test_dataset: Optional[Dataset[Any]] = None - with console.status("Loading datasets...", spinner="monkey"): - if training_mode: - train_dataset = dataset_partial(split="train", seed=seed) - val_dataset = dataset_partial(split="val", seed=seed) - else: - test_dataset = dataset_partial(split="test", augment=False, seed=seed) - return train_dataset, val_dataset, test_dataset - - -def make_dataloaders( - data_loader_partial: Partial[DataLoader[Dataset[Any]]], - train_dataset: Optional[Dataset[Any]], - val_dataset: Optional[Dataset[Any]], - test_dataset: Optional[Dataset[Any]], - training_mode: bool, - seed: int, -) -> Tuple[ - Optional[DataLoader[Dataset[Any]]], - Optional[DataLoader[Dataset[Any]]], - Optional[DataLoader[Dataset[Any]]], -]: - generator = None - if project_conf.REPRODUCIBLE: - generator = torch.Generator() - generator.manual_seed(seed) - - train_loader_inst: Optional[DataLoader[Any]] = None - val_loader_inst: Optional[DataLoader[Dataset[Any]]] = None - test_loader_inst: Optional[DataLoader[Any]] = None - if training_mode: - if train_dataset is None or val_dataset is None: - raise ValueError( - "train_dataset and val_dataset must be defined in training mode!" - ) - train_loader_inst = data_loader_partial(train_dataset, generator=generator) - val_loader_inst = data_loader_partial( - val_dataset, generator=generator, shuffle=False, drop_last=False - ) - else: - if test_dataset is None: - raise ValueError("test_dataset must be defined in testing mode!") - test_loader_inst = data_loader_partial( - test_dataset, generator=generator, shuffle=False, drop_last=False - ) - return train_loader_inst, val_loader_inst, test_loader_inst - - -def make_model( - model_partial: Partial[torch.nn.Module], dataset: Partial[Dataset[Any]] -) -> torch.nn.Module: - with console.status("Loading model...", spinner="runner"): - model_inst = model_partial( - encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore - ) # Use just() to get the config out of the Zen-Partial - - return model_inst - - -def parallelize_model(model: torch.nn.Module) -> torch.nn.Module: - console.print( - f"[*] Number of GPUs: {torch.cuda.device_count()}", - style="bold cyan", - ) - if torch.cuda.device_count() > 1: - console.print( - f"-> Using {torch.cuda.device_count()} GPUs!", - style="bold cyan", - ) - model = TransparentDataParallel(model) - return model - - -def make_optimizer( - optimizer_partial: Partial[torch.optim.Optimizer], model: torch.nn.Module -) -> torch.optim.Optimizer: - return optimizer_partial(model.parameters()) - - -def make_scheduler( - scheduler_partial: Partial[torch.optim.lr_scheduler.LRScheduler], - optimizer: torch.optim.Optimizer, - epochs: int, -) -> torch.optim.lr_scheduler.LRScheduler: - scheduler = scheduler_partial( - optimizer - ) # TODO: less hacky way to set T_max for CosineAnnealingLR? - if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR): - scheduler.T_max = epochs - return scheduler - - -def make_training_loss( - training_mode: bool, training_loss_partial: Partial[torch.nn.Module] -): - training_loss: Optional[torch.nn.Module] = None - if training_mode: - training_loss = training_loss_partial() - return training_loss - - def launch_experiment( run, # type: ignore data_loader: Partial[DataLoader[Any]], diff --git a/conf/experiment.py b/conf/experiment.py index 3c8f81f..b48ebf5 100644 --- a/conf/experiment.py +++ b/conf/experiment.py @@ -28,8 +28,8 @@ from unique_names_generator import get_random_name from unique_names_generator.data import ADJECTIVES, NAMES +from bootstrap.launch_experiment import launch_experiment from dataset.example import ExampleDataset -from launch_experiment import launch_experiment from model.example import ExampleModel from src.base_tester import BaseTester from src.base_trainer import BaseTrainer diff --git a/test.py b/test.py index 8b642ab..d5f2171 100755 --- a/test.py +++ b/test.py @@ -9,8 +9,8 @@ from hydra_zen import store, zen import conf.experiment # Must import the config to add all components to the store! # noqa +from bootstrap.launch_experiment import launch_experiment from conf import project as project_conf -from launch_experiment import launch_experiment from utils import seed_everything if __name__ == "__main__": diff --git a/train.py b/train.py index 090f533..24a98b4 100755 --- a/train.py +++ b/train.py @@ -9,8 +9,8 @@ from hydra_zen import store, zen import conf.experiment # Must import the config to add all components to the store! # noqa +from bootstrap.launch_experiment import launch_experiment from conf import project as project_conf -from launch_experiment import launch_experiment from utils import seed_everything if __name__ == "__main__": From 6039c0402bf2fcce27afd8c453cc0bfd80924049 Mon Sep 17 00:00:00 2001 From: Theo Date: Wed, 5 Jun 2024 22:50:13 +0100 Subject: [PATCH 03/38] Update project structure in README --- README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index be3da25..c33e01b 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) [![pre-commit](https://img.shields.io/badge/Pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) [![Vim](https://img.shields.io/badge/VIM%20ready!-forestgreen?style=for-the-badge&logo=vim)](https://github.com/DubiousCactus/bells-and-whistles/blob/main/.vimspector.json) - + A batteries-included PyTorch template with a terminal display that stays out of your way! Click on [Use this @@ -91,6 +91,9 @@ This template writes the necessary boilerplate for you, while staying out of you ``` my-pytorch-project/ + bootstrap/ + factories.py <-- Factory functions for instantiating models, optimizers, etc. + launch_experiment.py <-- Bootstraps the experiment and launches the training/testing loop conf/ experiment.py <-- experiment-level configurations project.py <-- project-level constants @@ -118,10 +121,9 @@ my-pytorch-project/ helpers.py <-- high-level utilities training.py <-- training-related utilities vendor/ - . <-- third-party code goes here - launch_experiment.py <-- Builds the trainer and tester, instantiates all partials, etc. - train.py <-- training entry point (calls launch_experiment) - test.py <-- testing entry point (calls launch_experiment) + . <-- third-party code goes here (github submodules, etc.) + train.py <-- training entry point (calls bootstrap/launch_experiment) + test.py <-- testing entry point (calls bootstrap/launch_experiment) ``` ## Setting up From d26a74c3c2a0df637f6c725605ee5e5144d4ae19 Mon Sep 17 00:00:00 2001 From: Theo Date: Wed, 5 Jun 2024 22:58:56 +0100 Subject: [PATCH 04/38] Refactor experiment building so that we dont rely on import --- bootstrap/factories.py | 14 +---- bootstrap/launch_experiment.py | 5 +- conf/experiment.py | 106 +++++++++++++++++---------------- test.py | 3 +- train.py | 3 +- 5 files changed, 61 insertions(+), 70 deletions(-) diff --git a/bootstrap/factories.py b/bootstrap/factories.py index ecbb6de..ec94400 100644 --- a/bootstrap/factories.py +++ b/bootstrap/factories.py @@ -9,28 +9,16 @@ All factories. """ -import os -from dataclasses import asdict from typing import Any, Optional, Tuple -import hydra_zen import torch -import wandb -import yaml -from hydra.core.hydra_config import HydraConfig from hydra_zen import just from hydra_zen.typing import Partial -from rich.console import Console, Group -from rich.panel import Panel -from rich.pretty import Pretty -from rich.syntax import Syntax +from rich.console import Console from torch.utils.data import DataLoader, Dataset from conf import project as project_conf from model import TransparentDataParallel -from src.base_tester import BaseTester -from src.base_trainer import BaseTrainer -from utils import load_model_ckpt, to_cuda_ console = Console() diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index a67e8eb..6552572 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -8,14 +8,13 @@ import os from dataclasses import asdict -from typing import Any, Optional, Tuple +from typing import Any import hydra_zen import torch import wandb import yaml from hydra.core.hydra_config import HydraConfig -from hydra_zen import just from hydra_zen.typing import Partial from rich.console import Console, Group from rich.panel import Panel @@ -30,9 +29,9 @@ make_optimizer, make_scheduler, make_training_loss, + parallelize_model, ) from conf import project as project_conf -from model import TransparentDataParallel from src.base_tester import BaseTester from src.base_trainer import BaseTrainer from utils import load_model_ckpt, to_cuda_ diff --git a/conf/experiment.py b/conf/experiment.py index b48ebf5..e8c9577 100644 --- a/conf/experiment.py +++ b/conf/experiment.py @@ -228,58 +228,60 @@ class RunConfig: tester_store = store(group="tester") tester_store(pbuilds(BaseTester, populate_full_signature=True), name="base") -Experiment = builds( - launch_experiment, - populate_full_signature=True, - hydra_defaults=[ - "_self_", - {"trainer": "base"}, - {"tester": "base"}, - {"dataset": "image_a"}, - {"model": "model_a"}, - {"optimizer": "adam"}, - {"scheduler": "step"}, - {"run": "default"}, - {"training_loss": "mse"}, - ], - trainer=MISSING, - tester=MISSING, - dataset=MISSING, - model=MISSING, - optimizer=MISSING, - scheduler=MISSING, - run=MISSING, - training_loss=MISSING, - data_loader=pbuilds( - DataLoader, builds_bases=(DataloaderConf,) - ), # Needs a partial because we need to set the dataset -) -store(Experiment, name="base_experiment") - -# the experiment configs: -# - must be stored under the _global_ package -# - must inherit from `Experiment` -experiment_store = store(group="experiment", package="_global_") -experiment_store( - make_config( - hydra_defaults=[ - "_self_", - {"override /model": "model_a"}, - {"override /dataset": "image_a"}, - ], - # training=dict(epochs=100), - bases=(Experiment,), - ), - name="exp_a", -) -experiment_store( - make_config( + +def make_experiment_configs(): + Experiment = builds( + launch_experiment, + populate_full_signature=True, hydra_defaults=[ "_self_", - {"override /model": "model_b"}, - {"override /dataset": "image_b"}, + {"trainer": "base"}, + {"tester": "base"}, + {"dataset": "image_a"}, + {"model": "model_a"}, + {"optimizer": "adam"}, + {"scheduler": "step"}, + {"run": "default"}, + {"training_loss": "mse"}, ], - bases=(Experiment,), - ), - name="exp_b", -) + trainer=MISSING, + tester=MISSING, + dataset=MISSING, + model=MISSING, + optimizer=MISSING, + scheduler=MISSING, + run=MISSING, + training_loss=MISSING, + data_loader=pbuilds( + DataLoader, builds_bases=(DataloaderConf,) + ), # Needs a partial because we need to set the dataset + ) + store(Experiment, name="base_experiment") + + # the experiment configs: + # - must be stored under the _global_ package + # - must inherit from `Experiment` + experiment_store = store(group="experiment", package="_global_") + experiment_store( + make_config( + hydra_defaults=[ + "_self_", + {"override /model": "model_a"}, + {"override /dataset": "image_a"}, + ], + # training=dict(epochs=100), + bases=(Experiment,), + ), + name="exp_a", + ) + experiment_store( + make_config( + hydra_defaults=[ + "_self_", + {"override /model": "model_b"}, + {"override /dataset": "image_b"}, + ], + bases=(Experiment,), + ), + name="exp_b", + ) diff --git a/test.py b/test.py index d5f2171..3b8c4f9 100755 --- a/test.py +++ b/test.py @@ -8,12 +8,13 @@ from hydra_zen import store, zen -import conf.experiment # Must import the config to add all components to the store! # noqa from bootstrap.launch_experiment import launch_experiment from conf import project as project_conf +from conf.experiment import make_experiment_configs from utils import seed_everything if __name__ == "__main__": + make_experiment_configs() def set_test_mode(cfg): cfg.run.training_mode = False diff --git a/train.py b/train.py index 24a98b4..ea3242e 100755 --- a/train.py +++ b/train.py @@ -8,12 +8,13 @@ from hydra_zen import store, zen -import conf.experiment # Must import the config to add all components to the store! # noqa from bootstrap.launch_experiment import launch_experiment from conf import project as project_conf +from conf.experiment import make_experiment_configs from utils import seed_everything if __name__ == "__main__": + make_experiment_configs() "============ Hydra-Zen ============" store.add_to_hydra_store( overwrite_ok=True From d7b348bc93ccbd2e3c2d600a25333d08519d216f Mon Sep 17 00:00:00 2001 From: Theo Date: Wed, 5 Jun 2024 23:04:14 +0100 Subject: [PATCH 05/38] Refactor trainer/tester inst. --- bootstrap/launch_experiment.py | 20 ++++++++++---------- src/base_trainer.py | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index 6552572..960c663 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -146,6 +146,12 @@ def launch_experiment( """ ============ Training ============ """ model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode) + common_args = dict( + run_name=run_name, + model=model_inst, + model_ckpt_path=model_ckpt_path, + training_loss=training_loss_inst, + ) if run.training_mode: if training_loss_inst is None: raise ValueError("training_loss must be defined in training mode!") @@ -154,13 +160,11 @@ def launch_experiment( "val_loader and train_loader must be defined in training mode!" ) trainer( - run_name=run_name, - model=model_inst, - opt=opt_inst, - scheduler=scheduler_inst, train_loader=train_loader_inst, val_loader=val_loader_inst, - training_loss=training_loss_inst, + opt=opt_inst, + scheduler=scheduler_inst, + **common_args, **asdict( run ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) @@ -170,17 +174,13 @@ def launch_experiment( visualize_every=run.viz_every, visualize_train_every=run.viz_train_every, visualize_n_samples=run.viz_num_samples, - model_ckpt_path=model_ckpt_path, ) else: if test_loader_inst is None: raise ValueError("test_loader must be defined in testing mode!") tester( - run_name=run_name, - model=model_inst, data_loader=test_loader_inst, - model_ckpt_path=model_ckpt_path, - training_loss=training_loss_inst, + **common_args, ).test( visualize_every=run.viz_every, **asdict( diff --git a/src/base_trainer.py b/src/base_trainer.py index c1ead4c..92b3116 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -44,6 +44,7 @@ def __init__( train_loader: DataLoader, val_loader: DataLoader, training_loss: Module, + model_ckpt_path: Optional[str] = None, scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, **kwargs: Dict[str, Optional[Union[str, int]]], ) -> None: @@ -71,6 +72,8 @@ def __init__( self._training_loss = training_loss self._viz_n_samples = 1 self._n_ctrl_c = 0 + if model_ckpt_path is not None: + self._load_checkpoint(model_ckpt_path) signal.signal(signal.SIGINT, self._terminator) @to_cuda @@ -258,7 +261,6 @@ def train( visualize_every: int = 10, # Visualize every n validations visualize_train_every: int = 0, # Visualize every n training epochs visualize_n_samples: int = 1, - model_ckpt_path: Optional[str] = None, ): """Train the model for a given number of epochs. Args: @@ -268,8 +270,6 @@ def train( Returns: None """ - if model_ckpt_path is not None: - self._load_checkpoint(model_ckpt_path) # console.print(f"[*] Training {self._run_name} for {epochs} epochs", style="bold green") self._viz_n_samples = visualize_n_samples train_losses: List[float] = [] From af36d43f8eb53f31f4efe5603f7f9b9c126381a4 Mon Sep 17 00:00:00 2001 From: Theo Date: Wed, 5 Jun 2024 23:05:10 +0100 Subject: [PATCH 06/38] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 85e1ea6..5b79084 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ tags.* **/*.pickle models/ FIGURES/ +.mypy_cache/ +.ruff_cache/ From de256cfaac75ea45d1f70e93ada650b19c5e0516 Mon Sep 17 00:00:00 2001 From: Theo Date: Thu, 6 Jun 2024 10:20:00 +0100 Subject: [PATCH 07/38] implement Live panel for dataset creation --- bootstrap/factories.py | 23 ++++++++++++++++------- dataset/base/__init__.py | 10 ++++++++-- dataset/base/image.py | 7 ++++++- dataset/example.py | 12 +++++++++--- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/bootstrap/factories.py b/bootstrap/factories.py index ec94400..8d8e00e 100644 --- a/bootstrap/factories.py +++ b/bootstrap/factories.py @@ -9,12 +9,16 @@ All factories. """ +from functools import partial from typing import Any, Optional, Tuple import torch from hydra_zen import just from hydra_zen.typing import Partial -from rich.console import Console +from rich.console import Console, Group +from rich.live import Live +from rich.panel import Panel +from rich.progress import Progress, TaskID from torch.utils.data import DataLoader, Dataset from conf import project as project_conf @@ -29,12 +33,17 @@ def make_datasets( train_dataset: Optional[Dataset[Any]] = None val_dataset: Optional[Dataset[Any]] = None test_dataset: Optional[Dataset[Any]] = None - with console.status("Loading datasets...", spinner="monkey"): - if training_mode: - train_dataset = dataset_partial(split="train", seed=seed) - val_dataset = dataset_partial(split="val", seed=seed) - else: - test_dataset = dataset_partial(split="test", augment=False, seed=seed) + status = console.status("Loading dataset...", spinner="monkey") + progress = Progress(transient=True) + with Live(Panel(Group(status, progress), title="Loading datasets")): + splits = ("train", "val") if training_mode else ("test") + for split in splits: + status.update(f"Loading {split} dataset...") + job_id: TaskID = progress.add_task(f"Processing {split} split...") + aug = {"augment": False} if split == "test" else {} + locals()[f"{split}_dataset"] = dataset_partial( + split=split, seed=seed, progress=progress, job_id=job_id, **aug + ) return train_dataset, val_dataset, test_dataset diff --git a/dataset/base/__init__.py b/dataset/base/__init__.py index 57f5f71..5b055c9 100644 --- a/dataset/base/__init__.py +++ b/dataset/base/__init__.py @@ -18,6 +18,7 @@ from typing import Any, Dict, List, Tuple, Union from hydra.utils import get_original_cwd +from rich.progress import Progress, TaskID from torch import Tensor from torch.utils.data import Dataset @@ -31,13 +32,18 @@ def __init__( normalize: bool, split: str, seed: int, + progress: Progress, + job_id: TaskID, debug: bool, tiny: bool = False, ) -> None: super().__init__() self._samples: Union[Dict[Any, Any], List[Any], Tensor] self._labels: Union[Dict[Any, Any], List[Any], Tensor] - self._samples, self._labels = self._load(dataset_root, tiny, split, seed) + self._progress = progress + self._samples, self._labels = self._load( + dataset_root, tiny, split, seed, job_id + ) self._augment = augment and split == "train" self._normalize = normalize self._dataset_name = dataset_name @@ -49,7 +55,7 @@ def __init__( @abc.abstractmethod def _load( - self, dataset_root: str, tiny: bool, split: str, seed: int + self, dataset_root: str, tiny: bool, split: str, seed: int, job_id: TaskID ) -> Tuple[ Union[Dict[str, Any], List[Any], Tensor], Union[Dict[str, Any], List[Any], Tensor], diff --git a/dataset/base/image.py b/dataset/base/image.py index e2c01b7..2d03559 100644 --- a/dataset/base/image.py +++ b/dataset/base/image.py @@ -12,6 +12,7 @@ import abc from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from rich.progress import Progress, TaskID from torch import Tensor from torchvision.io.image import read_image # type: ignore from torchvision.transforms import transforms # type: ignore @@ -31,6 +32,8 @@ def __init__( dataset_name: str, split: str, seed: int, + progress: Progress, + job_id: TaskID, img_size: Optional[tuple[int, ...]] = None, augment: bool = False, normalize: bool = False, @@ -44,6 +47,8 @@ def __init__( normalize, split, seed, + progress, + job_id, debug=debug, tiny=tiny, ) @@ -72,7 +77,7 @@ def __init__( @abc.abstractmethod def _load( - self, dataset_root: str, tiny: bool, split: str, seed: int + self, dataset_root: str, tiny: bool, split: str, seed: int, job_id: TaskID ) -> Tuple[ Union[Dict[str, Any], List[Any], Tensor], Union[Dict[str, Any], List[Any], Tensor], diff --git a/dataset/example.py b/dataset/example.py index c0390f8..1902f37 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -14,7 +14,7 @@ from typing import Optional, Tuple, Union import torch -from rich.progress import track +from rich.progress import Progress, TaskID from torch import Tensor from dataset.base.image import ImageDataset @@ -29,6 +29,8 @@ def __init__( dataset_name: str, split: str, seed: int, + progress: Progress, + job_id: TaskID, img_dim: Optional[int] = None, augment: bool = False, normalize: bool = False, @@ -41,6 +43,8 @@ def __init__( dataset_name, split, seed, + progress, + job_id, (img_dim, img_dim) if img_dim is not None else None, augment=augment, normalize=normalize, @@ -49,9 +53,11 @@ def __init__( ) def _load( - self, dataset_root: str, tiny: bool, split: str, seed: int + self, dataset_root: str, tiny: bool, split: str, seed: int, job_id: TaskID ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: - for _ in track(range(10), description=f"Loading dataset splt '{split}'"): + self._progress.update(job_id, total=100) + for _ in range(100): + self._progress.advance(job_id) sleep(0.1) return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8) From 5cdc8908131587b05544faa697df88fb8525f80d Mon Sep 17 00:00:00 2001 From: Theo Date: Thu, 6 Jun 2024 11:01:52 +0100 Subject: [PATCH 08/38] Fix dataset inst --- bootstrap/factories.py | 15 ++++++++------- dataset/example.py | 4 ++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/bootstrap/factories.py b/bootstrap/factories.py index 8d8e00e..87dfc17 100644 --- a/bootstrap/factories.py +++ b/bootstrap/factories.py @@ -10,7 +10,7 @@ """ from functools import partial -from typing import Any, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch from hydra_zen import just @@ -30,9 +30,11 @@ def make_datasets( training_mode: bool, seed: int, dataset_partial: Partial[Dataset[Any]] ) -> Tuple[Optional[Dataset[Any]], Optional[Dataset[Any]], Optional[Dataset[Any]]]: - train_dataset: Optional[Dataset[Any]] = None - val_dataset: Optional[Dataset[Any]] = None - test_dataset: Optional[Dataset[Any]] = None + datasets: Dict[str, Optional[Dataset[Any]]] = { + "train": None, + "val": None, + "test": None, + } status = console.status("Loading dataset...", spinner="monkey") progress = Progress(transient=True) with Live(Panel(Group(status, progress), title="Loading datasets")): @@ -41,10 +43,10 @@ def make_datasets( status.update(f"Loading {split} dataset...") job_id: TaskID = progress.add_task(f"Processing {split} split...") aug = {"augment": False} if split == "test" else {} - locals()[f"{split}_dataset"] = dataset_partial( + datasets[split] = dataset_partial( split=split, seed=seed, progress=progress, job_id=job_id, **aug ) - return train_dataset, val_dataset, test_dataset + return datasets["train"], datasets["val"], datasets["test"] def make_dataloaders( @@ -92,7 +94,6 @@ def make_model( model_inst = model_partial( encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore ) # Use just() to get the config out of the Zen-Partial - return model_inst diff --git a/dataset/example.py b/dataset/example.py index 1902f37..22d21bc 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -55,8 +55,8 @@ def __init__( def _load( self, dataset_root: str, tiny: bool, split: str, seed: int, job_id: TaskID ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: - self._progress.update(job_id, total=100) - for _ in range(100): + self._progress.update(job_id, total=20) + for _ in range(20): self._progress.advance(job_id) sleep(0.1) return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8) From eb8faa65c99c2702ee375fcd3988b78918f37d00 Mon Sep 17 00:00:00 2001 From: Theo Date: Thu, 6 Jun 2024 11:05:03 +0100 Subject: [PATCH 09/38] Fix pickling of datasets by removing progress attr --- dataset/base/__init__.py | 16 +++++++++++++--- dataset/base/image.py | 8 +++++++- dataset/example.py | 12 +++++++++--- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/dataset/base/__init__.py b/dataset/base/__init__.py index 5b055c9..db56449 100644 --- a/dataset/base/__init__.py +++ b/dataset/base/__init__.py @@ -40,9 +40,13 @@ def __init__( super().__init__() self._samples: Union[Dict[Any, Any], List[Any], Tensor] self._labels: Union[Dict[Any, Any], List[Any], Tensor] - self._progress = progress self._samples, self._labels = self._load( - dataset_root, tiny, split, seed, job_id + dataset_root, + tiny, + split, + seed, + progress, + job_id, ) self._augment = augment and split == "train" self._normalize = normalize @@ -55,7 +59,13 @@ def __init__( @abc.abstractmethod def _load( - self, dataset_root: str, tiny: bool, split: str, seed: int, job_id: TaskID + self, + dataset_root: str, + tiny: bool, + split: str, + seed: int, + progress: Progress, + job_id: TaskID, ) -> Tuple[ Union[Dict[str, Any], List[Any], Tensor], Union[Dict[str, Any], List[Any], Tensor], diff --git a/dataset/base/image.py b/dataset/base/image.py index 2d03559..cc3c21b 100644 --- a/dataset/base/image.py +++ b/dataset/base/image.py @@ -77,7 +77,13 @@ def __init__( @abc.abstractmethod def _load( - self, dataset_root: str, tiny: bool, split: str, seed: int, job_id: TaskID + self, + dataset_root: str, + tiny: bool, + split: str, + seed: int, + progress: Progress, + job_id: TaskID, ) -> Tuple[ Union[Dict[str, Any], List[Any], Tensor], Union[Dict[str, Any], List[Any], Tensor], diff --git a/dataset/example.py b/dataset/example.py index 22d21bc..e94a450 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -53,11 +53,17 @@ def __init__( ) def _load( - self, dataset_root: str, tiny: bool, split: str, seed: int, job_id: TaskID + self, + dataset_root: str, + tiny: bool, + split: str, + seed: int, + progress: Progress, + job_id: TaskID, ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: - self._progress.update(job_id, total=20) + progress.update(job_id, total=20) for _ in range(20): - self._progress.advance(job_id) + progress.advance(job_id) sleep(0.1) return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8) From 48eb97e54806b281b3fc3c3abdbe9c88e7376df8 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 25 Jun 2024 14:56:46 +0100 Subject: [PATCH 10/38] Use fancy new dataset mixins --- conf/experiment.py | 11 +- dataset/base/__init__.py | 80 ------ dataset/base/image.py | 27 +- dataset/example.py | 80 +++++- dataset/mixins/__init__.py | 491 +++++++++++++++++++++++++++++++++++++ 5 files changed, 576 insertions(+), 113 deletions(-) delete mode 100644 dataset/base/__init__.py create mode 100644 dataset/mixins/__init__.py diff --git a/conf/experiment.py b/conf/experiment.py index e8c9577..858a0cf 100644 --- a/conf/experiment.py +++ b/conf/experiment.py @@ -29,7 +29,7 @@ from unique_names_generator.data import ADJECTIVES, NAMES from bootstrap.launch_experiment import launch_experiment -from dataset.example import ExampleDataset +from dataset.example import SingleProcessingExampleDataset from model.example import ExampleModel from src.base_tester import BaseTester from src.base_trainer import BaseTrainer @@ -65,18 +65,19 @@ class ExampleDatasetConf: normalize: bool = True augment: bool = False debug: bool = False - img_dim: int = ExampleDataset.IMG_SIZE[0] + img_dim: int = SingleProcessingExampleDataset.IMG_SIZE[0] # Pre-set the group for store's dataset entries dataset_store = store(group="dataset") dataset_store( - pbuilds(ExampleDataset, builds_bases=(ExampleDatasetConf,)), name="image_a" + pbuilds(SingleProcessingExampleDataset, builds_bases=(ExampleDatasetConf,)), + name="image_a", ) dataset_store( pbuilds( - ExampleDataset, + SingleProcessingExampleDataset, builds_bases=(ExampleDatasetConf,), dataset_root="data/b", img_dim=64, @@ -85,7 +86,7 @@ class ExampleDatasetConf: ) dataset_store( pbuilds( - ExampleDataset, + SingleProcessingExampleDataset, builds_bases=(ExampleDatasetConf,), tiny=True, ), diff --git a/dataset/base/__init__.py b/dataset/base/__init__.py deleted file mode 100644 index db56449..0000000 --- a/dataset/base/__init__.py +++ /dev/null @@ -1,80 +0,0 @@ -#! /usr/bin/env python3 -# vim:fenc=utf-8 -# -# Copyright © 2023 Théo Morales -# -# Distributed under terms of the MIT license. - -""" -Base dataset. -In this file you may implement other base datasets that share the same characteristics and which -need the same data loading + transformation pipeline. The specificities of loading the data or -transforming it may be extended through class inheritance in a specific dataset file. -""" - -import abc -import os -import os.path as osp -from typing import Any, Dict, List, Tuple, Union - -from hydra.utils import get_original_cwd -from rich.progress import Progress, TaskID -from torch import Tensor -from torch.utils.data import Dataset - - -class BaseDataset(Dataset[Any], abc.ABC): - def __init__( - self, - dataset_root: str, - dataset_name: str, - augment: bool, - normalize: bool, - split: str, - seed: int, - progress: Progress, - job_id: TaskID, - debug: bool, - tiny: bool = False, - ) -> None: - super().__init__() - self._samples: Union[Dict[Any, Any], List[Any], Tensor] - self._labels: Union[Dict[Any, Any], List[Any], Tensor] - self._samples, self._labels = self._load( - dataset_root, - tiny, - split, - seed, - progress, - job_id, - ) - self._augment = augment and split == "train" - self._normalize = normalize - self._dataset_name = dataset_name - self._debug = debug - self._cache_dir = osp.join( - get_original_cwd(), "data", f"{dataset_name}_preprocessed" - ) - os.makedirs(self._cache_dir, exist_ok=True) - - @abc.abstractmethod - def _load( - self, - dataset_root: str, - tiny: bool, - split: str, - seed: int, - progress: Progress, - job_id: TaskID, - ) -> Tuple[ - Union[Dict[str, Any], List[Any], Tensor], - Union[Dict[str, Any], List[Any], Tensor], - ]: - # Implement this - raise NotImplementedError - - def __len__(self) -> int: - return len(self._samples) - - def disable_augs(self) -> None: - self._augment = False diff --git a/dataset/base/image.py b/dataset/base/image.py index cc3c21b..208f60d 100644 --- a/dataset/base/image.py +++ b/dataset/base/image.py @@ -10,17 +10,18 @@ """ import abc -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from rich.progress import Progress, TaskID from torch import Tensor +from torch.utils.data import Dataset from torchvision.io.image import read_image # type: ignore from torchvision.transforms import transforms # type: ignore -from dataset.base import BaseDataset +from dataset.mixins import BaseDatasetMixin -class ImageDataset(BaseDataset, abc.ABC): +class ImageDataset(BaseDatasetMixin, Dataset): IMAGE_NET_MEAN: List[float] = [] IMAGE_NET_STD: List[float] = [] COCO_MEAN, COCO_STD = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) @@ -47,10 +48,10 @@ def __init__( normalize, split, seed, + debug, + tiny, progress, job_id, - debug=debug, - tiny=tiny, ) self._img_size = self.IMG_SIZE if img_size is None else img_size self._transforms: Callable[[Tensor], Tensor] = transforms.Compose( @@ -75,22 +76,6 @@ def __init__( ] ) - @abc.abstractmethod - def _load( - self, - dataset_root: str, - tiny: bool, - split: str, - seed: int, - progress: Progress, - job_id: TaskID, - ) -> Tuple[ - Union[Dict[str, Any], List[Any], Tensor], - Union[Dict[str, Any], List[Any], Tensor], - ]: - # Implement this - raise NotImplementedError - def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: """ This should be common to all image datasets! diff --git a/dataset/example.py b/dataset/example.py index e94a450..01593fc 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -11,7 +11,7 @@ """ from time import sleep -from typing import Optional, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union import torch from rich.progress import Progress, TaskID @@ -20,7 +20,7 @@ from dataset.base.image import ImageDataset -class ExampleDataset(ImageDataset): +class SingleProcessingExampleDataset(ImageDataset): IMG_SIZE = (32, 32) def __init__( @@ -37,7 +37,6 @@ def __init__( tiny: bool = False, debug: bool = False, ) -> None: - self._img_dim = self.IMG_SIZE[0] if img_dim is None else img_dim super().__init__( dataset_root, dataset_name, @@ -51,13 +50,14 @@ def __init__( debug=debug, tiny=tiny, ) + self._img_dim = self.IMG_SIZE[0] if img_dim is None else img_dim + self._samples, self._labels = self._load( + progress, + job_id, + ) def _load( self, - dataset_root: str, - tiny: bool, - split: str, - seed: int, progress: Progress, job_id: TaskID, ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: @@ -69,3 +69,69 @@ def _load( def __getitem__(self, index: int): return self._samples[index], self._labels[index] + + +class MultiProcessingExampleDataset(ImageDataset): + IMG_SIZE = (32, 32) + + def __init__( + self, + dataset_root: str, + dataset_name: str, + split: str, + seed: int, + progress: Progress, + job_id: TaskID, + img_dim: Optional[int] = None, + augment: bool = False, + normalize: bool = False, + tiny: bool = False, + debug: bool = False, + ) -> None: + self._img_dim = self.IMG_SIZE[0] if img_dim is None else img_dim + super().__init__( + dataset_root, + dataset_name, + split, + seed, + progress, + job_id, + (img_dim, img_dim) if img_dim is not None else None, + augment=augment, + normalize=normalize, + debug=debug, + tiny=tiny, + ) + + +class MultiProcessingWithCachingExampleDataset(ImageDataset): + IMG_SIZE = (32, 32) + + def __init__( + self, + dataset_root: str, + dataset_name: str, + split: str, + seed: int, + progress: Progress, + job_id: TaskID, + img_dim: Optional[int] = None, + augment: bool = False, + normalize: bool = False, + tiny: bool = False, + debug: bool = False, + ) -> None: + self._img_dim = self.IMG_SIZE[0] if img_dim is None else img_dim + super().__init__( + dataset_root, + dataset_name, + split, + seed, + progress, + job_id, + (img_dim, img_dim) if img_dim is not None else None, + augment=augment, + normalize=normalize, + debug=debug, + tiny=tiny, + ) diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py new file mode 100644 index 0000000..4ff3bf4 --- /dev/null +++ b/dataset/mixins/__init__.py @@ -0,0 +1,491 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 +# +# Copyright © 2023 Théo Morales +# +# Distributed under terms of the MIT license. + +""" +Base dataset. +In this file you may implement other base datasets that share the same characteristics and which +need the same data loading + transformation pipeline. The specificities of loading the data or +transforming it may be extended through class inheritance in a specific dataset file. +""" + +import abc +import hashlib +import inspect +import itertools +import os +import os.path as osp +import pickle +import shutil +from multiprocessing.pool import Pool +from os import cpu_count +from typing import Any, Callable, List, Optional, Sequence, Tuple + +from hydra.utils import get_original_cwd +from rich.progress import Progress, TaskID +from tqdm import tqdm + +from utils.helpers import compressed_read, compressed_write + +# TODO: Create a named tuple for progress and job_id + + +class DatasetMixinInterface(abc.ABC): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + **kwargs, + ): + _ = dataset_root + _ = dataset_name + _ = augment + _ = normalize + _ = split + _ = seed + _ = debug + _ = tiny + _ = progress + _ = job_id + super().__init__(**kwargs) + + +class BaseDatasetMixin(DatasetMixinInterface): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + **kwargs, + ): + self._samples, self._labels = None, None + self._augment = augment and split == "train" + self._normalize = normalize + self._dataset_name = dataset_name + self._debug = debug + super().__init__( + dataset_root, + dataset_name, + augment, + normalize, + split, + seed, + debug, + tiny, + progress, + job_id, + **kwargs, + ) + + def __len__(self) -> int: + if self._samples is None or self._labels is None: + raise ValueError( + "Dataset not loaded. Please make sure to assign self._samples and self._labels." + ) + return len(self._samples) + + def disable_augs(self) -> None: + self._augment = False + + +class SafeCacheDatasetMixin(DatasetMixinInterface): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + scd_lazy: bool = True, + **kwargs, + ): + self._cache_dir = osp.join( + get_original_cwd(), + "data", + f"{dataset_name}_preprocessed", + f"{'tiny_' if tiny else ''}{split}", + ) + self._split = split + self._lazy = scd_lazy # TODO: Implement eager caching (rn the default is lazy) + # TODO: Compute fingerprint of dataset parameters, data source and *most importantly* + # code implementation of _load method and every other user function called from it!!! + # (could we use git for that? Like just querying git diff on the dataset implementation + # file) + # If a fingerprint is found in self._cache_dir, compare it to the current fingerprint. If + # they differ, flush the cache and recompute. If not, load the dataset from cache. If no + # fingerprint is found, store the current fingerprint. + argnames = inspect.getfullargspec(SafeCacheDatasetMixin.__init__).args + argvalues = { + k: v + for k, v in inspect.getargvalues(inspect.currentframe()).locals.items() + if k in argnames and k not in ["self", "tiny", "scd_lazy"] + } + hasher = hashlib.new("md5") + # TODO: We should also hash the locals of the user's class __init__ method! + hasher.update(pickle.dumps(argvalues)) + # TODO: Make sure the comments of the user's methods are not included in the fingerprint, + # and make sure to recursively hash the source code of the user's methods. + hasher.update(pickle.dumps(inspect.getsource(self._load))) + hasher.update(pickle.dumps(inspect.getsource(self._get_raw_elements))) + self.fingerprint = hasher.hexdigest() + print(f"Fingerprint: {self.fingerprint}") + mismatch, not_found = False, False + if osp.isfile(osp.join(self._cache_dir, "fingerprint")): + with open(osp.join(self._cache_dir, "fingerprint"), "r") as f: + cached_fingerprint = f.read() + if cached_fingerprint != self.fingerprint: + mismatch = True + else: + not_found = True + if mismatch or not_found: + print( + ("Fingerprint mismatch" if mismatch else "No fingerprint found") + + ", flushing cache." + ) + shutil.rmtree(self._cache_dir, ignore_errors=True) + os.makedirs(self._cache_dir, exist_ok=True) + super().__init__( + dataset_root, + dataset_name, + augment, + normalize, + split, + seed, + debug, + tiny, + progress, + job_id, + **kwargs, + ) + with open(osp.join(self._cache_dir, "fingerprint"), "w") as f: + f.write(self.fingerprint) + + def _get_raw_elements_hook(self, *args) -> Sequence[Any]: + class LazyCacheSequence(Sequence): + def __init__(self, cache_dir: str, seq_len: int, seq_type: str): + self._cache_paths = [] + self._seq_len = seq_len + self._seq_type = seq_type + + for i in itertools.count(): + cache_path = osp.join(cache_dir, f"{i:04d}.pkl") + if not osp.isfile(cache_path): + break + self._cache_paths.append(cache_path) + + if len(self._cache_paths) != seq_len: + raise ValueError( + f"Cache info file {osp.join(cache_dir, 'info.txt')} does not match the number of " + + f"cache files in {cache_dir}. " + + "This may be due to an interrupted dataset computation. " + + "Please manually flush the cash to recompute." + ) + + el_type_str = "unknown" + try: + el_type_str = str(type(compressed_read(self._cache_paths[0]))) + except Exception: + pass + + if el_type_str != seq_type: + raise ValueError( + f"Cache info file {osp.join(cache_dir, 'info.txt')} does not match the type of " + + f"cache files in {cache_dir}. " + + "This may be due to an interrupted dataset computation. " + + "Please manually flush the cash to recompute." + ) + + def __len__(self): + return self._seq_len + + def __getitem__(self, idx): + if idx >= self._seq_len: + raise IndexError + return compressed_read(self._cache_paths[idx]) + + # This hooks onto the user's _get_raw_elements method and overrides it if a cache entry is + # found. If not it just calls the user's _get_raw_elements method. + # return self._get_raw_elements(*args, **kwargs) + path = osp.join(self._cache_dir, "raw_elements") + try: + info = [None, None] + if not osp.isfile(osp.join(path, "info.txt")): + raise FileNotFoundError( + f"Cache info file not found at {osp.join(path, 'info.txt')}." + ) + with open(osp.join(path, "info.txt"), "r") as f: + info = f.readlines() + if len(info) != 2 or not info[0].strip().isdigit(): + raise ValueError( + f"Invalid cache info file {osp.join(path, 'info.txt')}." + ) + raw_elements = LazyCacheSequence( + path, int(info[0].strip()), info[1].strip() + ) + except FileNotFoundError: + if not hasattr(self, "_get_raw_elements"): + raise NotImplementedError( + "SafeCacheDatasetMixin._get_raw_elements() is called but the user has not " + + f"implemented a _get_raw_elements method in {self.__class__.__name__}." + ) + # Compute them: + raw_elements: Sequence[Any] = self._get_raw_elements(*args, **kwargs) + type_str = "unknown" + try: + type_str = type(raw_elements[0]) + except Exception: + pass + # Cache them: + os.makedirs(path, exist_ok=True) + with open(osp.join(path, "info.txt"), "w") as f: + f.writelines([f"{len(raw_elements)}\n", f"{type_str}"]) + + print(f"going to cache {len(raw_elements)} elements") + for i, element in enumerate(raw_elements): + compressed_write(osp.join(path, f"{i:04d}.pkl"), element) + print(f"cached {len(raw_elements)} elements") + # TODO: Rich.log("Raw elements cached here: <>") + return raw_elements + + def _load_hook(self, *args, **kwargs) -> Tuple[int, Any, Any]: + # This hooks onto the user's _load method and overrides it if a cache entry is found. If + # not it just calls the user's _load method. + idx = args[1] + try: + sample, label = self._load_sample_label(idx) + except KeyError: + if not hasattr(self, "_load"): + raise NotImplementedError( + "SafeCacheDatasetMixin._load() is called but the user has not implemented " + + f"a _load method in {self.__class__.__name__}." + ) + _idx, sample, label = self._load(*args, **kwargs) + assert _idx == idx + return idx, sample, label + + def _load_sample_label(self, idx: int) -> Tuple[Any, Any]: + if hasattr(super(), "_load_sample_label"): + raise Exception( + "SafeCacheDatasetMixin._load_sample_label() is overriden. " + + "As best practice, you should inherit from SafeCacheDatasetMixin " + + f"after {super().__class__.__name__} to avoid unwanted behavior." + ) + cache_path = osp.join(self._cache_dir, f"{idx:04d}.pkl") + if not osp.isfile(cache_path): + raise KeyError(f"Cache file {cache_path} not found, will recompute.") + return compressed_read(cache_path) + + def _register_sample_label( + self, + idx: int, + sample: Any, + label: Any, + memory_samples: List[Any], + memory_labels: List[Any], + ): + if hasattr(super(), "_register_sample_label"): + raise Exception( + "SafeCacheDatasetMixin._register_sample_label() is overriden. " + + "As best practice, you should inherit from SafeCacheDatasetMixin " + + f"after {super().__class__.__name__} to avoid unwanted behavior." + ) + # TODO: Implement safe caching here? Well all the caching logic should move to a caching + # module, but should be hooked here. So basically *ALL* dataset parameters, as well as the + # data source, should be hashed into a fingerprint. When the fingerprint is different than + # the cache's, we flush the cache and recompute. Now, a common use case of mine is to have + # several stages of data pre-processing, with one typically faster than the other but still + # slow enough to be cached. In this case, it would be handy to have a decorator to wrap the + # user's methods. But the _load method should probably be cached by default, no? + + # Let's do some temporary caching for now: + cache_path = osp.join(self._cache_dir, f"{idx:04d}.pkl") + compressed_write(cache_path, (sample, label)) + memory_samples.insert(idx, cache_path) + memory_labels.insert(idx, cache_path) + + +class MultiProcessingDatasetMixin(DatasetMixinInterface, abc.ABC): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + mpd_lazy: bool = True, + mpd_chunk_size: int = 1, + mpd_processes: Optional[int] = None, + **kwargs, + ) -> None: + super().__init__( + dataset_root, + dataset_name, + augment, + normalize, + split, + seed, + debug, + tiny, + progress, + job_id, + **kwargs, + ) + self._samples, self._labels = [], [] + processes = mpd_processes or (cpu_count() - 1) + + with Pool(processes) as pool: + raw_elements = self._get_raw_elements_hook( + dataset_root, tiny, split, seed, progress, job_id + ) + if raw_elements[0] is None or raw_elements[0] is None: + raise ValueError( + "The _get_raw_elements method returned None or a sequence of None. " + ) + + if len(raw_elements) == 0: + raise ValueError( + "The _get_raw_elements method must return a sequence of elements. " + + "If the dataset is empty, return an empty list." + ) + + pool_dispatch_method: Callable = pool.imap if mpd_lazy else pool.starmap + pool_dispatch_func: Callable = ( + self._load_hook_unpack if mpd_lazy else self._load_hook + ) + for idx, sample, label in tqdm( + pool_dispatch_method( + pool_dispatch_func, + zip( + raw_elements, + # itertools.count(len(raw_elements)), + range(len(raw_elements)), + itertools.repeat(tiny), + itertools.repeat(split), + itertools.repeat(seed), + ), # TODO: Pass the progress and job_id here + chunksize=mpd_chunk_size, + ), + total=len(raw_elements), + ): + self._register_sample_label(idx, sample, label) + print(f"{'Lazy' if mpd_lazy else 'Eager'} loaded {len(self._samples)} samples.") + + def _get_raw_elements_hook(self, *args): + if hasattr(super(), "_get_raw_elements_hook"): + return super()._get_raw_elements_hook(*args) + else: + return self._get_raw_elements(*args) + + def _load_hook_unpack(self, args): + # TODO: Fix this mess. It should not just load the samples and labels when we're calling + # the load hook from cache. + return self._load_hook(*args) + + def _load_hook(self, *args) -> Tuple[int, Any, Any]: + if hasattr(super(), "_load_hook"): + return super()._load_hook(*args) + else: + return self._load(*args) # TODO: Rename to _load_sample? + + def _load_sample_label(self, idx: int) -> Tuple[Any, Any]: + if hasattr(super(), "_load_sample_label"): + return super()._load_sample_label(idx) + return self._samples[idx], self._labels[idx] + + def _register_sample_label(self, idx: int, sample: Any, label: Any): + if hasattr(super(), "_register_sample_label"): + return super()._register_sample_label( + idx, sample, label, self._samples, self._labels + ) + if isinstance(sample, (List, Tuple)) or isinstance(label, (List, Tuple)): + raise NotImplementedError( + "_register_sample_label cannot yet handle lists of samples/labels" + ) + self._samples.insert(idx, sample) + self._labels.insert(idx, label) + + @abc.abstractmethod + def _get_raw_elements( + self, + dataset_root: str, + tiny: bool, + split: str, + seed: int, + progress: Progress, + job_id: TaskID, + ) -> Sequence[Any]: + # Implement this + raise NotImplementedError + + @abc.abstractmethod + def _load( + self, + element: Any, + idx: int, + tiny: bool, + split: str, + seed: int, + progress: Progress, + job_id: TaskID, + ) -> Tuple[int, Any, Any]: + # Implement this + raise NotImplementedError + + +class BatchedTensorsMultiprocessingDatasetMixin(DatasetMixinInterface, abc.ABC): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + **kwargs, + ) -> None: + super().__init__( + dataset_root, + dataset_name, + augment, + normalize, + split, + seed, + debug, + tiny, + progress, + job_id, + **kwargs, + ) + raise NotImplementedError("This is not implemented yet.") From f7fac2ac7983121b917bc03f67d4581282c6ebb7 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 25 Jun 2024 14:59:40 +0100 Subject: [PATCH 11/38] Use runner instead of monkey --- bootstrap/factories.py | 2 +- dataset/base/image.py | 3 +-- dataset/example.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/bootstrap/factories.py b/bootstrap/factories.py index 87dfc17..296757c 100644 --- a/bootstrap/factories.py +++ b/bootstrap/factories.py @@ -35,7 +35,7 @@ def make_datasets( "val": None, "test": None, } - status = console.status("Loading dataset...", spinner="monkey") + status = console.status("Loading dataset...", spinner="runner") progress = Progress(transient=True) with Live(Panel(Group(status, progress), title="Loading datasets")): splits = ("train", "val") if training_mode else ("test") diff --git a/dataset/base/image.py b/dataset/base/image.py index 208f60d..66e1e59 100644 --- a/dataset/base/image.py +++ b/dataset/base/image.py @@ -9,8 +9,7 @@ Base dataset for images. """ -import abc -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple from rich.progress import Progress, TaskID from torch import Tensor diff --git a/dataset/example.py b/dataset/example.py index 01593fc..6eff78d 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -11,7 +11,7 @@ """ from time import sleep -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Tuple, Union import torch from rich.progress import Progress, TaskID From 410d4b9f9f385db26569076025528c5f4e48d075 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 2 Jul 2024 23:57:21 +0100 Subject: [PATCH 12/38] Fix bugs and improve mixins --- dataset/mixins/__init__.py | 135 ++++++++++++++++++++----------------- 1 file changed, 72 insertions(+), 63 deletions(-) diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index 4ff3bf4..ca37e1b 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -30,7 +30,7 @@ from utils.helpers import compressed_read, compressed_write -# TODO: Create a named tuple for progress and job_id +# TODO: Can we speed all of this up with Cython or Numba? class DatasetMixinInterface(abc.ABC): @@ -76,7 +76,7 @@ def __init__( job_id: TaskID, **kwargs, ): - self._samples, self._labels = None, None + self._samples, self._labels = [], [] self._augment = augment and split == "train" self._normalize = normalize self._dataset_name = dataset_name @@ -96,16 +96,13 @@ def __init__( ) def __len__(self) -> int: - if self._samples is None or self._labels is None: - raise ValueError( - "Dataset not loaded. Please make sure to assign self._samples and self._labels." - ) return len(self._samples) def disable_augs(self) -> None: self._augment = False +# TODO: Automatically "remove" this mixin if debug=True. Idk how, maybe metaclass? class SafeCacheDatasetMixin(DatasetMixinInterface): def __init__( self, @@ -138,20 +135,26 @@ def __init__( # they differ, flush the cache and recompute. If not, load the dataset from cache. If no # fingerprint is found, store the current fingerprint. argnames = inspect.getfullargspec(SafeCacheDatasetMixin.__init__).args + frame = inspect.currentframe() + if frame is None: + raise RuntimeError("Cannot compute fingerprint without a frame.") argvalues = { k: v - for k, v in inspect.getargvalues(inspect.currentframe()).locals.items() + for k, v in inspect.getargvalues(frame).locals.items() if k in argnames and k not in ["self", "tiny", "scd_lazy"] } hasher = hashlib.new("md5") # TODO: We should also hash the locals of the user's class __init__ method! hasher.update(pickle.dumps(argvalues)) # TODO: Make sure the comments of the user's methods are not included in the fingerprint, - # and make sure to recursively hash the source code of the user's methods. - hasher.update(pickle.dumps(inspect.getsource(self._load))) - hasher.update(pickle.dumps(inspect.getsource(self._get_raw_elements))) + # and make sure to recursively hash the source code of the user's methods. For the former, + # we could use the inspect module, for the latter we could use the ast module or a regex + # with inspect.getcodelines(). + # NOTE: getsource() won't work if I have a decorator that wraps the method. I think it's + # best to keep this behaviour and not use decorators. + hasher.update(pickle.dumps(inspect.getsource(self.__class__._load))) # type: ignore + hasher.update(pickle.dumps(inspect.getsource(self.__class__._get_raw_elements))) # type: ignore self.fingerprint = hasher.hexdigest() - print(f"Fingerprint: {self.fingerprint}") mismatch, not_found = False, False if osp.isfile(osp.join(self._cache_dir, "fingerprint")): with open(osp.join(self._cache_dir, "fingerprint"), "r") as f: @@ -160,13 +163,24 @@ def __init__( mismatch = True else: not_found = True - if mismatch or not_found: - print( - ("Fingerprint mismatch" if mismatch else "No fingerprint found") - + ", flushing cache." - ) + + flush = False + if mismatch: + while flush not in ["y", "n"]: + flush = input("Fingerprint mismatch, flush cache? (y/n) ").lower() + flush = flush.lower().strip() == "y" + if not flush: + print( + "[!] Warning: Fingerprint mismatch, but cache will not be flushed." + ) + + if not_found: + print("No fingerprint found, flushing cache.") + flush = True + + if flush: shutil.rmtree(self._cache_dir, ignore_errors=True) - os.makedirs(self._cache_dir, exist_ok=True) + os.makedirs(self._cache_dir, exist_ok=True) super().__init__( dataset_root, dataset_name, @@ -180,17 +194,19 @@ def __init__( job_id, **kwargs, ) - with open(osp.join(self._cache_dir, "fingerprint"), "w") as f: - f.write(self.fingerprint) + if flush or not_found: + with open(osp.join(self._cache_dir, "fingerprint"), "w") as f: + f.write(self.fingerprint) - def _get_raw_elements_hook(self, *args) -> Sequence[Any]: + def _get_raw_elements_hook(self, *args, **kwargs) -> Sequence[Any]: + # TODO: Investigate slowness issues class LazyCacheSequence(Sequence): def __init__(self, cache_dir: str, seq_len: int, seq_type: str): self._cache_paths = [] self._seq_len = seq_len self._seq_type = seq_type - for i in itertools.count(): + for i in itertools.count(): # TODO: Could this be slow? cache_path = osp.join(cache_dir, f"{i:04d}.pkl") if not osp.isfile(cache_path): break @@ -252,7 +268,7 @@ def __getitem__(self, idx): + f"implemented a _get_raw_elements method in {self.__class__.__name__}." ) # Compute them: - raw_elements: Sequence[Any] = self._get_raw_elements(*args, **kwargs) + raw_elements: Sequence[Any] = self._get_raw_elements(*args, **kwargs) # type: ignore type_str = "unknown" try: type_str = type(raw_elements[0]) @@ -274,16 +290,20 @@ def _load_hook(self, *args, **kwargs) -> Tuple[int, Any, Any]: # This hooks onto the user's _load method and overrides it if a cache entry is found. If # not it just calls the user's _load method. idx = args[1] - try: - sample, label = self._load_sample_label(idx) - except KeyError: + cache_path = osp.join(self._cache_dir, f"{idx:04d}.pkl") + if osp.isfile(cache_path): + sample, label = None, None + else: if not hasattr(self, "_load"): raise NotImplementedError( "SafeCacheDatasetMixin._load() is called but the user has not implemented " + f"a _load method in {self.__class__.__name__}." ) - _idx, sample, label = self._load(*args, **kwargs) - assert _idx == idx + _idx, sample, label = self._load(*args, **kwargs) # type: ignore + if _idx != idx: + raise ValueError( + "The _load method returned an index different from the one requested." + ) return idx, sample, label def _load_sample_label(self, idx: int) -> Tuple[Any, Any]: @@ -312,17 +332,13 @@ def _register_sample_label( + "As best practice, you should inherit from SafeCacheDatasetMixin " + f"after {super().__class__.__name__} to avoid unwanted behavior." ) - # TODO: Implement safe caching here? Well all the caching logic should move to a caching - # module, but should be hooked here. So basically *ALL* dataset parameters, as well as the - # data source, should be hashed into a fingerprint. When the fingerprint is different than - # the cache's, we flush the cache and recompute. Now, a common use case of mine is to have - # several stages of data pre-processing, with one typically faster than the other but still - # slow enough to be cached. In this case, it would be handy to have a decorator to wrap the - # user's methods. But the _load method should probably be cached by default, no? - - # Let's do some temporary caching for now: cache_path = osp.join(self._cache_dir, f"{idx:04d}.pkl") - compressed_write(cache_path, (sample, label)) + if not osp.isfile(cache_path): + if sample is None: + raise ValueError( + "The _load_hook method returned sample=None, but no cache entry was found. " + ) + compressed_write(cache_path, (sample, label)) memory_samples.insert(idx, cache_path) memory_labels.insert(idx, cache_path) @@ -359,11 +375,17 @@ def __init__( **kwargs, ) self._samples, self._labels = [], [] - processes = mpd_processes or (cpu_count() - 1) + cpus = cpu_count() + processes = ( + 1 if debug else (mpd_processes or ((cpus - 1) if cpus is not None else 0)) + ) with Pool(processes) as pool: raw_elements = self._get_raw_elements_hook( - dataset_root, tiny, split, seed, progress, job_id + dataset_root, + tiny, + split, + seed, ) if raw_elements[0] is None or raw_elements[0] is None: raise ValueError( @@ -390,7 +412,7 @@ def __init__( itertools.repeat(tiny), itertools.repeat(split), itertools.repeat(seed), - ), # TODO: Pass the progress and job_id here + ), chunksize=mpd_chunk_size, ), total=len(raw_elements), @@ -398,31 +420,31 @@ def __init__( self._register_sample_label(idx, sample, label) print(f"{'Lazy' if mpd_lazy else 'Eager'} loaded {len(self._samples)} samples.") - def _get_raw_elements_hook(self, *args): + def _get_raw_elements_hook( + self, dataset_root: str, tiny: bool, split: str, seed: int + ): if hasattr(super(), "_get_raw_elements_hook"): - return super()._get_raw_elements_hook(*args) + return super()._get_raw_elements_hook(dataset_root, tiny, split, seed) # type: ignore else: - return self._get_raw_elements(*args) + return self._get_raw_elements(dataset_root, tiny, split, seed) def _load_hook_unpack(self, args): - # TODO: Fix this mess. It should not just load the samples and labels when we're calling - # the load hook from cache. return self._load_hook(*args) def _load_hook(self, *args) -> Tuple[int, Any, Any]: if hasattr(super(), "_load_hook"): - return super()._load_hook(*args) + return super()._load_hook(*args) # type: ignore else: return self._load(*args) # TODO: Rename to _load_sample? def _load_sample_label(self, idx: int) -> Tuple[Any, Any]: if hasattr(super(), "_load_sample_label"): - return super()._load_sample_label(idx) + return super()._load_sample_label(idx) # type: ignore return self._samples[idx], self._labels[idx] def _register_sample_label(self, idx: int, sample: Any, label: Any): if hasattr(super(), "_register_sample_label"): - return super()._register_sample_label( + return super()._register_sample_label( # type: ignore idx, sample, label, self._samples, self._labels ) if isinstance(sample, (List, Tuple)) or isinstance(label, (List, Tuple)): @@ -434,27 +456,14 @@ def _register_sample_label(self, idx: int, sample: Any, label: Any): @abc.abstractmethod def _get_raw_elements( - self, - dataset_root: str, - tiny: bool, - split: str, - seed: int, - progress: Progress, - job_id: TaskID, + self, dataset_root: str, tiny: bool, split: str, seed: int ) -> Sequence[Any]: # Implement this raise NotImplementedError @abc.abstractmethod def _load( - self, - element: Any, - idx: int, - tiny: bool, - split: str, - seed: int, - progress: Progress, - job_id: TaskID, + self, element: Any, idx: int, tiny: bool, split: str, seed: int ) -> Tuple[int, Any, Any]: # Implement this raise NotImplementedError From f312f5242bcbc02a6bc3845cf63efa9f4e606d95 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 2 Jul 2024 23:58:19 +0100 Subject: [PATCH 13/38] Pass the right arguments to mixins --- dataset/example.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/dataset/example.py b/dataset/example.py index 6eff78d..9d0b363 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -71,7 +71,7 @@ def __getitem__(self, index: int): return self._samples[index], self._labels[index] -class MultiProcessingExampleDataset(ImageDataset): +class MultiProcessingExampleDataset(ImageDataset): # TODO IMG_SIZE = (32, 32) def __init__( @@ -92,19 +92,19 @@ def __init__( super().__init__( dataset_root, dataset_name, + augment, + normalize, split, seed, + debug, + tiny, progress, job_id, (img_dim, img_dim) if img_dim is not None else None, - augment=augment, - normalize=normalize, - debug=debug, - tiny=tiny, ) -class MultiProcessingWithCachingExampleDataset(ImageDataset): +class MultiProcessingWithCachingExampleDataset(ImageDataset): # TODO IMG_SIZE = (32, 32) def __init__( @@ -125,13 +125,13 @@ def __init__( super().__init__( dataset_root, dataset_name, + augment, + normalize, split, seed, + debug, + tiny, progress, job_id, (img_dim, img_dim) if img_dim is not None else None, - augment=augment, - normalize=normalize, - debug=debug, - tiny=tiny, ) From aaf1ea6edd04b6fb07e01f8d53ba8f197e590334 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 13 Jul 2024 18:28:14 +0100 Subject: [PATCH 14/38] Greatly improve safe caching Compute code and argument fingerprints the right way --- dataset/mixins/__init__.py | 69 +++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index ca37e1b..eda36a4 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -127,15 +127,15 @@ def __init__( ) self._split = split self._lazy = scd_lazy # TODO: Implement eager caching (rn the default is lazy) - # TODO: Compute fingerprint of dataset parameters, data source and *most importantly* - # code implementation of _load method and every other user function called from it!!! - # (could we use git for that? Like just querying git diff on the dataset implementation - # file) - # If a fingerprint is found in self._cache_dir, compare it to the current fingerprint. If - # they differ, flush the cache and recompute. If not, load the dataset from cache. If no - # fingerprint is found, store the current fingerprint. - argnames = inspect.getfullargspec(SafeCacheDatasetMixin.__init__).args + argnames = inspect.getfullargspec(self.__class__.__init__).args + found = False frame = inspect.currentframe() + while not found: + frame = frame.f_back + found = ( + frame.f_code.co_qualname.strip() + == f"{self.__class__.__qualname__}.__init__".strip() + ) if frame is None: raise RuntimeError("Cannot compute fingerprint without a frame.") argvalues = { @@ -143,44 +143,44 @@ def __init__( for k, v in inspect.getargvalues(frame).locals.items() if k in argnames and k not in ["self", "tiny", "scd_lazy"] } - hasher = hashlib.new("md5") - # TODO: We should also hash the locals of the user's class __init__ method! - hasher.update(pickle.dumps(argvalues)) - # TODO: Make sure the comments of the user's methods are not included in the fingerprint, - # and make sure to recursively hash the source code of the user's methods. For the former, - # we could use the inspect module, for the latter we could use the ast module or a regex - # with inspect.getcodelines(). + # TODO: Recursively hash the source code for user's methods in self.__class__ # NOTE: getsource() won't work if I have a decorator that wraps the method. I think it's # best to keep this behaviour and not use decorators. - hasher.update(pickle.dumps(inspect.getsource(self.__class__._load))) # type: ignore - hasher.update(pickle.dumps(inspect.getsource(self.__class__._get_raw_elements))) # type: ignore - self.fingerprint = hasher.hexdigest() - mismatch, not_found = False, False - if osp.isfile(osp.join(self._cache_dir, "fingerprint")): - with open(osp.join(self._cache_dir, "fingerprint"), "r") as f: - cached_fingerprint = f.read() - if cached_fingerprint != self.fingerprint: - mismatch = True - else: - not_found = True + fingerprint_els = {"code": hashlib.new("md5"), "args": hashlib.new("md5")} + tree = ast.parse(inspect.getsource(self.__class__)) + fingerprint_els["code"].update(ast.dump(tree).encode()) + fingerprint_els["args"].update(pickle.dumps(argvalues)) + for k, v in fingerprint_els.items(): + fingerprint_els[k] = v.hexdigest() + mismatches, not_found = {k: True for k in fingerprint_els}, True + if osp.isfile(osp.join(self._cache_dir, "fingerprints")): + with open(osp.join(self._cache_dir, "fingerprints"), "r") as f: + not_found = False + cached_fingerprints = f.readlines() + for line in cached_fingerprints: + key, value = line.split(":") + mismatches[key] = value.strip() != fingerprint_els[key] + mismatch_list = [k for k, v in mismatches.items() if v] flush = False - if mismatch: + if not_found: + print("No fingerprint found, flushing cache.") + flush = True + elif mismatch_list != []: while flush not in ["y", "n"]: - flush = input("Fingerprint mismatch, flush cache? (y/n) ").lower() + flush = input( + f"Fingerprint mismatch in {' and '.join(mismatch_list)}, flush cache? (y/n) " + ).lower() flush = flush.lower().strip() == "y" if not flush: print( "[!] Warning: Fingerprint mismatch, but cache will not be flushed." ) - if not_found: - print("No fingerprint found, flushing cache.") - flush = True - if flush: shutil.rmtree(self._cache_dir, ignore_errors=True) os.makedirs(self._cache_dir, exist_ok=True) + super().__init__( dataset_root, dataset_name, @@ -195,8 +195,9 @@ def __init__( **kwargs, ) if flush or not_found: - with open(osp.join(self._cache_dir, "fingerprint"), "w") as f: - f.write(self.fingerprint) + with open(osp.join(self._cache_dir, "fingerprints"), "w") as f: + for k, v in fingerprint_els.items(): + f.write(f"{k}:{v}\n") def _get_raw_elements_hook(self, *args, **kwargs) -> Sequence[Any]: # TODO: Investigate slowness issues From 70f9401d187e0c92d5ebe8db0463b0943910efb9 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 13 Jul 2024 18:39:49 +0100 Subject: [PATCH 15/38] Fix potential bugs --- dataset/mixins/__init__.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index eda36a4..0fc9d9d 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -13,6 +13,7 @@ """ import abc +import ast import hashlib import inspect import itertools @@ -22,6 +23,7 @@ import shutil from multiprocessing.pool import Pool from os import cpu_count +from types import FrameType from typing import Any, Callable, List, Optional, Sequence, Tuple from hydra.utils import get_original_cwd @@ -129,15 +131,19 @@ def __init__( self._lazy = scd_lazy # TODO: Implement eager caching (rn the default is lazy) argnames = inspect.getfullargspec(self.__class__.__init__).args found = False - frame = inspect.currentframe() + frame: FrameType | None = inspect.currentframe() while not found: - frame = frame.f_back + frame = frame.f_back if frame is not None else None + if frame is None: + break found = ( frame.f_code.co_qualname.strip() == f"{self.__class__.__qualname__}.__init__".strip() ) if frame is None: - raise RuntimeError("Cannot compute fingerprint without a frame.") + raise RuntimeError( + f"Could not find frame for {self.__class__.__qualname__}.__init__" + ) argvalues = { k: v for k, v in inspect.getargvalues(frame).locals.items() @@ -151,7 +157,7 @@ def __init__( fingerprint_els["code"].update(ast.dump(tree).encode()) fingerprint_els["args"].update(pickle.dumps(argvalues)) for k, v in fingerprint_els.items(): - fingerprint_els[k] = v.hexdigest() + fingerprint_els[k] = v.hexdigest() # type: ignore mismatches, not_found = {k: True for k in fingerprint_els}, True if osp.isfile(osp.join(self._cache_dir, "fingerprints")): with open(osp.join(self._cache_dir, "fingerprints"), "r") as f: From 15dab5ee99e1afeebaacc20a247f3a43abeeaf03 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 16 Jul 2024 23:36:27 +0100 Subject: [PATCH 16/38] Implement a *very* basic GUI --- utils/gui.py | 205 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 utils/gui.py diff --git a/utils/gui.py b/utils/gui.py new file mode 100644 index 0000000..4e17c14 --- /dev/null +++ b/utils/gui.py @@ -0,0 +1,205 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 +# +# Copyright © 2024 Théo Morales +# +# Distributed under terms of the MIT license. + +""" +The fancy new GUI. +""" + +from collections import abc +from functools import partial +from time import sleep +from typing import Callable, Iterable, Iterator, Sequence + +from rich.layout import Layout +from rich.live import Live +from rich.panel import Panel +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskID, + TaskProgressColumn, + TextColumn, + TimeRemainingColumn, +) +from torch.utils.data.dataloader import DataLoader +from torchvision.datasets import MNIST +from torchvision.transforms.functional import to_tensor + + +class GUI: + def __init__(self) -> None: + self._layout = Layout() + self._layout.split( + Layout(name="header", size=2), + Layout(name="main", ratio=1), + Layout(name="footer", size=2), + ) + self._layout["main"].split_row( + Layout(name="body", ratio=3, minimum_size=60), + Layout(name="side"), + ) + self._live = Live(self._layout, screen=True) + self._console = self._live.console + self._pbar = Progress( + SpinnerColumn(spinner_name="monkey"), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(), + console=self._live.console, + # expand=True, + ) + self._main_progress = Panel( + self._pbar, + title="Training epoch ?/?", + expand=True, + ) + self._layout["footer"].update(self._pbar) + self._layout["header"].update(Panel("Stuff", title="Training run ...")) + self.tasks = { + "training": self._pbar.add_task("Training", visible=False), + "validation": self._pbar.add_task("Validation", visible=False), + "testing": self._pbar.add_task("Testing", visible=False), + } + + @property + def console(self): + return self._console + + def open(self) -> None: + self._live.__enter__() + + def close(self) -> None: + self._live.__exit__(None, None, None) + + def _track_iterable(self, iterable, task, total) -> Iterable: + class SeqWrapper(abc.Iterator): + def __init__( + self, + seq: Sequence, + len: int, + main_progress, + update_hook: Callable, + reset_hook: Callable, + ): + self._sequence = seq + self._idx = 0 + self._len = len + self.__main_progress = main_progress + self._update_hook = update_hook + self._reset_hook = reset_hook + + def __next__(self): + if self._idx >= self._len: + self._reset_hook() + raise StopIteration + item = self._sequence[self._idx] + self._update_hook() + # self.__main_progress.update(self.__pbar) + self._idx += 1 + return item + + class IteratorWrapper(abc.Iterator): + def __init__( + self, + iterator: Iterator | DataLoader, + len: int, + main_progress, + update_hook: Callable, + reset_hook: Callable, + ): + self._iterator = iter(iterator) + self._len = len + self.__main_progress = main_progress + self._update_hook = update_hook + self._reset_hook = reset_hook + + def __next__(self): + try: + item = next(self._iterator) + self._update_hook() + return item + except StopIteration: + self._reset_hook() + raise StopIteration + + def update_hook(task_id: TaskID): + self._pbar.advance(task_id) + + def reset_hook(task_id: TaskID, total: int): + self._pbar.reset(task_id, total=total, visible=False) + + wrapper = None + update_p, reset_p = ( + partial(update_hook, task), + partial(reset_hook, task, total), + ) + if isinstance(iterable, abc.Sequence): + wrapper = SeqWrapper( + iterable, + total, + self._main_progress, + update_p, + reset_p, + ) + elif isinstance(iterable, (abc.Iterator, DataLoader)): + wrapper = IteratorWrapper( + iterable, + total, + self._main_progress, + update_p, + reset_p, + ) + else: + raise ValueError( + f"iterable must be a Sequence or an Iterator, got {type(iterable)}" + ) + self._pbar.reset(task, total=total, visible=True) + return wrapper + + def track_training(self, iterable, description: str, total: int) -> Iterable: + task = self.tasks["training"] + return self._track_iterable(iterable, task, total) + + def track_validation(self, iterable, description: str, total: int) -> Iterable: + task = self.tasks["validation"] + return self._track_iterable(iterable, task, total) + + def print_footer(self, text: str): + self._layout["footer"].update(text) + + def print_header(self, text: str): + self._layout["header"].update(text) + + def print(self, text: str): + """ + Print text to the side panel. + """ + # NOTE: We could use a table to append messages in the renderable. I don't really know of + # another way to print stuff in a specific panel. + self._layout["side"].update(Panel(text, title="Logs")) + + +if __name__ == "__main__": + mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) + dataloader = DataLoader(mnist, 32, shuffle=True) + gui = GUI() + gui.open() + try: + for i, e in enumerate(gui.track_training(range(10), "Training", 10)): + gui.print(f"{i}/10") + sleep(0.1) + for i, e in enumerate( + gui.track_validation(dataloader, "Validation", len(dataloader)) + ): + gui.print(e) # TODO: Make this work! + gui.print(f"{i}/{len(dataloader)}") + sleep(0.01) + except Exception: + gui.close() + gui.close() From 8d40563a82933ba573e1002fcdebe475517547c4 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 23 Jul 2024 12:23:33 +0100 Subject: [PATCH 17/38] Implement brand new GUI! Create a Rich GUI for training and testing, and massively refactor all UI-related stuff (plotting, progress bars, etc.) --- .gitignore | 1 + bootstrap/factories.py | 3 +- bootstrap/launch_experiment.py | 88 ++++++---- dataset/example.py | 7 +- dataset/mixins/__init__.py | 1 + src/base_tester.py | 59 ++++--- src/base_trainer.py | 155 +++++++++--------- utils/__init__.py | 46 +++--- utils/gui.py | 290 ++++++++++++++++++++++++++++----- 9 files changed, 451 insertions(+), 199 deletions(-) diff --git a/.gitignore b/.gitignore index 5b79084..0ae14f3 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ models/ FIGURES/ .mypy_cache/ .ruff_cache/ +utils/data/MNIST/raw diff --git a/bootstrap/factories.py b/bootstrap/factories.py index 296757c..1afde4a 100644 --- a/bootstrap/factories.py +++ b/bootstrap/factories.py @@ -9,7 +9,6 @@ All factories. """ -from functools import partial from typing import Any, Dict, Optional, Tuple import torch @@ -38,7 +37,7 @@ def make_datasets( status = console.status("Loading dataset...", spinner="runner") progress = Progress(transient=True) with Live(Panel(Group(status, progress), title="Loading datasets")): - splits = ("train", "val") if training_mode else ("test") + splits = ("train", "val") if training_mode else ("test",) for split in splits: status.update(f"Loading {split} dataset...") job_id: TaskID = progress.add_task(f"Processing {split} split...") diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index 960c663..b476b59 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -8,6 +8,7 @@ import os from dataclasses import asdict +from time import sleep from typing import Any import hydra_zen @@ -35,6 +36,7 @@ from src.base_tester import BaseTester from src.base_trainer import BaseTrainer from utils import load_model_ckpt, to_cuda_ +from utils.gui import GUI console = Console() @@ -145,45 +147,63 @@ def launch_experiment( init_wandb(run_name, model_inst, exp_conf) """ ============ Training ============ """ + console.print( + "Launching GUI...", + style="bold cyan", + ) + sleep(1) + gui = GUI(run_name, project_conf.LOG_SCALE_PLOT) model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode) common_args = dict( run_name=run_name, model=model_inst, model_ckpt_path=model_ckpt_path, training_loss=training_loss_inst, + gui=gui, ) - if run.training_mode: - if training_loss_inst is None: - raise ValueError("training_loss must be defined in training mode!") - if val_loader_inst is None or train_loader_inst is None: - raise ValueError( - "val_loader and train_loader must be defined in training mode!" + gui.open() + try: + if run.training_mode: + gui.print("Training started!") + if training_loss_inst is None: + raise ValueError("training_loss must be defined in training mode!") + if val_loader_inst is None or train_loader_inst is None: + raise ValueError( + "val_loader and train_loader must be defined in training mode!" + ) + trainer( + train_loader=train_loader_inst, + val_loader=val_loader_inst, + opt=opt_inst, + scheduler=scheduler_inst, + **common_args, + **asdict( + run + ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) + ).train( + epochs=run.epochs, + val_every=run.val_every, + visualize_every=run.viz_every, + visualize_train_every=run.viz_train_every, + visualize_n_samples=run.viz_num_samples, ) - trainer( - train_loader=train_loader_inst, - val_loader=val_loader_inst, - opt=opt_inst, - scheduler=scheduler_inst, - **common_args, - **asdict( - run - ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) - ).train( - epochs=run.epochs, - val_every=run.val_every, - visualize_every=run.viz_every, - visualize_train_every=run.viz_train_every, - visualize_n_samples=run.viz_num_samples, - ) - else: - if test_loader_inst is None: - raise ValueError("test_loader must be defined in testing mode!") - tester( - data_loader=test_loader_inst, - **common_args, - ).test( - visualize_every=run.viz_every, - **asdict( - run - ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) - ) + gui.print("Training finished!") + else: + gui.print("Testing started!") + if test_loader_inst is None: + raise ValueError("test_loader must be defined in testing mode!") + tester( + data_loader=test_loader_inst, + **common_args, + ).test( + visualize_every=run.viz_every, + **asdict( + run + ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) + ) + gui.print("Testing finished!") + except Exception as e: + gui.close() + raise e + finally: + gui.close() diff --git a/dataset/example.py b/dataset/example.py index 9d0b363..1c8e622 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -61,10 +61,11 @@ def _load( progress: Progress, job_id: TaskID, ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: - progress.update(job_id, total=20) - for _ in range(20): + len = 3 if self._tiny else 20 + progress.update(job_id, total=len) + for _ in range(len): progress.advance(job_id) - sleep(0.1) + sleep(0.001 if self._tiny else 0.1) return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8) def __getitem__(self, index: int): diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index 0fc9d9d..7af001f 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -83,6 +83,7 @@ def __init__( self._normalize = normalize self._dataset_name = dataset_name self._debug = debug + self._tiny = tiny super().__init__( dataset_root, dataset_name, diff --git a/src/base_tester.py b/src/base_tester.py index c575ceb..fde5f59 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -13,26 +13,36 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +import torch +from rich.console import Console +from rich.text import Text from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader from torchmetrics import MeanMetric -from tqdm import tqdm from conf import project as project_conf from src.base_trainer import BaseTrainer -from utils import to_cuda, update_pbar_str +from utils import to_cuda +from utils.gui import GUI T = TypeVar("T") +console = Console() + +global print +print = console.print + + class BaseTester(BaseTrainer): def __init__( self, + gui: GUI, run_name: str, data_loader: DataLoader[T], model: Module, - model_ckpt_path: str, + model_ckpt_path: Optional[str] = None, training_loss: Optional[Module] = None, **kwargs: Optional[Dict[str, Any]], ) -> None: @@ -45,13 +55,19 @@ def __init__( """ _args = kwargs _loss = training_loss + self._gui = gui + global print + print = self._gui.print self._run_name = run_name self._model = model - assert model_ckpt_path is not None, "No model checkpoint path provided." - self._load_checkpoint(model_ckpt_path, model_only=True) + if model_ckpt_path is None: + print(Text("No model checkpoint path provided!", style="bold red")) + else: + print(Text("Loading model checkpoint...", style="bold cyan")) + self._load_checkpoint(model_ckpt_path, model_only=True) self._data_loader = data_loader self._running = True - self._pbar = tqdm(total=len(self._data_loader), desc="Testing") + # self._pbar = tqdm(total=len(self._data_loader), desc="Testing") signal.signal(signal.SIGINT, self._terminator) @to_cuda @@ -68,7 +84,7 @@ def _visualize( def _test_iteration( self, batch: Union[Tuple, List, Tensor], - ) -> Dict[str, Tensor]: + ) -> Tuple[Tensor, Dict[str, Tensor]]: """Evaluation procedure for one batch. We want to keep the code DRY and avoid making mistakes, so this code calls the BaseTrainer._train_val_iteration() method. Args: @@ -79,7 +95,7 @@ def _test_iteration( x, y = batch # type: ignore # noqa y_hat = self._model(x) # type: ignore # noqa # TODO: Compute your metrics here! - return {} + return torch.tensor(torch.inf), {} def test( self, visualize_every: int = 0, **kwargs: Optional[Dict[str, Any]] @@ -93,11 +109,14 @@ def test( test_loss: MeanMetric = MeanMetric() test_metrics: Dict[str, MeanMetric] = defaultdict(MeanMetric) self._model.eval() - self._pbar.reset() - self._pbar.set_description("Testing") + # self._pbar.reset() + # self._pbar.set_description("Testing") color_code = project_conf.ANSI_COLORS[project_conf.Theme.TESTING.value] """ ==================== Training loop for one epoch ==================== """ - for i, batch in enumerate(self._data_loader): + pbar, update_loss_hook = self._gui.track_testing( + self._data_loader, total=len(self._data_loader) + ) + for i, batch in enumerate(pbar): if not self._running: print("[!] Testing aborted.") break @@ -105,21 +124,23 @@ def test( test_loss.update(loss.item()) for k, v in metrics.items(): test_metrics[k].update(v.item()) - update_pbar_str( - self._pbar, - f"Testing [loss={test_loss.compute():.4f}]", - color_code, - ) + update_loss_hook(test_loss.compute()) + # update_pbar_str( + # self._pbar, + # f"Testing [loss={test_loss.compute():.4f}]", + # color_code, + # ) """ ==================== Visualization ==================== """ if visualize_every > 0 and (i + 1) % visualize_every == 0: self._visualize(batch, i) - self._pbar.update() - self._pbar.close() + # self._pbar.update() + # self._pbar.close() + # TODO: Report metrics in a special panel? Then hang the GUI until the user is done. print("=" * 81) print("==" + " " * 31 + " Test results " + " " * 31 + "==") print("=" * 81) for k, v in test_metrics.items(): print(f"\t -> {k}: {v.compute().item():.2f}") - print(f"\t -> Average loss: {test_loss:.4f}") + print(f"\t -> Average loss: {test_loss.compute():.4f}") print("_" * 81) diff --git a/src/base_trainer.py b/src/base_trainer.py index 92b3116..98fef27 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -15,12 +15,10 @@ from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union -import plotext as plt import torch import wandb from hydra.core.hydra_config import HydraConfig from rich.console import Console -from rich.progress import track from torch import Tensor from torch.nn import Module from torch.optim import Optimizer @@ -29,15 +27,20 @@ from conf import project as project_conf from utils import to_cuda +from utils.gui import GUI from utils.helpers import BestNModelSaver from utils.training import visualize_model_predictions console = Console() +global print +print = console.print + class BaseTrainer: def __init__( self, + gui: GUI, run_name: str, model: Module, opt: Optimizer, @@ -72,6 +75,9 @@ def __init__( self._training_loss = training_loss self._viz_n_samples = 1 self._n_ctrl_c = 0 + self._gui = gui + global print + print = self._gui.print if model_ckpt_path is not None: self._load_checkpoint(model_ckpt_path) signal.signal(signal.SIGINT, self._terminator) @@ -132,11 +138,11 @@ def _train_epoch( color_code = project_conf.ANSI_COLORS[project_conf.Theme.TRAINING.value] has_visualized = 0 """ ==================== Training loop for one epoch ==================== """ - for i, batch in track( - enumerate(self._train_loader), - description=description, + pbar, update_loss_hook = self._gui.track_training( + self._train_loader, total=len(self._train_loader), - ): + ) + for i, batch in enumerate(pbar): if ( not self._running and project_conf.SIGINT_BEHAVIOR @@ -153,6 +159,7 @@ def _train_epoch( epoch_loss.update(loss.item()) for k, v in loss_components.items(): epoch_loss_components[k].update(v.item()) + update_loss_hook(epoch_loss.compute()) # update_pbar_str( # self._pbar, # f"{description} [loss={epoch_loss.compute():.4f} /" @@ -194,11 +201,11 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: with torch.no_grad(): val_loss: MeanMetric = MeanMetric() val_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) - for i, batch in track( - enumerate(self._val_loader), - description=description, + pbar, update_loss_hook = self._gui.track_validation( + self._val_loader, total=len(self._val_loader), - ): + ) + for i, batch in enumerate(pbar): if ( not self._running and project_conf.SIGINT_BEHAVIOR @@ -215,6 +222,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: val_loss.update(loss.item()) for k, v in loss_components.items(): val_loss_components[k].update(v.item()) + update_loss_hook(val_loss.compute()) # update_pbar_str( # self._pbar, # f"{description} [loss={val_loss.compute():.4f} /" @@ -305,9 +313,10 @@ def train( if self._scheduler is not None: self._scheduler.step() """ ==================== Plotting ==================== """ - if project_conf.PLOT_ENABLED: - self._plot(epoch, train_losses, val_losses) - self._pbar.close() + self._gui.plot(epoch, train_losses, val_losses, self._model_saver) + # if project_conf.PLOT_ENABLED: + # self._plot(epoch, train_losses, val_losses) + # self._pbar.close() self._save_checkpoint( val_losses[-1], os.path.join(HydraConfig.get().runtime.output_dir, "last.ckpt"), @@ -318,67 +327,67 @@ def train( + f"at epoch {self._model_saver.min_val_loss_epoch}." ) - @staticmethod - def _setup_plot(run_name: str, log_scale: bool = False): - """Setup the plot for training and validation losses.""" - plt.title(f"Training curves for {run_name}") - plt.theme("dark") - plt.xlabel("Epoch") - if log_scale: - plt.ylabel("Loss (log scale)") - plt.yscale("log") - else: - plt.ylabel("Loss") - plt.grid(True, True) + # @staticmethod + # def _setup_plot(run_name: str, log_scale: bool = False): + # """Setup the plot for training and validation losses.""" + # plt.title(f"Training curves for {run_name}") + # plt.theme("dark") + # plt.xlabel("Epoch") + # if log_scale: + # plt.ylabel("Loss (log scale)") + # plt.yscale("log") + # else: + # plt.ylabel("Loss") + # plt.grid(True, True) - def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]): - """Plot the training and validation losses. - Args: - epoch (int): Current epoch number. - train_losses (List[float]): List of training losses. - val_losses (List[float]): List of validation losses. - Returns: - None - """ - plt.clf() - if project_conf.LOG_SCALE_PLOT and any( - loss_val <= 0 for loss_val in train_losses + val_losses - ): - raise ValueError( - "Cannot plot on a log scale if there are non-positive losses." - ) - self._setup_plot(self._run_name, log_scale=project_conf.LOG_SCALE_PLOT) - plt.plot( - list(range(self._starting_epoch, epoch + 1)), - train_losses, - color=project_conf.Theme.TRAINING.value, - label="Training loss", - ) - plt.plot( - list(range(self._starting_epoch, epoch + 1)), - val_losses, - color=project_conf.Theme.VALIDATION.value, - label="Validation loss", - ) - best_metrics = ( - "[" - + ", ".join( - [ - f"{metric_name}={metric_value:.2e} " - for metric_name, metric_value in self._model_saver.best_metrics.items() - ] - ) - + "]" - ) - plt.scatter( - [self._model_saver.min_val_loss_epoch], - [self._model_saver.min_val_loss], - color="red", - marker="+", - label=f"Best model {best_metrics}", - style="inverted", - ) - plt.show() + # def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]): + # """Plot the training and validation losses. + # Args: + # epoch (int): Current epoch number. + # train_losses (List[float]): List of training losses. + # val_losses (List[float]): List of validation losses. + # Returns: + # None + # """ + # plt.clf() + # if project_conf.LOG_SCALE_PLOT and any( + # loss_val <= 0 for loss_val in train_losses + val_losses + # ): + # raise ValueError( + # "Cannot plot on a log scale if there are non-positive losses." + # ) + # self._setup_plot(self._run_name, log_scale=project_conf.LOG_SCALE_PLOT) + # plt.plot( + # list(range(self._starting_epoch, epoch + 1)), + # train_losses, + # color=project_conf.Theme.TRAINING.value, + # label="Training loss", + # ) + # plt.plot( + # list(range(self._starting_epoch, epoch + 1)), + # val_losses, + # color=project_conf.Theme.VALIDATION.value, + # label="Validation loss", + # ) + # best_metrics = ( + # "[" + # + ", ".join( + # [ + # f"{metric_name}={metric_value:.2e} " + # for metric_name, metric_value in self._model_saver.best_metrics.items() + # ] + # ) + # + "]" + # ) + # plt.scatter( + # [self._model_saver.min_val_loss_epoch], + # [self._model_saver.min_val_loss], + # color="red", + # marker="+", + # label=f"Best model {best_metrics}", + # style="inverted", + # ) + # plt.show() def _save_checkpoint(self, val_loss: float, ckpt_path: str, **kwargs) -> None: """Saves the model and optimizer state to a checkpoint file. diff --git a/utils/__init__.py b/utils/__init__.py index 680ef44..a650dfa 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -99,29 +99,29 @@ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return wrapper -def blink_pbar(i: int, pbar: tqdm, n: int) -> None: - """Blink the progress bar every n iterations. - Args: - i (int): current iteration - pbar (tqdm): progress bar - n (int): blink every n iterations - """ - if i % n == 0: - pbar.colour = ( - project_conf.Theme.TRAINING.value - if pbar.colour == project_conf.Theme.VALIDATION.value - else project_conf.Theme.VALIDATION.value - ) - - -def update_pbar_str(pbar: tqdm, string: str, color_code: int) -> None: - """Update the progress bar string. - Args: - pbar (tqdm): progress bar - string (str): string to update the progress bar with - color_code (int): color code for the string - """ - pbar.set_description_str(colorize(string, color_code)) +# def blink_pbar(i: int, pbar: tqdm, n: int) -> None: +# """Blink the progress bar every n iterations. +# Args: +# i (int): current iteration +# pbar (tqdm): progress bar +# n (int): blink every n iterations +# """ +# if i % n == 0: +# pbar.colour = ( +# project_conf.Theme.TRAINING.value +# if pbar.colour == project_conf.Theme.VALIDATION.value +# else project_conf.Theme.VALIDATION.value +# ) +# +# +# def update_pbar_str(pbar: tqdm, string: str, color_code: int) -> None: +# """Update the progress bar string. +# Args: +# pbar (tqdm): progress bar +# string (str): string to update the progress bar with +# color_code (int): color code for the string +# """ +# pbar.set_description_str(colorize(string, color_code)) def get_function_frame(func, exc_traceback): diff --git a/utils/gui.py b/utils/gui.py index 4e17c14..0680c58 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -9,13 +9,30 @@ The fancy new GUI. """ -from collections import abc +import random +from collections import abc, namedtuple from functools import partial from time import sleep -from typing import Callable, Iterable, Iterator, Sequence +from typing import ( + Callable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + TypeVar, +) +import plotext as plt +import torch +from rich import box +from rich.ansi import AnsiDecoder +from rich.console import Group +from rich.jupyter import JupyterMixin from rich.layout import Layout from rich.live import Live +from rich.padding import Padding from rich.panel import Panel from rich.progress import ( BarColumn, @@ -26,13 +43,46 @@ TextColumn, TimeRemainingColumn, ) +from rich.style import Style +from rich.table import Table +from rich.text import Text +from torch import Tensor from torch.utils.data.dataloader import DataLoader from torchvision.datasets import MNIST from torchvision.transforms.functional import to_tensor +if __name__ != "__main__": + from conf import project as project_conf + from utils.helpers import BestNModelSaver +else: + project_conf = namedtuple("project_conf", "Theme")( + Theme=namedtuple("Theme", "TRAINING VALIDATION TESTING")( + TRAINING=namedtuple("value", "value")("blue"), + VALIDATION=namedtuple("value", "value")("green"), + TESTING=namedtuple("value", "value")("cyan"), + ) + ) + BestNModelSaver = TypeVar("BestNModelSaver") + + +class PlotextMixin(JupyterMixin): + def __init__(self, p_make_plot): + self.decoder = AnsiDecoder() + self.mk_plot = p_make_plot + + def __rich_console__(self, console, options): + self.width = options.max_width or console.width + self.height = options.height or console.height + canvas = self.mk_plot(width=self.width, height=self.height) + self.rich_canvas = Group(*self.decoder.decode(canvas)) + yield self.rich_canvas + class GUI: - def __init__(self) -> None: + def __init__(self, run_name: str, plot_log_scale: bool) -> None: + self._run_name = run_name + self._plot_log_scale = plot_log_scale + self._starting_epoch = 0 # TODO: self._layout = Layout() self._layout.split( Layout(name="header", size=2), @@ -43,28 +93,70 @@ def __init__(self) -> None: Layout(name="body", ratio=3, minimum_size=60), Layout(name="side"), ) + self._plot = Panel( + Padding( + Text( + "Waiting for training curves...", + justify="center", + style=Style(color="blue", bold=True), + ), + pad=30, + style="on black", + ), + title="Training curves", + expand=True, + ) + self._layout["body"].update(self._plot) self._live = Live(self._layout, screen=True) self._console = self._live.console self._pbar = Progress( SpinnerColumn(spinner_name="monkey"), - TextColumn("[progress.description]{task.description}"), + TextColumn( + "[progress.description]{task.description} \[loss={task.fields[loss]:.3f}]" + ), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), - console=self._live.console, # expand=True, ) - self._main_progress = Panel( + self._main_progress = Panel( # TODO: self._pbar, title="Training epoch ?/?", expand=True, ) self._layout["footer"].update(self._pbar) - self._layout["header"].update(Panel("Stuff", title="Training run ...")) + run_color = f"color({hash(run_name) % 255})" + background_color = f"color({(hash(run_name) + 128) % 255})" + self._layout["header"].update( + Text( + f"Running {run_name}", + style=f"bold {run_color} on {background_color}", + justify="center", + ) + ) + self._logger = Table.grid(padding=0) + self._logger.add_column(no_wrap=True) + self._layout["side"].update( + Panel( + self._logger, title="Logs", border_style="bright_red", box=box.ROUNDED + ) + ) self.tasks = { - "training": self._pbar.add_task("Training", visible=False), - "validation": self._pbar.add_task("Validation", visible=False), - "testing": self._pbar.add_task("Testing", visible=False), + "training": self._pbar.add_task( + f"[{project_conf.Theme.TRAINING.value}]Training", + visible=False, + loss=torch.inf, + ), + "validation": self._pbar.add_task( + f"[{project_conf.Theme.VALIDATION.value}]Validation", + visible=False, + loss=torch.inf, + ), + "testing": self._pbar.add_task( + f"[{project_conf.Theme.TESTING.value}]Testing", + visible=False, + loss=torch.inf, + ), } @property @@ -77,20 +169,26 @@ def open(self) -> None: def close(self) -> None: self._live.__exit__(None, None, None) - def _track_iterable(self, iterable, task, total) -> Iterable: - class SeqWrapper(abc.Iterator): + def _track_iterable(self, iterable, task, total) -> Tuple[Iterable, Callable]: + class LossHook: + def __init__(self): + self._loss = None + + def update_loss_hook(self, loss: float): + self._loss = loss + + class SeqWrapper(abc.Iterator, LossHook): def __init__( self, seq: Sequence, len: int, - main_progress, update_hook: Callable, reset_hook: Callable, ): + super().__init__() self._sequence = seq self._idx = 0 self._len = len - self.__main_progress = main_progress self._update_hook = update_hook self._reset_hook = reset_hook @@ -99,51 +197,54 @@ def __next__(self): self._reset_hook() raise StopIteration item = self._sequence[self._idx] - self._update_hook() - # self.__main_progress.update(self.__pbar) + self._update_hook(loss=self._loss) self._idx += 1 return item - class IteratorWrapper(abc.Iterator): + class IteratorWrapper(abc.Iterator, LossHook): def __init__( self, iterator: Iterator | DataLoader, len: int, - main_progress, update_hook: Callable, reset_hook: Callable, ): + super().__init__() self._iterator = iter(iterator) self._len = len - self.__main_progress = main_progress self._update_hook = update_hook self._reset_hook = reset_hook def __next__(self): try: item = next(self._iterator) - self._update_hook() + self._update_hook(loss=self._loss) return item except StopIteration: self._reset_hook() raise StopIteration - def update_hook(task_id: TaskID): + def update_hook(task_id: TaskID, loss: Optional[float] = None): self._pbar.advance(task_id) + if loss is not None: + self._pbar.tasks[task_id].fields["loss"] = loss + # TODO: Nice progress panel with overall progress and epoch progress + # self._main_progress = Panel(self._pbar, title="Training epoch ?/?") + # self._layout["footer"].update(self._main_progress) + # self._live.refresh() def reset_hook(task_id: TaskID, total: int): self._pbar.reset(task_id, total=total, visible=False) wrapper = None update_p, reset_p = ( - partial(update_hook, task), + partial(update_hook, task_id=task), partial(reset_hook, task, total), ) if isinstance(iterable, abc.Sequence): wrapper = SeqWrapper( iterable, total, - self._main_progress, update_p, reset_p, ) @@ -151,7 +252,6 @@ def reset_hook(task_id: TaskID, total: int): wrapper = IteratorWrapper( iterable, total, - self._main_progress, update_p, reset_p, ) @@ -160,46 +260,146 @@ def reset_hook(task_id: TaskID, total: int): f"iterable must be a Sequence or an Iterator, got {type(iterable)}" ) self._pbar.reset(task, total=total, visible=True) - return wrapper + return wrapper, wrapper.update_loss_hook - def track_training(self, iterable, description: str, total: int) -> Iterable: + def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: task = self.tasks["training"] return self._track_iterable(iterable, task, total) - def track_validation(self, iterable, description: str, total: int) -> Iterable: + def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: task = self.tasks["validation"] return self._track_iterable(iterable, task, total) - def print_footer(self, text: str): - self._layout["footer"].update(text) - def print_header(self, text: str): self._layout["header"].update(text) - def print(self, text: str): + def print(self, text: str | Tensor): """ Print text to the side panel. """ - # NOTE: We could use a table to append messages in the renderable. I don't really know of - # another way to print stuff in a specific panel. - self._layout["side"].update(Panel(text, title="Logs")) + if not isinstance(text, str): + raise NotImplementedError("Only text is supported for now.") + self._logger.add_row(text) + + def _make_plot( + self, + width, + height, + epoch: int, + train_losses: List[float], + val_losses: List[float], + model_saver: Optional[BestNModelSaver] = None, + ): + """Plot the training and validation losses. + Args: + epoch (int): Current epoch number. + train_losses (List[float]): List of training losses. + val_losses (List[float]): List of validation losses. + Returns: + None + """ + if self._plot_log_scale and any( + loss_val <= 0 for loss_val in train_losses + val_losses + ): + raise ValueError( + "Cannot plot on a log scale if there are non-positive losses." + ) + plt.clf() + plt.plotsize(width, height) + plt.title(f"Training curves for {self._run_name}") + plt.xlabel("Epoch") + plt.theme("dark") + if self._plot_log_scale: + plt.ylabel("Loss (log scale)") + plt.yscale("log") + else: + plt.ylabel("Loss") + plt.grid(True, True) + + plt.plot( + list(range(self._starting_epoch, epoch + 1)), + train_losses, + color=project_conf.Theme.TRAINING.value, + # color="blue", + label="Training loss", + ) + plt.plot( + list(range(self._starting_epoch, epoch + 1)), + val_losses, + color=project_conf.Theme.VALIDATION.value, + # color="green", + label="Validation loss", + ) + best_metrics = ( + "[" + + ", ".join( + [ + f"{metric_name}={metric_value:.2e} " + for metric_name, metric_value in model_saver.best_metrics.items() + ] + if model_saver is not None + else [] + ) + + "]" + ) + if model_saver is not None: + plt.scatter( + [model_saver.min_val_loss_epoch], + [model_saver.min_val_loss], + color="red", + marker="+", + label=f"Best model {best_metrics}", + style="inverted", + ) + return plt.build() + + def plot( + self, + epoch: int, + train_losses: List[float], + val_losses: List[float], + model_saver: Optional[BestNModelSaver] = None, + ) -> None: + mk_plot = partial( + self._make_plot, + epoch=epoch, + train_losses=train_losses, + val_losses=val_losses, + model_saver=model_saver, + ) + self._plot = Panel(PlotextMixin(mk_plot), title="Training curves") + self._layout["body"].update(self._plot) + self._live.refresh() if __name__ == "__main__": mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) dataloader = DataLoader(mnist, 32, shuffle=True) - gui = GUI() - gui.open() + gui = GUI("test-run", plot_log_scale=False) + gui.open() # TODO: Use a context manager, why not?? try: - for i, e in enumerate(gui.track_training(range(10), "Training", 10)): - gui.print(f"{i}/10") + gui.print("Hello, world!") + pbar, update_progress_loss = gui.track_training(range(10), 10) + for i, e in enumerate(pbar): + gui.print(f"[{i}/10]: We can iterate over iterables") sleep(0.1) - for i, e in enumerate( - gui.track_validation(dataloader, "Validation", len(dataloader)) - ): - gui.print(e) # TODO: Make this work! - gui.print(f"{i}/{len(dataloader)}") + train_losses, val_losses = [], [] + pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader)) + for i, e in enumerate(pbar): + # gui.print(e) # TODO: Make this work! + if i % 10 == 0: + train_losses.append(random.random()) + val_losses.append(random.random()) + update_progress_loss(random.random()) + gui.plot(epoch=i, train_losses=train_losses, val_losses=val_losses) + gui.print( + f"[{i}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!" + ) sleep(0.01) - except Exception: + gui.print("Goodbye, world!") + sleep(1) + except Exception as e: + gui.close() + raise e + finally: gui.close() - gui.close() From c39155399f66a0af58074cba3825816d7fa7c42b Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 23 Jul 2024 12:23:53 +0100 Subject: [PATCH 18/38] Add missing testing tracking --- utils/gui.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/utils/gui.py b/utils/gui.py index 0680c58..89a2ac0 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -270,14 +270,18 @@ def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: task = self.tasks["validation"] return self._track_iterable(iterable, task, total) + def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: + task = self.tasks["testing"] + return self._track_iterable(iterable, task, total) + def print_header(self, text: str): self._layout["header"].update(text) - def print(self, text: str | Tensor): + def print(self, text: str | Tensor | Text): """ Print text to the side panel. """ - if not isinstance(text, str): + if not isinstance(text, (str, Text)): raise NotImplementedError("Only text is supported for now.") self._logger.add_row(text) From edccd51d1822b889804887e01dd51bcf30c68984 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 23 Jul 2024 12:45:06 +0100 Subject: [PATCH 19/38] Improve the UI Add time to the logs --- bootstrap/factories.py | 2 +- utils/gui.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/bootstrap/factories.py b/bootstrap/factories.py index 1afde4a..9db036e 100644 --- a/bootstrap/factories.py +++ b/bootstrap/factories.py @@ -35,7 +35,7 @@ def make_datasets( "test": None, } status = console.status("Loading dataset...", spinner="runner") - progress = Progress(transient=True) + progress = Progress(transient=False) with Live(Panel(Group(status, progress), title="Loading datasets")): splits = ("train", "val") if training_mode else ("test",) for split in splits: diff --git a/utils/gui.py b/utils/gui.py index 89a2ac0..fca8965 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -11,6 +11,7 @@ import random from collections import abc, namedtuple +from datetime import datetime from functools import partial from time import sleep from typing import ( @@ -78,6 +79,8 @@ def __rich_console__(self, console, options): yield self.rich_canvas +# TODO: Make it a singleton so we can print from anywhere in the code, without passing a reference +# around. class GUI: def __init__(self, run_name: str, plot_log_scale: bool) -> None: self._run_name = run_name @@ -135,7 +138,7 @@ def __init__(self, run_name: str, plot_log_scale: bool) -> None: ) ) self._logger = Table.grid(padding=0) - self._logger.add_column(no_wrap=True) + self._logger.add_column(no_wrap=False) self._layout["side"].update( Panel( self._logger, title="Logs", border_style="bright_red", box=box.ROUNDED @@ -283,7 +286,10 @@ def print(self, text: str | Tensor | Text): """ if not isinstance(text, (str, Text)): raise NotImplementedError("Only text is supported for now.") - self._logger.add_row(text) + + self._logger.add_row( + Text(datetime.now().strftime("[%H:%M] "), style="dim cyan"), text + ) def _make_plot( self, @@ -383,6 +389,9 @@ def plot( gui.open() # TODO: Use a context manager, why not?? try: gui.print("Hello, world!") + gui.print( + "Veeeeeeeeeeeeeeeeeeeryyyyyyyyyyyyyyyy looooooooooooooooooooooooooooong seeeeeeeeeeeenteeeeeeeeeeeeeeennnnnnnnnce!!!!!!!!!!!!!!!!!" + ) pbar, update_progress_loss = gui.track_training(range(10), 10) for i, e in enumerate(pbar): gui.print(f"[{i}/10]: We can iterate over iterables") From 76796f54f22e44f24619e51f02cd926eaf215805 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 23 Jul 2024 12:45:19 +0100 Subject: [PATCH 20/38] Disable albumentations if not needed --- dataset/base/image.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/dataset/base/image.py b/dataset/base/image.py index 66e1e59..02aa7c4 100644 --- a/dataset/base/image.py +++ b/dataset/base/image.py @@ -61,19 +61,20 @@ def __init__( self._normalization: Callable[[Tensor], Tensor] = transforms.Normalize( self.IMAGE_NET_MEAN, self.IMAGE_NET_STD ) - try: - import albumentations as A # type: ignore - except ImportError: - raise ImportError( - "Please install albumentations to use the augmentation pipeline." + if self._augment: + try: + import albumentations as A # type: ignore + except ImportError: + raise ImportError( + "Please install albumentations to use the augmentation pipeline." + ) + self._augs: Callable[..., Dict[str, Any]] = A.Compose( + [ + A.RandomCropFromBorders(), + A.RandomBrightnessContrast(), + A.RandomGamma(), + ] ) - self._augs: Callable[..., Dict[str, Any]] = A.Compose( - [ - A.RandomCropFromBorders(), - A.RandomBrightnessContrast(), - A.RandomGamma(), - ] - ) def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: """ @@ -83,8 +84,6 @@ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: # ==== Load image and apply transforms === img: Tensor img = read_image(self._samples[index]) # type: ignore - if not isinstance(img, Tensor): - raise ValueError("Image not loaded as a Tensor.") img = self._transforms(img) if self._normalize: img = self._normalization(img) From b50592e42c59335aaac329cda9c2264ca3dc025e Mon Sep 17 00:00:00 2001 From: Theo Date: Thu, 25 Jul 2024 23:31:51 +0100 Subject: [PATCH 21/38] Improve GUI --- utils/gui.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/utils/gui.py b/utils/gui.py index fca8965..b118639 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -26,14 +26,15 @@ ) import plotext as plt +import rich import torch from rich import box +from rich.align import Align from rich.ansi import AnsiDecoder from rich.console import Group from rich.jupyter import JupyterMixin from rich.layout import Layout from rich.live import Live -from rich.padding import Padding from rich.panel import Panel from rich.progress import ( BarColumn, @@ -47,7 +48,6 @@ from rich.style import Style from rich.table import Table from rich.text import Text -from torch import Tensor from torch.utils.data.dataloader import DataLoader from torchvision.datasets import MNIST from torchvision.transforms.functional import to_tensor @@ -97,14 +97,13 @@ def __init__(self, run_name: str, plot_log_scale: bool) -> None: Layout(name="side"), ) self._plot = Panel( - Padding( + Align.center( Text( "Waiting for training curves...", justify="center", style=Style(color="blue", bold=True), ), - pad=30, - style="on black", + vertical="middle", ), title="Training curves", expand=True, @@ -280,16 +279,24 @@ def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: def print_header(self, text: str): self._layout["header"].update(text) - def print(self, text: str | Tensor | Text): + def print(self, text: str | Text): """ Print text to the side panel. """ - if not isinstance(text, (str, Text)): - raise NotImplementedError("Only text is supported for now.") - - self._logger.add_row( - Text(datetime.now().strftime("[%H:%M] "), style="dim cyan"), text - ) + # TODO: Use a fifo and pop the first item if we can't view the enw item. + try: + self._logger.add_row( + Text(datetime.now().strftime("[%H:%M] "), style="dim cyan"), text + ) + # TODO: Compute the max number of displayable rows. We have access to the height: + # height = self._console.size.height + # But what about the width of the column? Can we get that anywhere? + # Then, we need to compute the width of each row to see if it overflows and by how many + # lines. + # Finally, we'll have the percentage of height that our rows take, and if it's above + # 100%-threshold, we remove the first row. + except rich.errors.NotRenderableError as e: + self.print(Text("[Rich]: " + str(e), style="bold red")) def _make_plot( self, @@ -408,6 +415,8 @@ def plot( gui.print( f"[{i}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!" ) + if i == 0: + gui.print(e) sleep(0.01) gui.print("Goodbye, world!") sleep(1) From 68ce0b3a3e52cd9fef0fb467eeb11d82cecd56af Mon Sep 17 00:00:00 2001 From: Theo Date: Thu, 25 Jul 2024 23:32:06 +0100 Subject: [PATCH 22/38] Rebuild GUI with Textual :) --- utils/gui2.py | 248 ++++++++++++++++++++++++++++++++++++++++++++++++ utils/style.css | 11 +++ 2 files changed, 259 insertions(+) create mode 100644 utils/gui2.py create mode 100644 utils/style.css diff --git a/utils/gui2.py b/utils/gui2.py new file mode 100644 index 0000000..383c514 --- /dev/null +++ b/utils/gui2.py @@ -0,0 +1,248 @@ +# Let's use Textual to rewrite the GUI with better features. + +import asyncio +from datetime import datetime +from itertools import cycle +from random import random +from typing import ( + List, +) + +import numpy as np +import torch +from rich.console import Group, RenderableType +from rich.pretty import Pretty +from rich.text import Text +from textual.app import App, ComposeResult, RenderResult +from textual.reactive import var +from textual.widgets import ( + Footer, + Header, + Placeholder, + RichLog, + Static, +) +from textual_plotext import PlotextPlot + + +class PlotterWidget(PlotextPlot): + marker: var[str] = var("sd") + + """The type of marker to use for the plot.""" + + def __init__( + self, + title: str, + *, + name: str | None = None, + id: str | None = None, # pylint:disable=redefined-builtin + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialise the training curves plotter widget. + + Args: + name: The name of the plotter widget. + id: The ID of the plotter widget in the DOM. + classes: The CSS classes of the plotter widget. + disabled: Whether the plotter widget is disabled or not. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self._title = title + self._train_losses: list[float] = [] + self._val_losses: list[float] = [] + self._epoch = 0 + + def on_mount(self) -> None: + """Plot the data using Plotext.""" + self.plt.title(self._title) + self.plt.xlabel("Epoch") + self.plt.ylabel("Loss") # TODO: update + self.plt.grid(True, True) + + def replot(self) -> None: + """Redraw the plot.""" + self.plt.clear_data() + # self.plt.plot(self._time, self._data, marker=self.marker) + if len(self._train_losses) > 0: + assert (self._epoch + 1) == len(self._train_losses) + assert len(self._val_losses) == len(self._train_losses) + self.plt.plot( + list(range(0, self._epoch + 1)), # TODO: start epoch + self._train_losses, + # color=project_conf.Theme.TRAINING.value, # TODO: + color="blue", + label="Training loss", + marker=self.marker, + ) + self.plt.plot( + list(range(0, self._epoch + 1)), # TODO: start epoch + self._val_losses, + # color=project_conf.Theme.VALIDATION.value, # TODO: + color="green", + label="Validation loss", + marker=self.marker, + ) + self.refresh() + + def update( + self, epoch: int, train_losses: List[float], val_losses: List[float] + ) -> None: + """Update the data for the training curves plot. + + Args: + epoch: (int) The current epoch number. + train_losses: (List[float]) The list of training losses. + val_losses: (List[float]) The list of validation losses. + """ + # TODO: We only need to append to the losses. Do we need to keep track of them in the + # trianing loop? If not we should only keep track of the last one and let the GUI keep + # track of them all. + self._epoch = epoch + self._train_losses = train_losses + self._val_losses = val_losses + self.replot() + + def _watch_marker(self) -> None: + """React to the marker being changed.""" + self.replot() + + +# TODO: Also make a Rich renderable for Tensors + + +class TensorWidget(Static): # TODO: PRETTYER + def __init__(self, tensor: torch.Tensor): + self.tensor = tensor + + def render(self) -> RenderResult: + return Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Pretty(self.tensor), + ) + + +class GUI(App): + """A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework.""" + + CSS_PATH = "style.css" + + BINDINGS = [ + ("q", "quit", "Quit"), + ("d", "toggle_dark", "Toggle dark mode"), + ("m", "marker", "Cycle example markers"), + ("ctrl+z", "suspend_progress"), + ] + + MARKERS = { + "dot": "Dot", + "hd": "High Definition", + "fhd": "Higher Definition", + "braille": "Braille", + "sd": "Standard Definition", + } + + marker: var[str] = var("sd") + + def __init__(self) -> None: + """Initialise the application.""" + super().__init__() + self._markers = cycle(self.MARKERS.keys()) + + def compose(self) -> ComposeResult: + yield Header() + yield PlotterWidget(title="Trainign curves for run-name", classes="box") + yield RichLog( + highlight=True, markup=True, wrap=True, id="logger", classes="box" + ) + yield Placeholder(classes="box") + yield Placeholder(classes="box") + yield Footer() + + def action_toggle_dark(self) -> None: + self.dark = not self.dark + + def watch_marker(self) -> None: + """React to the marker type being changed.""" + self.sub_title = self.MARKERS[self.marker] + self.query_one(PlotterWidget).marker = self.marker + + def action_marker(self) -> None: + """Cycle to the next marker type.""" + self.marker = next(self._markers) + + def on_key(self, event) -> None: + logger: RichLog = self.query_one(RichLog) + logger.write( + Group( + Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""), + f"Key pressed: {event.key!r}", + ), + ) + if event.key == "t": + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + ) + ) + elif event.key == "p": + self.query_one(PlotterWidget).update( + epoch=9, + train_losses=[random() for _ in range(10)], + val_losses=[random() for _ in range(10)], + ) + + def print(self, message: RenderableType | str | torch.Tensor | np.ndarray): + logger: RichLog = self.query_one(RichLog) + if isinstance(message, (str, RenderableType)): + logger.write( + Group( + Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""), + message, + ), + ) + elif isinstance(message, (torch.Tensor, np.ndarray)): + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Pretty(message), + ) + ) + + +async def run_my_app(): + gui = GUI() + task = asyncio.create_task(gui.run_async()) + await asyncio.sleep(1) # Wait for the app to start up + gui.print("Hello, World!") + await asyncio.sleep(2) + gui.print(Text("Let's log some tensors :)", style="bold magenta")) + await asyncio.sleep(0.5) + gui.print(torch.rand(2, 4)) + await asyncio.sleep(2) + gui.print(Text("How about some numpy arrays?!", style="italic green")) + await asyncio.sleep(1) + gui.print(np.random.rand(3, 3)) + await asyncio.sleep(3) + gui.print("...") + await asyncio.sleep(3) + gui.print("Go on... Press 'p'! I know you want to!") + await asyncio.sleep(4) + gui.print("COME ON PRESS P!!!!") + _ = await task + + +if __name__ == "__main__": + asyncio.run(run_my_app()) diff --git a/utils/style.css b/utils/style.css new file mode 100644 index 0000000..1b57eb0 --- /dev/null +++ b/utils/style.css @@ -0,0 +1,11 @@ +Screen { + layout: grid; + grid-size: 2; + grid-columns: 3fr 1fr; + grid-rows: 95% 5%; +} + +.box { + height: 100%; + border: solid green; +} From aac869ba002f76b910964183984060f8782c7ec3 Mon Sep 17 00:00:00 2001 From: Theo Date: Fri, 26 Jul 2024 23:37:40 +0100 Subject: [PATCH 23/38] [V2] Implement progress bar and improve logger --- utils/gui2.py | 273 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 246 insertions(+), 27 deletions(-) diff --git a/utils/gui2.py b/utils/gui2.py index 383c514..395a988 100644 --- a/utils/gui2.py +++ b/utils/gui2.py @@ -1,11 +1,21 @@ # Let's use Textual to rewrite the GUI with better features. import asyncio +from collections import abc from datetime import datetime +from enum import Enum +from functools import partial from itertools import cycle -from random import random +from random import randint, random +from time import sleep from typing import ( + Callable, + Iterable, + Iterator, List, + Optional, + Sequence, + Tuple, ) import numpy as np @@ -14,15 +24,21 @@ from rich.pretty import Pretty from rich.text import Text from textual.app import App, ComposeResult, RenderResult +from textual.containers import Horizontal from textual.reactive import var from textual.widgets import ( Footer, Header, + Label, Placeholder, + ProgressBar, RichLog, Static, ) from textual_plotext import PlotextPlot +from torch.utils.data.dataloader import DataLoader +from torchvision.datasets import MNIST +from torchvision.transforms.functional import to_tensor class PlotterWidget(PlotextPlot): @@ -126,6 +142,131 @@ def render(self) -> RenderResult: ) +class Task(Enum): + TRAINING = 0 + VALIDATION = 1 + TESTING = 2 + + +class DatasetProgressBar(Static): + """A progress bar for PyTorch dataloader iteration.""" + + DESCRIPTIONS = { + Task.TRAINING: Text("Training: ", style="bold blue"), + Task.VALIDATION: Text("Validation: ", style="bold green"), + Task.TESTING: Text("Testing: ", style="bold yellow"), + } + + # def __init__(self): + # self.description = None + # self.total = None + # self.progress = 0 + + def compose(self) -> ComposeResult: + with Horizontal(): + yield Label("Waiting...", id="progress_label") + yield ProgressBar() + + def track_iterable( + self, + iterable: Iterable | Sequence | Iterator | DataLoader, + task: Task, + total: int, + ) -> Tuple[Iterable, Callable]: + class LossHook: + def __init__(self): + self._loss = None + + def update_loss_hook(self, loss: float): + self._loss = loss + + class SeqWrapper(abc.Iterator, LossHook): + def __init__( + self, + seq: Sequence, + len: int, + update_hook: Callable, + reset_hook: Callable, + ): + super().__init__() + self._sequence = seq + self._idx = 0 + self._len = len + self._update_hook = update_hook + self._reset_hook = reset_hook + + def __next__(self): + if self._idx >= self._len: + self._reset_hook() + raise StopIteration + item = self._sequence[self._idx] + self._update_hook(loss=self._loss) + self._idx += 1 + return item + + class IteratorWrapper(abc.Iterator, LossHook): + def __init__( + self, + iterator: Iterator | DataLoader, + len: int, + update_hook: Callable, + reset_hook: Callable, + ): + super().__init__() + self._iterator = iter(iterator) + self._len = len + self._update_hook = update_hook + self._reset_hook = reset_hook + + def __next__(self): + try: + item = next(self._iterator) + self._update_hook(loss=self._loss) + return item + except StopIteration: + self._reset_hook() + raise StopIteration + + def update_hook(loss: Optional[float] = None): + self.query_one(ProgressBar).advance() + if loss is not None: + self.query_one("#progress_label").update( + self.DESCRIPTIONS[task] + f"[loss={loss:.4f}]" + ) + + def reset_hook(total: int): + sleep(0.5) + self.query_one(ProgressBar).update(total=100, progress=0) + self.query_one("#progress_label").update("Waiting...") + + wrapper = None + update_p, reset_p = ( + partial(update_hook), + partial(reset_hook, total), + ) + if isinstance(iterable, abc.Sequence): + wrapper = SeqWrapper( + iterable, + total, + update_p, + reset_p, + ) + elif isinstance(iterable, (abc.Iterator, DataLoader)): + wrapper = IteratorWrapper( + iterable, + total, + update_p, + reset_p, + ) + else: + raise ValueError( + f"iterable must be a Sequence or an Iterator, got {type(iterable)}" + ) + self.query_one(ProgressBar).update(total=total, progress=0) + self.query_one("#progress_label").update(self.DESCRIPTIONS[task]) + return wrapper, wrapper.update_loss_hook + + class GUI(App): """A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework.""" @@ -159,7 +300,8 @@ def compose(self) -> ComposeResult: yield RichLog( highlight=True, markup=True, wrap=True, id="logger", classes="box" ) - yield Placeholder(classes="box") + # yield Placeholder(classes="box") + yield DatasetProgressBar() yield Placeholder(classes="box") yield Footer() @@ -191,6 +333,7 @@ def on_key(self, event) -> None: style="dim cyan", end="", ), + Pretty(torch.rand(randint(1, 12), randint(1, 12))), ) ) elif event.key == "p": @@ -202,14 +345,7 @@ def on_key(self, event) -> None: def print(self, message: RenderableType | str | torch.Tensor | np.ndarray): logger: RichLog = self.query_one(RichLog) - if isinstance(message, (str, RenderableType)): - logger.write( - Group( - Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""), - message, - ), - ) - elif isinstance(message, (torch.Tensor, np.ndarray)): + if isinstance(message, (RenderableType, str)): logger.write( Group( Text( @@ -217,30 +353,113 @@ def print(self, message: RenderableType | str | torch.Tensor | np.ndarray): style="dim cyan", end="", ), - Pretty(message), - ) + message, + ), ) + else: + ppable, pp_msg = True, None + try: + pp_msg = Pretty(message) + except Exception: + ppable = False + if ppable and pp_msg is not None: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text(str(type(message)) + " ", style="italic blue", end=""), + pp_msg, + ) + ) + else: + try: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + message, + ), + ) + except Exception as e: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text("Logging error: ", style="bold red"), + Text(str(e), style="bold red"), + ) + ) + + def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: + return self.query_one(DatasetProgressBar).track_iterable( + iterable, Task.TRAINING, total + ) + + def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: + return self.query_one(DatasetProgressBar).track_iterable( + iterable, Task.VALIDATION, total + ) + + def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: + return self.query_one(DatasetProgressBar).track_iterable( + iterable, Task.TESTING, total + ) async def run_my_app(): gui = GUI() task = asyncio.create_task(gui.run_async()) - await asyncio.sleep(1) # Wait for the app to start up + await asyncio.sleep(0.1) # Wait for the app to start up gui.print("Hello, World!") - await asyncio.sleep(2) - gui.print(Text("Let's log some tensors :)", style="bold magenta")) - await asyncio.sleep(0.5) - gui.print(torch.rand(2, 4)) - await asyncio.sleep(2) - gui.print(Text("How about some numpy arrays?!", style="italic green")) - await asyncio.sleep(1) - gui.print(np.random.rand(3, 3)) - await asyncio.sleep(3) - gui.print("...") - await asyncio.sleep(3) - gui.print("Go on... Press 'p'! I know you want to!") - await asyncio.sleep(4) - gui.print("COME ON PRESS P!!!!") + # await asyncio.sleep(2) + # gui.print(Text("Let's log some tensors :)", style="bold magenta")) + # await asyncio.sleep(0.5) + # gui.print(torch.rand(2, 4)) + # await asyncio.sleep(2) + # gui.print(Text("How about some numpy arrays?!", style="italic green")) + # await asyncio.sleep(1) + # gui.print(np.random.rand(3, 3)) + # await asyncio.sleep(3) + # gui.print("...") + # await asyncio.sleep(3) + # gui.print("Go on... Press 'p'! I know you want to!") + # await asyncio.sleep(4) + # gui.print("COME ON PRESS P!!!!") + # await asyncio.sleep(1) + pbar, update_progress_loss = gui.track_training(range(10), 10) + for i, e in enumerate(pbar): + gui.print(f"[{i+1}/10]: We can iterate over iterables") + gui.print(e) + # sleep(0.1) + await asyncio.sleep(0.1) + await asyncio.sleep(5) + mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) + dataloader = DataLoader(mnist, 32, shuffle=True) + train_losses, val_losses = [], [] + pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader)) + for i, batch in enumerate(pbar): + await asyncio.sleep(0.01) + gui.print(batch) # TODO: Make this work! + if i % 10 == 0: + train_losses.append(random()) + val_losses.append(random()) + update_progress_loss(random()) + # gui.plot(epoch=i, train_losses=train_losses, val_losses=val_losses) + gui.print( + f"[{i+1}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!" + ) + if i == 0: + gui.print(e) + gui.print("Goodbye, world!") _ = await task From 109ee867408bf2dd47254839a2309032159b5ffe Mon Sep 17 00:00:00 2001 From: Theo Date: Fri, 26 Jul 2024 23:47:07 +0100 Subject: [PATCH 24/38] [V2] Add plot() to GUI API --- utils/gui2.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/utils/gui2.py b/utils/gui2.py index 395a988..7058299 100644 --- a/utils/gui2.py +++ b/utils/gui2.py @@ -79,9 +79,7 @@ def on_mount(self) -> None: def replot(self) -> None: """Redraw the plot.""" self.plt.clear_data() - # self.plt.plot(self._time, self._data, marker=self.marker) if len(self._train_losses) > 0: - assert (self._epoch + 1) == len(self._train_losses) assert len(self._val_losses) == len(self._train_losses) self.plt.plot( list(range(0, self._epoch + 1)), # TODO: start epoch @@ -143,6 +141,7 @@ def render(self) -> RenderResult: class Task(Enum): + IDLE = -1 TRAINING = 0 VALIDATION = 1 TESTING = 2 @@ -152,6 +151,7 @@ class DatasetProgressBar(Static): """A progress bar for PyTorch dataloader iteration.""" DESCRIPTIONS = { + Task.IDLE: Text("Waiting for work..."), Task.TRAINING: Text("Training: ", style="bold blue"), Task.VALIDATION: Text("Validation: ", style="bold green"), Task.TESTING: Text("Testing: ", style="bold yellow"), @@ -164,7 +164,7 @@ class DatasetProgressBar(Static): def compose(self) -> ComposeResult: with Horizontal(): - yield Label("Waiting...", id="progress_label") + yield Label(self.DESCRIPTIONS[Task.IDLE], id="progress_label") yield ProgressBar() def track_iterable( @@ -237,7 +237,7 @@ def update_hook(loss: Optional[float] = None): def reset_hook(total: int): sleep(0.5) self.query_one(ProgressBar).update(total=100, progress=0) - self.query_one("#progress_label").update("Waiting...") + self.query_one("#progress_label").update(self.DESCRIPTIONS[Task.IDLE]) wrapper = None update_p, reset_p = ( @@ -414,6 +414,9 @@ def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: iterable, Task.TESTING, total ) + def plot(self, epoch: int, train_losses: List[float], val_losses: List[float]): + self.query_one(PlotterWidget).update(epoch, train_losses, val_losses) + async def run_my_app(): gui = GUI() @@ -435,30 +438,30 @@ async def run_my_app(): # await asyncio.sleep(4) # gui.print("COME ON PRESS P!!!!") # await asyncio.sleep(1) - pbar, update_progress_loss = gui.track_training(range(10), 10) - for i, e in enumerate(pbar): - gui.print(f"[{i+1}/10]: We can iterate over iterables") - gui.print(e) - # sleep(0.1) - await asyncio.sleep(0.1) - await asyncio.sleep(5) + # pbar, update_progress_loss = gui.track_training(range(10), 10) + # for i, e in enumerate(pbar): + # gui.print(f"[{i+1}/10]: We can iterate over iterables") + # gui.print(e) + # # sleep(0.1) + # await asyncio.sleep(0.1) + # await asyncio.sleep(5) mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) dataloader = DataLoader(mnist, 32, shuffle=True) train_losses, val_losses = [], [] pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader)) for i, batch in enumerate(pbar): - await asyncio.sleep(0.01) - gui.print(batch) # TODO: Make this work! if i % 10 == 0: + await asyncio.sleep(0.01) + gui.print(batch) train_losses.append(random()) val_losses.append(random()) update_progress_loss(random()) - # gui.plot(epoch=i, train_losses=train_losses, val_losses=val_losses) + gui.plot(epoch=i, train_losses=train_losses, val_losses=val_losses) gui.print( f"[{i+1}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!" ) if i == 0: - gui.print(e) + gui.print(batch) gui.print("Goodbye, world!") _ = await task From 3dd8f71d3c89239efa2bf6247a0b9649bc9a2fb1 Mon Sep 17 00:00:00 2001 From: Theo Date: Fri, 26 Jul 2024 23:58:25 +0100 Subject: [PATCH 25/38] Center the bar --- utils/gui2.py | 5 +++-- utils/style.css | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/utils/gui2.py b/utils/gui2.py index 7058299..b7d2637 100644 --- a/utils/gui2.py +++ b/utils/gui2.py @@ -24,7 +24,7 @@ from rich.pretty import Pretty from rich.text import Text from textual.app import App, ComposeResult, RenderResult -from textual.containers import Horizontal +from textual.containers import Center from textual.reactive import var from textual.widgets import ( Footer, @@ -163,7 +163,8 @@ class DatasetProgressBar(Static): # self.progress = 0 def compose(self) -> ComposeResult: - with Horizontal(): + # with Horizontal(): + with Center(): yield Label(self.DESCRIPTIONS[Task.IDLE], id="progress_label") yield ProgressBar() diff --git a/utils/style.css b/utils/style.css index 1b57eb0..283b8a2 100644 --- a/utils/style.css +++ b/utils/style.css @@ -9,3 +9,14 @@ Screen { height: 100%; border: solid green; } + + +Center { + margin-top: 1; + margin-bottom: 1; + layout: horizontal; +} + +ProgressBar { + padding-left: 3; +} From bc04b441a28de23357d24a93784b6f6900e1c751 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 14:00:02 +0100 Subject: [PATCH 26/38] Ship the new GUI! - Integrate the Textual app into the framework - Add fancy details to the UI - Improve the API of the GUI - Refactor a little --- bootstrap/launch_experiment.py | 43 ++- src/base_tester.py | 17 +- src/base_trainer.py | 138 ++----- utils/gui.py | 657 +++++++++++++++++---------------- utils/gui2.py | 471 ----------------------- 5 files changed, 404 insertions(+), 922 deletions(-) delete mode 100644 utils/gui2.py diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index b476b59..96e5311 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -6,6 +6,7 @@ # Distributed under terms of the MIT license. +import asyncio import os from dataclasses import asdict from time import sleep @@ -13,6 +14,7 @@ import hydra_zen import torch +import torch.multiprocessing as mp import wandb import yaml from hydra.core.hydra_config import HydraConfig @@ -136,6 +138,10 @@ def launch_experiment( scheduler_inst = make_scheduler(scheduler, opt_inst, run.epochs) model_inst = to_cuda_(parallelize_model(model_inst)) training_loss_inst = to_cuda_(make_training_loss(run.training_mode, training_loss)) + + # Somehow, the dataloader will crash if it's not forked when using multiprocessing along with + # Textual. + mp.set_start_method("fork") train_loader_inst, val_loader_inst, test_loader_inst = make_dataloaders( data_loader, train_dataset, @@ -152,17 +158,20 @@ def launch_experiment( style="bold cyan", ) sleep(1) - gui = GUI(run_name, project_conf.LOG_SCALE_PLOT) - model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode) - common_args = dict( - run_name=run_name, - model=model_inst, - model_ckpt_path=model_ckpt_path, - training_loss=training_loss_inst, - gui=gui, - ) - gui.open() - try: + + async def launch_with_async_gui(): + gui = GUI(run_name, project_conf.LOG_SCALE_PLOT) + task = asyncio.create_task(gui.run_async()) + while not gui.is_running: + await asyncio.sleep(0.01) # Wait for the app to start up + model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode) + common_args = dict( + run_name=run_name, + model=model_inst, + model_ckpt_path=model_ckpt_path, + training_loss=training_loss_inst, + gui=gui, + ) if run.training_mode: gui.print("Training started!") if training_loss_inst is None: @@ -171,7 +180,7 @@ def launch_experiment( raise ValueError( "val_loader and train_loader must be defined in training mode!" ) - trainer( + await trainer( train_loader=train_loader_inst, val_loader=val_loader_inst, opt=opt_inst, @@ -192,7 +201,7 @@ def launch_experiment( gui.print("Testing started!") if test_loader_inst is None: raise ValueError("test_loader must be defined in testing mode!") - tester( + await tester( data_loader=test_loader_inst, **common_args, ).test( @@ -202,8 +211,6 @@ def launch_experiment( ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) ) gui.print("Testing finished!") - except Exception as e: - gui.close() - raise e - finally: - gui.close() + _ = await task + + asyncio.run(launch_with_async_gui()) diff --git a/src/base_tester.py b/src/base_tester.py index fde5f59..de87233 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -9,6 +9,7 @@ Base tester class. """ +import asyncio import signal from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union @@ -21,7 +22,6 @@ from torch.utils.data import DataLoader from torchmetrics import MeanMetric -from conf import project as project_conf from src.base_trainer import BaseTrainer from utils import to_cuda from utils.gui import GUI @@ -97,7 +97,7 @@ def _test_iteration( # TODO: Compute your metrics here! return torch.tensor(torch.inf), {} - def test( + async def test( self, visualize_every: int = 0, **kwargs: Optional[Dict[str, Any]] ) -> None: """Computes the average loss on the test set. @@ -109,9 +109,7 @@ def test( test_loss: MeanMetric = MeanMetric() test_metrics: Dict[str, MeanMetric] = defaultdict(MeanMetric) self._model.eval() - # self._pbar.reset() - # self._pbar.set_description("Testing") - color_code = project_conf.ANSI_COLORS[project_conf.Theme.TESTING.value] + print(Text(f"[*] Testing {self._run_name}", style="bold green")) """ ==================== Training loop for one epoch ==================== """ pbar, update_loss_hook = self._gui.track_testing( self._data_loader, total=len(self._data_loader) @@ -120,22 +118,15 @@ def test( if not self._running: print("[!] Testing aborted.") break - loss, metrics = self._test_iteration(batch) + loss, metrics = await asyncio.to_thread(self._test_iteration, batch) test_loss.update(loss.item()) for k, v in metrics.items(): test_metrics[k].update(v.item()) update_loss_hook(test_loss.compute()) - # update_pbar_str( - # self._pbar, - # f"Testing [loss={test_loss.compute():.4f}]", - # color_code, - # ) """ ==================== Visualization ==================== """ if visualize_every > 0 and (i + 1) % visualize_every == 0: self._visualize(batch, i) - # self._pbar.update() - # self._pbar.close() # TODO: Report metrics in a special panel? Then hang the GUI until the user is done. print("=" * 81) print("==" + " " * 31 + " Test results " + " " * 31 + "==") diff --git a/src/base_trainer.py b/src/base_trainer.py index 98fef27..c1ddb43 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -9,6 +9,7 @@ Base trainer class. """ +import asyncio import os import random import signal @@ -19,6 +20,7 @@ import wandb from hydra.core.hydra_config import HydraConfig from rich.console import Console +from rich.text import Text from torch import Tensor from torch.nn import Module from torch.optim import Optimizer @@ -135,7 +137,7 @@ def _train_epoch( """ epoch_loss: MeanMetric = MeanMetric() epoch_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) - color_code = project_conf.ANSI_COLORS[project_conf.Theme.TRAINING.value] + # color_code = project_conf.ANSI_COLORS[project_conf.Theme.TRAINING.value] has_visualized = 0 """ ==================== Training loop for one epoch ==================== """ pbar, update_loss_hook = self._gui.track_training( @@ -160,12 +162,6 @@ def _train_epoch( for k, v in loss_components.items(): epoch_loss_components[k].update(v.item()) update_loss_hook(epoch_loss.compute()) - # update_pbar_str( - # self._pbar, - # f"{description} [loss={epoch_loss.compute():.4f} /" - # + f" val_loss={last_val_loss:.4f}]", - # color_code, - # ) if ( visualize and has_visualized < self._viz_n_samples @@ -174,7 +170,6 @@ def _train_epoch( with torch.no_grad(): self._visualize(batch, epoch) has_visualized += 1 - # self._pbar.update() mean_epoch_loss: float = epoch_loss.compute().item() if project_conf.USE_WANDB: wandb.log({"train_loss": mean_epoch_loss}, step=epoch) @@ -196,7 +191,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: float: Average validation loss for the epoch. """ has_visualized = 0 - color_code = project_conf.ANSI_COLORS[project_conf.Theme.VALIDATION.value] + # color_code = project_conf.ANSI_COLORS[project_conf.Theme.VALIDATION.value] """ ==================== Validation loop for one epoch ==================== """ with torch.no_grad(): val_loss: MeanMetric = MeanMetric() @@ -213,8 +208,6 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: ): print("[!] Training aborted.") break - # Blink the progress bar to indicate that the validation loop is running - # blink_pbar(i, self._pbar, 4) loss, loss_components = self._train_val_iteration( batch, epoch, @@ -223,12 +216,6 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: for k, v in loss_components.items(): val_loss_components[k].update(v.item()) update_loss_hook(val_loss.compute()) - # update_pbar_str( - # self._pbar, - # f"{description} [loss={val_loss.compute():.4f} /" - # + f" min_val_loss={self._model_saver.min_val_loss:.4f}]", - # color_code, - # ) """ ==================== Visualization ==================== """ if ( visualize @@ -262,7 +249,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: ) return mean_val_loss - def train( + async def train( self, epochs: int = 10, val_every: int = 1, # Validate every n epochs @@ -278,47 +265,44 @@ def train( Returns: None """ - # console.print(f"[*] Training {self._run_name} for {epochs} epochs", style="bold green") + print( + Text( + f"[*] Training {self._run_name} for {epochs} epochs", style="bold green" + ) + ) self._viz_n_samples = visualize_n_samples - train_losses: List[float] = [] - val_losses: List[float] = [] + self._gui.set_start_epoch(self._epoch) """ ==================== Training loop ==================== """ + last_val_loss = float("inf") for epoch in range(self._epoch, epochs): + print(f"Epoch: {epoch}") self._epoch = epoch # Update for the model saver if not self._running: break self._model.train() - # self._pbar.colour = project_conf.Theme.TRAINING.value - train_losses.append( - self._train_epoch( - f"Epoch {epoch}/{epochs}: Training", - visualize_train_every > 0 - and (epoch + 1) % visualize_train_every == 0, - epoch, - last_val_loss=val_losses[-1] - if len(val_losses) > 0 - else float("inf"), - ) + train_loss: float = await asyncio.to_thread( + self._train_epoch, + f"Epoch {epoch}/{epochs}: Training", + visualize_train_every > 0 and (epoch + 1) % visualize_train_every == 0, + epoch, + last_val_loss=last_val_loss, ) if epoch % val_every == 0: self._model.eval() - # self._pbar.colour = project_conf.Theme.VALIDATION.value - val_losses.append( - self._val_epoch( - f"Epoch {epoch}/{epochs}: Validation", - visualize_every > 0 and (epoch + 1) % visualize_every == 0, - epoch, - ) + val_loss = await asyncio.to_thread( + self._val_epoch, + f"Epoch {epoch}/{epochs}: Validation", + visualize_every > 0 and (epoch + 1) % visualize_every == 0, + epoch, ) + last_val_loss = val_loss if self._scheduler is not None: - self._scheduler.step() + await asyncio.to_thread(self._scheduler.step) """ ==================== Plotting ==================== """ - self._gui.plot(epoch, train_losses, val_losses, self._model_saver) - # if project_conf.PLOT_ENABLED: - # self._plot(epoch, train_losses, val_losses) - # self._pbar.close() - self._save_checkpoint( - val_losses[-1], + self._gui.plot(epoch, train_loss, val_loss) # , self._model_saver) + await asyncio.to_thread( + self._save_checkpoint, + last_val_loss, os.path.join(HydraConfig.get().runtime.output_dir, "last.ckpt"), ) print(f"[*] Training finished for {self._run_name}!") @@ -327,68 +311,6 @@ def train( + f"at epoch {self._model_saver.min_val_loss_epoch}." ) - # @staticmethod - # def _setup_plot(run_name: str, log_scale: bool = False): - # """Setup the plot for training and validation losses.""" - # plt.title(f"Training curves for {run_name}") - # plt.theme("dark") - # plt.xlabel("Epoch") - # if log_scale: - # plt.ylabel("Loss (log scale)") - # plt.yscale("log") - # else: - # plt.ylabel("Loss") - # plt.grid(True, True) - - # def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]): - # """Plot the training and validation losses. - # Args: - # epoch (int): Current epoch number. - # train_losses (List[float]): List of training losses. - # val_losses (List[float]): List of validation losses. - # Returns: - # None - # """ - # plt.clf() - # if project_conf.LOG_SCALE_PLOT and any( - # loss_val <= 0 for loss_val in train_losses + val_losses - # ): - # raise ValueError( - # "Cannot plot on a log scale if there are non-positive losses." - # ) - # self._setup_plot(self._run_name, log_scale=project_conf.LOG_SCALE_PLOT) - # plt.plot( - # list(range(self._starting_epoch, epoch + 1)), - # train_losses, - # color=project_conf.Theme.TRAINING.value, - # label="Training loss", - # ) - # plt.plot( - # list(range(self._starting_epoch, epoch + 1)), - # val_losses, - # color=project_conf.Theme.VALIDATION.value, - # label="Validation loss", - # ) - # best_metrics = ( - # "[" - # + ", ".join( - # [ - # f"{metric_name}={metric_value:.2e} " - # for metric_name, metric_value in self._model_saver.best_metrics.items() - # ] - # ) - # + "]" - # ) - # plt.scatter( - # [self._model_saver.min_val_loss_epoch], - # [self._model_saver.min_val_loss], - # color="red", - # marker="+", - # label=f"Best model {best_metrics}", - # style="inverted", - # ) - # plt.show() - def _save_checkpoint(self, val_loss: float, ckpt_path: str, **kwargs) -> None: """Saves the model and optimizer state to a checkpoint file. Args: diff --git a/utils/gui.py b/utils/gui.py index b118639..8f93c58 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -1,182 +1,183 @@ -#! /usr/bin/env python3 -# vim:fenc=utf-8 -# -# Copyright © 2024 Théo Morales -# -# Distributed under terms of the MIT license. - -""" -The fancy new GUI. -""" - -import random -from collections import abc, namedtuple +# Let's use Textual to rewrite the GUI with better features. + +import asyncio +from collections import abc from datetime import datetime +from enum import Enum from functools import partial +from itertools import cycle +from random import random from time import sleep from typing import ( + Any, Callable, Iterable, Iterator, - List, Optional, Sequence, Tuple, - TypeVar, ) -import plotext as plt -import rich +import numpy as np import torch -from rich import box -from rich.align import Align -from rich.ansi import AnsiDecoder -from rich.console import Group -from rich.jupyter import JupyterMixin -from rich.layout import Layout -from rich.live import Live -from rich.panel import Panel -from rich.progress import ( - BarColumn, - Progress, - SpinnerColumn, - TaskID, - TaskProgressColumn, - TextColumn, - TimeRemainingColumn, -) -from rich.style import Style -from rich.table import Table +import torch.multiprocessing as mp +from rich.console import Group, RenderableType +from rich.pretty import Pretty from rich.text import Text +from textual.app import App, ComposeResult +from textual.containers import Center +from textual.reactive import var +from textual.widgets import ( + Footer, + Header, + Label, + Placeholder, + ProgressBar, + RichLog, + Static, +) +from textual_plotext import PlotextPlot from torch.utils.data.dataloader import DataLoader from torchvision.datasets import MNIST from torchvision.transforms.functional import to_tensor -if __name__ != "__main__": - from conf import project as project_conf - from utils.helpers import BestNModelSaver -else: - project_conf = namedtuple("project_conf", "Theme")( - Theme=namedtuple("Theme", "TRAINING VALIDATION TESTING")( - TRAINING=namedtuple("value", "value")("blue"), - VALIDATION=namedtuple("value", "value")("green"), - TESTING=namedtuple("value", "value")("cyan"), - ) - ) - BestNModelSaver = TypeVar("BestNModelSaver") - - -class PlotextMixin(JupyterMixin): - def __init__(self, p_make_plot): - self.decoder = AnsiDecoder() - self.mk_plot = p_make_plot - - def __rich_console__(self, console, options): - self.width = options.max_width or console.width - self.height = options.height or console.height - canvas = self.mk_plot(width=self.width, height=self.height) - self.rich_canvas = Group(*self.decoder.decode(canvas)) - yield self.rich_canvas - - -# TODO: Make it a singleton so we can print from anywhere in the code, without passing a reference -# around. -class GUI: - def __init__(self, run_name: str, plot_log_scale: bool) -> None: - self._run_name = run_name - self._plot_log_scale = plot_log_scale - self._starting_epoch = 0 # TODO: - self._layout = Layout() - self._layout.split( - Layout(name="header", size=2), - Layout(name="main", ratio=1), - Layout(name="footer", size=2), - ) - self._layout["main"].split_row( - Layout(name="body", ratio=3, minimum_size=60), - Layout(name="side"), - ) - self._plot = Panel( - Align.center( - Text( - "Waiting for training curves...", - justify="center", - style=Style(color="blue", bold=True), - ), - vertical="middle", - ), - title="Training curves", - expand=True, - ) - self._layout["body"].update(self._plot) - self._live = Live(self._layout, screen=True) - self._console = self._live.console - self._pbar = Progress( - SpinnerColumn(spinner_name="monkey"), - TextColumn( - "[progress.description]{task.description} \[loss={task.fields[loss]:.3f}]" - ), - BarColumn(), - TaskProgressColumn(), - TimeRemainingColumn(), - # expand=True, - ) - self._main_progress = Panel( # TODO: - self._pbar, - title="Training epoch ?/?", - expand=True, - ) - self._layout["footer"].update(self._pbar) - run_color = f"color({hash(run_name) % 255})" - background_color = f"color({(hash(run_name) + 128) % 255})" - self._layout["header"].update( - Text( - f"Running {run_name}", - style=f"bold {run_color} on {background_color}", - justify="center", + +class PlotterWidget(PlotextPlot): + marker: var[str] = var("sd") + + """The type of marker to use for the plot.""" + + def __init__( + self, + title: str, + use_log_scale: bool = False, + *, + name: str | None = None, + id: str | None = None, # pylint:disable=redefined-builtin + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialise the training curves plotter widget. + + Args: + name: The name of the plotter widget. + id: The ID of the plotter widget in the DOM. + classes: The CSS classes of the plotter widget. + disabled: Whether the plotter widget is disabled or not. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self._title = title + self._log_scale = use_log_scale + self._train_losses: list[float] = [] + self._val_losses: list[float] = [] + self._start_epoch = 0 + self._epoch = 0 + + def on_mount(self) -> None: + """Plot the data using Plotext.""" + self.plt.title(self._title) + self.plt.xlabel("Epoch") + if self._log_scale: + self.plt.ylabel("Loss (log scale)") + self.plt.yscale("log") + else: + self.plt.ylabel("Loss") + self.plt.grid(True, True) + + def replot(self) -> None: + """Redraw the plot.""" + self.plt.clear_data() + if self._log_scale and ( + self._train_losses[-1] <= 0 or self._val_losses[-1] <= 0 + ): + raise ValueError( + "Cannot plot on a log scale if there are non-positive losses." ) - ) - self._logger = Table.grid(padding=0) - self._logger.add_column(no_wrap=False) - self._layout["side"].update( - Panel( - self._logger, title="Logs", border_style="bright_red", box=box.ROUNDED + if len(self._train_losses) > 0: + assert len(self._val_losses) == len(self._train_losses) + self.plt.plot( + list(range(self._start_epoch, self._epoch + 1)), + self._train_losses, + color="blue", # TODO: Theme + label="Training loss", + marker=self.marker, ) + self.plt.plot( + list(range(self._start_epoch, self._epoch + 1)), + self._val_losses, + color="green", # TODO: Theme + label="Validation loss", + marker=self.marker, + ) + self.refresh() + + def set_start_epoch(self, start_epoch: int): + self._start_epoch = start_epoch + + def update( + self, epoch: int, train_loss: float, val_loss: Optional[float] = None + ) -> None: + """Update the data for the training curves plot. + + Args: + epoch: (int) The current epoch number. + train_loss: (float) The last training loss. + val_loss: (float) The last validation loss. + """ + self._epoch = epoch + self._train_losses.append(train_loss) + self._val_losses.append( + val_loss if val_loss is not None else self._val_losses[-1] ) - self.tasks = { - "training": self._pbar.add_task( - f"[{project_conf.Theme.TRAINING.value}]Training", - visible=False, - loss=torch.inf, - ), - "validation": self._pbar.add_task( - f"[{project_conf.Theme.VALIDATION.value}]Validation", - visible=False, - loss=torch.inf, - ), - "testing": self._pbar.add_task( - f"[{project_conf.Theme.TESTING.value}]Testing", - visible=False, - loss=torch.inf, - ), - } - - @property - def console(self): - return self._console - - def open(self) -> None: - self._live.__enter__() - - def close(self) -> None: - self._live.__exit__(None, None, None) - - def _track_iterable(self, iterable, task, total) -> Tuple[Iterable, Callable]: + self.replot() + + def _watch_marker(self) -> None: + """React to the marker being changed.""" + self.replot() + + +# TODO: Also make a Rich renderable for Tensors (using tables?) + + +class Task(Enum): + IDLE = -1 + TRAINING = 0 + VALIDATION = 1 + TESTING = 2 + + +class DatasetProgressBar(Static): + """A progress bar for PyTorch dataloader iteration.""" + + DESCRIPTIONS = { + Task.IDLE: Text("Waiting for work..."), + Task.TRAINING: Text("Training: ", style="bold blue"), + Task.VALIDATION: Text("Validation: ", style="bold green"), + Task.TESTING: Text("Testing: ", style="bold yellow"), + } + + def compose(self) -> ComposeResult: + # with Horizontal(): + with Center(): + yield Label(self.DESCRIPTIONS[Task.IDLE], id="progress_label") + yield ProgressBar() + + def track_iterable( + self, + iterable: Iterable | Sequence | Iterator | DataLoader, + task: Task, + total: int, + ) -> Tuple[Iterable, Callable]: class LossHook: def __init__(self): self._loss = None - def update_loss_hook(self, loss: float): + def update_loss_hook( + self, loss: float, min_val_loss: Optional[float] = None + ) -> None: + """Update the loss value in the progress bar.""" + # TODO: min_val_loss during validation, val_loss during training. Ideally the + # second parameter would be super flexible (use a dict then). self._loss = loss class SeqWrapper(abc.Iterator, LossHook): @@ -226,22 +227,22 @@ def __next__(self): self._reset_hook() raise StopIteration - def update_hook(task_id: TaskID, loss: Optional[float] = None): - self._pbar.advance(task_id) + def update_hook(loss: Optional[float] = None): + self.query_one(ProgressBar).advance() if loss is not None: - self._pbar.tasks[task_id].fields["loss"] = loss - # TODO: Nice progress panel with overall progress and epoch progress - # self._main_progress = Panel(self._pbar, title="Training epoch ?/?") - # self._layout["footer"].update(self._main_progress) - # self._live.refresh() + plabel: Label = self.query_one("#progress_label") # type: ignore + plabel.update(self.DESCRIPTIONS[task] + f"[loss={loss:.4f}]") - def reset_hook(task_id: TaskID, total: int): - self._pbar.reset(task_id, total=total, visible=False) + def reset_hook(total: int): + sleep(0.5) + self.query_one(ProgressBar).update(total=100, progress=0) + plabel: Label = self.query_one("#progress_label") # type: ignore + plabel.update(self.DESCRIPTIONS[Task.IDLE]) wrapper = None update_p, reset_p = ( - partial(update_hook, task_id=task), - partial(reset_hook, task, total), + partial(update_hook), + partial(reset_hook, total), ) if isinstance(iterable, abc.Sequence): wrapper = SeqWrapper( @@ -261,167 +262,199 @@ def reset_hook(task_id: TaskID, total: int): raise ValueError( f"iterable must be a Sequence or an Iterator, got {type(iterable)}" ) - self._pbar.reset(task, total=total, visible=True) + self.query_one(ProgressBar).update(total=total, progress=0) + plabel: Label = self.query_one("#progress_label") # type: ignore + plabel.update(self.DESCRIPTIONS[task]) return wrapper, wrapper.update_loss_hook - def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: - task = self.tasks["training"] - return self._track_iterable(iterable, task, total) - def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: - task = self.tasks["validation"] - return self._track_iterable(iterable, task, total) +class GUI(App): + """A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework.""" - def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: - task = self.tasks["testing"] - return self._track_iterable(iterable, task, total) + CSS_PATH = "style.css" - def print_header(self, text: str): - self._layout["header"].update(text) + BINDINGS = [ + ("q", "quit", "Quit"), + ("d", "toggle_dark", "Toggle dark mode"), + ("p", "marker", "Change plotter style"), + ("ctrl+z", "suspend_progress"), + ] - def print(self, text: str | Text): - """ - Print text to the side panel. - """ - # TODO: Use a fifo and pop the first item if we can't view the enw item. - try: - self._logger.add_row( - Text(datetime.now().strftime("[%H:%M] "), style="dim cyan"), text - ) - # TODO: Compute the max number of displayable rows. We have access to the height: - # height = self._console.size.height - # But what about the width of the column? Can we get that anywhere? - # Then, we need to compute the width of each row to see if it overflows and by how many - # lines. - # Finally, we'll have the percentage of height that our rows take, and if it's above - # 100%-threshold, we remove the first row. - except rich.errors.NotRenderableError as e: - self.print(Text("[Rich]: " + str(e), style="bold red")) - - def _make_plot( - self, - width, - height, - epoch: int, - train_losses: List[float], - val_losses: List[float], - model_saver: Optional[BestNModelSaver] = None, - ): - """Plot the training and validation losses. - Args: - epoch (int): Current epoch number. - train_losses (List[float]): List of training losses. - val_losses (List[float]): List of validation losses. - Returns: - None - """ - if self._plot_log_scale and any( - loss_val <= 0 for loss_val in train_losses + val_losses - ): - raise ValueError( - "Cannot plot on a log scale if there are non-positive losses." + MARKERS = { + "dot": "Dot", + "hd": "High Definition", + "fhd": "Higher Definition", + "braille": "Braille", + "sd": "Standard Definition", + } + + marker: var[str] = var("hd") + + def __init__(self, run_name: str, log_scale: bool) -> None: + """Initialise the application.""" + super().__init__() + self._markers = cycle(self.MARKERS.keys()) + self._log_scale = log_scale + self.run_name = run_name + + def compose(self) -> ComposeResult: + yield Header() + yield PlotterWidget( + title=f"Trainign curves for {self.run_name}", + use_log_scale=self._log_scale, + classes="box", + ) + yield RichLog( + highlight=True, markup=True, wrap=True, id="logger", classes="box" + ) + yield DatasetProgressBar() + yield Placeholder(classes="box") + yield Footer() + + def on_mount(self): + self.query_one(PlotterWidget).loading = True + + def action_toggle_dark(self) -> None: + self.dark = not self.dark + + def watch_marker(self) -> None: + """React to the marker type being changed.""" + self.sub_title = self.MARKERS[self.marker] + self.query_one(PlotterWidget).marker = self.marker + + def action_marker(self) -> None: + """Cycle to the next marker type.""" + self.marker = next(self._markers) + + def print(self, message: Any): + logger: RichLog = self.query_one(RichLog) + if isinstance(message, (RenderableType, str)): + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + message, + ), ) - plt.clf() - plt.plotsize(width, height) - plt.title(f"Training curves for {self._run_name}") - plt.xlabel("Epoch") - plt.theme("dark") - if self._plot_log_scale: - plt.ylabel("Loss (log scale)") - plt.yscale("log") else: - plt.ylabel("Loss") - plt.grid(True, True) - - plt.plot( - list(range(self._starting_epoch, epoch + 1)), - train_losses, - color=project_conf.Theme.TRAINING.value, - # color="blue", - label="Training loss", + ppable, pp_msg = True, None + try: + pp_msg = Pretty(message) + except Exception: + ppable = False + if ppable and pp_msg is not None: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text(str(type(message)) + " ", style="italic blue", end=""), + pp_msg, + ) + ) + else: + try: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + message, + ), + ) + except Exception as e: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text("Logging error: ", style="bold red"), + Text(str(e), style="bold red"), + ) + ) + + def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: + """Return an iterable that tracks the progress of the training process, and a progress bar + hook to update the loss value at each iteration.""" + return self.query_one(DatasetProgressBar).track_iterable( + iterable, Task.TRAINING, total ) - plt.plot( - list(range(self._starting_epoch, epoch + 1)), - val_losses, - color=project_conf.Theme.VALIDATION.value, - # color="green", - label="Validation loss", + + def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: + """Return an iterable that tracks the progress of the validation process, and a progress bar + hook to update the loss value at each iteration.""" + return self.query_one(DatasetProgressBar).track_iterable( + iterable, Task.VALIDATION, total ) - best_metrics = ( - "[" - + ", ".join( - [ - f"{metric_name}={metric_value:.2e} " - for metric_name, metric_value in model_saver.best_metrics.items() - ] - if model_saver is not None - else [] - ) - + "]" + + def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: + """Return an iterable that tracks the progress of the testing process, and a progress bar + hook to update the loss value at each iteration.""" + return self.query_one(DatasetProgressBar).track_iterable( + iterable, Task.TESTING, total ) - if model_saver is not None: - plt.scatter( - [model_saver.min_val_loss_epoch], - [model_saver.min_val_loss], - color="red", - marker="+", - label=f"Best model {best_metrics}", - style="inverted", - ) - return plt.build() def plot( - self, - epoch: int, - train_losses: List[float], - val_losses: List[float], - model_saver: Optional[BestNModelSaver] = None, + self, epoch: int, train_loss: float, val_loss: Optional[float] = None ) -> None: - mk_plot = partial( - self._make_plot, - epoch=epoch, - train_losses=train_losses, - val_losses=val_losses, - model_saver=model_saver, - ) - self._plot = Panel(PlotextMixin(mk_plot), title="Training curves") - self._layout["body"].update(self._plot) - self._live.refresh() + """Plot the training and validation losses for the current epoch.""" + self.query_one(PlotterWidget).loading = False + self.query_one(PlotterWidget).update(epoch, train_loss, val_loss) + def set_start_epoch(self, start_epoch: int) -> None: + """Set the starting epoch for the plotter widget.""" + sef = self.query_one(PlotterWidget).set_start_epoch -if __name__ == "__main__": + +async def run_my_app(): + gui = GUI("test-run", log_scale=False) + task = asyncio.create_task(gui.run_async()) + while not gui.is_running: + await asyncio.sleep(0.01) # Wait for the app to start up + gui.print("Hello, World!") + await asyncio.sleep(2) + gui.print(Text("Let's log some tensors :)", style="bold magenta")) + await asyncio.sleep(0.5) + gui.print(torch.rand(2, 4)) + await asyncio.sleep(2) + gui.print(Text("How about some numpy arrays?!", style="italic green")) + await asyncio.sleep(1) + gui.print(np.random.rand(3, 3)) + pbar, update_progress_loss = gui.track_training(range(10), 10) + for i, e in enumerate(pbar): + gui.print(f"[{i+1}/10]: We can iterate over iterables") + gui.print(e) + await asyncio.sleep(0.1) + await asyncio.sleep(2) mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) - dataloader = DataLoader(mnist, 32, shuffle=True) - gui = GUI("test-run", plot_log_scale=False) - gui.open() # TODO: Use a context manager, why not?? - try: - gui.print("Hello, world!") - gui.print( - "Veeeeeeeeeeeeeeeeeeeryyyyyyyyyyyyyyyy looooooooooooooooooooooooooooong seeeeeeeeeeeenteeeeeeeeeeeeeeennnnnnnnnce!!!!!!!!!!!!!!!!!" - ) - pbar, update_progress_loss = gui.track_training(range(10), 10) - for i, e in enumerate(pbar): - gui.print(f"[{i}/10]: We can iterate over iterables") - sleep(0.1) - train_losses, val_losses = [], [] - pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader)) - for i, e in enumerate(pbar): - # gui.print(e) # TODO: Make this work! - if i % 10 == 0: - train_losses.append(random.random()) - val_losses.append(random.random()) - update_progress_loss(random.random()) - gui.plot(epoch=i, train_losses=train_losses, val_losses=val_losses) - gui.print( - f"[{i}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!" - ) - if i == 0: - gui.print(e) - sleep(0.01) - gui.print("Goodbye, world!") - sleep(1) - except Exception as e: - gui.close() - raise e - finally: - gui.close() + # Somehow, the dataloader will crash if it's not forked when using multiprocessing along with + # Textual. + mp.set_start_method("fork") + dataloader = DataLoader(mnist, 32, shuffle=True, num_workers=2) + pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader)) + for i, batch in enumerate(pbar): + await asyncio.sleep(0.01) + if i % 10 == 0: + gui.print(batch) + update_progress_loss(random()) + gui.plot(epoch=i, train_loss=random(), val_loss=random()) + gui.print( + f"[{i+1}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!" + ) + if i == 0: + gui.print(batch) + gui.print("Goodbye, world!") + _ = await task + + +if __name__ == "__main__": + asyncio.run(run_my_app()) diff --git a/utils/gui2.py b/utils/gui2.py deleted file mode 100644 index b7d2637..0000000 --- a/utils/gui2.py +++ /dev/null @@ -1,471 +0,0 @@ -# Let's use Textual to rewrite the GUI with better features. - -import asyncio -from collections import abc -from datetime import datetime -from enum import Enum -from functools import partial -from itertools import cycle -from random import randint, random -from time import sleep -from typing import ( - Callable, - Iterable, - Iterator, - List, - Optional, - Sequence, - Tuple, -) - -import numpy as np -import torch -from rich.console import Group, RenderableType -from rich.pretty import Pretty -from rich.text import Text -from textual.app import App, ComposeResult, RenderResult -from textual.containers import Center -from textual.reactive import var -from textual.widgets import ( - Footer, - Header, - Label, - Placeholder, - ProgressBar, - RichLog, - Static, -) -from textual_plotext import PlotextPlot -from torch.utils.data.dataloader import DataLoader -from torchvision.datasets import MNIST -from torchvision.transforms.functional import to_tensor - - -class PlotterWidget(PlotextPlot): - marker: var[str] = var("sd") - - """The type of marker to use for the plot.""" - - def __init__( - self, - title: str, - *, - name: str | None = None, - id: str | None = None, # pylint:disable=redefined-builtin - classes: str | None = None, - disabled: bool = False, - ) -> None: - """Initialise the training curves plotter widget. - - Args: - name: The name of the plotter widget. - id: The ID of the plotter widget in the DOM. - classes: The CSS classes of the plotter widget. - disabled: Whether the plotter widget is disabled or not. - """ - super().__init__(name=name, id=id, classes=classes, disabled=disabled) - self._title = title - self._train_losses: list[float] = [] - self._val_losses: list[float] = [] - self._epoch = 0 - - def on_mount(self) -> None: - """Plot the data using Plotext.""" - self.plt.title(self._title) - self.plt.xlabel("Epoch") - self.plt.ylabel("Loss") # TODO: update - self.plt.grid(True, True) - - def replot(self) -> None: - """Redraw the plot.""" - self.plt.clear_data() - if len(self._train_losses) > 0: - assert len(self._val_losses) == len(self._train_losses) - self.plt.plot( - list(range(0, self._epoch + 1)), # TODO: start epoch - self._train_losses, - # color=project_conf.Theme.TRAINING.value, # TODO: - color="blue", - label="Training loss", - marker=self.marker, - ) - self.plt.plot( - list(range(0, self._epoch + 1)), # TODO: start epoch - self._val_losses, - # color=project_conf.Theme.VALIDATION.value, # TODO: - color="green", - label="Validation loss", - marker=self.marker, - ) - self.refresh() - - def update( - self, epoch: int, train_losses: List[float], val_losses: List[float] - ) -> None: - """Update the data for the training curves plot. - - Args: - epoch: (int) The current epoch number. - train_losses: (List[float]) The list of training losses. - val_losses: (List[float]) The list of validation losses. - """ - # TODO: We only need to append to the losses. Do we need to keep track of them in the - # trianing loop? If not we should only keep track of the last one and let the GUI keep - # track of them all. - self._epoch = epoch - self._train_losses = train_losses - self._val_losses = val_losses - self.replot() - - def _watch_marker(self) -> None: - """React to the marker being changed.""" - self.replot() - - -# TODO: Also make a Rich renderable for Tensors - - -class TensorWidget(Static): # TODO: PRETTYER - def __init__(self, tensor: torch.Tensor): - self.tensor = tensor - - def render(self) -> RenderResult: - return Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), - Pretty(self.tensor), - ) - - -class Task(Enum): - IDLE = -1 - TRAINING = 0 - VALIDATION = 1 - TESTING = 2 - - -class DatasetProgressBar(Static): - """A progress bar for PyTorch dataloader iteration.""" - - DESCRIPTIONS = { - Task.IDLE: Text("Waiting for work..."), - Task.TRAINING: Text("Training: ", style="bold blue"), - Task.VALIDATION: Text("Validation: ", style="bold green"), - Task.TESTING: Text("Testing: ", style="bold yellow"), - } - - # def __init__(self): - # self.description = None - # self.total = None - # self.progress = 0 - - def compose(self) -> ComposeResult: - # with Horizontal(): - with Center(): - yield Label(self.DESCRIPTIONS[Task.IDLE], id="progress_label") - yield ProgressBar() - - def track_iterable( - self, - iterable: Iterable | Sequence | Iterator | DataLoader, - task: Task, - total: int, - ) -> Tuple[Iterable, Callable]: - class LossHook: - def __init__(self): - self._loss = None - - def update_loss_hook(self, loss: float): - self._loss = loss - - class SeqWrapper(abc.Iterator, LossHook): - def __init__( - self, - seq: Sequence, - len: int, - update_hook: Callable, - reset_hook: Callable, - ): - super().__init__() - self._sequence = seq - self._idx = 0 - self._len = len - self._update_hook = update_hook - self._reset_hook = reset_hook - - def __next__(self): - if self._idx >= self._len: - self._reset_hook() - raise StopIteration - item = self._sequence[self._idx] - self._update_hook(loss=self._loss) - self._idx += 1 - return item - - class IteratorWrapper(abc.Iterator, LossHook): - def __init__( - self, - iterator: Iterator | DataLoader, - len: int, - update_hook: Callable, - reset_hook: Callable, - ): - super().__init__() - self._iterator = iter(iterator) - self._len = len - self._update_hook = update_hook - self._reset_hook = reset_hook - - def __next__(self): - try: - item = next(self._iterator) - self._update_hook(loss=self._loss) - return item - except StopIteration: - self._reset_hook() - raise StopIteration - - def update_hook(loss: Optional[float] = None): - self.query_one(ProgressBar).advance() - if loss is not None: - self.query_one("#progress_label").update( - self.DESCRIPTIONS[task] + f"[loss={loss:.4f}]" - ) - - def reset_hook(total: int): - sleep(0.5) - self.query_one(ProgressBar).update(total=100, progress=0) - self.query_one("#progress_label").update(self.DESCRIPTIONS[Task.IDLE]) - - wrapper = None - update_p, reset_p = ( - partial(update_hook), - partial(reset_hook, total), - ) - if isinstance(iterable, abc.Sequence): - wrapper = SeqWrapper( - iterable, - total, - update_p, - reset_p, - ) - elif isinstance(iterable, (abc.Iterator, DataLoader)): - wrapper = IteratorWrapper( - iterable, - total, - update_p, - reset_p, - ) - else: - raise ValueError( - f"iterable must be a Sequence or an Iterator, got {type(iterable)}" - ) - self.query_one(ProgressBar).update(total=total, progress=0) - self.query_one("#progress_label").update(self.DESCRIPTIONS[task]) - return wrapper, wrapper.update_loss_hook - - -class GUI(App): - """A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework.""" - - CSS_PATH = "style.css" - - BINDINGS = [ - ("q", "quit", "Quit"), - ("d", "toggle_dark", "Toggle dark mode"), - ("m", "marker", "Cycle example markers"), - ("ctrl+z", "suspend_progress"), - ] - - MARKERS = { - "dot": "Dot", - "hd": "High Definition", - "fhd": "Higher Definition", - "braille": "Braille", - "sd": "Standard Definition", - } - - marker: var[str] = var("sd") - - def __init__(self) -> None: - """Initialise the application.""" - super().__init__() - self._markers = cycle(self.MARKERS.keys()) - - def compose(self) -> ComposeResult: - yield Header() - yield PlotterWidget(title="Trainign curves for run-name", classes="box") - yield RichLog( - highlight=True, markup=True, wrap=True, id="logger", classes="box" - ) - # yield Placeholder(classes="box") - yield DatasetProgressBar() - yield Placeholder(classes="box") - yield Footer() - - def action_toggle_dark(self) -> None: - self.dark = not self.dark - - def watch_marker(self) -> None: - """React to the marker type being changed.""" - self.sub_title = self.MARKERS[self.marker] - self.query_one(PlotterWidget).marker = self.marker - - def action_marker(self) -> None: - """Cycle to the next marker type.""" - self.marker = next(self._markers) - - def on_key(self, event) -> None: - logger: RichLog = self.query_one(RichLog) - logger.write( - Group( - Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""), - f"Key pressed: {event.key!r}", - ), - ) - if event.key == "t": - logger.write( - Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), - Pretty(torch.rand(randint(1, 12), randint(1, 12))), - ) - ) - elif event.key == "p": - self.query_one(PlotterWidget).update( - epoch=9, - train_losses=[random() for _ in range(10)], - val_losses=[random() for _ in range(10)], - ) - - def print(self, message: RenderableType | str | torch.Tensor | np.ndarray): - logger: RichLog = self.query_one(RichLog) - if isinstance(message, (RenderableType, str)): - logger.write( - Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), - message, - ), - ) - else: - ppable, pp_msg = True, None - try: - pp_msg = Pretty(message) - except Exception: - ppable = False - if ppable and pp_msg is not None: - logger.write( - Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), - Text(str(type(message)) + " ", style="italic blue", end=""), - pp_msg, - ) - ) - else: - try: - logger.write( - Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), - message, - ), - ) - except Exception as e: - logger.write( - Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), - Text("Logging error: ", style="bold red"), - Text(str(e), style="bold red"), - ) - ) - - def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: - return self.query_one(DatasetProgressBar).track_iterable( - iterable, Task.TRAINING, total - ) - - def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: - return self.query_one(DatasetProgressBar).track_iterable( - iterable, Task.VALIDATION, total - ) - - def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: - return self.query_one(DatasetProgressBar).track_iterable( - iterable, Task.TESTING, total - ) - - def plot(self, epoch: int, train_losses: List[float], val_losses: List[float]): - self.query_one(PlotterWidget).update(epoch, train_losses, val_losses) - - -async def run_my_app(): - gui = GUI() - task = asyncio.create_task(gui.run_async()) - await asyncio.sleep(0.1) # Wait for the app to start up - gui.print("Hello, World!") - # await asyncio.sleep(2) - # gui.print(Text("Let's log some tensors :)", style="bold magenta")) - # await asyncio.sleep(0.5) - # gui.print(torch.rand(2, 4)) - # await asyncio.sleep(2) - # gui.print(Text("How about some numpy arrays?!", style="italic green")) - # await asyncio.sleep(1) - # gui.print(np.random.rand(3, 3)) - # await asyncio.sleep(3) - # gui.print("...") - # await asyncio.sleep(3) - # gui.print("Go on... Press 'p'! I know you want to!") - # await asyncio.sleep(4) - # gui.print("COME ON PRESS P!!!!") - # await asyncio.sleep(1) - # pbar, update_progress_loss = gui.track_training(range(10), 10) - # for i, e in enumerate(pbar): - # gui.print(f"[{i+1}/10]: We can iterate over iterables") - # gui.print(e) - # # sleep(0.1) - # await asyncio.sleep(0.1) - # await asyncio.sleep(5) - mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) - dataloader = DataLoader(mnist, 32, shuffle=True) - train_losses, val_losses = [], [] - pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader)) - for i, batch in enumerate(pbar): - if i % 10 == 0: - await asyncio.sleep(0.01) - gui.print(batch) - train_losses.append(random()) - val_losses.append(random()) - update_progress_loss(random()) - gui.plot(epoch=i, train_losses=train_losses, val_losses=val_losses) - gui.print( - f"[{i+1}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!" - ) - if i == 0: - gui.print(batch) - gui.print("Goodbye, world!") - _ = await task - - -if __name__ == "__main__": - asyncio.run(run_my_app()) From b7202184ba0a551d56d01c4adf1eb3652c3b6f16 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 14:01:32 +0100 Subject: [PATCH 27/38] Fix typos and linter issues --- src/base_trainer.py | 2 +- utils/gui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/base_trainer.py b/src/base_trainer.py index c1ddb43..7a64a89 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -299,7 +299,7 @@ async def train( if self._scheduler is not None: await asyncio.to_thread(self._scheduler.step) """ ==================== Plotting ==================== """ - self._gui.plot(epoch, train_loss, val_loss) # , self._model_saver) + self._gui.plot(epoch, train_loss, last_val_loss) # , self._model_saver) await asyncio.to_thread( self._save_checkpoint, last_val_loss, diff --git a/utils/gui.py b/utils/gui.py index 8f93c58..d79fd04 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -412,7 +412,7 @@ def plot( def set_start_epoch(self, start_epoch: int) -> None: """Set the starting epoch for the plotter widget.""" - sef = self.query_one(PlotterWidget).set_start_epoch + self.query_one(PlotterWidget).set_start_epoch(start_epoch) async def run_my_app(): From ea3fe6f468d3c58b1783e68fefa841874c0ed33e Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 14:09:21 +0100 Subject: [PATCH 28/38] Set the TUI title --- bootstrap/launch_experiment.py | 2 +- utils/gui.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index 96e5311..cb0724d 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -154,7 +154,7 @@ def launch_experiment( """ ============ Training ============ """ console.print( - "Launching GUI...", + "Launching TUI...", style="bold cyan", ) sleep(1) diff --git a/utils/gui.py b/utils/gui.py index d79fd04..121cee3 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -271,6 +271,7 @@ def reset_hook(total: int): class GUI(App): """A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework.""" + TITLE = "Matchbox TUI" CSS_PATH = "style.css" BINDINGS = [ From d89f7bffa245da30e8aa43ddd186001b69fe7cb8 Mon Sep 17 00:00:00 2001 From: "deepsource-autofix[bot]" <62050782+deepsource-autofix[bot]@users.noreply.github.com> Date: Sat, 27 Jul 2024 14:26:11 +0000 Subject: [PATCH 29/38] refactor: autofix issues in 6 files Resolved issues in the following files with DeepSource Autofix: 1. conf/experiment.py 2. dataset/mixins/__init__.py 3. src/base_tester.py 4. src/base_trainer.py 5. utils/gui.py 6. utils/__init__.py --- conf/experiment.py | 7 ------- dataset/mixins/__init__.py | 13 ++----------- src/base_tester.py | 3 --- src/base_trainer.py | 4 ---- utils/__init__.py | 2 -- utils/gui.py | 2 -- 6 files changed, 2 insertions(+), 29 deletions(-) diff --git a/conf/experiment.py b/conf/experiment.py index 858a0cf..b74c02b 100644 --- a/conf/experiment.py +++ b/conf/experiment.py @@ -1,9 +1,3 @@ -#! /usr/bin/env python3 -# vim:fenc=utf-8 -# -# Copyright © 2023 Théo Morales -# -# Distributed under terms of the MIT license. """ Configurations for the experiments and config groups, using hydra-zen. @@ -270,7 +264,6 @@ def make_experiment_configs(): {"override /model": "model_a"}, {"override /dataset": "image_a"}, ], - # training=dict(epochs=100), bases=(Experiment,), ), name="exp_a", diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index 7af001f..968c84a 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -1,9 +1,3 @@ -#! /usr/bin/env python3 -# vim:fenc=utf-8 -# -# Copyright © 2023 Théo Morales -# -# Distributed under terms of the MIT license. """ Base dataset. @@ -415,7 +409,6 @@ def __init__( pool_dispatch_func, zip( raw_elements, - # itertools.count(len(raw_elements)), range(len(raw_elements)), itertools.repeat(tiny), itertools.repeat(split), @@ -433,8 +426,7 @@ def _get_raw_elements_hook( ): if hasattr(super(), "_get_raw_elements_hook"): return super()._get_raw_elements_hook(dataset_root, tiny, split, seed) # type: ignore - else: - return self._get_raw_elements(dataset_root, tiny, split, seed) + return self._get_raw_elements(dataset_root, tiny, split, seed) def _load_hook_unpack(self, args): return self._load_hook(*args) @@ -442,8 +434,7 @@ def _load_hook_unpack(self, args): def _load_hook(self, *args) -> Tuple[int, Any, Any]: if hasattr(super(), "_load_hook"): return super()._load_hook(*args) # type: ignore - else: - return self._load(*args) # TODO: Rename to _load_sample? + return self._load(*args) # TODO: Rename to _load_sample? def _load_sample_label(self, idx: int) -> Tuple[Any, Any]: if hasattr(super(), "_load_sample_label"): diff --git a/src/base_tester.py b/src/base_tester.py index de87233..7f80761 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -30,8 +30,6 @@ console = Console() - -global print print = console.print @@ -67,7 +65,6 @@ def __init__( self._load_checkpoint(model_ckpt_path, model_only=True) self._data_loader = data_loader self._running = True - # self._pbar = tqdm(total=len(self._data_loader), desc="Testing") signal.signal(signal.SIGINT, self._terminator) @to_cuda diff --git a/src/base_trainer.py b/src/base_trainer.py index 7a64a89..a0dccb4 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -34,8 +34,6 @@ from utils.training import visualize_model_predictions console = Console() - -global print print = console.print @@ -137,7 +135,6 @@ def _train_epoch( """ epoch_loss: MeanMetric = MeanMetric() epoch_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) - # color_code = project_conf.ANSI_COLORS[project_conf.Theme.TRAINING.value] has_visualized = 0 """ ==================== Training loop for one epoch ==================== """ pbar, update_loss_hook = self._gui.track_training( @@ -191,7 +188,6 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: float: Average validation loss for the epoch. """ has_visualized = 0 - # color_code = project_conf.ANSI_COLORS[project_conf.Theme.VALIDATION.value] """ ==================== Validation loop for one epoch ==================== """ with torch.no_grad(): val_loss: MeanMetric = MeanMetric() diff --git a/utils/__init__.py b/utils/__init__.py index a650dfa..5564d43 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -13,7 +13,6 @@ # import sys import traceback -from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Tuple, Union # import IPython @@ -21,7 +20,6 @@ import torch from hydra.utils import to_absolute_path from torch import Tensor, nn -from tqdm import tqdm from conf import project as project_conf diff --git a/utils/gui.py b/utils/gui.py index d79fd04..071f6c3 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -1,4 +1,3 @@ -# Let's use Textual to rewrite the GUI with better features. import asyncio from collections import abc @@ -157,7 +156,6 @@ class DatasetProgressBar(Static): } def compose(self) -> ComposeResult: - # with Horizontal(): with Center(): yield Label(self.DESCRIPTIONS[Task.IDLE], id="progress_label") yield ProgressBar() From 20875de24bc92cdae3e0b885a6d675fc96e26dd0 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 15:39:37 +0100 Subject: [PATCH 30/38] Fix deepsource issues --- bootstrap/launch_experiment.py | 4 ++-- dataset/mixins/__init__.py | 6 ++++-- src/base_tester.py | 2 +- src/base_trainer.py | 2 +- utils/gui.py | 21 +++++++++------------ 5 files changed, 17 insertions(+), 18 deletions(-) diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index cb0724d..7db7fd6 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -128,7 +128,7 @@ def launch_experiment( ) print_config(run_name, exp_conf) - """ ============ Partials instantiation ============ """ + # ============ Partials instantiation ============ model_inst = make_model(model, dataset) print_model(model_inst) train_dataset, val_dataset, test_dataset = make_datasets( @@ -152,7 +152,7 @@ def launch_experiment( ) init_wandb(run_name, model_inst, exp_conf) - """ ============ Training ============ """ + # ============ Training ============ console.print( "Launching TUI...", style="bold cyan", diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index 968c84a..97f064a 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -1,4 +1,3 @@ - """ Base dataset. In this file you may implement other base datasets that share the same characteristics and which @@ -147,7 +146,10 @@ def __init__( # TODO: Recursively hash the source code for user's methods in self.__class__ # NOTE: getsource() won't work if I have a decorator that wraps the method. I think it's # best to keep this behaviour and not use decorators. - fingerprint_els = {"code": hashlib.new("md5"), "args": hashlib.new("md5")} + fingerprint_els = { + "code": hashlib.new("md5", usedforsecurity=False), + "args": hashlib.new("md5", usedforsecurity=False), + } tree = ast.parse(inspect.getsource(self.__class__)) fingerprint_els["code"].update(ast.dump(tree).encode()) fingerprint_els["args"].update(pickle.dumps(argvalues)) diff --git a/src/base_tester.py b/src/base_tester.py index 7f80761..2b0fb61 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -54,7 +54,7 @@ def __init__( _args = kwargs _loss = training_loss self._gui = gui - global print + global print # skipcq: PYL-W0603 print = self._gui.print self._run_name = run_name self._model = model diff --git a/src/base_trainer.py b/src/base_trainer.py index a0dccb4..4117be2 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -76,7 +76,7 @@ def __init__( self._viz_n_samples = 1 self._n_ctrl_c = 0 self._gui = gui - global print + global print # skipcq: PYL-W0603 print = self._gui.print if model_ckpt_path is not None: self._load_checkpoint(model_ckpt_path) diff --git a/utils/gui.py b/utils/gui.py index 05ca134..39a6a83 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -1,4 +1,3 @@ - import asyncio from collections import abc from datetime import datetime @@ -170,9 +169,7 @@ class LossHook: def __init__(self): self._loss = None - def update_loss_hook( - self, loss: float, min_val_loss: Optional[float] = None - ) -> None: + def update_loss_hook(self, loss: float) -> None: """Update the loss value in the progress bar.""" # TODO: min_val_loss during validation, val_loss during training. Ideally the # second parameter would be super flexible (use a dict then). @@ -182,14 +179,14 @@ class SeqWrapper(abc.Iterator, LossHook): def __init__( self, seq: Sequence, - len: int, + length: int, update_hook: Callable, reset_hook: Callable, ): super().__init__() self._sequence = seq self._idx = 0 - self._len = len + self._len = length self._update_hook = update_hook self._reset_hook = reset_hook @@ -206,13 +203,13 @@ class IteratorWrapper(abc.Iterator, LossHook): def __init__( self, iterator: Iterator | DataLoader, - len: int, + length: int, update_hook: Callable, reset_hook: Callable, ): super().__init__() self._iterator = iter(iterator) - self._len = len + self._len = length self._update_hook = update_hook self._reset_hook = reset_hook @@ -231,7 +228,7 @@ def update_hook(loss: Optional[float] = None): plabel: Label = self.query_one("#progress_label") # type: ignore plabel.update(self.DESCRIPTIONS[task] + f"[loss={loss:.4f}]") - def reset_hook(total: int): + def reset_hook(): sleep(0.5) self.query_one(ProgressBar).update(total=100, progress=0) plabel: Label = self.query_one("#progress_label") # type: ignore @@ -240,7 +237,7 @@ def reset_hook(total: int): wrapper = None update_p, reset_p = ( partial(update_hook), - partial(reset_hook, total), + partial(reset_hook), ) if isinstance(iterable, abc.Sequence): wrapper = SeqWrapper( @@ -314,11 +311,11 @@ def on_mount(self): self.query_one(PlotterWidget).loading = True def action_toggle_dark(self) -> None: - self.dark = not self.dark + self.dark = not self.dark # skipcq: PYL-W0201 def watch_marker(self) -> None: """React to the marker type being changed.""" - self.sub_title = self.MARKERS[self.marker] + self.sub_title = self.MARKERS[self.marker] # skipcq: PYL-W0201 self.query_one(PlotterWidget).marker = self.marker def action_marker(self) -> None: From 8c01ae7d6dcd19e4aff4dd21d0d09be91ea451e1 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 17:27:56 +0100 Subject: [PATCH 31/38] Format code and comments --- conf/experiment.py | 1 - dataset/example.py | 6 ++-- dataset/mixins/__init__.py | 1 + pyproject.toml | 6 ++-- src/base_tester.py | 15 +++++---- src/base_trainer.py | 40 ++++++++++++----------- utils/__init__.py | 2 +- utils/gui.py | 65 ++++++++++++++++++-------------------- 8 files changed, 69 insertions(+), 67 deletions(-) diff --git a/conf/experiment.py b/conf/experiment.py index b74c02b..b294cbb 100644 --- a/conf/experiment.py +++ b/conf/experiment.py @@ -1,4 +1,3 @@ - """ Configurations for the experiments and config groups, using hydra-zen. """ diff --git a/dataset/example.py b/dataset/example.py index 1c8e622..b67eba5 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -61,9 +61,9 @@ def _load( progress: Progress, job_id: TaskID, ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: - len = 3 if self._tiny else 20 - progress.update(job_id, total=len) - for _ in range(len): + length = 3 if self._tiny else 20 + progress.update(job_id, total=length) + for _ in range(length): progress.advance(job_id) sleep(0.001 if self._tiny else 0.1) return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8) diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index 97f064a..ff92bba 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -123,6 +123,7 @@ def __init__( ) self._split = split self._lazy = scd_lazy # TODO: Implement eager caching (rn the default is lazy) + # TODO: Refactor and reduce cyclomatic complexity argnames = inspect.getfullargspec(self.__class__.__init__).args found = False frame: FrameType | None = inspect.currentframe() diff --git a/pyproject.toml b/pyproject.toml index 615a051..0ad2b90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ exclude = [ line-length = 88 indent-width = 4 -# Assume Python 3.8 -target-version = "py38" +# Assume Python 3.11 +target-version = "py311" [tool.ruff.lint] # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. @@ -69,7 +69,7 @@ line-ending = "auto" # # This is currently disabled by default, but it is planned for this # to be opt-out in the future. -docstring-code-format = false +docstring-code-format = true # Set the line length limit used when formatting code snippets in # docstrings. diff --git a/src/base_tester.py b/src/base_tester.py index 2b0fb61..9e582c2 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -30,7 +30,7 @@ console = Console() -print = console.print +print = console.print # skipcq: PYL-W0603 class BaseTester(BaseTrainer): @@ -55,7 +55,7 @@ def __init__( _loss = training_loss self._gui = gui global print # skipcq: PYL-W0603 - print = self._gui.print + print = self._gui.print # skipcq: PYL-W0603 self._run_name = run_name self._model = model if model_ckpt_path is None: @@ -83,7 +83,8 @@ def _test_iteration( batch: Union[Tuple, List, Tensor], ) -> Tuple[Tensor, Dict[str, Tensor]]: """Evaluation procedure for one batch. We want to keep the code DRY and avoid - making mistakes, so this code calls the BaseTrainer._train_val_iteration() method. + making mistakes, so this code calls the BaseTrainer._train_val_iteration() + method. Args: batch: The batch to process. Returns: @@ -99,7 +100,8 @@ async def test( ) -> None: """Computes the average loss on the test set. Args: - visualize_every (int, optional): Visualize the model predictions every n batches. + visualize_every (int, optional): Visualize the model predictions every n + batches. Defaults to 0 (no visualization). """ _ = kwargs @@ -107,7 +109,7 @@ async def test( test_metrics: Dict[str, MeanMetric] = defaultdict(MeanMetric) self._model.eval() print(Text(f"[*] Testing {self._run_name}", style="bold green")) - """ ==================== Training loop for one epoch ==================== """ + # ==================== Training loop for one epoch ==================== pbar, update_loss_hook = self._gui.track_testing( self._data_loader, total=len(self._data_loader) ) @@ -124,7 +126,8 @@ async def test( if visualize_every > 0 and (i + 1) % visualize_every == 0: self._visualize(batch, i) - # TODO: Report metrics in a special panel? Then hang the GUI until the user is done. + # TODO: Report metrics in a special panel? Then hang the GUI until the user is + # done. print("=" * 81) print("==" + " " * 31 + " Test results " + " " * 31 + "==") print("=" * 81) diff --git a/src/base_trainer.py b/src/base_trainer.py index 4117be2..2288b5e 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -34,7 +34,7 @@ from utils.training import visualize_model_predictions console = Console() -print = console.print +print = console.print # skipcq: PYL-W0603 class BaseTrainer: @@ -77,7 +77,7 @@ def __init__( self._n_ctrl_c = 0 self._gui = gui global print # skipcq: PYL-W0603 - print = self._gui.print + print = self._gui.print # skipcq: PYL-W0603 if model_ckpt_path is not None: self._load_checkpoint(model_ckpt_path) signal.signal(signal.SIGINT, self._terminator) @@ -104,8 +104,9 @@ def _train_val_iteration( epoch: int, validation: bool = False, ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Training or validation procedure for one batch. We want to keep the code DRY and avoid - making mistakes, so write this code only once at the cost of many function calls! + """Training or validation procedure for one batch. We want to keep the code DRY + and avoid making mistakes, so write this code only once at the cost of many + function calls! Args: batch: The batch to process. Returns: @@ -136,7 +137,7 @@ def _train_epoch( epoch_loss: MeanMetric = MeanMetric() epoch_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) has_visualized = 0 - """ ==================== Training loop for one epoch ==================== """ + # ==================== Training loop for one epoch ==================== pbar, update_loss_hook = self._gui.track_training( self._train_loader, total=len(self._train_loader), @@ -188,7 +189,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: float: Average validation loss for the epoch. """ has_visualized = 0 - """ ==================== Validation loop for one epoch ==================== """ + # ==================== Validation loop for one epoch ==================== with torch.no_grad(): val_loss: MeanMetric = MeanMetric() val_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) @@ -212,7 +213,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: for k, v in loss_components.items(): val_loss_components[k].update(v.item()) update_loss_hook(val_loss.compute()) - """ ==================== Visualization ==================== """ + # ==================== Visualization ==================== if ( visualize and has_visualized < self._viz_n_samples @@ -235,8 +236,8 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: }, step=epoch, ) - # Set minimize_metric to a key in val_loss_components if you wish to minimize - # a specific metric instead of the validation loss: + # Set minimize_metric to a key in val_loss_components if you wish to + # minimize a specific metric instead of the validation loss: self._model_saver( epoch, mean_val_loss, @@ -268,7 +269,7 @@ async def train( ) self._viz_n_samples = visualize_n_samples self._gui.set_start_epoch(self._epoch) - """ ==================== Training loop ==================== """ + # ==================== Training loop ==================== last_val_loss = float("inf") for epoch in range(self._epoch, epochs): print(f"Epoch: {epoch}") @@ -294,7 +295,7 @@ async def train( last_val_loss = val_loss if self._scheduler is not None: await asyncio.to_thread(self._scheduler.step) - """ ==================== Plotting ==================== """ + # ==================== Plotting ==================== self._gui.plot(epoch, train_loss, last_val_loss) # , self._model_saver) await asyncio.to_thread( self._save_checkpoint, @@ -333,18 +334,19 @@ def _save_checkpoint(self, val_loss: float, ckpt_path: str, **kwargs) -> None: ) def _load_checkpoint(self, ckpt_path: str, model_only: bool = False) -> None: - """Loads the model and optimizer state from a checkpoint file. This method should remain in - this class because it should be extendable in classes inheriting from this class, instead - of being overwritten/modified. That would be a source of bugs and a bad practice. + """Loads the model and optimizer state from a checkpoint file. This method + should remain in this class because it should be extendable in classes + inheriting from this class, instead of being overwritten/modified. That would be + a source of bugs and a bad practice. Args: ckpt_path (str): The path to the checkpoint file. - model_only (bool): If True, only the model is loaded (useful for BaseTester). + model_only (bool): If True, only the model is loaded (useful for + BaseTester). Returns: None """ print(f"[*] Restoring from checkpoint: {ckpt_path}") ckpt = torch.load(ckpt_path) - # If the model was optimized with torch.optimize() we need to remove the "_orig_mod" # prefix: if "_orig_mod" in list(ckpt["model_ckpt"].keys())[0]: ckpt["model_ckpt"] = { @@ -355,7 +357,8 @@ def _load_checkpoint(self, ckpt_path: str, model_only: bool = False) -> None: except Exception: if project_conf.PARTIALLY_LOAD_MODEL_IF_NO_FULL_MATCH: print( - "[!] Partially loading model weights (no full match between model and checkpoint)" + "[!] Partially loading model weights " + + "(no full match between model and checkpoint)" ) self._model.load_state_dict(ckpt["model_ckpt"], strict=False) if not model_only: @@ -378,7 +381,8 @@ def _terminator(self, sig, frame): and self._n_ctrl_c == 0 ): print( - f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}. Press Ctrl+C again to abort." + f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}." + + " Press Ctrl+C again to abort." ) self._n_ctrl_c += 1 elif ( diff --git a/utils/__init__.py b/utils/__init__.py index 5564d43..933b4b4 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -13,7 +13,7 @@ # import sys import traceback -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional # import IPython import numpy as np diff --git a/utils/gui.py b/utils/gui.py index 39a6a83..0daceea 100644 --- a/utils/gui.py +++ b/utils/gui.py @@ -63,7 +63,12 @@ def __init__( classes: The CSS classes of the plotter widget. disabled: Whether the plotter widget is disabled or not. """ - super().__init__(name=name, id=id, classes=classes, disabled=disabled) + super().__init__( + name=name, + id=id, + classes=classes, + disabled=disabled, + ) self._title = title self._log_scale = use_log_scale self._train_losses: list[float] = [] @@ -171,8 +176,9 @@ def __init__(self): def update_loss_hook(self, loss: float) -> None: """Update the loss value in the progress bar.""" - # TODO: min_val_loss during validation, val_loss during training. Ideally the - # second parameter would be super flexible (use a dict then). + # TODO: min_val_loss during validation, val_loss during training. + # Ideally the second parameter would be super flexible (use a dict + # then). self._loss = loss class SeqWrapper(abc.Iterator, LossHook): @@ -235,24 +241,11 @@ def reset_hook(): plabel.update(self.DESCRIPTIONS[Task.IDLE]) wrapper = None - update_p, reset_p = ( - partial(update_hook), - partial(reset_hook), - ) + update_p, reset_p = partial(update_hook), partial(reset_hook) if isinstance(iterable, abc.Sequence): - wrapper = SeqWrapper( - iterable, - total, - update_p, - reset_p, - ) + wrapper = SeqWrapper(iterable, total, update_p, reset_p) elif isinstance(iterable, (abc.Iterator, DataLoader)): - wrapper = IteratorWrapper( - iterable, - total, - update_p, - reset_p, - ) + wrapper = IteratorWrapper(iterable, total, update_p, reset_p) else: raise ValueError( f"iterable must be a Sequence or an Iterator, got {type(iterable)}" @@ -264,7 +257,9 @@ def reset_hook(): class GUI(App): - """A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework.""" + """ + A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework. + """ TITLE = "Matchbox TUI" CSS_PATH = "style.css" @@ -327,11 +322,7 @@ def print(self, message: Any): if isinstance(message, (RenderableType, str)): logger.write( Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), + Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""), message, ), ) @@ -379,28 +370,31 @@ def print(self, message: Any): ) def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: - """Return an iterable that tracks the progress of the training process, and a progress bar - hook to update the loss value at each iteration.""" + """Return an iterable that tracks the progress of the training process, and a + progress bar hook to update the loss value at each iteration.""" return self.query_one(DatasetProgressBar).track_iterable( iterable, Task.TRAINING, total ) def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: - """Return an iterable that tracks the progress of the validation process, and a progress bar - hook to update the loss value at each iteration.""" + """Return an iterable that tracks the progress of the validation process, and a + progress bar hook to update the loss value at each iteration.""" return self.query_one(DatasetProgressBar).track_iterable( iterable, Task.VALIDATION, total ) def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: - """Return an iterable that tracks the progress of the testing process, and a progress bar - hook to update the loss value at each iteration.""" + """Return an iterable that tracks the progress of the testing process, and a + progress bar hook to update the loss value at each iteration.""" return self.query_one(DatasetProgressBar).track_iterable( iterable, Task.TESTING, total ) def plot( - self, epoch: int, train_loss: float, val_loss: Optional[float] = None + self, + epoch: int, + train_loss: float, + val_loss: Optional[float] = None, ) -> None: """Plot the training and validation losses for the current epoch.""" self.query_one(PlotterWidget).loading = False @@ -432,8 +426,8 @@ async def run_my_app(): await asyncio.sleep(0.1) await asyncio.sleep(2) mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) - # Somehow, the dataloader will crash if it's not forked when using multiprocessing along with - # Textual. + # Somehow, the dataloader will crash if it's not forked when using multiprocessing + # along with Textual. mp.set_start_method("fork") dataloader = DataLoader(mnist, 32, shuffle=True, num_workers=2) pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader)) @@ -444,7 +438,8 @@ async def run_my_app(): update_progress_loss(random()) gui.plot(epoch=i, train_loss=random(), val_loss=random()) gui.print( - f"[{i+1}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!" + f"[{i+1}/{len(dataloader)}]: " + + "We can also iterate over PyTorch dataloaders!" ) if i == 0: gui.print(batch) From f150ad125d2b8163eb07f4e441aca2f256afc3b9 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 17:45:07 +0100 Subject: [PATCH 32/38] Format and ignore print redef --- bootstrap/launch_experiment.py | 19 ++++++++----------- src/base_tester.py | 4 ++-- src/base_trainer.py | 4 ++-- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index 7db7fd6..1d12eb1 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -43,7 +43,7 @@ console = Console() -# =========================================== Printing =========================================== +# ================================= Printing ===================================== def print_config(run_name: str, exp_conf: str) -> None: # Generate a random ANSI code: run_color = f"color({hash(run_name) % 255})" @@ -72,7 +72,8 @@ def print_model(model: torch.nn.Module) -> None: Group( Pretty(model), f"Number of parameters: {sum(p.numel() for p in model.parameters())}", - f"Number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}", + "Number of trainable parameters: " + + f"{sum(p.numel() for p in model.parameters() if p.requires_grad)}", ), title="Model architecture", expand=False, @@ -82,7 +83,7 @@ def print_model(model: torch.nn.Module) -> None: console.rule() -# ================================================================================================= +# ================================================================================== def init_wandb( @@ -139,8 +140,8 @@ def launch_experiment( model_inst = to_cuda_(parallelize_model(model_inst)) training_loss_inst = to_cuda_(make_training_loss(run.training_mode, training_loss)) - # Somehow, the dataloader will crash if it's not forked when using multiprocessing along with - # Textual. + # Somehow, the dataloader will crash if it's not forked when using multiprocessing + # along with Textual. mp.set_start_method("fork") train_loader_inst, val_loader_inst, test_loader_inst = make_dataloaders( data_loader, @@ -186,9 +187,7 @@ async def launch_with_async_gui(): opt=opt_inst, scheduler=scheduler_inst, **common_args, - **asdict( - run - ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) + **asdict(run), ).train( epochs=run.epochs, val_every=run.val_every, @@ -206,9 +205,7 @@ async def launch_with_async_gui(): **common_args, ).test( visualize_every=run.viz_every, - **asdict( - run - ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) + **asdict(run), ) gui.print("Testing finished!") _ = await task diff --git a/src/base_tester.py b/src/base_tester.py index 9e582c2..104ac49 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -30,7 +30,7 @@ console = Console() -print = console.print # skipcq: PYL-W0603 +print = console.print # skipcq: PYL-W0603, PYL-W0622 class BaseTester(BaseTrainer): @@ -55,7 +55,7 @@ def __init__( _loss = training_loss self._gui = gui global print # skipcq: PYL-W0603 - print = self._gui.print # skipcq: PYL-W0603 + print = self._gui.print # skipcq: PYL-W0603, PYL-W0622 self._run_name = run_name self._model = model if model_ckpt_path is None: diff --git a/src/base_trainer.py b/src/base_trainer.py index 2288b5e..365b216 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -34,7 +34,7 @@ from utils.training import visualize_model_predictions console = Console() -print = console.print # skipcq: PYL-W0603 +print = console.print # skipcq: PYL-W0603, PYL-W0622 class BaseTrainer: @@ -77,7 +77,7 @@ def __init__( self._n_ctrl_c = 0 self._gui = gui global print # skipcq: PYL-W0603 - print = self._gui.print # skipcq: PYL-W0603 + print = self._gui.print # skipcq: PYL-W0603, PYL-W0622 if model_ckpt_path is not None: self._load_checkpoint(model_ckpt_path) signal.signal(signal.SIGINT, self._terminator) From 7e9743204927513de1009f00769cafa4d7e494c1 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 17:45:16 +0100 Subject: [PATCH 33/38] Format and reduce complexity --- dataset/mixins/__init__.py | 112 ++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 51 deletions(-) diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index ff92bba..1eb6a10 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -1,8 +1,9 @@ """ Base dataset. -In this file you may implement other base datasets that share the same characteristics and which -need the same data loading + transformation pipeline. The specificities of loading the data or -transforming it may be extended through class inheritance in a specific dataset file. +In this file you may implement other base datasets that share the same characteristics +and which need the same data loading + transformation pipeline. The specificities of +loading the data or transforming it may be extended through class inheritance in a +specific dataset file. """ import abc @@ -17,7 +18,7 @@ from multiprocessing.pool import Pool from os import cpu_count from types import FrameType -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple from hydra.utils import get_original_cwd from rich.progress import Progress, TaskID @@ -123,7 +124,30 @@ def __init__( ) self._split = split self._lazy = scd_lazy # TODO: Implement eager caching (rn the default is lazy) - # TODO: Refactor and reduce cyclomatic complexity + fingerprints, flush, not_found = self._compute_fingerprints() + if flush: + shutil.rmtree(self._cache_dir, ignore_errors=True) + os.makedirs(self._cache_dir, exist_ok=True) + + super().__init__( + dataset_root, + dataset_name, + augment, + normalize, + split, + seed, + debug, + tiny, + progress, + job_id, + **kwargs, + ) + if flush or not_found: + with open(osp.join(self._cache_dir, "fingerprints"), "w") as f: + for k, v in fingerprints.items(): + f.write(f"{k}:{v}\n") + + def _compute_fingerprints(self) -> Tuple[Dict[str, str], bool, bool]: argnames = inspect.getfullargspec(self.__class__.__init__).args found = False frame: FrameType | None = inspect.currentframe() @@ -145,16 +169,17 @@ def __init__( if k in argnames and k not in ["self", "tiny", "scd_lazy"] } # TODO: Recursively hash the source code for user's methods in self.__class__ - # NOTE: getsource() won't work if I have a decorator that wraps the method. I think it's - # best to keep this behaviour and not use decorators. - fingerprint_els = { + # NOTE: getsource() won't work if I have a decorator that wraps the method. I + # think it's best to keep this behaviour and not use decorators. + hashers = { "code": hashlib.new("md5", usedforsecurity=False), "args": hashlib.new("md5", usedforsecurity=False), } tree = ast.parse(inspect.getsource(self.__class__)) - fingerprint_els["code"].update(ast.dump(tree).encode()) - fingerprint_els["args"].update(pickle.dumps(argvalues)) - for k, v in fingerprint_els.items(): + hashers["code"].update(ast.dump(tree).encode()) + hashers["args"].update(pickle.dumps(argvalues)) + fingerprint_els: Dict[str, str] = {} + for k, v in hashers.items(): fingerprint_els[k] = v.hexdigest() # type: ignore mismatches, not_found = {k: True for k in fingerprint_els}, True if osp.isfile(osp.join(self._cache_dir, "fingerprints")): @@ -173,35 +198,15 @@ def __init__( elif mismatch_list != []: while flush not in ["y", "n"]: flush = input( - f"Fingerprint mismatch in {' and '.join(mismatch_list)}, flush cache? (y/n) " + f"Fingerprint mismatch in {' and '.join(mismatch_list)}, " + + "flush cache? (y/n) " ).lower() flush = flush.lower().strip() == "y" if not flush: print( "[!] Warning: Fingerprint mismatch, but cache will not be flushed." ) - - if flush: - shutil.rmtree(self._cache_dir, ignore_errors=True) - os.makedirs(self._cache_dir, exist_ok=True) - - super().__init__( - dataset_root, - dataset_name, - augment, - normalize, - split, - seed, - debug, - tiny, - progress, - job_id, - **kwargs, - ) - if flush or not_found: - with open(osp.join(self._cache_dir, "fingerprints"), "w") as f: - for k, v in fingerprint_els.items(): - f.write(f"{k}:{v}\n") + return fingerprint_els, flush, not_found def _get_raw_elements_hook(self, *args, **kwargs) -> Sequence[Any]: # TODO: Investigate slowness issues @@ -219,8 +224,8 @@ def __init__(self, cache_dir: str, seq_len: int, seq_type: str): if len(self._cache_paths) != seq_len: raise ValueError( - f"Cache info file {osp.join(cache_dir, 'info.txt')} does not match the number of " - + f"cache files in {cache_dir}. " + f"Cache info file {osp.join(cache_dir, 'info.txt')} " + + f"does not match the number of cache files in {cache_dir}. " + "This may be due to an interrupted dataset computation. " + "Please manually flush the cash to recompute." ) @@ -233,8 +238,8 @@ def __init__(self, cache_dir: str, seq_len: int, seq_type: str): if el_type_str != seq_type: raise ValueError( - f"Cache info file {osp.join(cache_dir, 'info.txt')} does not match the type of " - + f"cache files in {cache_dir}. " + f"Cache info file {osp.join(cache_dir, 'info.txt')} " + + f"does not match the type of cache files in {cache_dir}. " + "This may be due to an interrupted dataset computation. " + "Please manually flush the cash to recompute." ) @@ -247,8 +252,9 @@ def __getitem__(self, idx): raise IndexError return compressed_read(self._cache_paths[idx]) - # This hooks onto the user's _get_raw_elements method and overrides it if a cache entry is - # found. If not it just calls the user's _get_raw_elements method. + # This hooks onto the user's _get_raw_elements method and overrides it if a + # cache entry is found. If not it just calls the user's _get_raw_elements + # method. # return self._get_raw_elements(*args, **kwargs) path = osp.join(self._cache_dir, "raw_elements") try: @@ -269,8 +275,9 @@ def __getitem__(self, idx): except FileNotFoundError: if not hasattr(self, "_get_raw_elements"): raise NotImplementedError( - "SafeCacheDatasetMixin._get_raw_elements() is called but the user has not " - + f"implemented a _get_raw_elements method in {self.__class__.__name__}." + "SafeCacheDatasetMixin._get_raw_elements() is called " + + "but the user has not implemented a _get_raw_elements " + + f"method in {self.__class__.__name__}." ) # Compute them: raw_elements: Sequence[Any] = self._get_raw_elements(*args, **kwargs) # type: ignore @@ -292,8 +299,8 @@ def __getitem__(self, idx): return raw_elements def _load_hook(self, *args, **kwargs) -> Tuple[int, Any, Any]: - # This hooks onto the user's _load method and overrides it if a cache entry is found. If - # not it just calls the user's _load method. + # This hooks onto the user's _load method and overrides it if a cache entry is + # found. If not it just calls the user's _load method. idx = args[1] cache_path = osp.join(self._cache_dir, f"{idx:04d}.pkl") if osp.isfile(cache_path): @@ -301,13 +308,15 @@ def _load_hook(self, *args, **kwargs) -> Tuple[int, Any, Any]: else: if not hasattr(self, "_load"): raise NotImplementedError( - "SafeCacheDatasetMixin._load() is called but the user has not implemented " - + f"a _load method in {self.__class__.__name__}." + "SafeCacheDatasetMixin._load() is called but " + + "the user has not implemented a _load method " + + f"in {self.__class__.__name__}." ) _idx, sample, label = self._load(*args, **kwargs) # type: ignore if _idx != idx: raise ValueError( - "The _load method returned an index different from the one requested." + "The _load method returned an index different " + + "from the one requested." ) return idx, sample, label @@ -330,7 +339,7 @@ def _register_sample_label( label: Any, memory_samples: List[Any], memory_labels: List[Any], - ): + ) -> None: if hasattr(super(), "_register_sample_label"): raise Exception( "SafeCacheDatasetMixin._register_sample_label() is overriden. " @@ -341,7 +350,8 @@ def _register_sample_label( if not osp.isfile(cache_path): if sample is None: raise ValueError( - "The _load_hook method returned sample=None, but no cache entry was found. " + "The _load_hook method returned sample=None, " + + "but no cache entry was found. " ) compressed_write(cache_path, (sample, label)) memory_samples.insert(idx, cache_path) @@ -444,9 +454,9 @@ def _load_sample_label(self, idx: int) -> Tuple[Any, Any]: return super()._load_sample_label(idx) # type: ignore return self._samples[idx], self._labels[idx] - def _register_sample_label(self, idx: int, sample: Any, label: Any): + def _register_sample_label(self, idx: int, sample: Any, label: Any) -> None: if hasattr(super(), "_register_sample_label"): - return super()._register_sample_label( # type: ignore + super()._register_sample_label( # type: ignore idx, sample, label, self._samples, self._labels ) if isinstance(sample, (List, Tuple)) or isinstance(label, (List, Tuple)): From e171ec483237ae4ebd0e2221185f233f6ede3e49 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 17:56:04 +0100 Subject: [PATCH 34/38] Try and reduce complexity further --- bootstrap/__init__.py | 0 dataset/mixins/__init__.py | 40 ++++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 19 deletions(-) create mode 100644 bootstrap/__init__.py diff --git a/bootstrap/__init__.py b/bootstrap/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index 1eb6a10..0f70b81 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -124,7 +124,24 @@ def __init__( ) self._split = split self._lazy = scd_lazy # TODO: Implement eager caching (rn the default is lazy) - fingerprints, flush, not_found = self._compute_fingerprints() + fingerprints, not_found, mismatch_list = self._compute_fingerprints() + + flush = False + if not_found: + print("No fingerprint found, flushing cache.") + flush = True + elif mismatch_list != []: + while flush not in ["y", "n"]: + flush = input( + f"Fingerprint mismatch in {' and '.join(mismatch_list)}, " + + "flush cache? (y/n) " + ).lower() + flush = flush.lower().strip() == "y" + if not flush: + print( + "[!] Warning: Fingerprint mismatch, but cache will not be flushed." + ) + if flush: shutil.rmtree(self._cache_dir, ignore_errors=True) os.makedirs(self._cache_dir, exist_ok=True) @@ -142,12 +159,12 @@ def __init__( job_id, **kwargs, ) - if flush or not_found: + if flush: with open(osp.join(self._cache_dir, "fingerprints"), "w") as f: for k, v in fingerprints.items(): f.write(f"{k}:{v}\n") - def _compute_fingerprints(self) -> Tuple[Dict[str, str], bool, bool]: + def _compute_fingerprints(self) -> Tuple[Dict[str, str], bool, List[str]]: argnames = inspect.getfullargspec(self.__class__.__init__).args found = False frame: FrameType | None = inspect.currentframe() @@ -191,22 +208,7 @@ def _compute_fingerprints(self) -> Tuple[Dict[str, str], bool, bool]: mismatches[key] = value.strip() != fingerprint_els[key] mismatch_list = [k for k, v in mismatches.items() if v] - flush = False - if not_found: - print("No fingerprint found, flushing cache.") - flush = True - elif mismatch_list != []: - while flush not in ["y", "n"]: - flush = input( - f"Fingerprint mismatch in {' and '.join(mismatch_list)}, " - + "flush cache? (y/n) " - ).lower() - flush = flush.lower().strip() == "y" - if not flush: - print( - "[!] Warning: Fingerprint mismatch, but cache will not be flushed." - ) - return fingerprint_els, flush, not_found + return fingerprint_els, not_found, mismatch_list def _get_raw_elements_hook(self, *args, **kwargs) -> Sequence[Any]: # TODO: Investigate slowness issues From 24e3b359deb72fc60e8ace189570faff850aeb61 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 17:59:10 +0100 Subject: [PATCH 35/38] Fix dataset examples --- dataset/example.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/dataset/example.py b/dataset/example.py index b67eba5..e1408ca 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -93,16 +93,17 @@ def __init__( super().__init__( dataset_root, dataset_name, - augment, - normalize, split, seed, - debug, - tiny, progress, job_id, (img_dim, img_dim) if img_dim is not None else None, + augment=augment, + normalize=normalize, + debug=debug, + tiny=tiny, ) + # TODO: class MultiProcessingWithCachingExampleDataset(ImageDataset): # TODO @@ -126,13 +127,14 @@ def __init__( super().__init__( dataset_root, dataset_name, - augment, - normalize, split, seed, - debug, - tiny, progress, job_id, (img_dim, img_dim) if img_dim is not None else None, + augment=augment, + normalize=normalize, + debug=debug, + tiny=tiny, ) + # TODO: From bc5fa955c3d8acb3c8fb7d1b86ed5f27a231fb05 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 18:03:35 +0100 Subject: [PATCH 36/38] Fix line length --- dataset/mixins/__init__.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py index 0f70b81..98162fd 100644 --- a/dataset/mixins/__init__.py +++ b/dataset/mixins/__init__.py @@ -282,7 +282,9 @@ def __getitem__(self, idx): + f"method in {self.__class__.__name__}." ) # Compute them: - raw_elements: Sequence[Any] = self._get_raw_elements(*args, **kwargs) # type: ignore + raw_elements: Sequence[Any] = ( + self._get_raw_elements(*args, **kwargs), # type: ignore + ) type_str = "unknown" try: type_str = type(raw_elements[0]) @@ -440,7 +442,12 @@ def _get_raw_elements_hook( self, dataset_root: str, tiny: bool, split: str, seed: int ): if hasattr(super(), "_get_raw_elements_hook"): - return super()._get_raw_elements_hook(dataset_root, tiny, split, seed) # type: ignore + return super()._get_raw_elements_hook( # type: ignore + dataset_root, + tiny, + split, + seed, + ) return self._get_raw_elements(dataset_root, tiny, split, seed) def _load_hook_unpack(self, args): From 6e75ec4effca12270de6d9ab403ea192835dd141 Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 18:03:42 +0100 Subject: [PATCH 37/38] Add missing deps --- .github/workflows/python-app.yml | 4 ++-- requirements.txt | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 973e707..ad974b3 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -26,8 +26,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff pytest pyright types-PyYAML types-tqdm - pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install --upgrade ruff pytest pyright types-PyYAML types-tqdm + pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with ruff run: | diff --git a/requirements.txt b/requirements.txt index b16827f..825ee8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,5 @@ blosc2 ipython neovim rich +textual +textual_plotext From 7eadbe00b298565e04f5b5e162a96d89e52d053a Mon Sep 17 00:00:00 2001 From: Theo Date: Sat, 27 Jul 2024 18:19:36 +0100 Subject: [PATCH 38/38] Lint --- .pre-commit-config.yaml | 2 +- bootstrap/factories.py | 6 +++--- bootstrap/launch_experiment.py | 2 +- conf/experiment.py | 25 +++++++++++++------------ src/base_trainer.py | 4 ++-- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 18b99e4..01e3ff7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.7 + rev: v0.5.5 hooks: # Run the linter. - id: ruff diff --git a/bootstrap/factories.py b/bootstrap/factories.py index 9db036e..a705593 100644 --- a/bootstrap/factories.py +++ b/bootstrap/factories.py @@ -111,14 +111,14 @@ def parallelize_model(model: torch.nn.Module) -> torch.nn.Module: def make_optimizer( - optimizer_partial: Partial[torch.optim.Optimizer], model: torch.nn.Module -) -> torch.optim.Optimizer: + optimizer_partial: Partial[torch.optim.optimizer.Optimizer], model: torch.nn.Module +) -> torch.optim.optimizer.Optimizer: return optimizer_partial(model.parameters()) def make_scheduler( scheduler_partial: Partial[torch.optim.lr_scheduler.LRScheduler], - optimizer: torch.optim.Optimizer, + optimizer: torch.optim.optimizer.Optimizer, epochs: int, ) -> torch.optim.lr_scheduler.LRScheduler: scheduler = scheduler_partial( diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index 1d12eb1..63e51c5 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -108,7 +108,7 @@ def init_wandb( def launch_experiment( run, # type: ignore data_loader: Partial[DataLoader[Any]], - optimizer: Partial[torch.optim.Optimizer], + optimizer: Partial[torch.optim.optimizer.Optimizer], scheduler: Partial[torch.optim.lr_scheduler.LRScheduler], trainer: Partial[BaseTrainer], tester: Partial[BaseTester], diff --git a/conf/experiment.py b/conf/experiment.py index b294cbb..a149442 100644 --- a/conf/experiment.py +++ b/conf/experiment.py @@ -46,10 +46,11 @@ zen_partial=True, populate_full_signature=False ) -""" ================== Dataset ================== """ +# ================== Dataset ================== -# Dataclasses are a great and simple way to define a base config group with default values. +# Dataclasses are a great and simple way to define a base config group with default +# values. @dataclass class ExampleDatasetConf: dataset_name: str = "image_dataset" @@ -86,7 +87,7 @@ class ExampleDatasetConf: name="image_a_tiny", ) -""" ================== Dataloader & sampler ================== """ +# ================== Dataloader & sampler ================== @dataclass @@ -107,12 +108,12 @@ class DataloaderConf: persistent_workers: bool = False -""" ================== Model ================== """ +# ================== Model ================== # Pre-set the group for store's model entries model_store = store(group="model") -# Not that encoder_input_dim depend on dataset.img_dim, so we need to use a partial to set them in -# the launch_experiment function. +# Not that encoder_input_dim depend on dataset.img_dim, so we need to use a partial to +# set them in the launch_experiment function. model_store( pbuilds( ExampleModel, @@ -134,7 +135,7 @@ class DataloaderConf: name="model_b", ) -""" ================== Losses ================== """ +# ================== Losses ================== training_loss_store = store(group="training_loss") training_loss_store( pbuilds( @@ -145,7 +146,7 @@ class DataloaderConf: ) -""" ================== Optimizer ================== """ +# ================== Optimizer ================== @dataclass @@ -157,21 +158,21 @@ class Optimizer: opt_store = store(group="optimizer") opt_store( pbuilds( - torch.optim.Adam, + torch.optim.adam.Adam, builds_bases=(Optimizer,), ), name="adam", ) opt_store( pbuilds( - torch.optim.SGD, + torch.optim.sgd.SGD, builds_bases=(Optimizer,), ), name="sgd", ) -""" ================== Scheduler ================== """ +# ================== Scheduler ================== sched_store = store(group="scheduler") sched_store( pbuilds( @@ -197,7 +198,7 @@ class Optimizer: name="cosine", ) -""" ================== Experiment ================== """ +# ================== Experiment ================== @dataclass diff --git a/src/base_trainer.py b/src/base_trainer.py index 365b216..442f4c2 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -23,7 +23,7 @@ from rich.text import Text from torch import Tensor from torch.nn import Module -from torch.optim import Optimizer +from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from torchmetrics import MeanMetric @@ -54,7 +54,7 @@ def __init__( """Base trainer class. Args: model (torch.nn.Module): Model to train. - opt (torch.optim.Optimizer): Optimizer to use. + opt (torch.optim.optimizer.Optimizer): Optimizer to use. train_loader (torch.utils.data.DataLoader): Training dataloader. val_loader (torch.utils.data.DataLoader): Validation dataloader. """