diff --git a/.gitignore b/.gitignore index 2ef21c8..dd745d5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,169 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + + +.vscode/ +checkpoints/ +wandb/ +outputs/ +datasets/* /logs/ /checkpoints/ diff --git a/src/configs/computer/1-gpu.yaml b/src/configs/computer/1-gpu.yaml new file mode 100644 index 0000000..d46e420 --- /dev/null +++ b/src/configs/computer/1-gpu.yaml @@ -0,0 +1,7 @@ +devices: 1 +progress_bar_refresh_rate: 2 +num_workers: 8 +accelerator: gpu +precision: bf16-mixed +strategy: auto +num_nodes: 1 \ No newline at end of file diff --git a/src/configs/computer/2-gpu.yaml b/src/configs/computer/2-gpu.yaml new file mode 100644 index 0000000..2a4f591 --- /dev/null +++ b/src/configs/computer/2-gpu.yaml @@ -0,0 +1,7 @@ +devices: 2 +progress_bar_refresh_rate: 2 +num_workers: 8 +accelerator: gpu +precision: bf16-mixed +strategy: ddp +num_nodes: 1 \ No newline at end of file diff --git a/src/configs/computer/4-gpu.yaml b/src/configs/computer/4-gpu.yaml new file mode 100644 index 0000000..c36c65f --- /dev/null +++ b/src/configs/computer/4-gpu.yaml @@ -0,0 +1,7 @@ +devices: 4 +progress_bar_refresh_rate: 2 +num_workers: 8 +accelerator: gpu +precision: bf16-mixed +strategy: ddp +num_nodes: 1 \ No newline at end of file diff --git a/src/configs/data/additional_train_transforms/randaugment.yaml b/src/configs/data/additional_train_transforms/randaugment.yaml new file mode 100644 index 0000000..11574f8 --- /dev/null +++ b/src/configs/data/additional_train_transforms/randaugment.yaml @@ -0,0 +1,11 @@ +instance: + - _target_: torchvision.transforms.v2.RandomApply + transforms: + - _target_: torchvision.transforms.v2.RandAugment + magnitude: ${data.additional_train_transforms.randaugment_magnitude} + p: ${data.additional_train_transforms.randaugment_p} + + + +randaugment_magnitude: 6 +randaugment_p: 0.1 \ No newline at end of file diff --git a/src/configs/data/imagenet.yaml b/src/configs/data/imagenet.yaml new file mode 100644 index 0000000..4ff41da --- /dev/null +++ b/src/configs/data/imagenet.yaml @@ -0,0 +1,24 @@ +defaults: + - additional_train_transforms: randaugment + - _self_ + +full_batch_size: 128 +size: 256 +num_classes: 1000 +name: Imagenet-1K + +dataset_builder: + _partial_: true + _target_: utils.data.build_imagenet + data_dir: ${data_dir}/imagenet + num_classes: ${data.num_classes} + size: ${data.size} + additional_transforms: ${data.additional_train_transforms.instance} + +datamodule: + _target_: utils.datamodule.ImageDataModule + dataset_builder: ${data.dataset_builder} + full_batch_size: ${data.full_batch_size} + num_workers: ${computer.num_workers} + num_nodes: ${computer.num_nodes} + num_devices: ${computer.devices} \ No newline at end of file diff --git a/src/configs/logger/tensorboard.yaml b/src/configs/logger/tensorboard.yaml new file mode 100644 index 0000000..e69de29 diff --git a/src/configs/logger/wandb.yaml b/src/configs/logger/wandb.yaml new file mode 100644 index 0000000..05b1958 --- /dev/null +++ b/src/configs/logger/wandb.yaml @@ -0,0 +1,6 @@ +_target_: pytorch_lightning.loggers.WandbLogger +save_dir: ${root_dir}/wandb +name: ${experiment_name} +project: Homm +log_model: False +offline: False \ No newline at end of file diff --git a/src/configs/model/homm.yaml b/src/configs/model/homm.yaml new file mode 100644 index 0000000..d42ca5a --- /dev/null +++ b/src/configs/model/homm.yaml @@ -0,0 +1,24 @@ +defaults: + - train_batch_preprocess: mixup_and_cutmix + - network: homm + - loss: CE + - optimizer: adamw + - lr_scheduler: warmup_cosine_decay + - train_metric: accuracy + - val_metric: accuracy + - test_metric: accuracy + +name: HoMM +instance: + _target_: model.classification.ClassificationModule + model: ${model.network} + loss: ${model.loss} + optimizer_cfg: ${model.optimizer} + lr_scheduler_builder: ${model.lr_scheduler} + train_batch_preprocess: ${model.train_batch_preprocess} + train_metrics: ${model.train_metric} + val_metrics: ${model.val_metric} + test_metrics: ${model.test_metric} + + +lsuv_normalize: False \ No newline at end of file diff --git a/src/configs/model/loss/CE.yaml b/src/configs/model/loss/CE.yaml new file mode 100644 index 0000000..908308e --- /dev/null +++ b/src/configs/model/loss/CE.yaml @@ -0,0 +1,2 @@ +_target_: model.losses.CE.CrossEntropyLossModule +sum_label_dim: True \ No newline at end of file diff --git a/src/configs/model/lr_scheduler/warmup.yaml b/src/configs/model/lr_scheduler/warmup.yaml new file mode 100644 index 0000000..168a29f --- /dev/null +++ b/src/configs/model/lr_scheduler/warmup.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: utils.lr_scheduler.WarmupLR +warmup_steps: 10000 + diff --git a/src/configs/model/lr_scheduler/warmup_cosine_decay.yaml b/src/configs/model/lr_scheduler/warmup_cosine_decay.yaml new file mode 100644 index 0000000..76c18cf --- /dev/null +++ b/src/configs/model/lr_scheduler/warmup_cosine_decay.yaml @@ -0,0 +1,5 @@ +_partial_: true +_target_: utils.lr_scheduler.WarmupCosineDecayLR +warmup_steps: 10000 +total_steps: ${trainer.max_steps} + diff --git a/src/configs/model/network/homm.yaml b/src/configs/model/network/homm.yaml new file mode 100644 index 0000000..4135eb0 --- /dev/null +++ b/src/configs/model/network/homm.yaml @@ -0,0 +1,12 @@ +_target_: model.network.vision.HoMVision +nb_classes: ${data.num_classes} +dim: 256 +im_size: 256 +kernel_size: 32 +nb_layers: 4 +order: 2 +order_expand: 4 +ffw_expand: 4 +dropout: 0.0 +pooling: cls +in_conv: true diff --git a/src/configs/model/network/vit.yaml b/src/configs/model/network/vit.yaml new file mode 100644 index 0000000..e69de29 diff --git a/src/configs/model/optimizer/adam.yaml b/src/configs/model/optimizer/adam.yaml new file mode 100755 index 0000000..d6cca7c --- /dev/null +++ b/src/configs/model/optimizer/adam.yaml @@ -0,0 +1,8 @@ +optim: + _partial_: true + _target_: torch.optim.Adam + lr: 1e-3 + betas: [0.9, 0.999] + weight_decay: 0.01 + +exclude_ln_and_biases_from_weight_decay: False \ No newline at end of file diff --git a/src/configs/model/optimizer/adamw.yaml b/src/configs/model/optimizer/adamw.yaml new file mode 100755 index 0000000..01fd575 --- /dev/null +++ b/src/configs/model/optimizer/adamw.yaml @@ -0,0 +1,8 @@ +optim: + _partial_: true + _target_: torch.optim.AdamW + lr: 1e-3 + betas: [0.9, 0.999] + weight_decay: 0.01 + +exclude_ln_and_biases_from_weight_decay: False \ No newline at end of file diff --git a/src/configs/model/optimizer/sgd.yaml b/src/configs/model/optimizer/sgd.yaml new file mode 100644 index 0000000..663ff94 --- /dev/null +++ b/src/configs/model/optimizer/sgd.yaml @@ -0,0 +1,7 @@ +optim: + _partial_: true + _target_: torch.optim.SGD + lr: 1e-3 + weight_decay: 0.01 + +exclude_ln_and_biases_from_weight_decay: False \ No newline at end of file diff --git a/src/configs/model/test_metric/accuracy.yaml b/src/configs/model/test_metric/accuracy.yaml new file mode 100644 index 0000000..0c6874a --- /dev/null +++ b/src/configs/model/test_metric/accuracy.yaml @@ -0,0 +1,2 @@ +_target_: utils.metrics.ClassificationMetrics +num_classes: ${data.num_classes} \ No newline at end of file diff --git a/src/configs/model/train_batch_preprocess/cutmix.yaml b/src/configs/model/train_batch_preprocess/cutmix.yaml new file mode 100644 index 0000000..609a8ed --- /dev/null +++ b/src/configs/model/train_batch_preprocess/cutmix.yaml @@ -0,0 +1,4 @@ +_target_: utils.mixup.CutMix +apply_transform_prob: 1.0 +alpha: 1.0 +num_classes: 1000 diff --git a/src/configs/model/train_batch_preprocess/mixup.yaml b/src/configs/model/train_batch_preprocess/mixup.yaml new file mode 100644 index 0000000..eb8b2cc --- /dev/null +++ b/src/configs/model/train_batch_preprocess/mixup.yaml @@ -0,0 +1,4 @@ +_target_: utils.mixup.MixUp +apply_transform_prob: 1.0 +alpha: 0.1 +num_classes: 1000 diff --git a/src/configs/model/train_batch_preprocess/mixup_and_cutmix.yaml b/src/configs/model/train_batch_preprocess/mixup_and_cutmix.yaml new file mode 100644 index 0000000..ee0755e --- /dev/null +++ b/src/configs/model/train_batch_preprocess/mixup_and_cutmix.yaml @@ -0,0 +1,6 @@ +_target_: utils.mixup.CutMixUp +apply_transform_prob: 1.0 +mixup_prob: 0.5 +alpha_mixup: 0.1 +alpha_cutmix: 1.0 +num_classes: 1000 diff --git a/src/configs/model/train_metric/accuracy.yaml b/src/configs/model/train_metric/accuracy.yaml new file mode 100644 index 0000000..0c6874a --- /dev/null +++ b/src/configs/model/train_metric/accuracy.yaml @@ -0,0 +1,2 @@ +_target_: utils.metrics.ClassificationMetrics +num_classes: ${data.num_classes} \ No newline at end of file diff --git a/src/configs/model/val_metric/accuracy.yaml b/src/configs/model/val_metric/accuracy.yaml new file mode 100644 index 0000000..0c6874a --- /dev/null +++ b/src/configs/model/val_metric/accuracy.yaml @@ -0,0 +1,2 @@ +_target_: utils.metrics.ClassificationMetrics +num_classes: ${data.num_classes} \ No newline at end of file diff --git a/src/configs/train.yaml b/src/configs/train.yaml new file mode 100644 index 0000000..6850941 --- /dev/null +++ b/src/configs/train.yaml @@ -0,0 +1,42 @@ +defaults: + - model: homm + - computer: 1-gpu + - data: imagenet + - logger: wandb + - _self_ + +trainer: + _target_: pytorch_lightning.Trainer + max_steps: 300000 + deterministic: True + devices: ${computer.devices} + accelerator: ${computer.accelerator} + strategy: ${computer.strategy} + log_every_n_steps: 1 + num_nodes: ${computer.num_nodes} + precision: ${computer.precision} + gradient_clip_val: 1.0 + +checkpoints: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + dirpath: ${checkpoint_dir}/${experiment_name} + save_last: True + monitor: val/loss + mode: min + +progress_bar: + _target_: pytorch_lightning.callbacks.TQDMProgressBar + refresh_rate: ${computer.progress_bar_refresh_rate} + +seed: 3407 +data_dir: ${root_dir}/datasets +root_dir: ${hydra:runtime.cwd} +checkpoint_dir: ${root_dir}/checkpoints +experiment_name_suffix: base +experiment_name: ${data.name}_${model.name}_${experiment_name_suffix} + +hydra: + run: + dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: true \ No newline at end of file diff --git a/src/model/classification.py b/src/model/classification.py new file mode 100644 index 0000000..c72f43c --- /dev/null +++ b/src/model/classification.py @@ -0,0 +1,148 @@ +import os +from typing import Any +import pytorch_lightning as L +import torch +import torch.nn as nn +from hydra.utils import instantiate +import copy +import numpy as np + + +class ClassificationModule(L.LightningModule): + def __init__( + self, + model, + loss, + optimizer_cfg, + lr_scheduler_builder, + train_batch_preprocess, + train_metrics, + val_metrics, + test_metrics, + ): + super().__init__() + self.model = model + self.loss = loss + self.optimizer_cfg = optimizer_cfg + self.lr_scheduler_builder = lr_scheduler_builder + self.train_batch_preprocess = train_batch_preprocess + self.train_metrics = train_metrics + self.val_metrics = val_metrics + self.test_metrics = test_metrics + + def training_step(self, batch, batch_idx): + img, label = batch + img, label = self.train_batch_preprocess(img, label) + pred = self.model(img) + loss = self.loss(pred, label, average=True) + for metric_name, metric_value in loss.items(): + self.log( + f"train/{metric_name}", + metric_value, + sync_dist=True, + on_step=True, + on_epoch=True, + ) + self.train_metrics.update(pred, label) + return loss + + def on_train_epoch_end(self): + metrics = self.train_metrics.compute() + for metric_name, metric_value in metrics.items(): + self.log( + f"train/{metric_name}", + metric_value, + sync_dist=True, + on_step=False, + on_epoch=True, + ) + + def validation_step(self, batch, batch_idx): + img, label = batch + pred = self.model(img) + loss = self.loss(pred, label, average=True) + self.val_metrics(pred, label) + self.log("val/loss", loss["loss"], sync_dist=True, on_step=False, on_epoch=True) + + def on_validation_epoch_end(self): + metrics = self.val_metrics.compute() + for metric_name, metric_value in metrics.items(): + self.log( + f"val/{metric_name}", + metric_value, + sync_dist=True, + on_step=False, + on_epoch=True, + ) + + def test_step(self, batch, batch_idx): + img, label = batch + pred = self.model(img) + loss = self.loss(pred, label, average=True) + self.test_metrics.update(pred, label) + self.log( + "test/loss", loss["loss"], sync_dist=True, on_step=False, on_epoch=True + ) + + def test_epoch_end(self, outputs): + metrics = self.test_metrics.compute() + for metric_name, metric_value in metrics.items(): + self.log( + f"test/{metric_name}", + metric_value, + sync_dist=True, + on_step=False, + on_epoch=True, + ) + + def configure_optimizers(self): + if self.optimizer_cfg.exclude_ln_and_biases_from_weight_decay: + parameters_names_wd = get_parameter_names(self.model, [nn.LayerNorm]) + parameters_names_wd = [ + name for name in parameters_names_wd if "bias" not in name + ] + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in self.model.named_parameters() + if n in parameters_names_wd + ], + "weight_decay": self.optimizer_cfg.optimizer.optim.weight_decay, + "layer_adaptation": True, # for lamb + }, + { + "params": [ + p + for n, p in self.model.named_parameters() + if n not in parameters_names_wd + ], + "weight_decay": 0.0, + "layer_adaptation": False, # for lamb + }, + ] + optimizer = self.optimizer_cfg.optim(optimizer_grouped_parameters) + else: + optimizer = self.optimizer_cfg.optim(self.model.parameters()) + scheduler = self.lr_scheduler_builder(optimizer) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + def lr_scheduler_step(self, scheduler, metric): + scheduler.step(self.global_step) + + +def get_parameter_names(model, forbidden_layer_types): + """ + Returns the names of the model parameters that are not inside a forbidden layer. + Taken from HuggingFace transformers. + """ + result = [] + for name, child in model.named_children(): + result += [ + f"{name}.{n}" + for n in get_parameter_names(child, forbidden_layer_types) + if not isinstance(child, tuple(forbidden_layer_types)) + ] + # Add model specific parameters (defined with nn.Parameter) since they are not in any child. + result += list(model._parameters.keys()) + return result diff --git a/src/model/losses/CE.py b/src/model/losses/CE.py new file mode 100644 index 0000000..c2c14b0 --- /dev/null +++ b/src/model/losses/CE.py @@ -0,0 +1,17 @@ +import torch.nn as nn + + +class CrossEntropyLossModule(nn.Module): + def __init__(self, sum_label_dim=True): + super().__init__() + self.loss = nn.BCEWithLogitsLoss(reduction="none") + self.sum_label_dim = sum_label_dim + + def forward(self, pred, label, average=True): + loss = self.loss(pred, label) + if self.sum_label_dim: + loss = loss.sum(dim=1) + if average: + loss = loss.mean() + output = {"loss": loss} + return output diff --git a/src/model/losses/__init__.py b/src/model/losses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/network/__init__.py b/src/model/network/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/layers.py b/src/model/network/layers.py similarity index 100% rename from src/model/layers.py rename to src/model/network/layers.py diff --git a/src/model/vision.py b/src/model/network/vision.py similarity index 67% rename from src/model/vision.py rename to src/model/network/vision.py index ae8ddba..ddc8bba 100644 --- a/src/model/vision.py +++ b/src/model/network/vision.py @@ -3,8 +3,22 @@ import torch.nn as nn from .layers import HoMLayer + class HoMVision(nn.Module): - def __init__(self, nb_classes, dim=256, im_size=256, kernel_size=16, nb_layers=12, order=4, order_expand=8, ffw_expand=4, dropout=0., pooling='cls', in_conv=True): + def __init__( + self, + nb_classes, + dim=256, + im_size=256, + kernel_size=16, + nb_layers=12, + order=4, + order_expand=8, + ffw_expand=4, + dropout=0.0, + pooling="cls", + in_conv=True, + ): super().__init__() self.nb_classes = nb_classes self.dim = dim @@ -19,28 +33,28 @@ def __init__(self, nb_classes, dim=256, im_size=256, kernel_size=16, nb_layers=1 self.conv = None if in_conv: self.conv = nn.Conv2d(3, dim, kernel_size=kernel_size, stride=kernel_size) - self.layers = nn.ModuleList([HoMLayer(dim, order, order_expand, ffw_expand, dropout=dropout) for i in range(nb_layers)]) + self.layers = nn.ModuleList( + [ + HoMLayer(dim, order, order_expand, ffw_expand, dropout=dropout) + for i in range(nb_layers) + ] + ) self.out_proj = nn.Linear(dim, nb_classes) - n = (im_size//kernel_size)**2 + n = (im_size // kernel_size) ** 2 self.position = nn.Parameter(torch.randn((1, n, dim)), requires_grad=True) - nn.init.trunc_normal_( - self.position, std=0.02, a=-2 * 0.02, b=2 * 0.02 - ) + nn.init.trunc_normal_(self.position, std=0.02, a=-2 * 0.02, b=2 * 0.02) self.cls = nn.Parameter(torch.randn((1, 1, dim)), requires_grad=True) - nn.init.trunc_normal_( - self.cls, std=0.02, a=-2 * 0.02, b=2 * 0.02 - ) + nn.init.trunc_normal_(self.cls, std=0.02, a=-2 * 0.02, b=2 * 0.02) # init self.apply(self.init_weights_) nn.init.zeros_(self.out_proj.weight) - if pooling == 'cls': + if pooling == "cls": nn.init.constant_(self.out_proj.bias, -6.9) else: nn.init.zeros_(self.out_proj.bias) - def init_weights_(self, m): if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) @@ -48,19 +62,20 @@ def init_weights_(self, m): nn.init.zeros_(m.bias) def forward(self, x, mask=None): - if self.conv is not None: x = self.conv(x) - x = einops.rearrange(x, 'b c h w -> b (h w) c') + x = einops.rearrange(x, "b c h w -> b (h w) c") b, n, c = x.shape ones = torch.ones(b, 1, 1).to(x.device) if n == self.position.shape[1]: position = self.position else: - position = einops.rearrange(self.position, "b (m n) d ->b d m n", m=self.im_size//self.kernel_size) + position = einops.rearrange( + self.position, "b (m n) d ->b d m n", m=self.im_size // self.kernel_size + ) m = torch.sqrt(torch.tensor(n)).int() - position = nn.functional.interpolate(position, size=(m, m), mode='bicubic') + position = nn.functional.interpolate(position, size=(m, m), mode="bicubic") position = einops.rearrange(position, "b d m n -> b (m n) d") x = x + position * ones cls = self.cls * ones @@ -72,10 +87,9 @@ def forward(self, x, mask=None): for i in range(self.nb_layers): x = self.layers[i](x, mask=mask) - if self.pooling == 'cls': + if self.pooling == "cls": x = self.out_proj(x[:, 0, :]) else: - x = self.out_proj(x)[:, 1:, :] # return without cls + x = self.out_proj(x)[:, 1:, :] # return without cls return x - diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..dbb3a0b --- /dev/null +++ b/src/train.py @@ -0,0 +1,96 @@ +import hydra +import shutil +import os + +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning import seed_everything + +from pathlib import Path + +from omegaconf import OmegaConf +import torch + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("medium") + + +@hydra.main(config_path="configs", config_name="train", version_base=None) +def train(cfg): + dict_config = OmegaConf.to_container(cfg, resolve=True) + + Path(cfg.checkpoints.dirpath).mkdir(parents=True, exist_ok=True) + + print("Working directory : {}".format(os.getcwd())) + + # copy full config and overrides to checkpoint dir + shutil.copyfile( + Path(".hydra/config.yaml"), + f"{cfg.checkpoint_dir}/config.yaml", + ) + shutil.copyfile( + Path(".hydra/overrides.yaml"), + f"{cfg.checkpoint_dir}/overrides.yaml", + ) + + log_dict = {} + + log_dict["model"] = dict_config["model"] + + log_dict["data"] = dict_config["data"] + + log_dict["trainer"] = dict_config["trainer"] + + seed_everything(cfg.seed) + + datamodule = hydra.utils.instantiate(cfg.data.datamodule) + + checkpoint_callback = hydra.utils.instantiate(cfg.checkpoints) + + progress_bar = hydra.utils.instantiate(cfg.progress_bar) + + lr_monitor = LearningRateMonitor() + + callbacks = [ + checkpoint_callback, + progress_bar, + lr_monitor, + ] + + logger = hydra.utils.instantiate(cfg.logger) + logger.log_hyperparams(dict_config) + # Instantiate model and trainer + model = hydra.utils.instantiate(cfg.model.instance) + trainer = hydra.utils.instantiate( + cfg.trainer, + logger=logger, + callbacks=callbacks, + ) + # Resume experiments if last.ckpt exists for this experiment + ckpt_path = None + + if (Path(cfg.checkpoints.dirpath) / Path("last.ckpt")).exists(): + ckpt_path = Path(cfg.checkpoints.dirpath) / Path("last.ckpt") + else: + ckpt_path = None + if cfg.model.lsuv_normalize: + from lsuv import lsuv_with_dataloader + + datamodule.setup() + model.model = lsuv_with_dataloader( + model.model, + datamodule.train_dataloader(), + device=torch.device("cuda:0"), + verbose=False, + ) + torch.nn.init.zeros_(model.out_proj.weight) + torch.nn.init.constant_(model.out_proj.bias, -6.9) + # Log activation and gradients if wandb + if cfg.logger._target_ == "pytorch_lightning.loggers.wandb.WandbLogger": + logger.experiment.watch(model, log="all", log_graph=True, log_freq=100) + + trainer.fit(model, datamodule, ckpt_path=ckpt_path) + + +if __name__ == "__main__": + train() diff --git a/src/train_imagenet.py b/src/train_imagenet.py index 4faed5f..06eeb92 100644 --- a/src/train_imagenet.py +++ b/src/train_imagenet.py @@ -16,10 +16,11 @@ try: import accimage from torchvision import set_image_backend - set_image_backend('accimage') - print('Accelerated image loading available') + + set_image_backend("accimage") + print("Accelerated image loading available") except ImportError: - print('No accelerated image loading') + print("No accelerated image loading") from model.vision import HoMVision @@ -27,7 +28,6 @@ from utils.mixup import CutMixUp - def eval(model, val_ds, criterion): model.eval() val_loss = [] @@ -37,10 +37,24 @@ def eval(model, val_ds, criterion): imgs = imgs.to(device) lbls = lbls.to(device) outputs = model(imgs) - loss = criterion(outputs, nn.functional.one_hot(lbls, num_classes=1000).float()).sum(dim=1).mean().detach().cpu() + loss = ( + criterion( + outputs, nn.functional.one_hot(lbls, num_classes=1000).float() + ) + .sum(dim=1) + .mean() + .detach() + .cpu() + ) val_loss.append(loss) - val_acc.append(((outputs.argmax(dim=1) == lbls).sum() / lbls.shape[0]).detach().cpu()) - val.set_postfix_str(s='val loss {:5.02f} val acc {:5.02f}'.format(torch.stack(val_loss).mean(), 100. * torch.stack(val_acc).mean())) + val_acc.append( + ((outputs.argmax(dim=1) == lbls).sum() / lbls.shape[0]).detach().cpu() + ) + val.set_postfix_str( + s="val loss {:5.02f} val acc {:5.02f}".format( + torch.stack(val_loss).mean(), 100.0 * torch.stack(val_acc).mean() + ) + ) return torch.stack(val_loss).mean().item(), torch.stack(val_acc).mean().item() @@ -50,38 +64,38 @@ def eval(model, val_ds, criterion): parser = argparse.ArgumentParser() -parser.add_argument("--data_dir", help="path to imagenet") +parser.add_argument("--data_dir", help="path to imagenet") # ok parser.add_argument("--seed", type=int, default=3407) # model param -parser.add_argument("--dim", type=int, default=128) -parser.add_argument("--size", type=int, default=256) -parser.add_argument("--kernel_size", type=int, default=16) -parser.add_argument("--nb_layers", type=int, default=8) -parser.add_argument("--order", type=int, default=4) -parser.add_argument("--order_expand", type=int, default=8) -parser.add_argument("--ffw_expand", type=int, default=4) -parser.add_argument("--dropout", type=float, default=0.) -parser.add_argument("--wd", type=float, default=0.01) +parser.add_argument("--dim", type=int, default=128) # ok +parser.add_argument("--size", type=int, default=256) # ok +parser.add_argument("--kernel_size", type=int, default=16) # ok +parser.add_argument("--nb_layers", type=int, default=8) # ok +parser.add_argument("--order", type=int, default=4) # ok +parser.add_argument("--order_expand", type=int, default=8) # ok +parser.add_argument("--ffw_expand", type=int, default=4) # ok +parser.add_argument("--dropout", type=float, default=0.0) # ok +parser.add_argument("--wd", type=float, default=0.01) # ok # training params -parser.add_argument("--lr", type=float, default=0.001) -parser.add_argument("--batch_size", type=int, default=128) -parser.add_argument("--val_batch_size", type=int, default=25) -parser.add_argument("--max_iteration", type=int, default=300000) -parser.add_argument("--warmup", type=int, default=10000) -parser.add_argument("--num_worker", type=int, default=8) -parser.add_argument("--precision", type=str, default="bf16") +parser.add_argument("--lr", type=float, default=0.001) # ok +parser.add_argument("--batch_size", type=int, default=128) # ok +parser.add_argument("--val_batch_size", type=int, default=25) # ok +parser.add_argument("--max_iteration", type=int, default=300000) # ok +parser.add_argument("--warmup", type=int, default=10000) # ok +parser.add_argument("--num_worker", type=int, default=8) # ok +parser.add_argument("--precision", type=str, default="bf16") # ok # augment param -parser.add_argument("--ra", type=bool, default=False) -parser.add_argument("--ra_prob", type=float, default=0.1) -parser.add_argument("--mixup_prob", type=float, default=1.) +parser.add_argument("--ra", type=bool, default=False) # ok +parser.add_argument("--ra_prob", type=float, default=0.1) # ok +parser.add_argument("--mixup_prob", type=float, default=1.0) # ok # log param -parser.add_argument("--log_dir", type=str, default="./logs/") -parser.add_argument("--log_freq", type=int, default=5000) -parser.add_argument("--log_graph", type=bool, default=False) +parser.add_argument("--log_dir", type=str, default="./logs/") # ok +parser.add_argument("--log_freq", type=int, default=5000) # ok +parser.add_argument("--log_graph", type=bool, default=False) # ok # checkpoints -parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints/") -parser.add_argument("--load_checkpoint", type=str, default=None) -parser.add_argument("--load_weights", type=str, default=None) +parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints/") # ok +parser.add_argument("--load_checkpoint", type=str, default=None) # ok +parser.add_argument("--load_weights", type=str, default=None) # ok args = parser.parse_args() @@ -96,92 +110,156 @@ def eval(model, val_ds, criterion): # augment cutmix_or_mixup = CutMixUp() -randaug = [v2.RandomApply([v2.RandAugment(magnitude=6)], p=args.ra_prob)] if args.ra else None +randaug = ( + [v2.RandomApply([v2.RandAugment(magnitude=6)], p=args.ra_prob)] if args.ra else None +) # build dataset -train, val = build_imagenet(args.data_dir, size=args.size, additional_transforms=randaug) -train_ds = DataLoader(train, batch_size=args.batch_size, num_workers=args.num_worker, shuffle=True, prefetch_factor=4, pin_memory=True, persistent_workers=True, drop_last=True) +train, val = build_imagenet( + args.data_dir, size=args.size, additional_transforms=randaug +) +train_ds = DataLoader( + train, + batch_size=args.batch_size, + num_workers=args.num_worker, + shuffle=True, + prefetch_factor=4, + pin_memory=True, + persistent_workers=True, + drop_last=True, +) val_ds = DataLoader(val, batch_size=args.val_batch_size, num_workers=2) n_train = len(train_ds) epoch = args.max_iteration // n_train + 1 # loss crterion # criterion = nn.CrossEntropyLoss(label_smoothing=0.1) -criterion = nn.BCEWithLogitsLoss(reduction='none') +criterion = nn.BCEWithLogitsLoss(reduction="none") tr_loss = [] tr_acc = [] if args.load_checkpoint is not None: - print('loading model from checkpoint: {}'.format(args.load_checkpoint)) + print("loading model from checkpoint: {}".format(args.load_checkpoint)) ckpt = torch.load(args.load_checkpoint) - resume_args = SimpleNamespace(**ckpt['train_config']) - model = HoMVision(1000, resume_args.dim, resume_args.size, resume_args.kernel_size, resume_args.nb_layers, resume_args.order, resume_args.order_expand, - resume_args.ffw_expand, resume_args.dropout) - model.load_state_dict(ckpt['model']) + resume_args = SimpleNamespace(**ckpt["train_config"]) + model = HoMVision( + 1000, + resume_args.dim, + resume_args.size, + resume_args.kernel_size, + resume_args.nb_layers, + resume_args.order, + resume_args.order_expand, + resume_args.ffw_expand, + resume_args.dropout, + ) + model.load_state_dict(ckpt["model"]) model = model.to(device) - model_name = "i{}_k_{}_d{}_n{}_o{}_e{}_f{}".format(resume_args.size, resume_args.kernel_size, resume_args.dim, - resume_args.nb_layers, resume_args.order, resume_args.order_expand, resume_args.ffw_expand) - optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr, weight_decay=args.wd) - optimizer.load_state_dict(ckpt['optimizer']) + model_name = "i{}_k_{}_d{}_n{}_o{}_e{}_f{}".format( + resume_args.size, + resume_args.kernel_size, + resume_args.dim, + resume_args.nb_layers, + resume_args.order, + resume_args.order_expand, + resume_args.ffw_expand, + ) + optimizer = torch.optim.AdamW( + params=model.parameters(), lr=args.lr, weight_decay=args.wd + ) + optimizer.load_state_dict(ckpt["optimizer"]) scaler = torch.cuda.amp.GradScaler(enabled=True) - scaler.load_state_dict(ckpt['scaler']) - sched = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=args.lr, total_steps=args.max_iteration+1, - anneal_strategy='cos', pct_start=args.warmup / args.max_iteration, - last_epoch=ckpt['global_step']-1) - start_step = ckpt['global_step'] - start_epoch = start_step//n_train + scaler.load_state_dict(ckpt["scaler"]) + sched = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + max_lr=args.lr, + total_steps=args.max_iteration + 1, + anneal_strategy="cos", + pct_start=args.warmup / args.max_iteration, + last_epoch=ckpt["global_step"] - 1, + ) + start_step = ckpt["global_step"] + start_epoch = start_step // n_train else: - model = HoMVision(1000, args.dim, args.size, args.kernel_size, args.nb_layers, args.order, args.order_expand, - args.ffw_expand, args.dropout) + model = HoMVision( + 1000, + args.dim, + args.size, + args.kernel_size, + args.nb_layers, + args.order, + args.order_expand, + args.ffw_expand, + args.dropout, + ) model = model.to(device) - model = lsuv_with_dataloader(model, train_ds, device=torch.device(device), verbose=False) + model = lsuv_with_dataloader( + model, train_ds, device=torch.device(device), verbose=False + ) nn.init.zeros_(model.out_proj.weight) nn.init.constant_(model.out_proj.bias, -6.9) - model_name = "i{}_k_{}_d{}_n{}_o{}_e{}_f{}".format(args.size, args.kernel_size, args.dim, - args.nb_layers, args.order, args.order_expand, args.ffw_expand) - - optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr, weight_decay=args.wd) + model_name = "i{}_k_{}_d{}_n{}_o{}_e{}_f{}".format( + args.size, + args.kernel_size, + args.dim, + args.nb_layers, + args.order, + args.order_expand, + args.ffw_expand, + ) + + optimizer = torch.optim.AdamW( + params=model.parameters(), lr=args.lr, weight_decay=args.wd + ) scaler = torch.cuda.amp.GradScaler(enabled=True) - sched = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=args.lr, total_steps=args.max_iteration+1, - anneal_strategy='cos', pct_start=args.warmup/args.max_iteration) - start_step=1 - start_epoch=0 + sched = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + max_lr=args.lr, + total_steps=args.max_iteration + 1, + anneal_strategy="cos", + pct_start=args.warmup / args.max_iteration, + ) + start_step = 1 + start_epoch = 0 if args.load_weights is not None: - print('loading weights from chekpoint: {}'.format(args.load_weights)) + print("loading weights from chekpoint: {}".format(args.load_weights)) ckpt = torch.load(args.load_weights) - model.load_state_dict(ckpt['model']) + model.load_state_dict(ckpt["model"]) model = model.to(device) -print('model and optimizer built') -print('training model {}'.format(model_name)) +print("model and optimizer built") +print("training model {}".format(model_name)) # loging version = 0 -path = Path(args.log_dir+"/train/"+model_name+"_{}".format(version)) +path = Path(args.log_dir + "/train/" + model_name + "_{}".format(version)) if path.exists(): - while(path.exists()): + while path.exists(): version += 1 - path = Path(args.log_dir + "/train/" + model_name+"_{}".format(version)) -train_writer = SummaryWriter(args.log_dir+"/train/"+model_name+"_{}".format(version)) -val_writer = SummaryWriter(args.log_dir+"/val/"+model_name+"_{}".format(version)) + path = Path(args.log_dir + "/train/" + model_name + "_{}".format(version)) +train_writer = SummaryWriter( + args.log_dir + "/train/" + model_name + "_{}".format(version) +) +val_writer = SummaryWriter(args.log_dir + "/val/" + model_name + "_{}".format(version)) x = torch.randn((8, 3, args.size, args.size)).to(device) torchinfo.summary(model, input_data=x.to(device)) if args.log_graph: train_writer.add_graph(model, x) -train_writer.add_hparams(hparam_dict=vars(args), metric_dict={"version": version}, run_name="") +train_writer.add_hparams( + hparam_dict=vars(args), metric_dict={"version": version}, run_name="" +) # big loop i = start_step for e in range(start_epoch, epoch): # loop over the dataset multiple times - with tqdm(train_ds, desc='Epoch={}'.format(e)) as tepoch: + with tqdm(train_ds, desc="Epoch={}".format(e)) as tepoch: for imgs, lbls in tepoch: - # to gpu imgs = imgs.to(device) lbls = lbls.to(device) @@ -195,7 +273,6 @@ def eval(model, val_ds, criterion): optimizer.zero_grad() with torch.autocast(device_type=device, dtype=precision_type, enabled=True): - outputs = model(imgs) loss = criterion(outputs, lbls).sum(dim=1).mean() @@ -209,12 +286,18 @@ def eval(model, val_ds, criterion): sched.step() # print statistics running_loss = loss.detach().cpu() - running_acc = ((outputs.argmax(dim=1) == lbls.argmax(dim=1)).sum() / lbls.shape[0]).detach().cpu() + running_acc = ( + ((outputs.argmax(dim=1) == lbls.argmax(dim=1)).sum() / lbls.shape[0]) + .detach() + .cpu() + ) if i % 10 == 0: train_writer.add_scalar("loss", running_loss, global_step=i) train_writer.add_scalar("acc", running_acc, global_step=i) - tepoch.set_postfix_str(s='loss: {:5.02f} acc: {:5.02f}'.format(running_loss, 100 * running_acc)) + tepoch.set_postfix_str( + s="loss: {:5.02f} acc: {:5.02f}".format(running_loss, 100 * running_acc) + ) if i % args.log_freq == 0: l, a = eval(model, val_ds, criterion) @@ -222,29 +305,31 @@ def eval(model, val_ds, criterion): val_writer.add_scalar("acc", a, global_step=i) model.train() - checkpoint = {"model": model.state_dict(), - "optimizer": optimizer.state_dict(), - "scaler": scaler.state_dict(), - "global_step": i, - "train_config": vars(args) - } - torch.save(checkpoint, "{}/{}_{}.ckpt".format(args.checkpoint_dir, model_name, version)) + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scaler": scaler.state_dict(), + "global_step": i, + "train_config": vars(args), + } + torch.save( + checkpoint, + "{}/{}_{}.ckpt".format(args.checkpoint_dir, model_name, version), + ) if i >= args.max_iteration: - print('training finished, saving last model') - checkpoint = {"model": model.state_dict(), - "optimizer": optimizer.state_dict(), - "scaler": scaler.state_dict(), - "global_step": i, - "train_config": vars(args) - } - torch.save(checkpoint, "{}/{}_{}.ckpt".format(args.checkpoint_dir, model_name, version)) + print("training finished, saving last model") + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scaler": scaler.state_dict(), + "global_step": i, + "train_config": vars(args), + } + torch.save( + checkpoint, + "{}/{}_{}.ckpt".format(args.checkpoint_dir, model_name, version), + ) sys.exit(0) i += 1 - - - - - - diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/data.py b/src/utils/data.py index 83ba524..d6cb6fe 100644 --- a/src/utils/data.py +++ b/src/utils/data.py @@ -1,44 +1,69 @@ - from torchvision import transforms from torchvision.datasets import ImageNet, ImageFolder +import torch -normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) -denormalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], - std=[1./0.229, 1./0.224, 1./0.225]) +normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +denormalize = transforms.Normalize( + mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], + std=[1.0 / 0.229, 1.0 / 0.224, 1.0 / 0.225], +) -def build_imagenet(data_dir, device="cuda", size=224, additional_transforms=None): +def build_imagenet(data_dir, num_classes, size=224, additional_transforms=None): tr = [ transforms.RandomResizedCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - normalize + normalize, ] if additional_transforms is not None: tr.extend(additional_transforms) transform_train = transforms.Compose(tr) - transform_val = transforms.Compose([ - transforms.Resize(int(size/0.95)), - transforms.CenterCrop(size), - transforms.ToTensor(), - normalize - ]) + transform_val = transforms.Compose( + [ + transforms.Resize(int(size / 0.95)), + transforms.CenterCrop(size), + transforms.ToTensor(), + normalize, + ] + ) - train = ImageNet(data_dir, transform=transform_train) - val = ImageNet(data_dir, split='val', transform=transform_val) + train = ImageNet( + data_dir, + transform=transform_train, + target_transform=lambda x: torch.nn.functional.one_hot( + torch.LongTensor([x]), num_classes + ) + .float() + .squeeze(0), + ) + val = ImageNet( + data_dir, + split="val", + transform=transform_val, + target_transform=lambda x: torch.nn.functional.one_hot( + torch.LongTensor([x]), num_classes + ) + .float() + .squeeze(0), + ) return train, val -def build_imagefolder(data_dir, size=224, additional_transforms=None): + +def build_imagefolder(data_dir, num_classes, size=224, additional_transforms=None): tr = [ transforms.RandomResizedCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - normalize + normalize, ] if additional_transforms is not None: tr.extend(additional_transforms) transform_train = transforms.Compose(tr) - return ImageFolder(data_dir, transform_train) \ No newline at end of file + return ImageFolder( + data_dir, + transform_train, + target_transform=lambda x: torch.nn.functional.one_hot(x, num_classes).float(), + ) diff --git a/src/utils/datamodule.py b/src/utils/datamodule.py new file mode 100644 index 0000000..9f7071f --- /dev/null +++ b/src/utils/datamodule.py @@ -0,0 +1,59 @@ +import pytorch_lightning as L +from torch.utils.data import DataLoader +import math +import torch + + +class ImageDataModule(L.LightningDataModule): + """ + Module to load image data + """ + + def __init__( + self, + dataset_builder, + full_batch_size, + num_workers, + num_nodes=1, + num_devices=1, + ): + super().__init__() + self.batch_size = full_batch_size // (num_nodes * num_devices) + print(f"Each GPU will receive {self.batch_size} images") + self.num_workers = num_workers + self._dataset_builder = dataset_builder + + def setup(self, stage=None): + self.train_dataset, self.val_dataset = self._dataset_builder() + print(f"Train dataset size: {len(self.train_dataset)}") + print(f"Val dataset size: {len(self.val_dataset)}") + self.train_aug = self.train_dataset.transform + self.val_aug = self.val_dataset.transform + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + prefetch_factor=4, + pin_memory=True, + persistent_workers=True, + drop_last=True, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) diff --git a/src/utils/lr_scheduler.py b/src/utils/lr_scheduler.py new file mode 100755 index 0000000..b3b116a --- /dev/null +++ b/src/utils/lr_scheduler.py @@ -0,0 +1,74 @@ +import math + + +class WarmupLR: + def __init__(self, optimizer, warmup_steps): + self.optimizer = optimizer + self.warmup_steps = warmup_steps + self.base_lr = None + + def get_lr(self, lr, step): + return lr * min(step / max(self.warmup_steps, 1), 1.0) + + def step(self, step): + if self.base_lr is None: + self.base_lr = [ + param_group["lr"] for param_group in self.optimizer.param_groups + ] + for param_group, base_lr_group in zip( + self.optimizer.param_groups, self.base_lr + ): + param_group["lr"] = self.get_lr(base_lr_group, step) + + def state_dict(self): + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + +class WarmupCosineDecayLR: + def __init__(self, optimizer, warmup_steps, total_steps, rate=1.0): + self.optimizer = optimizer + self.warmup_steps = warmup_steps + self.base_lr = None + self.total_steps = total_steps + self.rate = rate + + def get_lr(self, lr, step): + if step < self.warmup_steps: + return lr * min(step / max(self.warmup_steps, 1), 1.0) + else: + return ( + 0.5 + * lr + * ( + 1 + + math.cos( + self.rate + * math.pi + * (step - self.warmup_steps) + / (self.total_steps - self.warmup_steps) + ) + ) + ) + + def step(self, step): + if self.base_lr is None: + self.base_lr = [ + param_group["lr"] for param_group in self.optimizer.param_groups + ] + for param_group, base_lr_group in zip( + self.optimizer.param_groups, self.base_lr + ): + param_group["lr"] = self.get_lr(base_lr_group, step) + + def state_dict(self): + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) diff --git a/src/utils/metrics.py b/src/utils/metrics.py new file mode 100644 index 0000000..59a56fc --- /dev/null +++ b/src/utils/metrics.py @@ -0,0 +1,34 @@ +from torchmetrics import Metric +import torch + + +class ClassificationMetrics(Metric): + def __init__(self, num_classes): + super().__init__() + self.num_classes = num_classes + self.add_state("tp", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("fp", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("fn", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("tn", default=torch.zeros(num_classes), dist_reduce_fx="sum") + + def update(self, preds, target): + preds = preds.argmax(dim=1) + target = target.argmax(dim=1) + for c in range(self.num_classes): + self.tp[c] += ((preds == c) & (target == c)).sum() + self.fp[c] += ((preds == c) & (target != c)).sum() + self.fn[c] += ((preds != c) & (target == c)).sum() + self.tn[c] += ((preds != c) & (target != c)).sum() + + def compute(self): + precision = self.tp / (self.tp + self.fp + 1e-8) + recall = self.tp / (self.tp + self.fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + acc = (self.tp + self.tn) / (self.tp + self.tn + self.fp + self.fn + 1e-8) + + return { + "acc": acc.mean(), + "f1": f1.mean(), + "precision": precision.mean(), + "recall": recall.mean(), + } diff --git a/src/utils/mixup.py b/src/utils/mixup.py index 7d5e7b8..50c92ba 100644 --- a/src/utils/mixup.py +++ b/src/utils/mixup.py @@ -2,73 +2,115 @@ from torchvision.transforms import v2 import einops -class MixUp(): - def __init__(self, alpha=0.1, num_classes=1000): + +class MixUp: + def __init__(self, apply_transform_prob=1.0, alpha=0.1, num_classes=1000): super().__init__() self.alpha = alpha self._mixup = v2.MixUp(alpha=alpha, num_classes=num_classes) + self.apply_transform_prob = apply_transform_prob @torch.no_grad() def __call__(self, x, y): - x_2 = einops.rearrange(x, "(b m) c h w -> b m c h w", m=2) - y_2 = einops.rearrange(y, "(b m) -> b m", m=2) - b, m = y_2.shape - x_out = b*[None]; y_out = b*[None] - for i in torch.arange(0, b): - x_, y_ = self._mixup(x_2[i], y_2[i]) - x_out[i] = x_ - y_out[i] = y_ - x_out = torch.cat(x_out, dim=0) - y_out = torch.cat(y_out, dim=0) - return x_out, y_out - - -class CutMix(): - def __init__(self, alpha=1., num_classes=1000): + if torch.rand(1) < self.apply_transform_prob: + if y.ndim == 2: + y = torch.argmax(y, dim=1) + elif y.ndim == 1: + pass + else: + raise ValueError("y must be 1 or 2 dim") + x_2 = einops.rearrange(x, "(b m) c h w -> b m c h w", m=2) + y_2 = einops.rearrange(y, "(b m) -> b m", m=2) + b, m = y_2.shape + x_out = b * [None] + y_out = b * [None] + for i in torch.arange(0, b): + x_, y_ = self._mixup(x_2[i], y_2[i]) + x_out[i] = x_ + y_out[i] = y_ + x_out = torch.cat(x_out, dim=0) + y_out = torch.cat(y_out, dim=0) + return x_out, y_out + else: + return x, y + + +class CutMix: + def __init__(self, apply_transform_prob=1.0, alpha=1.0, num_classes=1000): super().__init__() self.alpha = alpha self._cutmix = v2.CutMix(alpha=alpha, num_classes=num_classes) + self.apply_transform_prob = apply_transform_prob @torch.no_grad() def __call__(self, x, y): - x_2 = einops.rearrange(x, "(b m) c h w -> b m c h w", m=2) - y_2 = einops.rearrange(y, "(b m) -> b m", m=2) - b, m = y_2.shape - x_out = b*[None]; y_out = b*[None] - for i in torch.arange(0, b): - x_, y_ = self._cutmix(x_2[i], y_2[i]) - x_out[i] = x_ - y_out[i] = y_ - x_out = torch.cat(x_out, dim=0) - y_out = torch.cat(y_out, dim=0) - return x_out, y_out - -class CutMixUp(): - def __init__(self, mixup_prob=0.5, alpha_mixup=0.1, alpha_cutmix=1.0, num_classes=1000): + if torch.rand(1) < self.apply_transform_prob: + if y.ndim == 2: + y = torch.argmax(y, dim=1) + elif y.ndim == 1: + pass + else: + raise ValueError("y must be 1 or 2 dim") + x_2 = einops.rearrange(x, "(b m) c h w -> b m c h w", m=2) + y_2 = einops.rearrange(y, "(b m) -> b m", m=2) + b, m = y_2.shape + x_out = b * [None] + y_out = b * [None] + for i in torch.arange(0, b): + x_, y_ = self._cutmix(x_2[i], y_2[i]) + x_out[i] = x_ + y_out[i] = y_ + x_out = torch.cat(x_out, dim=0) + y_out = torch.cat(y_out, dim=0) + return x_out, y_out + else: + return x, y + + +class CutMixUp: + def __init__( + self, + apply_transform_prob=1.0, + mixup_prob=0.5, + alpha_mixup=0.1, + alpha_cutmix=1.0, + num_classes=1000, + ): self.mixup_prob = mixup_prob self.alpha_mixup = alpha_mixup self.alpha_cutmix = alpha_cutmix self.num_classes = num_classes self._mixup = v2.MixUp(alpha=alpha_mixup, num_classes=num_classes) self._cutmix = v2.CutMix(alpha=alpha_cutmix, num_classes=num_classes) + self.apply_transform_prob = apply_transform_prob @torch.no_grad() def __call__(self, x, y): - x_2 = einops.rearrange(x, "(b m) c h w -> b m c h w", m=8) - y_2 = einops.rearrange(y, "(b m) -> b m", m=8) - b, m = y_2.shape - x_out = b*[None]; y_out = b*[None] - r = torch.rand((b,)) - for i in torch.arange(0, b): - if r[i] < self.mixup_prob: - x_, y_ = self._mixup(x_2[i], y_2[i]) + if torch.rand(1) < self.apply_transform_prob: + if y.ndim == 2: + y = torch.argmax(y, dim=1) + elif y.ndim == 1: + pass else: - x_, y_ = self._cutmix(x_2[i], y_2[i]) - x_out[i] = x_ - y_out[i] = y_ - x_out = torch.cat(x_out, dim=0) - y_out = torch.cat(y_out, dim=0) - return x_out, y_out + raise ValueError("y must be 1 or 2 dim") + x_2 = einops.rearrange(x, "(b m) c h w -> b m c h w", m=8) + y_2 = einops.rearrange(y, "(b m) -> b m", m=8) + b, m = y_2.shape + x_out = b * [None] + y_out = b * [None] + r = torch.rand((b,)) + for i in torch.arange(0, b): + if r[i] < self.mixup_prob: + x_, y_ = self._mixup(x_2[i], y_2[i]) + else: + x_, y_ = self._cutmix(x_2[i], y_2[i]) + x_out[i] = x_ + y_out[i] = y_ + x_out = torch.cat(x_out, dim=0) + y_out = torch.cat(y_out, dim=0) + return x_out, y_out + else: + return x, y if __name__ == "__main__": @@ -86,7 +128,7 @@ def __call__(self, x, y): plt.subplot(4, 1, 1) plt.imshow(einops.rearrange(imgs, "b c h w -> h (b w) c")) - mixup = MixUp(alpha=1.) + mixup = MixUp(alpha=1.0) m_imgs, m_lbls = mixup(imgs, lbls) plt.subplot(4, 1, 2) plt.imshow(einops.rearrange(m_imgs, "b c h w -> h (b w) c")) @@ -101,4 +143,4 @@ def __call__(self, x, y): plt.subplot(4, 1, 4) plt.imshow(einops.rearrange(cm_imgs, "b c h w -> h (b w) c")) - plt.show() \ No newline at end of file + plt.show()