Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed Dec 6, 2024
1 parent e34d5b1 commit 9271ec4
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
38 changes: 33 additions & 5 deletions test/evaluator/test_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,43 @@
""" Test the TBEvaluator class."""
import pytest

import torch
from topobenchmark.evaluator import TBEvaluator

class TestTBEvaluator:
""" Test the TBXEvaluator class."""

def setup_method(self):
""" Setup the test."""
self.evaluator_multilable = TBEvaluator(task="multilabel classification")
self.evaluator_regression = TBEvaluator(task="regression")
self.classification_metrics = ["accuracy", "precision", "recall", "auroc"]
self.evaluator_classification = TBEvaluator(task="classification", num_classes=3, metrics=self.classification_metrics)
self.evaluator_multilabel = TBEvaluator(task="multilabel classification", num_classes=2, metrics=self.classification_metrics)
self.regression_metrics = ["example", "mae"]
self.evaluator_regression = TBEvaluator(task="regression", num_classes=1, metrics=self.regression_metrics)
with pytest.raises(ValueError):
TBEvaluator(task="wrong")
repr = self.evaluator_multilable.__repr__()
TBEvaluator(task="wrong", num_classes=2, metrics=self.classification_metrics)

def test_repr(self):
"""Test the __repr__ method."""
assert "TBEvaluator" in self.evaluator_classification.__repr__()
assert "TBEvaluator" in self.evaluator_multilabel.__repr__()
assert "TBEvaluator" in self.evaluator_regression.__repr__()

def test_update_and_compute(self):
"""Test the update and compute methods."""
self.evaluator_classification.update({"logits": torch.randn(10, 3), "labels": torch.randint(0, 3, (10,))})
out = self.evaluator_classification.compute()
for metric in self.classification_metrics:
assert metric in out
self.evaluator_multilabel.update({"logits": torch.randn(10, 2), "labels": torch.randint(0, 2, (10, 2))})
out = self.evaluator_multilabel.compute()
for metric in self.classification_metrics:
assert metric in out
self.evaluator_regression.update({"logits": torch.randn(10, 1), "labels": torch.randn(10,)})
out = self.evaluator_regression.compute()
for metric in self.regression_metrics:
assert metric in out

def test_reset(self):
"""Test the reset method."""
self.evaluator_multilabel.reset()
self.evaluator_regression.reset()
3 changes: 3 additions & 0 deletions test/utils/test_config_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def test_infer_num_cell_dimensions(self):

def test_get_default_metrics(self):
"""Test get_default_metrics."""
out = get_default_metrics("classification", ["accuracy", "precision"])
assert out == ["accuracy", "precision"]

out = get_default_metrics("classification")
assert out == ["accuracy", "precision", "recall", "auroc"]

Expand Down
8 changes: 6 additions & 2 deletions topobenchmark/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ def __init__(self, task, **kwargs):
elif self.task == "multilabel classification":
parameters = {"num_classes": kwargs["num_classes"]}
parameters["task"] = "multilabel"
parameters["num_labels"] = kwargs["num_classes"]
metric_names = kwargs["metrics"]

elif self.task == "regression":
parameters = {}
metric_names = kwargs["metrics"]

else:
raise ValueError(f"Invalid task {kwargs['task']}")
raise ValueError(f"Invalid task {task}")

metrics = {}
for name in metric_names:
Expand Down Expand Up @@ -83,7 +84,10 @@ def update(self, model_out: dict):
if self.task == "regression":
self.metrics.update(preds, target.unsqueeze(1))

elif self.task == "classification":
elif (
self.task == "classification"
or self.task == "multilabel classification"
):
self.metrics.update(preds, target)

else:
Expand Down

0 comments on commit 9271ec4

Please sign in to comment.