@@ -420,6 +420,8 @@ def __init__(
420
420
train_metrics : Optional [nn .Module ] = None ,
421
421
log_val_metrics : bool = False ,
422
422
val_metrics : Optional [nn .Module ] = None ,
423
+ log_test_metrics : bool = False ,
424
+ test_metrics : Optional [nn .Module ] = None ,
423
425
):
424
426
"""
425
427
Initializes the SetR model.
@@ -481,6 +483,7 @@ def __init__(
481
483
482
484
self .log_train_metrics = log_train_metrics
483
485
self .log_val_metrics = log_val_metrics
486
+ self .log_test_metrics = log_test_metrics
484
487
485
488
if log_train_metrics :
486
489
assert (
@@ -494,6 +497,12 @@ def __init__(
494
497
), "val_metrics must be provided if log_val_metrics is True"
495
498
self .val_metrics = val_metrics
496
499
500
+ if log_test_metrics :
501
+ assert (
502
+ test_metrics is not None
503
+ ), "test_metrics must be provided if log_test_metrics is True"
504
+ self .test_metrics = test_metrics
505
+
497
506
self .model = _SetR_PUP (
498
507
image_size = image_size ,
499
508
patch_size = patch_size ,
@@ -515,6 +524,15 @@ def __init__(
515
524
align_corners = align_corners ,
516
525
)
517
526
527
+ self .train_step_outputs = []
528
+ self .train_step_labels = []
529
+
530
+ self .val_step_outputs = []
531
+ self .val_step_labels = []
532
+
533
+ self .test_step_outputs = []
534
+ self .test_step_labels = []
535
+
518
536
def forward (self , x : torch .Tensor ) -> torch .Tensor :
519
537
return self .model (x )
520
538
@@ -536,9 +554,7 @@ def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
536
554
loss = self .loss_fn (y_hat , y .long ())
537
555
return loss
538
556
539
- def _single_step (
540
- self , batch : torch .Tensor , batch_idx : int , step_name : str
541
- ) -> torch .Tensor :
557
+ def _single_step (self , batch : torch .Tensor , batch_idx : int , step_name : str ):
542
558
"""Perform a single step of the training/validation loop.
543
559
544
560
Parameters
@@ -559,39 +575,86 @@ def _single_step(
559
575
y_hat = self .model (x .float ())
560
576
loss = self ._loss_func (y_hat [0 ], y .squeeze (1 ))
561
577
562
- if step_name == "train" and self .log_train_metrics :
563
- preds = torch .argmax (y_hat [0 ], dim = 1 , keepdim = True )
578
+ if step_name == "train" :
579
+ self .train_step_outputs .append (y_hat [0 ])
580
+ self .train_step_labels .append (y )
581
+ elif step_name == "val" :
582
+ self .val_step_outputs .append (y_hat [0 ])
583
+ self .val_step_labels .append (y )
584
+ elif step_name == "test" :
585
+ self .test_step_outputs .append (y_hat [0 ])
586
+ self .test_step_labels .append (y )
587
+
588
+ self .log_dict (
589
+ {
590
+ f"{ step_name } _loss" : loss ,
591
+ },
592
+ on_step = True ,
593
+ on_epoch = True ,
594
+ prog_bar = True ,
595
+ logger = True ,
596
+ )
597
+
598
+ return loss
599
+
600
+ def on_train_epoch_end (self ):
601
+ if self .log_train_metrics :
602
+ y_hat = torch .cat (self .train_step_outputs )
603
+ y = torch .cat (self .train_step_labels )
604
+ preds = torch .argmax (y_hat , dim = 1 , keepdim = True )
564
605
self .train_metrics (preds , y )
565
606
mIoU = self .train_metrics .compute ()
566
- self .log (
567
- f"{ step_name } _metrics" ,
568
- mIoU ,
569
- on_step = True ,
607
+
608
+ self .log_dict (
609
+ {
610
+ f"train_metrics" : mIoU ,
611
+ },
612
+ on_step = False ,
570
613
on_epoch = True ,
614
+ prog_bar = True ,
571
615
logger = True ,
572
616
)
573
-
574
- if step_name == "val" and self .log_val_metrics :
575
- preds = torch .argmax (y_hat [0 ], dim = 1 , keepdim = True )
576
- self .train_metrics (preds , y )
617
+ self .train_step_outputs .clear ()
618
+ self .train_step_labels .clear ()
619
+
620
+ def on_validation_epoch_end (self ):
621
+ if self .log_val_metrics :
622
+ y_hat = torch .cat (self .val_step_outputs )
623
+ y = torch .cat (self .val_step_labels )
624
+ preds = torch .argmax (y_hat , dim = 1 , keepdim = True )
625
+ self .val_metrics (preds , y )
577
626
mIoU = self .val_metrics .compute ()
578
- self .log (
579
- f"{ step_name } _metrics" ,
580
- mIoU ,
581
- on_step = True ,
627
+
628
+ self .log_dict (
629
+ {
630
+ f"val_metrics" : mIoU ,
631
+ },
632
+ on_step = False ,
582
633
on_epoch = True ,
634
+ prog_bar = True ,
583
635
logger = True ,
584
636
)
585
-
586
- self .log (
587
- f"{ step_name } _loss" ,
588
- loss ,
589
- on_step = True ,
590
- on_epoch = True ,
591
- prog_bar = True ,
592
- logger = True ,
593
- )
594
- return loss
637
+ self .val_step_outputs .clear ()
638
+ self .val_step_labels .clear ()
639
+
640
+ def on_test_epoch_end (self ):
641
+ if self .log_test_metrics :
642
+ y_hat = torch .cat (self .test_step_outputs )
643
+ y = torch .cat (self .test_step_labels )
644
+ preds = torch .argmax (y_hat , dim = 1 , keepdim = True )
645
+ self .test_metrics (preds , y )
646
+ mIoU = self .test_metrics .compute ()
647
+ self .log_dict (
648
+ {
649
+ f"test_metrics" : mIoU ,
650
+ },
651
+ on_step = False ,
652
+ on_epoch = True ,
653
+ prog_bar = True ,
654
+ logger = True ,
655
+ )
656
+ self .test_step_outputs .clear ()
657
+ self .test_step_labels .clear ()
595
658
596
659
def training_step (self , batch : torch .Tensor , batch_idx : int ):
597
660
return self ._single_step (batch , batch_idx , "train" )
0 commit comments