From 57ee873fd955a571fadc5db6b0a354601730cfae Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Mon, 13 May 2024 10:07:28 -0300 Subject: [PATCH 1/4] simple logging class --- minerva/models/nets/base.py | 172 +++++++++++++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 1 deletion(-) diff --git a/minerva/models/nets/base.py b/minerva/models/nets/base.py index a00355e..959ffab 100644 --- a/minerva/models/nets/base.py +++ b/minerva/models/nets/base.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional import lightning as L import torch @@ -149,3 +149,173 @@ def configure_optimizers(self): lr=self.learning_rate, ) return optimizer + + +class LoggedSupervisedModel(L.LightningModule): + """Simple pipeline for supervised models with logging.""" + + def __init__( + self, + model: torch.nn.Module, + loss_fn: torch.nn.Module, + learning_rate: float = 1e-3, + flatten: bool = True, + train_metric: Optional[torch.nn.Module] = None, + val_metric: Optional[torch.nn.Module] = None, + test_metric: Optional[torch.nn.Module] = None, + ): + super().__init__() + self.model = model + self.loss_fn = loss_fn + self.learning_rate = learning_rate + self.flatten = flatten + + if train_metric is not None: + self.train_metric = train_metric + self.log_train_metric = True + + if val_metric is not None: + self.val_metric = val_metric + self.log_val_metric = True + + if test_metric is not None: + self.test_metric = test_metric + self.log_test_metric = True + + self.train_step_outputs = [] + self.train_step_labels = [] + + self.val_step_outputs = [] + self.val_step_labels = [] + + self.test_step_outputs = [] + self.test_step_labels = [] + + def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Calculate the loss between the output and the input data. + + Parameters + ---------- + y_hat : torch.Tensor + The output data from the forward pass. + y : torch.Tensor + The input data/label. + + Returns + ------- + torch.Tensor + The loss value. + """ + loss = self.loss_fn(y_hat, y) + return loss + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass with the input data on the backbone model. + + Parameters + ---------- + x : torch.Tensor + The input data. + + Returns + ------- + torch.Tensor + The output data from the forward pass. + """ + return self.model(x) + + def _single_step( + self, batch: torch.Tensor, batch_idx: int, step_name: str + ) -> torch.Tensor: + + x, y = batch + y_hat = self.forward(x) + loss = self._loss_func(y_hat, y) + + if step_name == "train": + self.train_step_outputs.append(y_hat) + self.train_step_labels.append(y) + + elif step_name == "val": + self.val_step_outputs.append(y_hat) + self.val_step_labels.append(y) + + elif step_name == "test": + self.test_step_outputs.append(y_hat) + self.test_step_labels.append(y) + + self.log( + f"{step_name}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def training_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, step_name="train") + + def on_train_batch_end(self): + if self.log_train_metric: + y_hat = torch.cat(self.train_step_outputs) + y = torch.cat(self.train_step_labels) + metric = self.train_metric(y_hat, y) + self.log_dict( + {"train_metric": metric}, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.train_step_outputs.clear() + self.train_step_labels.clear() + + def validation_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, step_name="val") + + def on_validation_batch_end(self): + if self.log_val_metric: + y_hat = torch.cat(self.val_step_outputs) + y = torch.cat(self.val_step_labels) + metric = self.val_metric(y_hat, y) + self.log_dict( + {"val_metric": metric}, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.val_step_outputs.clear() + self.val_step_labels.clear() + + def test_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, step_name="test") + + def on_test_batch_end(self): + if self.log_test_metric: + y_hat = torch.cat(self.test_step_outputs) + y = torch.cat(self.test_step_labels) + metric = self.test_metric(y_hat, y) + self.log_dict( + {"test_metric": metric}, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.test_step_outputs.clear() + self.test_step_labels.clear() + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + x, _ = batch + y_hat = self.forward(x) + return y_hat + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.learning_rate, + ) + return optimizer From b943951511f9d291db88edd074f63623376d0c7c Mon Sep 17 00:00:00 2001 From: Otavio Napoli Date: Mon, 13 May 2024 14:23:00 +0000 Subject: [PATCH 2/4] Changed metrics to use torchmetrics API and put it directly in SimpleSupervisedModel class --- minerva/models/nets/base.py | 233 ++++++++++-------------------------- 1 file changed, 63 insertions(+), 170 deletions(-) diff --git a/minerva/models/nets/base.py b/minerva/models/nets/base.py index 959ffab..8b46e7d 100644 --- a/minerva/models/nets/base.py +++ b/minerva/models/nets/base.py @@ -2,6 +2,7 @@ import lightning as L import torch +from torchmetrics import Metric class SimpleSupervisedModel(L.LightningModule): @@ -30,8 +31,15 @@ def __init__( loss_fn: torch.nn.Module, learning_rate: float = 1e-3, flatten: bool = True, + train_metrics: Dict[str, Metric] = None, + val_metrics: Dict[str, Metric] = None, + test_metrics: Dict[str, Metric] = None, ): - """Initialize the model. + """Initialize the model with the backbone, fc, loss function and + metrics. Metrics are used to evaluate the model during training, + validation, testing or prediction. It will be logged using + lightning logger at the end of each epoch. Metrics should implement + the `torchmetrics.Metric` interface. Parameters ---------- @@ -47,6 +55,15 @@ def __init__( flatten : bool, optional If `True` the input data will be flattened before passing through the fc model, by default True + + train_metrics : Dict[str, Metric], optional + The metrics to be used during training, by default None + val_metrics : Dict[str, Metric], optional + The metrics to be used during validation, by default None + test_metrics : Dict[str, Metric], optional + The metrics to be used during testing, by default None + predict_metrics : Dict[str, Metric], optional + The metrics to be used during prediction, by default None """ super().__init__() self.backbone = backbone @@ -55,6 +72,12 @@ def __init__( self.learning_rate = learning_rate self.flatten = flatten + self.metrics = { + "train": train_metrics, + "val": val_metrics, + "test": test_metrics, + } + def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Calculate the loss between the output and the input data. @@ -92,6 +115,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc(x) return x + def _compute_metrics( + self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str + ) -> Dict[str, torch.Tensor]: + """Calculate the metrics for the given step. + + Parameters + ---------- + y_hat : torch.Tensor + The output data from the forward pass. + y : torch.Tensor + The input data/label. + step_name : str + Name of the step. It will be used to get the metrics from the + `self.metrics` attribute. + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary with the metrics values. + """ + if self.metrics[step_name] is None: + return {} + + return { + f"{step_name}_{metric_name}": metric.to(self.device)(y_hat, y) + for metric_name, metric in self.metrics[step_name].items() + } + def _single_step( self, batch: torch.Tensor, batch_idx: int, step_name: str ) -> torch.Tensor: @@ -127,187 +178,29 @@ def _single_step( prog_bar=True, logger=True, ) - return loss - - def training_step(self, batch: torch.Tensor, batch_idx: int): - return self._single_step(batch, batch_idx, step_name="train") - - def validation_step(self, batch: torch.Tensor, batch_idx: int): - return self._single_step(batch, batch_idx, step_name="val") - - def test_step(self, batch: torch.Tensor, batch_idx: int): - return self._single_step(batch, batch_idx, step_name="test") - - def predict_step(self, batch, batch_idx, dataloader_idx=None): - x, _ = batch - y_hat = self.forward(x) - return y_hat - - def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), - lr=self.learning_rate, - ) - return optimizer - - -class LoggedSupervisedModel(L.LightningModule): - """Simple pipeline for supervised models with logging.""" - - def __init__( - self, - model: torch.nn.Module, - loss_fn: torch.nn.Module, - learning_rate: float = 1e-3, - flatten: bool = True, - train_metric: Optional[torch.nn.Module] = None, - val_metric: Optional[torch.nn.Module] = None, - test_metric: Optional[torch.nn.Module] = None, - ): - super().__init__() - self.model = model - self.loss_fn = loss_fn - self.learning_rate = learning_rate - self.flatten = flatten - - if train_metric is not None: - self.train_metric = train_metric - self.log_train_metric = True - - if val_metric is not None: - self.val_metric = val_metric - self.log_val_metric = True - - if test_metric is not None: - self.test_metric = test_metric - self.log_test_metric = True - - self.train_step_outputs = [] - self.train_step_labels = [] - - self.val_step_outputs = [] - self.val_step_labels = [] - - self.test_step_outputs = [] - self.test_step_labels = [] - - def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Calculate the loss between the output and the input data. - - Parameters - ---------- - y_hat : torch.Tensor - The output data from the forward pass. - y : torch.Tensor - The input data/label. - - Returns - ------- - torch.Tensor - The loss value. - """ - loss = self.loss_fn(y_hat, y) - return loss - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Perform a forward pass with the input data on the backbone model. - - Parameters - ---------- - x : torch.Tensor - The input data. - - Returns - ------- - torch.Tensor - The output data from the forward pass. - """ - return self.model(x) - - def _single_step( - self, batch: torch.Tensor, batch_idx: int, step_name: str - ) -> torch.Tensor: - - x, y = batch - y_hat = self.forward(x) - loss = self._loss_func(y_hat, y) - - if step_name == "train": - self.train_step_outputs.append(y_hat) - self.train_step_labels.append(y) - - elif step_name == "val": - self.val_step_outputs.append(y_hat) - self.val_step_labels.append(y) - - elif step_name == "test": - self.test_step_outputs.append(y_hat) - self.test_step_labels.append(y) - - self.log( - f"{step_name}_loss", - loss, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss - - def training_step(self, batch: torch.Tensor, batch_idx: int): - return self._single_step(batch, batch_idx, step_name="train") - - def on_train_batch_end(self): - if self.log_train_metric: - y_hat = torch.cat(self.train_step_outputs) - y = torch.cat(self.train_step_labels) - metric = self.train_metric(y_hat, y) - self.log_dict( - {"train_metric": metric}, + + metrics = self._compute_metrics(y_hat, y, step_name) + for metric_name, metric_value in metrics.items(): + self.log( + metric_name, + metric_value, on_step=False, on_epoch=True, prog_bar=True, logger=True, ) - self.train_step_outputs.clear() - self.train_step_labels.clear() + + return loss + + def training_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, step_name="train") def validation_step(self, batch: torch.Tensor, batch_idx: int): return self._single_step(batch, batch_idx, step_name="val") - def on_validation_batch_end(self): - if self.log_val_metric: - y_hat = torch.cat(self.val_step_outputs) - y = torch.cat(self.val_step_labels) - metric = self.val_metric(y_hat, y) - self.log_dict( - {"val_metric": metric}, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - self.val_step_outputs.clear() - self.val_step_labels.clear() - def test_step(self, batch: torch.Tensor, batch_idx: int): return self._single_step(batch, batch_idx, step_name="test") - def on_test_batch_end(self): - if self.log_test_metric: - y_hat = torch.cat(self.test_step_outputs) - y = torch.cat(self.test_step_labels) - metric = self.test_metric(y_hat, y) - self.log_dict( - {"test_metric": metric}, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - self.test_step_outputs.clear() - self.test_step_labels.clear() - def predict_step(self, batch, batch_idx, dataloader_idx=None): x, _ = batch y_hat = self.forward(x) From 2c9e360934b8f2777ecedf674818fa8352e49c42 Mon Sep 17 00:00:00 2001 From: Otavio Napoli Date: Mon, 13 May 2024 14:23:27 +0000 Subject: [PATCH 3/4] Adapted model to allow metrics usage via kwargs --- minerva/models/nets/deeplabv3.py | 22 ++++++++++++++++++---- minerva/models/nets/unet.py | 5 +++++ minerva/models/nets/wisenet.py | 2 ++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/minerva/models/nets/deeplabv3.py b/minerva/models/nets/deeplabv3.py index 685c599..e24ea25 100644 --- a/minerva/models/nets/deeplabv3.py +++ b/minerva/models/nets/deeplabv3.py @@ -22,10 +22,14 @@ class DeepLabV3_Head(nn.Module): def __init__(self) -> None: super().__init__() - raise NotImplementedError("DeepLabV3's head has not yet been implemented") + raise NotImplementedError( + "DeepLabV3's head has not yet been implemented" + ) def forward(self, x): - raise NotImplementedError("DeepLabV3's head has not yet been implemented") + raise NotImplementedError( + "DeepLabV3's head has not yet been implemented" + ) class DeepLabV3(SimpleSupervisedModel): @@ -33,10 +37,16 @@ class DeepLabV3(SimpleSupervisedModel): References ---------- - Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam. "Rethinking Atrous Convolution for Semantic Image Segmentation", 2017 + Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam. + "Rethinking Atrous Convolution for Semantic Image Segmentation", 2017 """ - def __init__(self, learning_rate: float = 1e-3, loss_fn: torch.nn.Module = None): + def __init__( + self, + learning_rate: float = 1e-3, + loss_fn: torch.nn.Module = None, + **kwargs, + ): """Wrapper implementation of the DeepLabv3 model. Parameters @@ -46,10 +56,14 @@ def __init__(self, learning_rate: float = 1e-3, loss_fn: torch.nn.Module = None) loss_fn : torch.nn.Module, optional The function used to compute the loss. If `None`, it will be used the MSELoss, by default None. + kwargs : Dict + Additional arguments to be passed to the `SimpleSupervisedModel` + class. """ super().__init__( backbone=Resnet50Backbone(), fc=DeepLabV3_Head(), loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, + **kwargs, ) diff --git a/minerva/models/nets/unet.py b/minerva/models/nets/unet.py index b5ec57b..ec42fd8 100644 --- a/minerva/models/nets/unet.py +++ b/minerva/models/nets/unet.py @@ -201,6 +201,7 @@ def __init__( bilinear: bool = False, learning_rate: float = 1e-3, loss_fn: Optional[torch.nn.Module] = None, + **kwargs, ): """Wrapper implementation of the U-Net model. @@ -216,6 +217,9 @@ def __init__( loss_fn : torch.nn.Module, optional The function used to compute the loss. If `None`, it will be used the MSELoss, by default None. + kwargs : Dict + Additional arguments to be passed to the `SimpleSupervisedModel` + class. """ super().__init__( backbone=_UNet(n_channels=n_channels, bilinear=bilinear), @@ -223,4 +227,5 @@ def __init__( loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, flatten=False, + **kwargs ) diff --git a/minerva/models/nets/wisenet.py b/minerva/models/nets/wisenet.py index 5c1959a..e1147d9 100644 --- a/minerva/models/nets/wisenet.py +++ b/minerva/models/nets/wisenet.py @@ -101,6 +101,7 @@ def __init__( out_channels: int = 1, loss_fn: torch.nn.Module = None, learning_rate: float = 1e-3, + **kwargs, ): super().__init__( backbone=_WiseNet(in_channels=in_channels, out_channels=out_channels), @@ -108,6 +109,7 @@ def __init__( loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, flatten=False, + **kwargs, ) def _single_step( From aecea7b7df79e76a9682b37ff04ca534008c6471 Mon Sep 17 00:00:00 2001 From: Otavio Napoli Date: Mon, 13 May 2024 14:23:38 +0000 Subject: [PATCH 4/4] Added test of metrics using Unet --- tests/models/nets/test_unet.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/models/nets/test_unet.py b/tests/models/nets/test_unet.py index 4fd306e..916fe0a 100644 --- a/tests/models/nets/test_unet.py +++ b/tests/models/nets/test_unet.py @@ -1,6 +1,9 @@ import torch from minerva.models.nets.unet import UNet +import torchmetrics +import lightning as L +from minerva.utils.data import RandomDataModule def test_unet(): @@ -24,3 +27,34 @@ def test_unet(): loss = model.training_step((x, mask), 0).item() assert loss is not None assert loss >= 0, f"Expected non-negative loss, but got {loss}" + + +def test_unet_train_metrics(): + metrics = { + "mse": torchmetrics.MeanSquaredError(squared=True), + "mae": torchmetrics.MeanAbsoluteError(), + } + + # Generate a random input tensor (B, C, H, W) and the random mask of the + # same shape + data_module = RandomDataModule( + data_shape=(1, 128, 128), + label_shape=(1, 128, 128), + num_train_samples=2, + batch_size=2, + ) + + # Test the class instantiation + model = UNet(train_metrics=metrics) + trainer = L.Trainer(accelerator="cpu", max_epochs=1, devices=1) + + assert data_module is not None + assert model is not None + assert trainer is not None + + # Do fit + trainer.fit(model, data_module) + + assert "train_mse" in trainer.logged_metrics + assert "train_mae" in trainer.logged_metrics + assert "train_loss" in trainer.logged_metrics