Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simple logging class #57

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions minerva/models/nets/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Dict
from typing import Dict, Optional

import lightning as L
import torch
from torchmetrics import Metric


class SimpleSupervisedModel(L.LightningModule):
Expand Down Expand Up @@ -30,8 +31,15 @@ def __init__(
loss_fn: torch.nn.Module,
learning_rate: float = 1e-3,
flatten: bool = True,
train_metrics: Dict[str, Metric] = None,
val_metrics: Dict[str, Metric] = None,
test_metrics: Dict[str, Metric] = None,
):
"""Initialize the model.
"""Initialize the model with the backbone, fc, loss function and
metrics. Metrics are used to evaluate the model during training,
validation, testing or prediction. It will be logged using
lightning logger at the end of each epoch. Metrics should implement
the `torchmetrics.Metric` interface.

Parameters
----------
Expand All @@ -47,6 +55,15 @@ def __init__(
flatten : bool, optional
If `True` the input data will be flattened before passing through
the fc model, by default True

train_metrics : Dict[str, Metric], optional
The metrics to be used during training, by default None
val_metrics : Dict[str, Metric], optional
The metrics to be used during validation, by default None
test_metrics : Dict[str, Metric], optional
The metrics to be used during testing, by default None
predict_metrics : Dict[str, Metric], optional
The metrics to be used during prediction, by default None
"""
super().__init__()
self.backbone = backbone
Expand All @@ -55,6 +72,12 @@ def __init__(
self.learning_rate = learning_rate
self.flatten = flatten

self.metrics = {
"train": train_metrics,
"val": val_metrics,
"test": test_metrics,
}

def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calculate the loss between the output and the input data.

Expand Down Expand Up @@ -92,6 +115,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc(x)
return x

def _compute_metrics(
self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str
) -> Dict[str, torch.Tensor]:
"""Calculate the metrics for the given step.

Parameters
----------
y_hat : torch.Tensor
The output data from the forward pass.
y : torch.Tensor
The input data/label.
step_name : str
Name of the step. It will be used to get the metrics from the
`self.metrics` attribute.

Returns
-------
Dict[str, torch.Tensor]
A dictionary with the metrics values.
"""
if self.metrics[step_name] is None:
return {}

return {
f"{step_name}_{metric_name}": metric.to(self.device)(y_hat, y)
for metric_name, metric in self.metrics[step_name].items()
}

def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
Expand Down Expand Up @@ -127,6 +178,18 @@ def _single_step(
prog_bar=True,
logger=True,
)

metrics = self._compute_metrics(y_hat, y, step_name)
for metric_name, metric_value in metrics.items():
self.log(
metric_name,
metric_value,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

def training_step(self, batch: torch.Tensor, batch_idx: int):
Expand Down
22 changes: 18 additions & 4 deletions minerva/models/nets/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,31 @@ class DeepLabV3_Head(nn.Module):

def __init__(self) -> None:
super().__init__()
raise NotImplementedError("DeepLabV3's head has not yet been implemented")
raise NotImplementedError(
"DeepLabV3's head has not yet been implemented"
)

def forward(self, x):
raise NotImplementedError("DeepLabV3's head has not yet been implemented")
raise NotImplementedError(
"DeepLabV3's head has not yet been implemented"
)


class DeepLabV3(SimpleSupervisedModel):
"""A DeeplabV3 with a ResNet50 backbone

References
----------
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam. "Rethinking Atrous Convolution for Semantic Image Segmentation", 2017
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam.
"Rethinking Atrous Convolution for Semantic Image Segmentation", 2017
"""

def __init__(self, learning_rate: float = 1e-3, loss_fn: torch.nn.Module = None):
def __init__(
self,
learning_rate: float = 1e-3,
loss_fn: torch.nn.Module = None,
**kwargs,
):
"""Wrapper implementation of the DeepLabv3 model.

Parameters
Expand All @@ -46,10 +56,14 @@ def __init__(self, learning_rate: float = 1e-3, loss_fn: torch.nn.Module = None)
loss_fn : torch.nn.Module, optional
The function used to compute the loss. If `None`, it will be used
the MSELoss, by default None.
kwargs : Dict
Additional arguments to be passed to the `SimpleSupervisedModel`
class.
"""
super().__init__(
backbone=Resnet50Backbone(),
fc=DeepLabV3_Head(),
loss_fn=loss_fn or torch.nn.MSELoss(),
learning_rate=learning_rate,
**kwargs,
)
5 changes: 5 additions & 0 deletions minerva/models/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
bilinear: bool = False,
learning_rate: float = 1e-3,
loss_fn: Optional[torch.nn.Module] = None,
**kwargs,
):
"""Wrapper implementation of the U-Net model.

Expand All @@ -216,11 +217,15 @@ def __init__(
loss_fn : torch.nn.Module, optional
The function used to compute the loss. If `None`, it will be used
the MSELoss, by default None.
kwargs : Dict
Additional arguments to be passed to the `SimpleSupervisedModel`
class.
"""
super().__init__(
backbone=_UNet(n_channels=n_channels, bilinear=bilinear),
fc=torch.nn.Identity(),
loss_fn=loss_fn or torch.nn.MSELoss(),
learning_rate=learning_rate,
flatten=False,
**kwargs
)
2 changes: 2 additions & 0 deletions minerva/models/nets/wisenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,15 @@ def __init__(
out_channels: int = 1,
loss_fn: torch.nn.Module = None,
learning_rate: float = 1e-3,
**kwargs,
):
super().__init__(
backbone=_WiseNet(in_channels=in_channels, out_channels=out_channels),
fc=torch.nn.Identity(),
loss_fn=loss_fn or torch.nn.MSELoss(),
learning_rate=learning_rate,
flatten=False,
**kwargs,
)

def _single_step(
Expand Down
34 changes: 34 additions & 0 deletions tests/models/nets/test_unet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch

from minerva.models.nets.unet import UNet
import torchmetrics
import lightning as L
from minerva.utils.data import RandomDataModule


def test_unet():
Expand All @@ -24,3 +27,34 @@ def test_unet():
loss = model.training_step((x, mask), 0).item()
assert loss is not None
assert loss >= 0, f"Expected non-negative loss, but got {loss}"


def test_unet_train_metrics():
metrics = {
"mse": torchmetrics.MeanSquaredError(squared=True),
"mae": torchmetrics.MeanAbsoluteError(),
}

# Generate a random input tensor (B, C, H, W) and the random mask of the
# same shape
data_module = RandomDataModule(
data_shape=(1, 128, 128),
label_shape=(1, 128, 128),
num_train_samples=2,
batch_size=2,
)

# Test the class instantiation
model = UNet(train_metrics=metrics)
trainer = L.Trainer(accelerator="cpu", max_epochs=1, devices=1)

assert data_module is not None
assert model is not None
assert trainer is not None

# Do fit
trainer.fit(model, data_module)

assert "train_mse" in trainer.logged_metrics
assert "train_mae" in trainer.logged_metrics
assert "train_loss" in trainer.logged_metrics
Loading