diff --git a/minerva/models/nets/base.py b/minerva/models/nets/base.py index a00355e2..8b46e7d4 100644 --- a/minerva/models/nets/base.py +++ b/minerva/models/nets/base.py @@ -1,7 +1,8 @@ -from typing import Dict +from typing import Dict, Optional import lightning as L import torch +from torchmetrics import Metric class SimpleSupervisedModel(L.LightningModule): @@ -30,8 +31,15 @@ def __init__( loss_fn: torch.nn.Module, learning_rate: float = 1e-3, flatten: bool = True, + train_metrics: Dict[str, Metric] = None, + val_metrics: Dict[str, Metric] = None, + test_metrics: Dict[str, Metric] = None, ): - """Initialize the model. + """Initialize the model with the backbone, fc, loss function and + metrics. Metrics are used to evaluate the model during training, + validation, testing or prediction. It will be logged using + lightning logger at the end of each epoch. Metrics should implement + the `torchmetrics.Metric` interface. Parameters ---------- @@ -47,6 +55,15 @@ def __init__( flatten : bool, optional If `True` the input data will be flattened before passing through the fc model, by default True + + train_metrics : Dict[str, Metric], optional + The metrics to be used during training, by default None + val_metrics : Dict[str, Metric], optional + The metrics to be used during validation, by default None + test_metrics : Dict[str, Metric], optional + The metrics to be used during testing, by default None + predict_metrics : Dict[str, Metric], optional + The metrics to be used during prediction, by default None """ super().__init__() self.backbone = backbone @@ -55,6 +72,12 @@ def __init__( self.learning_rate = learning_rate self.flatten = flatten + self.metrics = { + "train": train_metrics, + "val": val_metrics, + "test": test_metrics, + } + def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Calculate the loss between the output and the input data. @@ -92,6 +115,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc(x) return x + def _compute_metrics( + self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str + ) -> Dict[str, torch.Tensor]: + """Calculate the metrics for the given step. + + Parameters + ---------- + y_hat : torch.Tensor + The output data from the forward pass. + y : torch.Tensor + The input data/label. + step_name : str + Name of the step. It will be used to get the metrics from the + `self.metrics` attribute. + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary with the metrics values. + """ + if self.metrics[step_name] is None: + return {} + + return { + f"{step_name}_{metric_name}": metric.to(self.device)(y_hat, y) + for metric_name, metric in self.metrics[step_name].items() + } + def _single_step( self, batch: torch.Tensor, batch_idx: int, step_name: str ) -> torch.Tensor: @@ -127,6 +178,18 @@ def _single_step( prog_bar=True, logger=True, ) + + metrics = self._compute_metrics(y_hat, y, step_name) + for metric_name, metric_value in metrics.items(): + self.log( + metric_name, + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss def training_step(self, batch: torch.Tensor, batch_idx: int): diff --git a/minerva/models/nets/deeplabv3.py b/minerva/models/nets/deeplabv3.py index 685c5994..e24ea25b 100644 --- a/minerva/models/nets/deeplabv3.py +++ b/minerva/models/nets/deeplabv3.py @@ -22,10 +22,14 @@ class DeepLabV3_Head(nn.Module): def __init__(self) -> None: super().__init__() - raise NotImplementedError("DeepLabV3's head has not yet been implemented") + raise NotImplementedError( + "DeepLabV3's head has not yet been implemented" + ) def forward(self, x): - raise NotImplementedError("DeepLabV3's head has not yet been implemented") + raise NotImplementedError( + "DeepLabV3's head has not yet been implemented" + ) class DeepLabV3(SimpleSupervisedModel): @@ -33,10 +37,16 @@ class DeepLabV3(SimpleSupervisedModel): References ---------- - Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam. "Rethinking Atrous Convolution for Semantic Image Segmentation", 2017 + Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam. + "Rethinking Atrous Convolution for Semantic Image Segmentation", 2017 """ - def __init__(self, learning_rate: float = 1e-3, loss_fn: torch.nn.Module = None): + def __init__( + self, + learning_rate: float = 1e-3, + loss_fn: torch.nn.Module = None, + **kwargs, + ): """Wrapper implementation of the DeepLabv3 model. Parameters @@ -46,10 +56,14 @@ def __init__(self, learning_rate: float = 1e-3, loss_fn: torch.nn.Module = None) loss_fn : torch.nn.Module, optional The function used to compute the loss. If `None`, it will be used the MSELoss, by default None. + kwargs : Dict + Additional arguments to be passed to the `SimpleSupervisedModel` + class. """ super().__init__( backbone=Resnet50Backbone(), fc=DeepLabV3_Head(), loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, + **kwargs, ) diff --git a/minerva/models/nets/unet.py b/minerva/models/nets/unet.py index b5ec57b8..ec42fd8c 100644 --- a/minerva/models/nets/unet.py +++ b/minerva/models/nets/unet.py @@ -201,6 +201,7 @@ def __init__( bilinear: bool = False, learning_rate: float = 1e-3, loss_fn: Optional[torch.nn.Module] = None, + **kwargs, ): """Wrapper implementation of the U-Net model. @@ -216,6 +217,9 @@ def __init__( loss_fn : torch.nn.Module, optional The function used to compute the loss. If `None`, it will be used the MSELoss, by default None. + kwargs : Dict + Additional arguments to be passed to the `SimpleSupervisedModel` + class. """ super().__init__( backbone=_UNet(n_channels=n_channels, bilinear=bilinear), @@ -223,4 +227,5 @@ def __init__( loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, flatten=False, + **kwargs ) diff --git a/minerva/models/nets/wisenet.py b/minerva/models/nets/wisenet.py index 5c1959a4..e1147d94 100644 --- a/minerva/models/nets/wisenet.py +++ b/minerva/models/nets/wisenet.py @@ -101,6 +101,7 @@ def __init__( out_channels: int = 1, loss_fn: torch.nn.Module = None, learning_rate: float = 1e-3, + **kwargs, ): super().__init__( backbone=_WiseNet(in_channels=in_channels, out_channels=out_channels), @@ -108,6 +109,7 @@ def __init__( loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, flatten=False, + **kwargs, ) def _single_step( diff --git a/tests/models/nets/test_unet.py b/tests/models/nets/test_unet.py index 4fd306e7..916fe0a9 100644 --- a/tests/models/nets/test_unet.py +++ b/tests/models/nets/test_unet.py @@ -1,6 +1,9 @@ import torch from minerva.models.nets.unet import UNet +import torchmetrics +import lightning as L +from minerva.utils.data import RandomDataModule def test_unet(): @@ -24,3 +27,34 @@ def test_unet(): loss = model.training_step((x, mask), 0).item() assert loss is not None assert loss >= 0, f"Expected non-negative loss, but got {loss}" + + +def test_unet_train_metrics(): + metrics = { + "mse": torchmetrics.MeanSquaredError(squared=True), + "mae": torchmetrics.MeanAbsoluteError(), + } + + # Generate a random input tensor (B, C, H, W) and the random mask of the + # same shape + data_module = RandomDataModule( + data_shape=(1, 128, 128), + label_shape=(1, 128, 128), + num_train_samples=2, + batch_size=2, + ) + + # Test the class instantiation + model = UNet(train_metrics=metrics) + trainer = L.Trainer(accelerator="cpu", max_epochs=1, devices=1) + + assert data_module is not None + assert model is not None + assert trainer is not None + + # Do fit + trainer.fit(model, data_module) + + assert "train_mse" in trainer.logged_metrics + assert "train_mae" in trainer.logged_metrics + assert "train_loss" in trainer.logged_metrics