diff --git a/test/evaluator/test_evaluator.py b/test/evaluator/test_evaluator.py index e09eb579..eecc59a0 100644 --- a/test/evaluator/test_evaluator.py +++ b/test/evaluator/test_evaluator.py @@ -1,6 +1,6 @@ """ Test the TBEvaluator class.""" import pytest - +import torch from topobenchmark.evaluator import TBEvaluator class TestTBEvaluator: @@ -8,8 +8,36 @@ class TestTBEvaluator: def setup_method(self): """ Setup the test.""" - self.evaluator_multilable = TBEvaluator(task="multilabel classification") - self.evaluator_regression = TBEvaluator(task="regression") + self.classification_metrics = ["accuracy", "precision", "recall", "auroc"] + self.evaluator_classification = TBEvaluator(task="classification", num_classes=3, metrics=self.classification_metrics) + self.evaluator_multilabel = TBEvaluator(task="multilabel classification", num_classes=2, metrics=self.classification_metrics) + self.regression_metrics = ["example", "mae"] + self.evaluator_regression = TBEvaluator(task="regression", num_classes=1, metrics=self.regression_metrics) with pytest.raises(ValueError): - TBEvaluator(task="wrong") - repr = self.evaluator_multilable.__repr__() \ No newline at end of file + TBEvaluator(task="wrong", num_classes=2, metrics=self.classification_metrics) + + def test_repr(self): + """Test the __repr__ method.""" + assert "TBEvaluator" in self.evaluator_classification.__repr__() + assert "TBEvaluator" in self.evaluator_multilabel.__repr__() + assert "TBEvaluator" in self.evaluator_regression.__repr__() + + def test_update_and_compute(self): + """Test the update and compute methods.""" + self.evaluator_classification.update({"logits": torch.randn(10, 3), "labels": torch.randint(0, 3, (10,))}) + out = self.evaluator_classification.compute() + for metric in self.classification_metrics: + assert metric in out + self.evaluator_multilabel.update({"logits": torch.randn(10, 2), "labels": torch.randint(0, 2, (10, 2))}) + out = self.evaluator_multilabel.compute() + for metric in self.classification_metrics: + assert metric in out + self.evaluator_regression.update({"logits": torch.randn(10, 1), "labels": torch.randn(10,)}) + out = self.evaluator_regression.compute() + for metric in self.regression_metrics: + assert metric in out + + def test_reset(self): + """Test the reset method.""" + self.evaluator_multilabel.reset() + self.evaluator_regression.reset() diff --git a/test/utils/test_config_resolvers.py b/test/utils/test_config_resolvers.py index 9137de1a..6da4697f 100644 --- a/test/utils/test_config_resolvers.py +++ b/test/utils/test_config_resolvers.py @@ -117,6 +117,9 @@ def test_infer_num_cell_dimensions(self): def test_get_default_metrics(self): """Test get_default_metrics.""" + out = get_default_metrics("classification", ["accuracy", "precision"]) + assert out == ["accuracy", "precision"] + out = get_default_metrics("classification") assert out == ["accuracy", "precision", "recall", "auroc"] diff --git a/topobenchmark/evaluator/evaluator.py b/topobenchmark/evaluator/evaluator.py index 8206f87e..c091ca62 100755 --- a/topobenchmark/evaluator/evaluator.py +++ b/topobenchmark/evaluator/evaluator.py @@ -37,6 +37,7 @@ def __init__(self, task, **kwargs): elif self.task == "multilabel classification": parameters = {"num_classes": kwargs["num_classes"]} parameters["task"] = "multilabel" + parameters["num_labels"] = kwargs["num_classes"] metric_names = kwargs["metrics"] elif self.task == "regression": @@ -44,7 +45,7 @@ def __init__(self, task, **kwargs): metric_names = kwargs["metrics"] else: - raise ValueError(f"Invalid task {kwargs['task']}") + raise ValueError(f"Invalid task {task}") metrics = {} for name in metric_names: @@ -83,7 +84,10 @@ def update(self, model_out: dict): if self.task == "regression": self.metrics.update(preds, target.unsqueeze(1)) - elif self.task == "classification": + elif ( + self.task == "classification" + or self.task == "multilabel classification" + ): self.metrics.update(preds, target) else: