Skip to content

Commit

Permalink
Added test of metrics using Unet
Browse files Browse the repository at this point in the history
  • Loading branch information
otavioon committed May 13, 2024
1 parent 2c9e360 commit aecea7b
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/models/nets/test_unet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch

from minerva.models.nets.unet import UNet
import torchmetrics
import lightning as L
from minerva.utils.data import RandomDataModule


def test_unet():
Expand All @@ -24,3 +27,34 @@ def test_unet():
loss = model.training_step((x, mask), 0).item()
assert loss is not None
assert loss >= 0, f"Expected non-negative loss, but got {loss}"


def test_unet_train_metrics():
metrics = {
"mse": torchmetrics.MeanSquaredError(squared=True),
"mae": torchmetrics.MeanAbsoluteError(),
}

# Generate a random input tensor (B, C, H, W) and the random mask of the
# same shape
data_module = RandomDataModule(
data_shape=(1, 128, 128),
label_shape=(1, 128, 128),
num_train_samples=2,
batch_size=2,
)

# Test the class instantiation
model = UNet(train_metrics=metrics)
trainer = L.Trainer(accelerator="cpu", max_epochs=1, devices=1)

assert data_module is not None
assert model is not None
assert trainer is not None

# Do fit
trainer.fit(model, data_module)

assert "train_mse" in trainer.logged_metrics
assert "train_mae" in trainer.logged_metrics
assert "train_loss" in trainer.logged_metrics

0 comments on commit aecea7b

Please sign in to comment.