1
1
"""Test precision-recall curve metric."""
2
2
from functools import partial
3
- from types import ModuleType
4
3
from typing import List , Tuple , Union
5
4
6
5
import array_api_compat as apc
32
31
from cyclops .evaluate .metrics .experimental .utils .validation import is_floating_point
33
32
34
33
from ..conftest import NUM_CLASSES , NUM_LABELS
35
- from .inputs import _binary_cases , _multiclass_cases , _multilabel_cases
34
+ from .inputs import _binary_cases , _multiclass_cases , _multilabel_cases , _thresholds
36
35
from .testers import MetricTester , _inject_ignore_index
37
36
38
37
39
- def _thresholds_for_prc (* , xp : ModuleType ) -> list :
40
- """Return thresholds for precision-recall curve."""
41
- thresh_list = [0.0 , 0.3 , 0.5 , 0.7 , 0.9 , 1.0 ]
42
- return [None , 5 , thresh_list , xp .asarray (thresh_list )]
43
-
44
-
45
38
def _binary_precision_recall_curve_reference (
46
39
target ,
47
40
preds ,
@@ -63,7 +56,7 @@ class TestBinaryPrecisionRecallCurve(MetricTester):
63
56
"""Test binary precision-recall curve function and class."""
64
57
65
58
@pytest .mark .parametrize ("inputs" , _binary_cases (xp = anp )[3 :])
66
- @pytest .mark .parametrize ("thresholds" , _thresholds_for_prc (xp = anp ))
59
+ @pytest .mark .parametrize ("thresholds" , _thresholds (xp = anp ))
67
60
@pytest .mark .parametrize ("ignore_index" , [None , 0 , - 1 ])
68
61
def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays (
69
62
self ,
@@ -99,7 +92,7 @@ def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays(
99
92
)
100
93
101
94
@pytest .mark .parametrize ("inputs" , _binary_cases (xp = anp )[3 :])
102
- @pytest .mark .parametrize ("thresholds" , _thresholds_for_prc (xp = anp ))
95
+ @pytest .mark .parametrize ("thresholds" , _thresholds (xp = anp ))
103
96
@pytest .mark .parametrize ("ignore_index" , [None , 0 , - 1 ])
104
97
def test_binary_precision_recall_curve_class_with_numpy_array_api_arrays (
105
98
self ,
@@ -149,7 +142,7 @@ def test_binary_precision_recall_curve_class_with_numpy_array_api_arrays(
149
142
@pytest .mark .parametrize ("inputs" , _binary_cases (xp = array_api_compat .torch )[3 :])
150
143
@pytest .mark .parametrize (
151
144
"thresholds" ,
152
- _thresholds_for_prc (xp = array_api_compat .torch ),
145
+ _thresholds (xp = array_api_compat .torch ),
153
146
)
154
147
@pytest .mark .parametrize ("ignore_index" , [None , 0 , - 1 ])
155
148
def test_binary_precision_recall_curve_with_torch_tensors (
@@ -233,7 +226,7 @@ class TestMulticlassPrecisionRecallCurve(MetricTester):
233
226
"""Test multiclass precision-recall curve function and class."""
234
227
235
228
@pytest .mark .parametrize ("inputs" , _multiclass_cases (xp = anp )[4 :])
236
- @pytest .mark .parametrize ("thresholds" , _thresholds_for_prc (xp = anp ))
229
+ @pytest .mark .parametrize ("thresholds" , _thresholds (xp = anp ))
237
230
@pytest .mark .parametrize ("average" , [None , "none" ])
238
231
@pytest .mark .parametrize ("ignore_index" , [None , 0 , - 1 ])
239
232
def test_multiclass_precision_recall_curve_with_numpy_array_api_arrays (
@@ -273,7 +266,7 @@ def test_multiclass_precision_recall_curve_with_numpy_array_api_arrays(
273
266
)
274
267
275
268
@pytest .mark .parametrize ("inputs" , _multiclass_cases (xp = anp )[4 :])
276
- @pytest .mark .parametrize ("thresholds" , _thresholds_for_prc (xp = anp ))
269
+ @pytest .mark .parametrize ("thresholds" , _thresholds (xp = anp ))
277
270
@pytest .mark .parametrize ("average" , [None , "none" ])
278
271
@pytest .mark .parametrize ("ignore_index" , [None , 1 , - 1 ])
279
272
def test_multiclass_precision_recall_curve_class_with_numpy_array_api_arrays (
@@ -316,7 +309,7 @@ def test_multiclass_precision_recall_curve_class_with_numpy_array_api_arrays(
316
309
@pytest .mark .parametrize ("inputs" , _multiclass_cases (xp = array_api_compat .torch )[4 :])
317
310
@pytest .mark .parametrize (
318
311
"thresholds" ,
319
- _thresholds_for_prc (xp = array_api_compat .torch ),
312
+ _thresholds (xp = array_api_compat .torch ),
320
313
)
321
314
@pytest .mark .parametrize ("average" , [None , "none" ])
322
315
@pytest .mark .parametrize ("ignore_index" , [None , 1 , - 1 ])
@@ -389,7 +382,7 @@ class TestMultilabelPrecisionRecallCurve(MetricTester):
389
382
"""Test multilabel precision-recall curve function and class."""
390
383
391
384
@pytest .mark .parametrize ("inputs" , _multilabel_cases (xp = anp )[2 :])
392
- @pytest .mark .parametrize ("thresholds" , _thresholds_for_prc (xp = anp ))
385
+ @pytest .mark .parametrize ("thresholds" , _thresholds (xp = anp ))
393
386
@pytest .mark .parametrize ("ignore_index" , [None , 0 , - 1 ])
394
387
def test_multilabel_precision_recall_curve_with_numpy_array_api_arrays (
395
388
self ,
@@ -420,7 +413,7 @@ def test_multilabel_precision_recall_curve_with_numpy_array_api_arrays(
420
413
)
421
414
422
415
@pytest .mark .parametrize ("inputs" , _multilabel_cases (xp = anp )[2 :])
423
- @pytest .mark .parametrize ("thresholds" , _thresholds_for_prc (xp = anp ))
416
+ @pytest .mark .parametrize ("thresholds" , _thresholds (xp = anp ))
424
417
@pytest .mark .parametrize ("ignore_index" , [None , 0 , - 1 ])
425
418
def test_multilabel_precision_recall_curve_class_with_numpy_array_api_arrays (
426
419
self ,
@@ -454,7 +447,7 @@ def test_multilabel_precision_recall_curve_class_with_numpy_array_api_arrays(
454
447
@pytest .mark .parametrize ("inputs" , _multilabel_cases (xp = array_api_compat .torch )[2 :])
455
448
@pytest .mark .parametrize (
456
449
"thresholds" ,
457
- _thresholds_for_prc (xp = array_api_compat .torch ),
450
+ _thresholds (xp = array_api_compat .torch ),
458
451
)
459
452
@pytest .mark .parametrize ("ignore_index" , [None , 0 , - 1 ])
460
453
def test_multilabel_precision_recall_curve_class_with_torch_tensors (
0 commit comments