Skip to content

Commit 244d164

Browse files
committed
Refactor tests to use a common thresholds list
1 parent 6e57b47 commit 244d164

File tree

2 files changed

+20
-34
lines changed

2 files changed

+20
-34
lines changed

tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Test precision-recall curve metric."""
22
from functools import partial
3-
from types import ModuleType
43
from typing import List, Tuple, Union
54

65
import array_api_compat as apc
@@ -32,16 +31,10 @@
3231
from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point
3332

3433
from ..conftest import NUM_CLASSES, NUM_LABELS
35-
from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases
34+
from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds
3635
from .testers import MetricTester, _inject_ignore_index
3736

3837

39-
def _thresholds_for_prc(*, xp: ModuleType) -> list:
40-
"""Return thresholds for precision-recall curve."""
41-
thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]
42-
return [None, 5, thresh_list, xp.asarray(thresh_list)]
43-
44-
4538
def _binary_precision_recall_curve_reference(
4639
target,
4740
preds,
@@ -63,7 +56,7 @@ class TestBinaryPrecisionRecallCurve(MetricTester):
6356
"""Test binary precision-recall curve function and class."""
6457

6558
@pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:])
66-
@pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp))
59+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
6760
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
6861
def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays(
6962
self,
@@ -99,7 +92,7 @@ def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays(
9992
)
10093

10194
@pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:])
102-
@pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp))
95+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
10396
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
10497
def test_binary_precision_recall_curve_class_with_numpy_array_api_arrays(
10598
self,
@@ -149,7 +142,7 @@ def test_binary_precision_recall_curve_class_with_numpy_array_api_arrays(
149142
@pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:])
150143
@pytest.mark.parametrize(
151144
"thresholds",
152-
_thresholds_for_prc(xp=array_api_compat.torch),
145+
_thresholds(xp=array_api_compat.torch),
153146
)
154147
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
155148
def test_binary_precision_recall_curve_with_torch_tensors(
@@ -233,7 +226,7 @@ class TestMulticlassPrecisionRecallCurve(MetricTester):
233226
"""Test multiclass precision-recall curve function and class."""
234227

235228
@pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:])
236-
@pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp))
229+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
237230
@pytest.mark.parametrize("average", [None, "none"])
238231
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
239232
def test_multiclass_precision_recall_curve_with_numpy_array_api_arrays(
@@ -273,7 +266,7 @@ def test_multiclass_precision_recall_curve_with_numpy_array_api_arrays(
273266
)
274267

275268
@pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:])
276-
@pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp))
269+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
277270
@pytest.mark.parametrize("average", [None, "none"])
278271
@pytest.mark.parametrize("ignore_index", [None, 1, -1])
279272
def test_multiclass_precision_recall_curve_class_with_numpy_array_api_arrays(
@@ -316,7 +309,7 @@ def test_multiclass_precision_recall_curve_class_with_numpy_array_api_arrays(
316309
@pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:])
317310
@pytest.mark.parametrize(
318311
"thresholds",
319-
_thresholds_for_prc(xp=array_api_compat.torch),
312+
_thresholds(xp=array_api_compat.torch),
320313
)
321314
@pytest.mark.parametrize("average", [None, "none"])
322315
@pytest.mark.parametrize("ignore_index", [None, 1, -1])
@@ -389,7 +382,7 @@ class TestMultilabelPrecisionRecallCurve(MetricTester):
389382
"""Test multilabel precision-recall curve function and class."""
390383

391384
@pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:])
392-
@pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp))
385+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
393386
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
394387
def test_multilabel_precision_recall_curve_with_numpy_array_api_arrays(
395388
self,
@@ -420,7 +413,7 @@ def test_multilabel_precision_recall_curve_with_numpy_array_api_arrays(
420413
)
421414

422415
@pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:])
423-
@pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp))
416+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
424417
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
425418
def test_multilabel_precision_recall_curve_class_with_numpy_array_api_arrays(
426419
self,
@@ -454,7 +447,7 @@ def test_multilabel_precision_recall_curve_class_with_numpy_array_api_arrays(
454447
@pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:])
455448
@pytest.mark.parametrize(
456449
"thresholds",
457-
_thresholds_for_prc(xp=array_api_compat.torch),
450+
_thresholds(xp=array_api_compat.torch),
458451
)
459452
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
460453
def test_multilabel_precision_recall_curve_class_with_torch_tensors(

tests/cyclops/evaluate/metrics/experimental/test_roc.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Test roc curve metric."""
22
from functools import partial
3-
from types import ModuleType
43
from typing import List, Tuple, Union
54

65
import array_api_compat as apc
@@ -32,16 +31,10 @@
3231
from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point
3332

3433
from ..conftest import NUM_CLASSES, NUM_LABELS
35-
from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases
34+
from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds
3635
from .testers import MetricTester, _inject_ignore_index
3736

3837

39-
def _thresholds_for_roc(*, xp: ModuleType) -> list:
40-
"""Return thresholds for roc curve."""
41-
thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]
42-
return [None, 5, thresh_list, xp.asarray(thresh_list)]
43-
44-
4538
def _binary_roc_reference(
4639
target,
4740
preds,
@@ -63,7 +56,7 @@ class TestBinaryROC(MetricTester):
6356
"""Test binary roc curve function and class."""
6457

6558
@pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:])
66-
@pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp))
59+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
6760
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
6861
def test_binary_roc_function_with_numpy_array_api_arrays(
6962
self,
@@ -99,7 +92,7 @@ def test_binary_roc_function_with_numpy_array_api_arrays(
9992
)
10093

10194
@pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:])
102-
@pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp))
95+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
10396
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
10497
def test_binary_roc_class_with_numpy_array_api_arrays(
10598
self,
@@ -149,7 +142,7 @@ def test_binary_roc_class_with_numpy_array_api_arrays(
149142
@pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:])
150143
@pytest.mark.parametrize(
151144
"thresholds",
152-
_thresholds_for_roc(xp=array_api_compat.torch),
145+
_thresholds(xp=array_api_compat.torch),
153146
)
154147
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
155148
def test_binary_roc_with_torch_tensors(
@@ -233,7 +226,7 @@ class TestMulticlassROC(MetricTester):
233226
"""Test multiclass roc curve function and class."""
234227

235228
@pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:])
236-
@pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp))
229+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
237230
@pytest.mark.parametrize("average", [None, "none"])
238231
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
239232
def test_multiclass_roc_with_numpy_array_api_arrays(
@@ -273,7 +266,7 @@ def test_multiclass_roc_with_numpy_array_api_arrays(
273266
)
274267

275268
@pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:])
276-
@pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp))
269+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
277270
@pytest.mark.parametrize("average", [None, "none"])
278271
@pytest.mark.parametrize("ignore_index", [None, 1, -1])
279272
def test_multiclass_roc_class_with_numpy_array_api_arrays(
@@ -316,7 +309,7 @@ def test_multiclass_roc_class_with_numpy_array_api_arrays(
316309
@pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:])
317310
@pytest.mark.parametrize(
318311
"thresholds",
319-
_thresholds_for_roc(xp=array_api_compat.torch),
312+
_thresholds(xp=array_api_compat.torch),
320313
)
321314
@pytest.mark.parametrize("average", [None, "none"])
322315
@pytest.mark.parametrize("ignore_index", [None, 1, -1])
@@ -389,7 +382,7 @@ class TestMultilabelROC(MetricTester):
389382
"""Test multilabel roc curve function and class."""
390383

391384
@pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:])
392-
@pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp))
385+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
393386
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
394387
def test_multilabel_roc_with_numpy_array_api_arrays(
395388
self,
@@ -420,7 +413,7 @@ def test_multilabel_roc_with_numpy_array_api_arrays(
420413
)
421414

422415
@pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:])
423-
@pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp))
416+
@pytest.mark.parametrize("thresholds", _thresholds(xp=anp))
424417
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
425418
def test_multilabel_roc_class_with_numpy_array_api_arrays(
426419
self,
@@ -454,7 +447,7 @@ def test_multilabel_roc_class_with_numpy_array_api_arrays(
454447
@pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:])
455448
@pytest.mark.parametrize(
456449
"thresholds",
457-
_thresholds_for_roc(xp=array_api_compat.torch),
450+
_thresholds(xp=array_api_compat.torch),
458451
)
459452
@pytest.mark.parametrize("ignore_index", [None, 0, -1])
460453
def test_multilabel_roc_class_with_torch_tensors(

0 commit comments

Comments
 (0)