Skip to content

Commit 49b17f1

Browse files
committed
logging
1 parent 615bb09 commit 49b17f1

File tree

2 files changed

+91
-28
lines changed

2 files changed

+91
-28
lines changed

minerva/models/nets/setr.py

Lines changed: 90 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,8 @@ def __init__(
420420
train_metrics: Optional[nn.Module] = None,
421421
log_val_metrics: bool = False,
422422
val_metrics: Optional[nn.Module] = None,
423+
log_test_metrics: bool = False,
424+
test_metrics: Optional[nn.Module] = None,
423425
):
424426
"""
425427
Initializes the SetR model.
@@ -481,6 +483,7 @@ def __init__(
481483

482484
self.log_train_metrics = log_train_metrics
483485
self.log_val_metrics = log_val_metrics
486+
self.log_test_metrics = log_test_metrics
484487

485488
if log_train_metrics:
486489
assert (
@@ -494,6 +497,12 @@ def __init__(
494497
), "val_metrics must be provided if log_val_metrics is True"
495498
self.val_metrics = val_metrics
496499

500+
if log_test_metrics:
501+
assert (
502+
test_metrics is not None
503+
), "test_metrics must be provided if log_test_metrics is True"
504+
self.test_metrics = test_metrics
505+
497506
self.model = _SetR_PUP(
498507
image_size=image_size,
499508
patch_size=patch_size,
@@ -515,6 +524,15 @@ def __init__(
515524
align_corners=align_corners,
516525
)
517526

527+
self.train_step_outputs = []
528+
self.train_step_labels = []
529+
530+
self.val_step_outputs = []
531+
self.val_step_labels = []
532+
533+
self.test_step_outputs = []
534+
self.test_step_labels = []
535+
518536
def forward(self, x: torch.Tensor) -> torch.Tensor:
519537
return self.model(x)
520538

@@ -536,9 +554,7 @@ def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
536554
loss = self.loss_fn(y_hat, y.long())
537555
return loss
538556

539-
def _single_step(
540-
self, batch: torch.Tensor, batch_idx: int, step_name: str
541-
) -> torch.Tensor:
557+
def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
542558
"""Perform a single step of the training/validation loop.
543559
544560
Parameters
@@ -559,39 +575,86 @@ def _single_step(
559575
y_hat = self.model(x.float())
560576
loss = self._loss_func(y_hat[0], y.squeeze(1))
561577

562-
if step_name == "train" and self.log_train_metrics:
563-
preds = torch.argmax(y_hat[0], dim=1, keepdim=True)
578+
if step_name == "train":
579+
self.train_step_outputs.append(y_hat[0])
580+
self.train_step_labels.append(y)
581+
elif step_name == "val":
582+
self.val_step_outputs.append(y_hat[0])
583+
self.val_step_labels.append(y)
584+
elif step_name == "test":
585+
self.test_step_outputs.append(y_hat[0])
586+
self.test_step_labels.append(y)
587+
588+
self.log_dict(
589+
{
590+
f"{step_name}_loss": loss,
591+
},
592+
on_step=True,
593+
on_epoch=True,
594+
prog_bar=True,
595+
logger=True,
596+
)
597+
598+
return loss
599+
600+
def on_train_epoch_end(self):
601+
if self.log_train_metrics:
602+
y_hat = torch.cat(self.train_step_outputs)
603+
y = torch.cat(self.train_step_labels)
604+
preds = torch.argmax(y_hat, dim=1, keepdim=True)
564605
self.train_metrics(preds, y)
565606
mIoU = self.train_metrics.compute()
566-
self.log(
567-
f"{step_name}_metrics",
568-
mIoU,
569-
on_step=True,
607+
608+
self.log_dict(
609+
{
610+
f"train_metrics": mIoU,
611+
},
612+
on_step=False,
570613
on_epoch=True,
614+
prog_bar=True,
571615
logger=True,
572616
)
573-
574-
if step_name == "val" and self.log_val_metrics:
575-
preds = torch.argmax(y_hat[0], dim=1, keepdim=True)
576-
self.train_metrics(preds, y)
617+
self.train_step_outputs.clear()
618+
self.train_step_labels.clear()
619+
620+
def on_validation_epoch_end(self):
621+
if self.log_val_metrics:
622+
y_hat = torch.cat(self.val_step_outputs)
623+
y = torch.cat(self.val_step_labels)
624+
preds = torch.argmax(y_hat, dim=1, keepdim=True)
625+
self.val_metrics(preds, y)
577626
mIoU = self.val_metrics.compute()
578-
self.log(
579-
f"{step_name}_metrics",
580-
mIoU,
581-
on_step=True,
627+
628+
self.log_dict(
629+
{
630+
f"val_metrics": mIoU,
631+
},
632+
on_step=False,
582633
on_epoch=True,
634+
prog_bar=True,
583635
logger=True,
584636
)
585-
586-
self.log(
587-
f"{step_name}_loss",
588-
loss,
589-
on_step=True,
590-
on_epoch=True,
591-
prog_bar=True,
592-
logger=True,
593-
)
594-
return loss
637+
self.val_step_outputs.clear()
638+
self.val_step_labels.clear()
639+
640+
def on_test_epoch_end(self):
641+
if self.log_test_metrics:
642+
y_hat = torch.cat(self.test_step_outputs)
643+
y = torch.cat(self.test_step_labels)
644+
preds = torch.argmax(y_hat, dim=1, keepdim=True)
645+
self.test_metrics(preds, y)
646+
mIoU = self.test_metrics.compute()
647+
self.log_dict(
648+
{
649+
f"test_metrics": mIoU,
650+
},
651+
on_step=False,
652+
on_epoch=True,
653+
prog_bar=True,
654+
logger=True,
655+
)
656+
self.test_step_outputs.clear()
657+
self.test_step_labels.clear()
595658

596659
def training_step(self, batch: torch.Tensor, batch_idx: int):
597660
return self._single_step(batch, batch_idx, "train")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ dependencies = [
5151
]
5252

5353
[tool.setuptools]
54-
py-modules = []
54+
packages = ["minerva"]
5555

5656
[project.optional-dependencies]
5757
dev = ["mock", "pytest", "black", "isort"]

0 commit comments

Comments
 (0)