diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 973e707..ad974b3 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -26,8 +26,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff pytest pyright types-PyYAML types-tqdm - pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install --upgrade ruff pytest pyright types-PyYAML types-tqdm + pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with ruff run: | diff --git a/.gitignore b/.gitignore index 85e1ea6..0ae14f3 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ tags.* **/*.pickle models/ FIGURES/ +.mypy_cache/ +.ruff_cache/ +utils/data/MNIST/raw diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 18b99e4..01e3ff7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.7 + rev: v0.5.5 hooks: # Run the linter. - id: ruff diff --git a/README.md b/README.md index be3da25..c33e01b 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) [![pre-commit](https://img.shields.io/badge/Pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) [![Vim](https://img.shields.io/badge/VIM%20ready!-forestgreen?style=for-the-badge&logo=vim)](https://github.com/DubiousCactus/bells-and-whistles/blob/main/.vimspector.json) - + A batteries-included PyTorch template with a terminal display that stays out of your way! Click on [Use this @@ -91,6 +91,9 @@ This template writes the necessary boilerplate for you, while staying out of you ``` my-pytorch-project/ + bootstrap/ + factories.py <-- Factory functions for instantiating models, optimizers, etc. + launch_experiment.py <-- Bootstraps the experiment and launches the training/testing loop conf/ experiment.py <-- experiment-level configurations project.py <-- project-level constants @@ -118,10 +121,9 @@ my-pytorch-project/ helpers.py <-- high-level utilities training.py <-- training-related utilities vendor/ - . <-- third-party code goes here - launch_experiment.py <-- Builds the trainer and tester, instantiates all partials, etc. - train.py <-- training entry point (calls launch_experiment) - test.py <-- testing entry point (calls launch_experiment) + . <-- third-party code goes here (github submodules, etc.) + train.py <-- training entry point (calls bootstrap/launch_experiment) + test.py <-- testing entry point (calls bootstrap/launch_experiment) ``` ## Setting up diff --git a/bootstrap/__init__.py b/bootstrap/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bootstrap/factories.py b/bootstrap/factories.py new file mode 100644 index 0000000..a705593 --- /dev/null +++ b/bootstrap/factories.py @@ -0,0 +1,138 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 +# +# Copyright © 2024 Théo Morales +# +# Distributed under terms of the MIT license. + +""" +All factories. +""" + +from typing import Any, Dict, Optional, Tuple + +import torch +from hydra_zen import just +from hydra_zen.typing import Partial +from rich.console import Console, Group +from rich.live import Live +from rich.panel import Panel +from rich.progress import Progress, TaskID +from torch.utils.data import DataLoader, Dataset + +from conf import project as project_conf +from model import TransparentDataParallel + +console = Console() + + +def make_datasets( + training_mode: bool, seed: int, dataset_partial: Partial[Dataset[Any]] +) -> Tuple[Optional[Dataset[Any]], Optional[Dataset[Any]], Optional[Dataset[Any]]]: + datasets: Dict[str, Optional[Dataset[Any]]] = { + "train": None, + "val": None, + "test": None, + } + status = console.status("Loading dataset...", spinner="runner") + progress = Progress(transient=False) + with Live(Panel(Group(status, progress), title="Loading datasets")): + splits = ("train", "val") if training_mode else ("test",) + for split in splits: + status.update(f"Loading {split} dataset...") + job_id: TaskID = progress.add_task(f"Processing {split} split...") + aug = {"augment": False} if split == "test" else {} + datasets[split] = dataset_partial( + split=split, seed=seed, progress=progress, job_id=job_id, **aug + ) + return datasets["train"], datasets["val"], datasets["test"] + + +def make_dataloaders( + data_loader_partial: Partial[DataLoader[Dataset[Any]]], + train_dataset: Optional[Dataset[Any]], + val_dataset: Optional[Dataset[Any]], + test_dataset: Optional[Dataset[Any]], + training_mode: bool, + seed: int, +) -> Tuple[ + Optional[DataLoader[Dataset[Any]]], + Optional[DataLoader[Dataset[Any]]], + Optional[DataLoader[Dataset[Any]]], +]: + generator = None + if project_conf.REPRODUCIBLE: + generator = torch.Generator() + generator.manual_seed(seed) + + train_loader_inst: Optional[DataLoader[Any]] = None + val_loader_inst: Optional[DataLoader[Dataset[Any]]] = None + test_loader_inst: Optional[DataLoader[Any]] = None + if training_mode: + if train_dataset is None or val_dataset is None: + raise ValueError( + "train_dataset and val_dataset must be defined in training mode!" + ) + train_loader_inst = data_loader_partial(train_dataset, generator=generator) + val_loader_inst = data_loader_partial( + val_dataset, generator=generator, shuffle=False, drop_last=False + ) + else: + if test_dataset is None: + raise ValueError("test_dataset must be defined in testing mode!") + test_loader_inst = data_loader_partial( + test_dataset, generator=generator, shuffle=False, drop_last=False + ) + return train_loader_inst, val_loader_inst, test_loader_inst + + +def make_model( + model_partial: Partial[torch.nn.Module], dataset: Partial[Dataset[Any]] +) -> torch.nn.Module: + with console.status("Loading model...", spinner="runner"): + model_inst = model_partial( + encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore + ) # Use just() to get the config out of the Zen-Partial + return model_inst + + +def parallelize_model(model: torch.nn.Module) -> torch.nn.Module: + console.print( + f"[*] Number of GPUs: {torch.cuda.device_count()}", + style="bold cyan", + ) + if torch.cuda.device_count() > 1: + console.print( + f"-> Using {torch.cuda.device_count()} GPUs!", + style="bold cyan", + ) + model = TransparentDataParallel(model) + return model + + +def make_optimizer( + optimizer_partial: Partial[torch.optim.optimizer.Optimizer], model: torch.nn.Module +) -> torch.optim.optimizer.Optimizer: + return optimizer_partial(model.parameters()) + + +def make_scheduler( + scheduler_partial: Partial[torch.optim.lr_scheduler.LRScheduler], + optimizer: torch.optim.optimizer.Optimizer, + epochs: int, +) -> torch.optim.lr_scheduler.LRScheduler: + scheduler = scheduler_partial( + optimizer + ) # TODO: less hacky way to set T_max for CosineAnnealingLR? + if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR): + scheduler.T_max = epochs + return scheduler + + +def make_training_loss( + training_mode: bool, training_loss_partial: Partial[torch.nn.Module] +): + training_loss: Optional[torch.nn.Module] = None + if training_mode: + training_loss = training_loss_partial() + return training_loss diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py new file mode 100644 index 0000000..63e51c5 --- /dev/null +++ b/bootstrap/launch_experiment.py @@ -0,0 +1,213 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 +# +# Copyright © 2023 Théo Morales +# +# Distributed under terms of the MIT license. + + +import asyncio +import os +from dataclasses import asdict +from time import sleep +from typing import Any + +import hydra_zen +import torch +import torch.multiprocessing as mp +import wandb +import yaml +from hydra.core.hydra_config import HydraConfig +from hydra_zen.typing import Partial +from rich.console import Console, Group +from rich.panel import Panel +from rich.pretty import Pretty +from rich.syntax import Syntax +from torch.utils.data import DataLoader, Dataset + +from bootstrap.factories import ( + make_dataloaders, + make_datasets, + make_model, + make_optimizer, + make_scheduler, + make_training_loss, + parallelize_model, +) +from conf import project as project_conf +from src.base_tester import BaseTester +from src.base_trainer import BaseTrainer +from utils import load_model_ckpt, to_cuda_ +from utils.gui import GUI + +console = Console() + + +# ================================= Printing ===================================== +def print_config(run_name: str, exp_conf: str) -> None: + # Generate a random ANSI code: + run_color = f"color({hash(run_name) % 255})" + background_color = f"color({(hash(run_name) + 128) % 255})" + console.print( + f"Running {run_name}", + style=f"bold {run_color} on {background_color}", + justify="center", + ) + console.rule() + console.print( + Panel( + Syntax( + exp_conf, lexer="yaml", dedent=True, word_wrap=False, theme="dracula" + ), + title="Experiment configuration", + expand=False, + ), + overflow="ellipsis", + ) + + +def print_model(model: torch.nn.Module) -> None: + console.print( + Panel( + Group( + Pretty(model), + f"Number of parameters: {sum(p.numel() for p in model.parameters())}", + "Number of trainable parameters: " + + f"{sum(p.numel() for p in model.parameters() if p.requires_grad)}", + ), + title="Model architecture", + expand=False, + ), + overflow="ellipsis", + ) + console.rule() + + +# ================================================================================== + + +def init_wandb( + run_name: str, + model: torch.nn.Module, + exp_conf: str, + log="gradients", + log_graph=False, +) -> None: + if project_conf.USE_WANDB: + with console.status("Initializing Weights & Biases...", spinner="moon"): + # exp_conf is a string, so we need to load it back to a dict: + exp_conf = yaml.safe_load(exp_conf) + wandb.init( # type: ignore + project=project_conf.PROJECT_NAME, + name=run_name, + config=exp_conf, + ) + wandb.watch(model, log=log, log_graph=log_graph) # type: ignore + + +def launch_experiment( + run, # type: ignore + data_loader: Partial[DataLoader[Any]], + optimizer: Partial[torch.optim.optimizer.Optimizer], + scheduler: Partial[torch.optim.lr_scheduler.LRScheduler], + trainer: Partial[BaseTrainer], + tester: Partial[BaseTester], + dataset: Partial[Dataset[Any]], + model: Partial[torch.nn.Module], + training_loss: Partial[torch.nn.Module], +): + run_name = os.path.basename(HydraConfig.get().runtime.output_dir) + exp_conf = hydra_zen.to_yaml( + dict( + run_conf=run, + dataset=dataset, + model=model, + optimizer=optimizer, + scheduler=scheduler, + training_loss=training_loss, + ) + ) + print_config(run_name, exp_conf) + + # ============ Partials instantiation ============ + model_inst = make_model(model, dataset) + print_model(model_inst) + train_dataset, val_dataset, test_dataset = make_datasets( + run.training_mode, run.seed, dataset + ) + opt_inst = make_optimizer(optimizer, model_inst) + scheduler_inst = make_scheduler(scheduler, opt_inst, run.epochs) + model_inst = to_cuda_(parallelize_model(model_inst)) + training_loss_inst = to_cuda_(make_training_loss(run.training_mode, training_loss)) + + # Somehow, the dataloader will crash if it's not forked when using multiprocessing + # along with Textual. + mp.set_start_method("fork") + train_loader_inst, val_loader_inst, test_loader_inst = make_dataloaders( + data_loader, + train_dataset, + val_dataset, + test_dataset, + run.training_mode, + run.seed, + ) + init_wandb(run_name, model_inst, exp_conf) + + # ============ Training ============ + console.print( + "Launching TUI...", + style="bold cyan", + ) + sleep(1) + + async def launch_with_async_gui(): + gui = GUI(run_name, project_conf.LOG_SCALE_PLOT) + task = asyncio.create_task(gui.run_async()) + while not gui.is_running: + await asyncio.sleep(0.01) # Wait for the app to start up + model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode) + common_args = dict( + run_name=run_name, + model=model_inst, + model_ckpt_path=model_ckpt_path, + training_loss=training_loss_inst, + gui=gui, + ) + if run.training_mode: + gui.print("Training started!") + if training_loss_inst is None: + raise ValueError("training_loss must be defined in training mode!") + if val_loader_inst is None or train_loader_inst is None: + raise ValueError( + "val_loader and train_loader must be defined in training mode!" + ) + await trainer( + train_loader=train_loader_inst, + val_loader=val_loader_inst, + opt=opt_inst, + scheduler=scheduler_inst, + **common_args, + **asdict(run), + ).train( + epochs=run.epochs, + val_every=run.val_every, + visualize_every=run.viz_every, + visualize_train_every=run.viz_train_every, + visualize_n_samples=run.viz_num_samples, + ) + gui.print("Training finished!") + else: + gui.print("Testing started!") + if test_loader_inst is None: + raise ValueError("test_loader must be defined in testing mode!") + await tester( + data_loader=test_loader_inst, + **common_args, + ).test( + visualize_every=run.viz_every, + **asdict(run), + ) + gui.print("Testing finished!") + _ = await task + + asyncio.run(launch_with_async_gui()) diff --git a/conf/experiment.py b/conf/experiment.py index 3c8f81f..a149442 100644 --- a/conf/experiment.py +++ b/conf/experiment.py @@ -1,10 +1,3 @@ -#! /usr/bin/env python3 -# vim:fenc=utf-8 -# -# Copyright © 2023 Théo Morales -# -# Distributed under terms of the MIT license. - """ Configurations for the experiments and config groups, using hydra-zen. """ @@ -28,8 +21,8 @@ from unique_names_generator import get_random_name from unique_names_generator.data import ADJECTIVES, NAMES -from dataset.example import ExampleDataset -from launch_experiment import launch_experiment +from bootstrap.launch_experiment import launch_experiment +from dataset.example import SingleProcessingExampleDataset from model.example import ExampleModel from src.base_tester import BaseTester from src.base_trainer import BaseTrainer @@ -53,10 +46,11 @@ zen_partial=True, populate_full_signature=False ) -""" ================== Dataset ================== """ +# ================== Dataset ================== -# Dataclasses are a great and simple way to define a base config group with default values. +# Dataclasses are a great and simple way to define a base config group with default +# values. @dataclass class ExampleDatasetConf: dataset_name: str = "image_dataset" @@ -65,18 +59,19 @@ class ExampleDatasetConf: normalize: bool = True augment: bool = False debug: bool = False - img_dim: int = ExampleDataset.IMG_SIZE[0] + img_dim: int = SingleProcessingExampleDataset.IMG_SIZE[0] # Pre-set the group for store's dataset entries dataset_store = store(group="dataset") dataset_store( - pbuilds(ExampleDataset, builds_bases=(ExampleDatasetConf,)), name="image_a" + pbuilds(SingleProcessingExampleDataset, builds_bases=(ExampleDatasetConf,)), + name="image_a", ) dataset_store( pbuilds( - ExampleDataset, + SingleProcessingExampleDataset, builds_bases=(ExampleDatasetConf,), dataset_root="data/b", img_dim=64, @@ -85,14 +80,14 @@ class ExampleDatasetConf: ) dataset_store( pbuilds( - ExampleDataset, + SingleProcessingExampleDataset, builds_bases=(ExampleDatasetConf,), tiny=True, ), name="image_a_tiny", ) -""" ================== Dataloader & sampler ================== """ +# ================== Dataloader & sampler ================== @dataclass @@ -113,12 +108,12 @@ class DataloaderConf: persistent_workers: bool = False -""" ================== Model ================== """ +# ================== Model ================== # Pre-set the group for store's model entries model_store = store(group="model") -# Not that encoder_input_dim depend on dataset.img_dim, so we need to use a partial to set them in -# the launch_experiment function. +# Not that encoder_input_dim depend on dataset.img_dim, so we need to use a partial to +# set them in the launch_experiment function. model_store( pbuilds( ExampleModel, @@ -140,7 +135,7 @@ class DataloaderConf: name="model_b", ) -""" ================== Losses ================== """ +# ================== Losses ================== training_loss_store = store(group="training_loss") training_loss_store( pbuilds( @@ -151,7 +146,7 @@ class DataloaderConf: ) -""" ================== Optimizer ================== """ +# ================== Optimizer ================== @dataclass @@ -163,21 +158,21 @@ class Optimizer: opt_store = store(group="optimizer") opt_store( pbuilds( - torch.optim.Adam, + torch.optim.adam.Adam, builds_bases=(Optimizer,), ), name="adam", ) opt_store( pbuilds( - torch.optim.SGD, + torch.optim.sgd.SGD, builds_bases=(Optimizer,), ), name="sgd", ) -""" ================== Scheduler ================== """ +# ================== Scheduler ================== sched_store = store(group="scheduler") sched_store( pbuilds( @@ -203,7 +198,7 @@ class Optimizer: name="cosine", ) -""" ================== Experiment ================== """ +# ================== Experiment ================== @dataclass @@ -228,58 +223,59 @@ class RunConfig: tester_store = store(group="tester") tester_store(pbuilds(BaseTester, populate_full_signature=True), name="base") -Experiment = builds( - launch_experiment, - populate_full_signature=True, - hydra_defaults=[ - "_self_", - {"trainer": "base"}, - {"tester": "base"}, - {"dataset": "image_a"}, - {"model": "model_a"}, - {"optimizer": "adam"}, - {"scheduler": "step"}, - {"run": "default"}, - {"training_loss": "mse"}, - ], - trainer=MISSING, - tester=MISSING, - dataset=MISSING, - model=MISSING, - optimizer=MISSING, - scheduler=MISSING, - run=MISSING, - training_loss=MISSING, - data_loader=pbuilds( - DataLoader, builds_bases=(DataloaderConf,) - ), # Needs a partial because we need to set the dataset -) -store(Experiment, name="base_experiment") - -# the experiment configs: -# - must be stored under the _global_ package -# - must inherit from `Experiment` -experiment_store = store(group="experiment", package="_global_") -experiment_store( - make_config( - hydra_defaults=[ - "_self_", - {"override /model": "model_a"}, - {"override /dataset": "image_a"}, - ], - # training=dict(epochs=100), - bases=(Experiment,), - ), - name="exp_a", -) -experiment_store( - make_config( + +def make_experiment_configs(): + Experiment = builds( + launch_experiment, + populate_full_signature=True, hydra_defaults=[ "_self_", - {"override /model": "model_b"}, - {"override /dataset": "image_b"}, + {"trainer": "base"}, + {"tester": "base"}, + {"dataset": "image_a"}, + {"model": "model_a"}, + {"optimizer": "adam"}, + {"scheduler": "step"}, + {"run": "default"}, + {"training_loss": "mse"}, ], - bases=(Experiment,), - ), - name="exp_b", -) + trainer=MISSING, + tester=MISSING, + dataset=MISSING, + model=MISSING, + optimizer=MISSING, + scheduler=MISSING, + run=MISSING, + training_loss=MISSING, + data_loader=pbuilds( + DataLoader, builds_bases=(DataloaderConf,) + ), # Needs a partial because we need to set the dataset + ) + store(Experiment, name="base_experiment") + + # the experiment configs: + # - must be stored under the _global_ package + # - must inherit from `Experiment` + experiment_store = store(group="experiment", package="_global_") + experiment_store( + make_config( + hydra_defaults=[ + "_self_", + {"override /model": "model_a"}, + {"override /dataset": "image_a"}, + ], + bases=(Experiment,), + ), + name="exp_a", + ) + experiment_store( + make_config( + hydra_defaults=[ + "_self_", + {"override /model": "model_b"}, + {"override /dataset": "image_b"}, + ], + bases=(Experiment,), + ), + name="exp_b", + ) diff --git a/dataset/base/__init__.py b/dataset/base/__init__.py deleted file mode 100644 index 57f5f71..0000000 --- a/dataset/base/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -#! /usr/bin/env python3 -# vim:fenc=utf-8 -# -# Copyright © 2023 Théo Morales -# -# Distributed under terms of the MIT license. - -""" -Base dataset. -In this file you may implement other base datasets that share the same characteristics and which -need the same data loading + transformation pipeline. The specificities of loading the data or -transforming it may be extended through class inheritance in a specific dataset file. -""" - -import abc -import os -import os.path as osp -from typing import Any, Dict, List, Tuple, Union - -from hydra.utils import get_original_cwd -from torch import Tensor -from torch.utils.data import Dataset - - -class BaseDataset(Dataset[Any], abc.ABC): - def __init__( - self, - dataset_root: str, - dataset_name: str, - augment: bool, - normalize: bool, - split: str, - seed: int, - debug: bool, - tiny: bool = False, - ) -> None: - super().__init__() - self._samples: Union[Dict[Any, Any], List[Any], Tensor] - self._labels: Union[Dict[Any, Any], List[Any], Tensor] - self._samples, self._labels = self._load(dataset_root, tiny, split, seed) - self._augment = augment and split == "train" - self._normalize = normalize - self._dataset_name = dataset_name - self._debug = debug - self._cache_dir = osp.join( - get_original_cwd(), "data", f"{dataset_name}_preprocessed" - ) - os.makedirs(self._cache_dir, exist_ok=True) - - @abc.abstractmethod - def _load( - self, dataset_root: str, tiny: bool, split: str, seed: int - ) -> Tuple[ - Union[Dict[str, Any], List[Any], Tensor], - Union[Dict[str, Any], List[Any], Tensor], - ]: - # Implement this - raise NotImplementedError - - def __len__(self) -> int: - return len(self._samples) - - def disable_augs(self) -> None: - self._augment = False diff --git a/dataset/base/image.py b/dataset/base/image.py index e2c01b7..02aa7c4 100644 --- a/dataset/base/image.py +++ b/dataset/base/image.py @@ -9,17 +9,18 @@ Base dataset for images. """ -import abc -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple +from rich.progress import Progress, TaskID from torch import Tensor +from torch.utils.data import Dataset from torchvision.io.image import read_image # type: ignore from torchvision.transforms import transforms # type: ignore -from dataset.base import BaseDataset +from dataset.mixins import BaseDatasetMixin -class ImageDataset(BaseDataset, abc.ABC): +class ImageDataset(BaseDatasetMixin, Dataset): IMAGE_NET_MEAN: List[float] = [] IMAGE_NET_STD: List[float] = [] COCO_MEAN, COCO_STD = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) @@ -31,6 +32,8 @@ def __init__( dataset_name: str, split: str, seed: int, + progress: Progress, + job_id: TaskID, img_size: Optional[tuple[int, ...]] = None, augment: bool = False, normalize: bool = False, @@ -44,8 +47,10 @@ def __init__( normalize, split, seed, - debug=debug, - tiny=tiny, + debug, + tiny, + progress, + job_id, ) self._img_size = self.IMG_SIZE if img_size is None else img_size self._transforms: Callable[[Tensor], Tensor] = transforms.Compose( @@ -56,29 +61,20 @@ def __init__( self._normalization: Callable[[Tensor], Tensor] = transforms.Normalize( self.IMAGE_NET_MEAN, self.IMAGE_NET_STD ) - try: - import albumentations as A # type: ignore - except ImportError: - raise ImportError( - "Please install albumentations to use the augmentation pipeline." + if self._augment: + try: + import albumentations as A # type: ignore + except ImportError: + raise ImportError( + "Please install albumentations to use the augmentation pipeline." + ) + self._augs: Callable[..., Dict[str, Any]] = A.Compose( + [ + A.RandomCropFromBorders(), + A.RandomBrightnessContrast(), + A.RandomGamma(), + ] ) - self._augs: Callable[..., Dict[str, Any]] = A.Compose( - [ - A.RandomCropFromBorders(), - A.RandomBrightnessContrast(), - A.RandomGamma(), - ] - ) - - @abc.abstractmethod - def _load( - self, dataset_root: str, tiny: bool, split: str, seed: int - ) -> Tuple[ - Union[Dict[str, Any], List[Any], Tensor], - Union[Dict[str, Any], List[Any], Tensor], - ]: - # Implement this - raise NotImplementedError def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: """ @@ -88,8 +84,6 @@ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: # ==== Load image and apply transforms === img: Tensor img = read_image(self._samples[index]) # type: ignore - if not isinstance(img, Tensor): - raise ValueError("Image not loaded as a Tensor.") img = self._transforms(img) if self._normalize: img = self._normalization(img) diff --git a/dataset/example.py b/dataset/example.py index a526e40..e1408ca 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -10,15 +10,17 @@ This is mostly used to test the framework. """ +from time import sleep from typing import Optional, Tuple, Union import torch +from rich.progress import Progress, TaskID from torch import Tensor from dataset.base.image import ImageDataset -class ExampleDataset(ImageDataset): +class SingleProcessingExampleDataset(ImageDataset): IMG_SIZE = (32, 32) def __init__( @@ -27,29 +29,112 @@ def __init__( dataset_name: str, split: str, seed: int, + progress: Progress, + job_id: TaskID, img_dim: Optional[int] = None, augment: bool = False, normalize: bool = False, tiny: bool = False, debug: bool = False, ) -> None: - self._img_dim = self.IMG_SIZE[0] if img_dim is None else img_dim super().__init__( dataset_root, dataset_name, split, seed, + progress, + job_id, (img_dim, img_dim) if img_dim is not None else None, augment=augment, normalize=normalize, debug=debug, tiny=tiny, ) + self._img_dim = self.IMG_SIZE[0] if img_dim is None else img_dim + self._samples, self._labels = self._load( + progress, + job_id, + ) def _load( - self, dataset_root: str, tiny: bool, split: str, seed: int + self, + progress: Progress, + job_id: TaskID, ) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]: + length = 3 if self._tiny else 20 + progress.update(job_id, total=length) + for _ in range(length): + progress.advance(job_id) + sleep(0.001 if self._tiny else 0.1) return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8) def __getitem__(self, index: int): return self._samples[index], self._labels[index] + + +class MultiProcessingExampleDataset(ImageDataset): # TODO + IMG_SIZE = (32, 32) + + def __init__( + self, + dataset_root: str, + dataset_name: str, + split: str, + seed: int, + progress: Progress, + job_id: TaskID, + img_dim: Optional[int] = None, + augment: bool = False, + normalize: bool = False, + tiny: bool = False, + debug: bool = False, + ) -> None: + self._img_dim = self.IMG_SIZE[0] if img_dim is None else img_dim + super().__init__( + dataset_root, + dataset_name, + split, + seed, + progress, + job_id, + (img_dim, img_dim) if img_dim is not None else None, + augment=augment, + normalize=normalize, + debug=debug, + tiny=tiny, + ) + # TODO: + + +class MultiProcessingWithCachingExampleDataset(ImageDataset): # TODO + IMG_SIZE = (32, 32) + + def __init__( + self, + dataset_root: str, + dataset_name: str, + split: str, + seed: int, + progress: Progress, + job_id: TaskID, + img_dim: Optional[int] = None, + augment: bool = False, + normalize: bool = False, + tiny: bool = False, + debug: bool = False, + ) -> None: + self._img_dim = self.IMG_SIZE[0] if img_dim is None else img_dim + super().__init__( + dataset_root, + dataset_name, + split, + seed, + progress, + job_id, + (img_dim, img_dim) if img_dim is not None else None, + augment=augment, + normalize=normalize, + debug=debug, + tiny=tiny, + ) + # TODO: diff --git a/dataset/mixins/__init__.py b/dataset/mixins/__init__.py new file mode 100644 index 0000000..98162fd --- /dev/null +++ b/dataset/mixins/__init__.py @@ -0,0 +1,521 @@ +""" +Base dataset. +In this file you may implement other base datasets that share the same characteristics +and which need the same data loading + transformation pipeline. The specificities of +loading the data or transforming it may be extended through class inheritance in a +specific dataset file. +""" + +import abc +import ast +import hashlib +import inspect +import itertools +import os +import os.path as osp +import pickle +import shutil +from multiprocessing.pool import Pool +from os import cpu_count +from types import FrameType +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +from hydra.utils import get_original_cwd +from rich.progress import Progress, TaskID +from tqdm import tqdm + +from utils.helpers import compressed_read, compressed_write + +# TODO: Can we speed all of this up with Cython or Numba? + + +class DatasetMixinInterface(abc.ABC): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + **kwargs, + ): + _ = dataset_root + _ = dataset_name + _ = augment + _ = normalize + _ = split + _ = seed + _ = debug + _ = tiny + _ = progress + _ = job_id + super().__init__(**kwargs) + + +class BaseDatasetMixin(DatasetMixinInterface): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + **kwargs, + ): + self._samples, self._labels = [], [] + self._augment = augment and split == "train" + self._normalize = normalize + self._dataset_name = dataset_name + self._debug = debug + self._tiny = tiny + super().__init__( + dataset_root, + dataset_name, + augment, + normalize, + split, + seed, + debug, + tiny, + progress, + job_id, + **kwargs, + ) + + def __len__(self) -> int: + return len(self._samples) + + def disable_augs(self) -> None: + self._augment = False + + +# TODO: Automatically "remove" this mixin if debug=True. Idk how, maybe metaclass? +class SafeCacheDatasetMixin(DatasetMixinInterface): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + scd_lazy: bool = True, + **kwargs, + ): + self._cache_dir = osp.join( + get_original_cwd(), + "data", + f"{dataset_name}_preprocessed", + f"{'tiny_' if tiny else ''}{split}", + ) + self._split = split + self._lazy = scd_lazy # TODO: Implement eager caching (rn the default is lazy) + fingerprints, not_found, mismatch_list = self._compute_fingerprints() + + flush = False + if not_found: + print("No fingerprint found, flushing cache.") + flush = True + elif mismatch_list != []: + while flush not in ["y", "n"]: + flush = input( + f"Fingerprint mismatch in {' and '.join(mismatch_list)}, " + + "flush cache? (y/n) " + ).lower() + flush = flush.lower().strip() == "y" + if not flush: + print( + "[!] Warning: Fingerprint mismatch, but cache will not be flushed." + ) + + if flush: + shutil.rmtree(self._cache_dir, ignore_errors=True) + os.makedirs(self._cache_dir, exist_ok=True) + + super().__init__( + dataset_root, + dataset_name, + augment, + normalize, + split, + seed, + debug, + tiny, + progress, + job_id, + **kwargs, + ) + if flush: + with open(osp.join(self._cache_dir, "fingerprints"), "w") as f: + for k, v in fingerprints.items(): + f.write(f"{k}:{v}\n") + + def _compute_fingerprints(self) -> Tuple[Dict[str, str], bool, List[str]]: + argnames = inspect.getfullargspec(self.__class__.__init__).args + found = False + frame: FrameType | None = inspect.currentframe() + while not found: + frame = frame.f_back if frame is not None else None + if frame is None: + break + found = ( + frame.f_code.co_qualname.strip() + == f"{self.__class__.__qualname__}.__init__".strip() + ) + if frame is None: + raise RuntimeError( + f"Could not find frame for {self.__class__.__qualname__}.__init__" + ) + argvalues = { + k: v + for k, v in inspect.getargvalues(frame).locals.items() + if k in argnames and k not in ["self", "tiny", "scd_lazy"] + } + # TODO: Recursively hash the source code for user's methods in self.__class__ + # NOTE: getsource() won't work if I have a decorator that wraps the method. I + # think it's best to keep this behaviour and not use decorators. + hashers = { + "code": hashlib.new("md5", usedforsecurity=False), + "args": hashlib.new("md5", usedforsecurity=False), + } + tree = ast.parse(inspect.getsource(self.__class__)) + hashers["code"].update(ast.dump(tree).encode()) + hashers["args"].update(pickle.dumps(argvalues)) + fingerprint_els: Dict[str, str] = {} + for k, v in hashers.items(): + fingerprint_els[k] = v.hexdigest() # type: ignore + mismatches, not_found = {k: True for k in fingerprint_els}, True + if osp.isfile(osp.join(self._cache_dir, "fingerprints")): + with open(osp.join(self._cache_dir, "fingerprints"), "r") as f: + not_found = False + cached_fingerprints = f.readlines() + for line in cached_fingerprints: + key, value = line.split(":") + mismatches[key] = value.strip() != fingerprint_els[key] + mismatch_list = [k for k, v in mismatches.items() if v] + + return fingerprint_els, not_found, mismatch_list + + def _get_raw_elements_hook(self, *args, **kwargs) -> Sequence[Any]: + # TODO: Investigate slowness issues + class LazyCacheSequence(Sequence): + def __init__(self, cache_dir: str, seq_len: int, seq_type: str): + self._cache_paths = [] + self._seq_len = seq_len + self._seq_type = seq_type + + for i in itertools.count(): # TODO: Could this be slow? + cache_path = osp.join(cache_dir, f"{i:04d}.pkl") + if not osp.isfile(cache_path): + break + self._cache_paths.append(cache_path) + + if len(self._cache_paths) != seq_len: + raise ValueError( + f"Cache info file {osp.join(cache_dir, 'info.txt')} " + + f"does not match the number of cache files in {cache_dir}. " + + "This may be due to an interrupted dataset computation. " + + "Please manually flush the cash to recompute." + ) + + el_type_str = "unknown" + try: + el_type_str = str(type(compressed_read(self._cache_paths[0]))) + except Exception: + pass + + if el_type_str != seq_type: + raise ValueError( + f"Cache info file {osp.join(cache_dir, 'info.txt')} " + + f"does not match the type of cache files in {cache_dir}. " + + "This may be due to an interrupted dataset computation. " + + "Please manually flush the cash to recompute." + ) + + def __len__(self): + return self._seq_len + + def __getitem__(self, idx): + if idx >= self._seq_len: + raise IndexError + return compressed_read(self._cache_paths[idx]) + + # This hooks onto the user's _get_raw_elements method and overrides it if a + # cache entry is found. If not it just calls the user's _get_raw_elements + # method. + # return self._get_raw_elements(*args, **kwargs) + path = osp.join(self._cache_dir, "raw_elements") + try: + info = [None, None] + if not osp.isfile(osp.join(path, "info.txt")): + raise FileNotFoundError( + f"Cache info file not found at {osp.join(path, 'info.txt')}." + ) + with open(osp.join(path, "info.txt"), "r") as f: + info = f.readlines() + if len(info) != 2 or not info[0].strip().isdigit(): + raise ValueError( + f"Invalid cache info file {osp.join(path, 'info.txt')}." + ) + raw_elements = LazyCacheSequence( + path, int(info[0].strip()), info[1].strip() + ) + except FileNotFoundError: + if not hasattr(self, "_get_raw_elements"): + raise NotImplementedError( + "SafeCacheDatasetMixin._get_raw_elements() is called " + + "but the user has not implemented a _get_raw_elements " + + f"method in {self.__class__.__name__}." + ) + # Compute them: + raw_elements: Sequence[Any] = ( + self._get_raw_elements(*args, **kwargs), # type: ignore + ) + type_str = "unknown" + try: + type_str = type(raw_elements[0]) + except Exception: + pass + # Cache them: + os.makedirs(path, exist_ok=True) + with open(osp.join(path, "info.txt"), "w") as f: + f.writelines([f"{len(raw_elements)}\n", f"{type_str}"]) + + print(f"going to cache {len(raw_elements)} elements") + for i, element in enumerate(raw_elements): + compressed_write(osp.join(path, f"{i:04d}.pkl"), element) + print(f"cached {len(raw_elements)} elements") + # TODO: Rich.log("Raw elements cached here: <>") + return raw_elements + + def _load_hook(self, *args, **kwargs) -> Tuple[int, Any, Any]: + # This hooks onto the user's _load method and overrides it if a cache entry is + # found. If not it just calls the user's _load method. + idx = args[1] + cache_path = osp.join(self._cache_dir, f"{idx:04d}.pkl") + if osp.isfile(cache_path): + sample, label = None, None + else: + if not hasattr(self, "_load"): + raise NotImplementedError( + "SafeCacheDatasetMixin._load() is called but " + + "the user has not implemented a _load method " + + f"in {self.__class__.__name__}." + ) + _idx, sample, label = self._load(*args, **kwargs) # type: ignore + if _idx != idx: + raise ValueError( + "The _load method returned an index different " + + "from the one requested." + ) + return idx, sample, label + + def _load_sample_label(self, idx: int) -> Tuple[Any, Any]: + if hasattr(super(), "_load_sample_label"): + raise Exception( + "SafeCacheDatasetMixin._load_sample_label() is overriden. " + + "As best practice, you should inherit from SafeCacheDatasetMixin " + + f"after {super().__class__.__name__} to avoid unwanted behavior." + ) + cache_path = osp.join(self._cache_dir, f"{idx:04d}.pkl") + if not osp.isfile(cache_path): + raise KeyError(f"Cache file {cache_path} not found, will recompute.") + return compressed_read(cache_path) + + def _register_sample_label( + self, + idx: int, + sample: Any, + label: Any, + memory_samples: List[Any], + memory_labels: List[Any], + ) -> None: + if hasattr(super(), "_register_sample_label"): + raise Exception( + "SafeCacheDatasetMixin._register_sample_label() is overriden. " + + "As best practice, you should inherit from SafeCacheDatasetMixin " + + f"after {super().__class__.__name__} to avoid unwanted behavior." + ) + cache_path = osp.join(self._cache_dir, f"{idx:04d}.pkl") + if not osp.isfile(cache_path): + if sample is None: + raise ValueError( + "The _load_hook method returned sample=None, " + + "but no cache entry was found. " + ) + compressed_write(cache_path, (sample, label)) + memory_samples.insert(idx, cache_path) + memory_labels.insert(idx, cache_path) + + +class MultiProcessingDatasetMixin(DatasetMixinInterface, abc.ABC): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + mpd_lazy: bool = True, + mpd_chunk_size: int = 1, + mpd_processes: Optional[int] = None, + **kwargs, + ) -> None: + super().__init__( + dataset_root, + dataset_name, + augment, + normalize, + split, + seed, + debug, + tiny, + progress, + job_id, + **kwargs, + ) + self._samples, self._labels = [], [] + cpus = cpu_count() + processes = ( + 1 if debug else (mpd_processes or ((cpus - 1) if cpus is not None else 0)) + ) + + with Pool(processes) as pool: + raw_elements = self._get_raw_elements_hook( + dataset_root, + tiny, + split, + seed, + ) + if raw_elements[0] is None or raw_elements[0] is None: + raise ValueError( + "The _get_raw_elements method returned None or a sequence of None. " + ) + + if len(raw_elements) == 0: + raise ValueError( + "The _get_raw_elements method must return a sequence of elements. " + + "If the dataset is empty, return an empty list." + ) + + pool_dispatch_method: Callable = pool.imap if mpd_lazy else pool.starmap + pool_dispatch_func: Callable = ( + self._load_hook_unpack if mpd_lazy else self._load_hook + ) + for idx, sample, label in tqdm( + pool_dispatch_method( + pool_dispatch_func, + zip( + raw_elements, + range(len(raw_elements)), + itertools.repeat(tiny), + itertools.repeat(split), + itertools.repeat(seed), + ), + chunksize=mpd_chunk_size, + ), + total=len(raw_elements), + ): + self._register_sample_label(idx, sample, label) + print(f"{'Lazy' if mpd_lazy else 'Eager'} loaded {len(self._samples)} samples.") + + def _get_raw_elements_hook( + self, dataset_root: str, tiny: bool, split: str, seed: int + ): + if hasattr(super(), "_get_raw_elements_hook"): + return super()._get_raw_elements_hook( # type: ignore + dataset_root, + tiny, + split, + seed, + ) + return self._get_raw_elements(dataset_root, tiny, split, seed) + + def _load_hook_unpack(self, args): + return self._load_hook(*args) + + def _load_hook(self, *args) -> Tuple[int, Any, Any]: + if hasattr(super(), "_load_hook"): + return super()._load_hook(*args) # type: ignore + return self._load(*args) # TODO: Rename to _load_sample? + + def _load_sample_label(self, idx: int) -> Tuple[Any, Any]: + if hasattr(super(), "_load_sample_label"): + return super()._load_sample_label(idx) # type: ignore + return self._samples[idx], self._labels[idx] + + def _register_sample_label(self, idx: int, sample: Any, label: Any) -> None: + if hasattr(super(), "_register_sample_label"): + super()._register_sample_label( # type: ignore + idx, sample, label, self._samples, self._labels + ) + if isinstance(sample, (List, Tuple)) or isinstance(label, (List, Tuple)): + raise NotImplementedError( + "_register_sample_label cannot yet handle lists of samples/labels" + ) + self._samples.insert(idx, sample) + self._labels.insert(idx, label) + + @abc.abstractmethod + def _get_raw_elements( + self, dataset_root: str, tiny: bool, split: str, seed: int + ) -> Sequence[Any]: + # Implement this + raise NotImplementedError + + @abc.abstractmethod + def _load( + self, element: Any, idx: int, tiny: bool, split: str, seed: int + ) -> Tuple[int, Any, Any]: + # Implement this + raise NotImplementedError + + +class BatchedTensorsMultiprocessingDatasetMixin(DatasetMixinInterface, abc.ABC): + def __init__( + self, + dataset_root: str, + dataset_name: str, + augment: bool, + normalize: bool, + split: str, + seed: int, + debug: bool, + tiny: bool, + progress: Progress, + job_id: TaskID, + **kwargs, + ) -> None: + super().__init__( + dataset_root, + dataset_name, + augment, + normalize, + split, + seed, + debug, + tiny, + progress, + job_id, + **kwargs, + ) + raise NotImplementedError("This is not implemented yet.") diff --git a/launch_experiment.py b/launch_experiment.py deleted file mode 100644 index e328b03..0000000 --- a/launch_experiment.py +++ /dev/null @@ -1,218 +0,0 @@ -#! /usr/bin/env python3 -# vim:fenc=utf-8 -# -# Copyright © 2023 Théo Morales -# -# Distributed under terms of the MIT license. - - -import os -from dataclasses import asdict -from typing import Any, Optional - -import hydra_zen -import torch -import wandb -import yaml -from hydra.core.hydra_config import HydraConfig -from hydra.utils import to_absolute_path -from hydra_zen import just -from hydra_zen.typing import Partial -from torch.utils.data import DataLoader, Dataset - -import conf.experiment as exp_conf -from conf import project as project_conf -from model import TransparentDataParallel -from src.base_tester import BaseTester -from src.base_trainer import BaseTrainer -from utils import colorize, to_cuda_ - - -def launch_experiment( - run: exp_conf.RunConfig, - data_loader: Partial[DataLoader[Any]], - optimizer: Partial[torch.optim.Optimizer], - scheduler: Partial[torch.optim.lr_scheduler.LRScheduler], - trainer: Partial[BaseTrainer], - tester: Partial[BaseTester], - dataset: Partial[Dataset[Any]], - model: Partial[torch.nn.Module], - training_loss: Partial[torch.nn.Module], -): - run_name = os.path.basename(HydraConfig.get().runtime.output_dir) - # Generate a random ANSI code: - color_code = f"38;5;{hash(run_name) % 255}" - print( - colorize( - f"========================= Running {run_name} =========================", - color_code, - ) - ) - exp_conf = hydra_zen.to_yaml( - dict( - run_name=run_name, - run_conf=run, - dataset=dataset, - model=model, - optimizer=optimizer, - scheduler=scheduler, - training_loss=training_loss, - ) - ) - print( - colorize( - "Experiment config:\n" + "_" * 18 + "\n" + exp_conf + "_" * 18, color_code - ) - ) - - "============ Partials instantiation ============" - model_inst = model( - encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore - ) # Use just() to get the config out of the Zen-Partial - print(model_inst) - print(f"Number of parameters: {sum(p.numel() for p in model_inst.parameters())}") - print( - f"Number of trainable parameters: {sum(p.numel() for p in model_inst.parameters() if p.requires_grad)}" - ) - train_dataset: Optional[Dataset[Any]] = None - val_dataset: Optional[Dataset[Any]] = None - test_dataset: Optional[Dataset[Any]] = None - if run.training_mode: - train_dataset = dataset(split="train", seed=run.seed) - val_dataset = dataset(split="val", seed=run.seed) - else: - test_dataset = dataset(split="test", augment=False, seed=run.seed) - - opt_inst = optimizer(model_inst.parameters()) - scheduler_inst = scheduler( - opt_inst - ) # TODO: less hacky way to set T_max for CosineAnnealingLR? - if isinstance(scheduler_inst, torch.optim.lr_scheduler.CosineAnnealingLR): - scheduler_inst.T_max = run.epochs - - "======== Multi GPUs ==========" - print( - colorize( - f"[*] Number of GPUs: {torch.cuda.device_count()}", - project_conf.ANSI_COLORS["cyan"], - ) - ) - if torch.cuda.device_count() > 1: - print( - colorize( - f"-> Using {torch.cuda.device_count()} GPUs!", - project_conf.ANSI_COLORS["cyan"], - ) - ) - model_inst = TransparentDataParallel(model_inst) - - training_loss_inst: Optional[torch.nn.Module] = None - if run.training_mode: - training_loss_inst = training_loss() - - "============ CUDA ============" - model_inst: torch.nn.Module = to_cuda_(model_inst) # type: ignore - training_loss_inst = to_cuda_(training_loss_inst) # type: ignore - - "============ Weights & Biases ============" - if project_conf.USE_WANDB: - # exp_conf is a string, so we need to load it back to a dict: - exp_conf = yaml.safe_load(exp_conf) - wandb.init( # type: ignore - project=project_conf.PROJECT_NAME, - name=run_name, - config=exp_conf, - ) - wandb.watch(model_inst, log="all", log_graph=True) # type: ignore - " ============ Reproducibility of data loaders ============ " - g = None - if project_conf.REPRODUCIBLE: - g = torch.Generator() - g.manual_seed(run.seed) - - train_loader_inst: Optional[DataLoader[Any]] = None - val_loader_inst: Optional[DataLoader[Dataset[Any]]] = None - test_loader_inst: Optional[DataLoader[Any]] = None - if run.training_mode: - if train_dataset is None or val_dataset is None: - raise ValueError( - "train_dataset and val_dataset must be defined in training mode!" - ) - train_loader_inst = data_loader(train_dataset, generator=g) - val_loader_inst = data_loader( - val_dataset, generator=g, shuffle=False, drop_last=False - ) - else: - if test_dataset is None: - raise ValueError("test_dataset must be defined in testing mode!") - test_loader_inst = data_loader( - test_dataset, generator=g, shuffle=False, drop_last=False - ) - - " ============ Training ============ " - model_ckpt_path = None - if run.load_from is not None: - if run.load_from.endswith(".ckpt"): - model_ckpt_path = to_absolute_path(run.load_from) - if not os.path.exists(model_ckpt_path): - raise ValueError(f"File {model_ckpt_path} does not exist!") - else: - run_models = sorted( - [ - f - for f in os.listdir(to_absolute_path(f"runs/{run.load_from}/")) - if f.endswith(".ckpt") - and (not f.startswith("last") if not run.training_mode else True) - ] - ) - if len(run_models) < 1: - raise ValueError(f"No model found in runs/{run.load_from}/") - model_ckpt_path = to_absolute_path( - os.path.join( - "runs", - run.load_from, - run_models[-1], - ) - ) - - if run.training_mode: - if training_loss_inst is None: - raise ValueError("training_loss must be defined in training mode!") - if val_loader_inst is None or train_loader_inst is None: - raise ValueError( - "val_loader and train_loader must be defined in training mode!" - ) - trainer( - run_name=run_name, - model=model_inst, - opt=opt_inst, - scheduler=scheduler_inst, - train_loader=train_loader_inst, - val_loader=val_loader_inst, - training_loss=training_loss_inst, - **asdict( - run - ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) - ).train( - epochs=run.epochs, - val_every=run.val_every, - visualize_every=run.viz_every, - visualize_train_every=run.viz_train_every, - visualize_n_samples=run.viz_num_samples, - model_ckpt_path=model_ckpt_path, - ) - else: - if test_loader_inst is None: - raise ValueError("test_loader must be defined in testing mode!") - tester( - run_name=run_name, - model=model_inst, - data_loader=test_loader_inst, - model_ckpt_path=model_ckpt_path, - training_loss=training_loss_inst, - ).test( - visualize_every=run.viz_every, - **asdict( - run - ), # Extra stuff if needed. You can get them from the trainer's __init__ with kwrags.get(key, default_value) - ) diff --git a/pyproject.toml b/pyproject.toml index 615a051..0ad2b90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ exclude = [ line-length = 88 indent-width = 4 -# Assume Python 3.8 -target-version = "py38" +# Assume Python 3.11 +target-version = "py311" [tool.ruff.lint] # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. @@ -69,7 +69,7 @@ line-ending = "auto" # # This is currently disabled by default, but it is planned for this # to be opt-out in the future. -docstring-code-format = false +docstring-code-format = true # Set the line length limit used when formatting code snippets in # docstrings. diff --git a/requirements.txt b/requirements.txt index ddfcdc4..825ee8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,6 @@ pre-commit blosc2 ipython neovim +rich +textual +textual_plotext diff --git a/src/base_tester.py b/src/base_tester.py index c575ceb..104ac49 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -9,30 +9,38 @@ Base tester class. """ +import asyncio import signal from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +import torch +from rich.console import Console +from rich.text import Text from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader from torchmetrics import MeanMetric -from tqdm import tqdm -from conf import project as project_conf from src.base_trainer import BaseTrainer -from utils import to_cuda, update_pbar_str +from utils import to_cuda +from utils.gui import GUI T = TypeVar("T") +console = Console() +print = console.print # skipcq: PYL-W0603, PYL-W0622 + + class BaseTester(BaseTrainer): def __init__( self, + gui: GUI, run_name: str, data_loader: DataLoader[T], model: Module, - model_ckpt_path: str, + model_ckpt_path: Optional[str] = None, training_loss: Optional[Module] = None, **kwargs: Optional[Dict[str, Any]], ) -> None: @@ -45,13 +53,18 @@ def __init__( """ _args = kwargs _loss = training_loss + self._gui = gui + global print # skipcq: PYL-W0603 + print = self._gui.print # skipcq: PYL-W0603, PYL-W0622 self._run_name = run_name self._model = model - assert model_ckpt_path is not None, "No model checkpoint path provided." - self._load_checkpoint(model_ckpt_path, model_only=True) + if model_ckpt_path is None: + print(Text("No model checkpoint path provided!", style="bold red")) + else: + print(Text("Loading model checkpoint...", style="bold cyan")) + self._load_checkpoint(model_ckpt_path, model_only=True) self._data_loader = data_loader self._running = True - self._pbar = tqdm(total=len(self._data_loader), desc="Testing") signal.signal(signal.SIGINT, self._terminator) @to_cuda @@ -68,9 +81,10 @@ def _visualize( def _test_iteration( self, batch: Union[Tuple, List, Tensor], - ) -> Dict[str, Tensor]: + ) -> Tuple[Tensor, Dict[str, Tensor]]: """Evaluation procedure for one batch. We want to keep the code DRY and avoid - making mistakes, so this code calls the BaseTrainer._train_val_iteration() method. + making mistakes, so this code calls the BaseTrainer._train_val_iteration() + method. Args: batch: The batch to process. Returns: @@ -79,47 +93,45 @@ def _test_iteration( x, y = batch # type: ignore # noqa y_hat = self._model(x) # type: ignore # noqa # TODO: Compute your metrics here! - return {} + return torch.tensor(torch.inf), {} - def test( + async def test( self, visualize_every: int = 0, **kwargs: Optional[Dict[str, Any]] ) -> None: """Computes the average loss on the test set. Args: - visualize_every (int, optional): Visualize the model predictions every n batches. + visualize_every (int, optional): Visualize the model predictions every n + batches. Defaults to 0 (no visualization). """ _ = kwargs test_loss: MeanMetric = MeanMetric() test_metrics: Dict[str, MeanMetric] = defaultdict(MeanMetric) self._model.eval() - self._pbar.reset() - self._pbar.set_description("Testing") - color_code = project_conf.ANSI_COLORS[project_conf.Theme.TESTING.value] - """ ==================== Training loop for one epoch ==================== """ - for i, batch in enumerate(self._data_loader): + print(Text(f"[*] Testing {self._run_name}", style="bold green")) + # ==================== Training loop for one epoch ==================== + pbar, update_loss_hook = self._gui.track_testing( + self._data_loader, total=len(self._data_loader) + ) + for i, batch in enumerate(pbar): if not self._running: print("[!] Testing aborted.") break - loss, metrics = self._test_iteration(batch) + loss, metrics = await asyncio.to_thread(self._test_iteration, batch) test_loss.update(loss.item()) for k, v in metrics.items(): test_metrics[k].update(v.item()) - update_pbar_str( - self._pbar, - f"Testing [loss={test_loss.compute():.4f}]", - color_code, - ) + update_loss_hook(test_loss.compute()) """ ==================== Visualization ==================== """ if visualize_every > 0 and (i + 1) % visualize_every == 0: self._visualize(batch, i) - self._pbar.update() - self._pbar.close() + # TODO: Report metrics in a special panel? Then hang the GUI until the user is + # done. print("=" * 81) print("==" + " " * 31 + " Test results " + " " * 31 + "==") print("=" * 81) for k, v in test_metrics.items(): print(f"\t -> {k}: {v.compute().item():.2f}") - print(f"\t -> Average loss: {test_loss:.4f}") + print(f"\t -> Average loss: {test_loss.compute():.4f}") print("_" * 81) diff --git a/src/base_trainer.py b/src/base_trainer.py index be05375..442f4c2 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -9,45 +9,52 @@ Base trainer class. """ +import asyncio import os import random import signal from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union -import plotext as plt import torch import wandb from hydra.core.hydra_config import HydraConfig +from rich.console import Console +from rich.text import Text from torch import Tensor from torch.nn import Module -from torch.optim import Optimizer +from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from torchmetrics import MeanMetric -from tqdm import tqdm from conf import project as project_conf -from utils import blink_pbar, colorize, to_cuda, update_pbar_str +from utils import to_cuda +from utils.gui import GUI from utils.helpers import BestNModelSaver from utils.training import visualize_model_predictions +console = Console() +print = console.print # skipcq: PYL-W0603, PYL-W0622 + class BaseTrainer: def __init__( self, + gui: GUI, run_name: str, model: Module, opt: Optimizer, train_loader: DataLoader, val_loader: DataLoader, training_loss: Module, + model_ckpt_path: Optional[str] = None, scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, **kwargs: Dict[str, Optional[Union[str, int]]], ) -> None: """Base trainer class. Args: model (torch.nn.Module): Model to train. - opt (torch.optim.Optimizer): Optimizer to use. + opt (torch.optim.optimizer.Optimizer): Optimizer to use. train_loader (torch.utils.data.DataLoader): Training dataloader. val_loader (torch.utils.data.DataLoader): Validation dataloader. """ @@ -65,10 +72,14 @@ def __init__( project_conf.BEST_N_MODELS_TO_KEEP, self._save_checkpoint ) self._minimize_metric = "val_loss" - self._pbar = tqdm(total=len(self._train_loader), desc="Training") self._training_loss = training_loss self._viz_n_samples = 1 self._n_ctrl_c = 0 + self._gui = gui + global print # skipcq: PYL-W0603 + print = self._gui.print # skipcq: PYL-W0603, PYL-W0622 + if model_ckpt_path is not None: + self._load_checkpoint(model_ckpt_path) signal.signal(signal.SIGINT, self._terminator) @to_cuda @@ -93,8 +104,9 @@ def _train_val_iteration( epoch: int, validation: bool = False, ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Training or validation procedure for one batch. We want to keep the code DRY and avoid - making mistakes, so write this code only once at the cost of many function calls! + """Training or validation procedure for one batch. We want to keep the code DRY + and avoid making mistakes, so write this code only once at the cost of many + function calls! Args: batch: The batch to process. Returns: @@ -124,12 +136,13 @@ def _train_epoch( """ epoch_loss: MeanMetric = MeanMetric() epoch_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) - self._pbar.reset() - self._pbar.set_description(description) - color_code = project_conf.ANSI_COLORS[project_conf.Theme.TRAINING.value] has_visualized = 0 - """ ==================== Training loop for one epoch ==================== """ - for i, batch in enumerate(self._train_loader): + # ==================== Training loop for one epoch ==================== + pbar, update_loss_hook = self._gui.track_training( + self._train_loader, + total=len(self._train_loader), + ) + for i, batch in enumerate(pbar): if ( not self._running and project_conf.SIGINT_BEHAVIOR @@ -146,12 +159,7 @@ def _train_epoch( epoch_loss.update(loss.item()) for k, v in loss_components.items(): epoch_loss_components[k].update(v.item()) - update_pbar_str( - self._pbar, - f"{description} [loss={epoch_loss.compute():.4f} /" - + f" val_loss={last_val_loss:.4f}]", - color_code, - ) + update_loss_hook(epoch_loss.compute()) if ( visualize and has_visualized < self._viz_n_samples @@ -160,7 +168,6 @@ def _train_epoch( with torch.no_grad(): self._visualize(batch, epoch) has_visualized += 1 - self._pbar.update() mean_epoch_loss: float = epoch_loss.compute().item() if project_conf.USE_WANDB: wandb.log({"train_loss": mean_epoch_loss}, step=epoch) @@ -182,12 +189,15 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: float: Average validation loss for the epoch. """ has_visualized = 0 - color_code = project_conf.ANSI_COLORS[project_conf.Theme.VALIDATION.value] - """ ==================== Validation loop for one epoch ==================== """ + # ==================== Validation loop for one epoch ==================== with torch.no_grad(): val_loss: MeanMetric = MeanMetric() val_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric) - for i, batch in enumerate(self._val_loader): + pbar, update_loss_hook = self._gui.track_validation( + self._val_loader, + total=len(self._val_loader), + ) + for i, batch in enumerate(pbar): if ( not self._running and project_conf.SIGINT_BEHAVIOR @@ -195,8 +205,6 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: ): print("[!] Training aborted.") break - # Blink the progress bar to indicate that the validation loop is running - blink_pbar(i, self._pbar, 4) loss, loss_components = self._train_val_iteration( batch, epoch, @@ -204,13 +212,8 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: val_loss.update(loss.item()) for k, v in loss_components.items(): val_loss_components[k].update(v.item()) - update_pbar_str( - self._pbar, - f"{description} [loss={val_loss.compute():.4f} /" - + f" min_val_loss={self._model_saver.min_val_loss:.4f}]", - color_code, - ) - """ ==================== Visualization ==================== """ + update_loss_hook(val_loss.compute()) + # ==================== Visualization ==================== if ( visualize and has_visualized < self._viz_n_samples @@ -233,8 +236,8 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: }, step=epoch, ) - # Set minimize_metric to a key in val_loss_components if you wish to minimize - # a specific metric instead of the validation loss: + # Set minimize_metric to a key in val_loss_components if you wish to + # minimize a specific metric instead of the validation loss: self._model_saver( epoch, mean_val_loss, @@ -243,14 +246,13 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: ) return mean_val_loss - def train( + async def train( self, epochs: int = 10, val_every: int = 1, # Validate every n epochs visualize_every: int = 10, # Visualize every n validations visualize_train_every: int = 0, # Visualize every n training epochs visualize_n_samples: int = 1, - model_ckpt_path: Optional[str] = None, ): """Train the model for a given number of epochs. Args: @@ -260,53 +262,44 @@ def train( Returns: None """ - if model_ckpt_path is not None: - self._load_checkpoint(model_ckpt_path) print( - colorize( - f"[*] Training {self._run_name} for {epochs} epochs", - project_conf.ANSI_COLORS["green"], + Text( + f"[*] Training {self._run_name} for {epochs} epochs", style="bold green" ) ) self._viz_n_samples = visualize_n_samples - train_losses: List[float] = [] - val_losses: List[float] = [] - """ ==================== Training loop ==================== """ + self._gui.set_start_epoch(self._epoch) + # ==================== Training loop ==================== + last_val_loss = float("inf") for epoch in range(self._epoch, epochs): + print(f"Epoch: {epoch}") self._epoch = epoch # Update for the model saver if not self._running: break self._model.train() - self._pbar.colour = project_conf.Theme.TRAINING.value - train_losses.append( - self._train_epoch( - f"Epoch {epoch}/{epochs}: Training", - visualize_train_every > 0 - and (epoch + 1) % visualize_train_every == 0, - epoch, - last_val_loss=val_losses[-1] - if len(val_losses) > 0 - else float("inf"), - ) + train_loss: float = await asyncio.to_thread( + self._train_epoch, + f"Epoch {epoch}/{epochs}: Training", + visualize_train_every > 0 and (epoch + 1) % visualize_train_every == 0, + epoch, + last_val_loss=last_val_loss, ) if epoch % val_every == 0: self._model.eval() - self._pbar.colour = project_conf.Theme.VALIDATION.value - val_losses.append( - self._val_epoch( - f"Epoch {epoch}/{epochs}: Validation", - visualize_every > 0 and (epoch + 1) % visualize_every == 0, - epoch, - ) + val_loss = await asyncio.to_thread( + self._val_epoch, + f"Epoch {epoch}/{epochs}: Validation", + visualize_every > 0 and (epoch + 1) % visualize_every == 0, + epoch, ) + last_val_loss = val_loss if self._scheduler is not None: - self._scheduler.step() - """ ==================== Plotting ==================== """ - if project_conf.PLOT_ENABLED: - self._plot(epoch, train_losses, val_losses) - self._pbar.close() - self._save_checkpoint( - val_losses[-1], + await asyncio.to_thread(self._scheduler.step) + # ==================== Plotting ==================== + self._gui.plot(epoch, train_loss, last_val_loss) # , self._model_saver) + await asyncio.to_thread( + self._save_checkpoint, + last_val_loss, os.path.join(HydraConfig.get().runtime.output_dir, "last.ckpt"), ) print(f"[*] Training finished for {self._run_name}!") @@ -315,68 +308,6 @@ def train( + f"at epoch {self._model_saver.min_val_loss_epoch}." ) - @staticmethod - def _setup_plot(run_name: str, log_scale: bool = False): - """Setup the plot for training and validation losses.""" - plt.title(f"Training curves for {run_name}") - plt.theme("dark") - plt.xlabel("Epoch") - if log_scale: - plt.ylabel("Loss (log scale)") - plt.yscale("log") - else: - plt.ylabel("Loss") - plt.grid(True, True) - - def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]): - """Plot the training and validation losses. - Args: - epoch (int): Current epoch number. - train_losses (List[float]): List of training losses. - val_losses (List[float]): List of validation losses. - Returns: - None - """ - plt.clf() - if project_conf.LOG_SCALE_PLOT and any( - loss_val <= 0 for loss_val in train_losses + val_losses - ): - raise ValueError( - "Cannot plot on a log scale if there are non-positive losses." - ) - self._setup_plot(self._run_name, log_scale=project_conf.LOG_SCALE_PLOT) - plt.plot( - list(range(self._starting_epoch, epoch + 1)), - train_losses, - color=project_conf.Theme.TRAINING.value, - label="Training loss", - ) - plt.plot( - list(range(self._starting_epoch, epoch + 1)), - val_losses, - color=project_conf.Theme.VALIDATION.value, - label="Validation loss", - ) - best_metrics = ( - "[" - + ", ".join( - [ - f"{metric_name}={metric_value:.2e} " - for metric_name, metric_value in self._model_saver.best_metrics.items() - ] - ) - + "]" - ) - plt.scatter( - [self._model_saver.min_val_loss_epoch], - [self._model_saver.min_val_loss], - color="red", - marker="+", - label=f"Best model {best_metrics}", - style="inverted", - ) - plt.show() - def _save_checkpoint(self, val_loss: float, ckpt_path: str, **kwargs) -> None: """Saves the model and optimizer state to a checkpoint file. Args: @@ -403,18 +334,19 @@ def _save_checkpoint(self, val_loss: float, ckpt_path: str, **kwargs) -> None: ) def _load_checkpoint(self, ckpt_path: str, model_only: bool = False) -> None: - """Loads the model and optimizer state from a checkpoint file. This method should remain in - this class because it should be extendable in classes inheriting from this class, instead - of being overwritten/modified. That would be a source of bugs and a bad practice. + """Loads the model and optimizer state from a checkpoint file. This method + should remain in this class because it should be extendable in classes + inheriting from this class, instead of being overwritten/modified. That would be + a source of bugs and a bad practice. Args: ckpt_path (str): The path to the checkpoint file. - model_only (bool): If True, only the model is loaded (useful for BaseTester). + model_only (bool): If True, only the model is loaded (useful for + BaseTester). Returns: None """ print(f"[*] Restoring from checkpoint: {ckpt_path}") ckpt = torch.load(ckpt_path) - # If the model was optimized with torch.optimize() we need to remove the "_orig_mod" # prefix: if "_orig_mod" in list(ckpt["model_ckpt"].keys())[0]: ckpt["model_ckpt"] = { @@ -425,7 +357,8 @@ def _load_checkpoint(self, ckpt_path: str, model_only: bool = False) -> None: except Exception: if project_conf.PARTIALLY_LOAD_MODEL_IF_NO_FULL_MATCH: print( - "[!] Partially loading model weights (no full match between model and checkpoint)" + "[!] Partially loading model weights " + + "(no full match between model and checkpoint)" ) self._model.load_state_dict(ckpt["model_ckpt"], strict=False) if not model_only: @@ -448,7 +381,8 @@ def _terminator(self, sig, frame): and self._n_ctrl_c == 0 ): print( - f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}. Press Ctrl+C again to abort." + f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}." + + " Press Ctrl+C again to abort." ) self._n_ctrl_c += 1 elif ( diff --git a/test.py b/test.py index 8b642ab..3b8c4f9 100755 --- a/test.py +++ b/test.py @@ -8,12 +8,13 @@ from hydra_zen import store, zen -import conf.experiment # Must import the config to add all components to the store! # noqa +from bootstrap.launch_experiment import launch_experiment from conf import project as project_conf -from launch_experiment import launch_experiment +from conf.experiment import make_experiment_configs from utils import seed_everything if __name__ == "__main__": + make_experiment_configs() def set_test_mode(cfg): cfg.run.training_mode = False diff --git a/train.py b/train.py index 090f533..ea3242e 100755 --- a/train.py +++ b/train.py @@ -8,12 +8,13 @@ from hydra_zen import store, zen -import conf.experiment # Must import the config to add all components to the store! # noqa +from bootstrap.launch_experiment import launch_experiment from conf import project as project_conf -from launch_experiment import launch_experiment +from conf.experiment import make_experiment_configs from utils import seed_everything if __name__ == "__main__": + make_experiment_configs() "============ Hydra-Zen ============" store.add_to_hydra_store( overwrite_ok=True diff --git a/utils/__init__.py b/utils/__init__.py index 4832e55..933b4b4 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -8,22 +8,50 @@ # import importlib # import inspect +import os import random # import sys import traceback -from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Optional # import IPython import numpy as np import torch +from hydra.utils import to_absolute_path from torch import Tensor, nn -from tqdm import tqdm from conf import project as project_conf +def load_model_ckpt(load_from: Optional[str], training_mode: bool) -> Optional[str]: + model_ckpt_path = None + if load_from is not None: + if load_from.endswith(".ckpt"): + model_ckpt_path = to_absolute_path(load_from) + if not os.path.exists(model_ckpt_path): + raise ValueError(f"File {model_ckpt_path} does not exist!") + else: + run_models = sorted( + [ + f + for f in os.listdir(to_absolute_path(f"runs/{load_from}/")) + if f.endswith(".ckpt") + and (not f.startswith("last") if not training_mode else True) + ] + ) + if len(run_models) < 1: + raise ValueError(f"No model found in runs/{load_from}/") + model_ckpt_path = to_absolute_path( + os.path.join( + "runs", + load_from, + run_models[-1], + ) + ) + return model_ckpt_path + + def seed_everything(seed: int): torch.manual_seed(seed) # type: ignore np.random.seed(seed) @@ -69,44 +97,29 @@ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return wrapper -def colorize(string: str, ansii_code: Union[int, str]) -> str: - return f"\033[{ansii_code}m{string}\033[0m" - - -def blink_pbar(i: int, pbar: tqdm, n: int) -> None: - """Blink the progress bar every n iterations. - Args: - i (int): current iteration - pbar (tqdm): progress bar - n (int): blink every n iterations - """ - if i % n == 0: - pbar.colour = ( - project_conf.Theme.TRAINING.value - if pbar.colour == project_conf.Theme.VALIDATION.value - else project_conf.Theme.VALIDATION.value - ) - - -@contextmanager -def colorize_prints(ansii_code: Union[int, str]): - if isinstance(ansii_code, str): - ansii_code = project_conf.ANSI_COLORS[ansii_code] - print(f"\033[{ansii_code}m", end="") - try: - yield - finally: - print("\033[0m", end="") - - -def update_pbar_str(pbar: tqdm, string: str, color_code: int) -> None: - """Update the progress bar string. - Args: - pbar (tqdm): progress bar - string (str): string to update the progress bar with - color_code (int): color code for the string - """ - pbar.set_description_str(colorize(string, color_code)) +# def blink_pbar(i: int, pbar: tqdm, n: int) -> None: +# """Blink the progress bar every n iterations. +# Args: +# i (int): current iteration +# pbar (tqdm): progress bar +# n (int): blink every n iterations +# """ +# if i % n == 0: +# pbar.colour = ( +# project_conf.Theme.TRAINING.value +# if pbar.colour == project_conf.Theme.VALIDATION.value +# else project_conf.Theme.VALIDATION.value +# ) +# +# +# def update_pbar_str(pbar: tqdm, string: str, color_code: int) -> None: +# """Update the progress bar string. +# Args: +# pbar (tqdm): progress bar +# string (str): string to update the progress bar with +# color_code (int): color code for the string +# """ +# pbar.set_description_str(colorize(string, color_code)) def get_function_frame(func, exc_traceback): diff --git a/utils/gui.py b/utils/gui.py new file mode 100644 index 0000000..0daceea --- /dev/null +++ b/utils/gui.py @@ -0,0 +1,451 @@ +import asyncio +from collections import abc +from datetime import datetime +from enum import Enum +from functools import partial +from itertools import cycle +from random import random +from time import sleep +from typing import ( + Any, + Callable, + Iterable, + Iterator, + Optional, + Sequence, + Tuple, +) + +import numpy as np +import torch +import torch.multiprocessing as mp +from rich.console import Group, RenderableType +from rich.pretty import Pretty +from rich.text import Text +from textual.app import App, ComposeResult +from textual.containers import Center +from textual.reactive import var +from textual.widgets import ( + Footer, + Header, + Label, + Placeholder, + ProgressBar, + RichLog, + Static, +) +from textual_plotext import PlotextPlot +from torch.utils.data.dataloader import DataLoader +from torchvision.datasets import MNIST +from torchvision.transforms.functional import to_tensor + + +class PlotterWidget(PlotextPlot): + marker: var[str] = var("sd") + + """The type of marker to use for the plot.""" + + def __init__( + self, + title: str, + use_log_scale: bool = False, + *, + name: str | None = None, + id: str | None = None, # pylint:disable=redefined-builtin + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialise the training curves plotter widget. + + Args: + name: The name of the plotter widget. + id: The ID of the plotter widget in the DOM. + classes: The CSS classes of the plotter widget. + disabled: Whether the plotter widget is disabled or not. + """ + super().__init__( + name=name, + id=id, + classes=classes, + disabled=disabled, + ) + self._title = title + self._log_scale = use_log_scale + self._train_losses: list[float] = [] + self._val_losses: list[float] = [] + self._start_epoch = 0 + self._epoch = 0 + + def on_mount(self) -> None: + """Plot the data using Plotext.""" + self.plt.title(self._title) + self.plt.xlabel("Epoch") + if self._log_scale: + self.plt.ylabel("Loss (log scale)") + self.plt.yscale("log") + else: + self.plt.ylabel("Loss") + self.plt.grid(True, True) + + def replot(self) -> None: + """Redraw the plot.""" + self.plt.clear_data() + if self._log_scale and ( + self._train_losses[-1] <= 0 or self._val_losses[-1] <= 0 + ): + raise ValueError( + "Cannot plot on a log scale if there are non-positive losses." + ) + if len(self._train_losses) > 0: + assert len(self._val_losses) == len(self._train_losses) + self.plt.plot( + list(range(self._start_epoch, self._epoch + 1)), + self._train_losses, + color="blue", # TODO: Theme + label="Training loss", + marker=self.marker, + ) + self.plt.plot( + list(range(self._start_epoch, self._epoch + 1)), + self._val_losses, + color="green", # TODO: Theme + label="Validation loss", + marker=self.marker, + ) + self.refresh() + + def set_start_epoch(self, start_epoch: int): + self._start_epoch = start_epoch + + def update( + self, epoch: int, train_loss: float, val_loss: Optional[float] = None + ) -> None: + """Update the data for the training curves plot. + + Args: + epoch: (int) The current epoch number. + train_loss: (float) The last training loss. + val_loss: (float) The last validation loss. + """ + self._epoch = epoch + self._train_losses.append(train_loss) + self._val_losses.append( + val_loss if val_loss is not None else self._val_losses[-1] + ) + self.replot() + + def _watch_marker(self) -> None: + """React to the marker being changed.""" + self.replot() + + +# TODO: Also make a Rich renderable for Tensors (using tables?) + + +class Task(Enum): + IDLE = -1 + TRAINING = 0 + VALIDATION = 1 + TESTING = 2 + + +class DatasetProgressBar(Static): + """A progress bar for PyTorch dataloader iteration.""" + + DESCRIPTIONS = { + Task.IDLE: Text("Waiting for work..."), + Task.TRAINING: Text("Training: ", style="bold blue"), + Task.VALIDATION: Text("Validation: ", style="bold green"), + Task.TESTING: Text("Testing: ", style="bold yellow"), + } + + def compose(self) -> ComposeResult: + with Center(): + yield Label(self.DESCRIPTIONS[Task.IDLE], id="progress_label") + yield ProgressBar() + + def track_iterable( + self, + iterable: Iterable | Sequence | Iterator | DataLoader, + task: Task, + total: int, + ) -> Tuple[Iterable, Callable]: + class LossHook: + def __init__(self): + self._loss = None + + def update_loss_hook(self, loss: float) -> None: + """Update the loss value in the progress bar.""" + # TODO: min_val_loss during validation, val_loss during training. + # Ideally the second parameter would be super flexible (use a dict + # then). + self._loss = loss + + class SeqWrapper(abc.Iterator, LossHook): + def __init__( + self, + seq: Sequence, + length: int, + update_hook: Callable, + reset_hook: Callable, + ): + super().__init__() + self._sequence = seq + self._idx = 0 + self._len = length + self._update_hook = update_hook + self._reset_hook = reset_hook + + def __next__(self): + if self._idx >= self._len: + self._reset_hook() + raise StopIteration + item = self._sequence[self._idx] + self._update_hook(loss=self._loss) + self._idx += 1 + return item + + class IteratorWrapper(abc.Iterator, LossHook): + def __init__( + self, + iterator: Iterator | DataLoader, + length: int, + update_hook: Callable, + reset_hook: Callable, + ): + super().__init__() + self._iterator = iter(iterator) + self._len = length + self._update_hook = update_hook + self._reset_hook = reset_hook + + def __next__(self): + try: + item = next(self._iterator) + self._update_hook(loss=self._loss) + return item + except StopIteration: + self._reset_hook() + raise StopIteration + + def update_hook(loss: Optional[float] = None): + self.query_one(ProgressBar).advance() + if loss is not None: + plabel: Label = self.query_one("#progress_label") # type: ignore + plabel.update(self.DESCRIPTIONS[task] + f"[loss={loss:.4f}]") + + def reset_hook(): + sleep(0.5) + self.query_one(ProgressBar).update(total=100, progress=0) + plabel: Label = self.query_one("#progress_label") # type: ignore + plabel.update(self.DESCRIPTIONS[Task.IDLE]) + + wrapper = None + update_p, reset_p = partial(update_hook), partial(reset_hook) + if isinstance(iterable, abc.Sequence): + wrapper = SeqWrapper(iterable, total, update_p, reset_p) + elif isinstance(iterable, (abc.Iterator, DataLoader)): + wrapper = IteratorWrapper(iterable, total, update_p, reset_p) + else: + raise ValueError( + f"iterable must be a Sequence or an Iterator, got {type(iterable)}" + ) + self.query_one(ProgressBar).update(total=total, progress=0) + plabel: Label = self.query_one("#progress_label") # type: ignore + plabel.update(self.DESCRIPTIONS[task]) + return wrapper, wrapper.update_loss_hook + + +class GUI(App): + """ + A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework. + """ + + TITLE = "Matchbox TUI" + CSS_PATH = "style.css" + + BINDINGS = [ + ("q", "quit", "Quit"), + ("d", "toggle_dark", "Toggle dark mode"), + ("p", "marker", "Change plotter style"), + ("ctrl+z", "suspend_progress"), + ] + + MARKERS = { + "dot": "Dot", + "hd": "High Definition", + "fhd": "Higher Definition", + "braille": "Braille", + "sd": "Standard Definition", + } + + marker: var[str] = var("hd") + + def __init__(self, run_name: str, log_scale: bool) -> None: + """Initialise the application.""" + super().__init__() + self._markers = cycle(self.MARKERS.keys()) + self._log_scale = log_scale + self.run_name = run_name + + def compose(self) -> ComposeResult: + yield Header() + yield PlotterWidget( + title=f"Trainign curves for {self.run_name}", + use_log_scale=self._log_scale, + classes="box", + ) + yield RichLog( + highlight=True, markup=True, wrap=True, id="logger", classes="box" + ) + yield DatasetProgressBar() + yield Placeholder(classes="box") + yield Footer() + + def on_mount(self): + self.query_one(PlotterWidget).loading = True + + def action_toggle_dark(self) -> None: + self.dark = not self.dark # skipcq: PYL-W0201 + + def watch_marker(self) -> None: + """React to the marker type being changed.""" + self.sub_title = self.MARKERS[self.marker] # skipcq: PYL-W0201 + self.query_one(PlotterWidget).marker = self.marker + + def action_marker(self) -> None: + """Cycle to the next marker type.""" + self.marker = next(self._markers) + + def print(self, message: Any): + logger: RichLog = self.query_one(RichLog) + if isinstance(message, (RenderableType, str)): + logger.write( + Group( + Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""), + message, + ), + ) + else: + ppable, pp_msg = True, None + try: + pp_msg = Pretty(message) + except Exception: + ppable = False + if ppable and pp_msg is not None: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text(str(type(message)) + " ", style="italic blue", end=""), + pp_msg, + ) + ) + else: + try: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + message, + ), + ) + except Exception as e: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text("Logging error: ", style="bold red"), + Text(str(e), style="bold red"), + ) + ) + + def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: + """Return an iterable that tracks the progress of the training process, and a + progress bar hook to update the loss value at each iteration.""" + return self.query_one(DatasetProgressBar).track_iterable( + iterable, Task.TRAINING, total + ) + + def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]: + """Return an iterable that tracks the progress of the validation process, and a + progress bar hook to update the loss value at each iteration.""" + return self.query_one(DatasetProgressBar).track_iterable( + iterable, Task.VALIDATION, total + ) + + def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]: + """Return an iterable that tracks the progress of the testing process, and a + progress bar hook to update the loss value at each iteration.""" + return self.query_one(DatasetProgressBar).track_iterable( + iterable, Task.TESTING, total + ) + + def plot( + self, + epoch: int, + train_loss: float, + val_loss: Optional[float] = None, + ) -> None: + """Plot the training and validation losses for the current epoch.""" + self.query_one(PlotterWidget).loading = False + self.query_one(PlotterWidget).update(epoch, train_loss, val_loss) + + def set_start_epoch(self, start_epoch: int) -> None: + """Set the starting epoch for the plotter widget.""" + self.query_one(PlotterWidget).set_start_epoch(start_epoch) + + +async def run_my_app(): + gui = GUI("test-run", log_scale=False) + task = asyncio.create_task(gui.run_async()) + while not gui.is_running: + await asyncio.sleep(0.01) # Wait for the app to start up + gui.print("Hello, World!") + await asyncio.sleep(2) + gui.print(Text("Let's log some tensors :)", style="bold magenta")) + await asyncio.sleep(0.5) + gui.print(torch.rand(2, 4)) + await asyncio.sleep(2) + gui.print(Text("How about some numpy arrays?!", style="italic green")) + await asyncio.sleep(1) + gui.print(np.random.rand(3, 3)) + pbar, update_progress_loss = gui.track_training(range(10), 10) + for i, e in enumerate(pbar): + gui.print(f"[{i+1}/10]: We can iterate over iterables") + gui.print(e) + await asyncio.sleep(0.1) + await asyncio.sleep(2) + mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) + # Somehow, the dataloader will crash if it's not forked when using multiprocessing + # along with Textual. + mp.set_start_method("fork") + dataloader = DataLoader(mnist, 32, shuffle=True, num_workers=2) + pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader)) + for i, batch in enumerate(pbar): + await asyncio.sleep(0.01) + if i % 10 == 0: + gui.print(batch) + update_progress_loss(random()) + gui.plot(epoch=i, train_loss=random(), val_loss=random()) + gui.print( + f"[{i+1}/{len(dataloader)}]: " + + "We can also iterate over PyTorch dataloaders!" + ) + if i == 0: + gui.print(batch) + gui.print("Goodbye, world!") + _ = await task + + +if __name__ == "__main__": + asyncio.run(run_my_app()) diff --git a/utils/style.css b/utils/style.css new file mode 100644 index 0000000..283b8a2 --- /dev/null +++ b/utils/style.css @@ -0,0 +1,22 @@ +Screen { + layout: grid; + grid-size: 2; + grid-columns: 3fr 1fr; + grid-rows: 95% 5%; +} + +.box { + height: 100%; + border: solid green; +} + + +Center { + margin-top: 1; + margin-bottom: 1; + layout: horizontal; +} + +ProgressBar { + padding-left: 3; +}