From 7c6e5fb5dfcb06e5c6d4294503dd4c4a01d4391a Mon Sep 17 00:00:00 2001 From: golmschenk Date: Sat, 13 Jul 2024 12:42:19 -0400 Subject: [PATCH] Add test that functional metrics return the correct compute value --- tests/unit_tests/metrics.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/unit_tests/metrics.py b/tests/unit_tests/metrics.py index 61bd188..dcdec42 100644 --- a/tests/unit_tests/metrics.py +++ b/tests/unit_tests/metrics.py @@ -1,5 +1,6 @@ import torch from torch.nn import MSELoss +from torchmetrics import MeanSquaredError from qusi.internal.train_session import update_logging_metrics @@ -12,3 +13,15 @@ def test_update_logging_metrics_for_functional_metrics(): metric = MSELoss() update_logging_metrics(predicted_targets, targets, [metric], metric_totals) assert metric_totals == expected_metric_totals + + +def test_update_logging_metrics_for_state_based_torchmetrics(): + metric = MeanSquaredError() + value = metric(torch.tensor([1.]), torch.tensor([2.])) # Add some already stored metric value. + predicted_targets = torch.tensor([1.]) + targets = torch.tensor([3.]) + metric_totals = torch.tensor([0.]) + expected_computed_metric_value = torch.tensor(2.5) # Average of 1 and 4. + update_logging_metrics(predicted_targets, targets, [metric], metric_totals) + computed_metric_value = metric.compute() + assert computed_metric_value == expected_computed_metric_value