Skip to content

Commit ecec7a7

Browse files
committed
Use namedtuple to store curve results (ROC, PR)
1 parent fcf3598 commit ecec7a7

File tree

7 files changed

+507
-107
lines changed

7 files changed

+507
-107
lines changed

cyclops/evaluate/metrics/experimental/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,13 @@
5252
multilabel_tpr,
5353
)
5454
from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import (
55+
PRCurve,
5556
binary_precision_recall_curve,
5657
multiclass_precision_recall_curve,
5758
multilabel_precision_recall_curve,
5859
)
5960
from cyclops.evaluate.metrics.experimental.functional.roc import (
61+
ROCCurve,
6062
binary_roc,
6163
multiclass_roc,
6264
multilabel_roc,

cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Functions for computing the precision and recall for different unique thresholds."""
22
from types import ModuleType
3-
from typing import Any, List, Literal, Optional, Sequence, Tuple, Union
3+
from typing import Any, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
44

55
import array_api_compat as apc
66
import numpy as np
@@ -28,6 +28,14 @@
2828
)
2929

3030

31+
class PRCurve(NamedTuple):
32+
"""Named tuple with Precision-Recall curve (Precision, Recall and thresholds)."""
33+
34+
precision: Union[Array, List[Array]]
35+
recall: Union[Array, List[Array]]
36+
thresholds: Union[Array, List[Array]]
37+
38+
3139
def _validate_thresholds(thresholds: Optional[Union[int, List[float], Array]]) -> None:
3240
"""Validate the `thresholds` argument."""
3341
if thresholds is not None and not (
@@ -352,14 +360,13 @@ def binary_precision_recall_curve(
352360
353361
Returns
354362
-------
355-
precision : Array
356-
The precision values for all unique thresholds. The shape of the array is
363+
PRCurve
364+
A named tuple that contains the following elements:
365+
- `precision` values for all unique thresholds. The shape of the array is
357366
`(num_thresholds + 1,)`.
358-
recall : Array
359-
The recall values for all unique thresholds. The shape of the array is
367+
- `recall` values for all unique thresholds. The shape of the array is
360368
`(num_thresholds + 1,)`.
361-
thresholds : Array
362-
The thresholds used for computing the precision and recall values, in
369+
- `thresholds` used for computing the precision and recall values, in
363370
ascending order. The shape of the array is `(num_thresholds,)`.
364371
365372
Raises
@@ -688,7 +695,7 @@ def multiclass_precision_recall_curve(
688695
thresholds: Optional[Union[int, List[float], Array]] = None,
689696
average: Optional[Literal["macro", "micro", "none"]] = None,
690697
ignore_index: Optional[Union[int, Tuple[int]]] = None,
691-
) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]:
698+
) -> PRCurve:
692699
"""Compute the precision and recall for all unique thresholds.
693700
694701
Parameters
@@ -730,18 +737,17 @@ def multiclass_precision_recall_curve(
730737
731738
Returns
732739
-------
733-
precision : Array or List[Array]
734-
The precision values for all unique thresholds. If `thresholds` is `"none"`
740+
PRCurve
741+
A named tuple that contains the following elements:
742+
- `precision` values for all unique thresholds. If `thresholds` is `"none"`
735743
or `None`, a list for each class is returned with 1-D Arrays of shape
736744
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
737745
`(num_thresholds + 1, num_classes)` is returned.
738-
recall : Array or List[Array]
739-
The recall values for all unique thresholds. If `thresholds` is `"none"`
746+
- `recall` values for all unique thresholds. If `thresholds` is `"none"`
740747
or `None`, a list for each class is returned with 1-D Arrays of shape
741748
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
742749
`(num_thresholds + 1, num_classes)` is returned.
743-
thresholds : Array or List[Array]
744-
The thresholds used for computing the precision and recall values, in
750+
- `thresholds` used for computing the precision and recall values, in
745751
ascending order. If `thresholds` is `"none"` or `None`, a list for each
746752
class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise,
747753
a 1-D Array of shape `(num_thresholds,)` is returned.
@@ -868,12 +874,13 @@ class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise,
868874
average,
869875
xp=xp,
870876
)
871-
return _multiclass_precision_recall_curve_compute(
877+
precision, recall, thresholds_ = _multiclass_precision_recall_curve_compute(
872878
state,
873879
num_classes,
874880
thresholds=thresholds,
875881
average=average,
876882
)
883+
return PRCurve(precision, recall, thresholds_)
877884

878885

879886
def _multilabel_precision_recall_curve_validate_args(
@@ -1035,7 +1042,7 @@ def multilabel_precision_recall_curve(
10351042
num_labels: int,
10361043
thresholds: Optional[Union[int, List[float], Array]] = None,
10371044
ignore_index: Optional[int] = None,
1038-
) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]:
1045+
) -> PRCurve:
10391046
"""Compute the precision and recall for all unique thresholds.
10401047
10411048
Parameters
@@ -1067,21 +1074,20 @@ def multilabel_precision_recall_curve(
10671074
10681075
Returns
10691076
-------
1070-
precision : Array or List[Array]
1071-
The precision values for all unique thresholds. If `thresholds` is `None`,
1072-
a list for each label is returned with 1-D Arrays of shape
1077+
PRCurve
1078+
A named tuple that contains the following elements:
1079+
- `precision` values for all unique thresholds. If `thresholds` is `"none"`
1080+
or `None`, a list for each class is returned with 1-D Arrays of shape
10731081
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
1074-
`(num_thresholds + 1, num_labels)` is returned.
1075-
recall : Array or List[Array]
1076-
The recall values for all unique thresholds. If `thresholds` is `None`,
1077-
a list for each label is returned with 1-D Arrays of shape
1082+
`(num_thresholds + 1, num_classes)` is returned.
1083+
- `recall` values for all unique thresholds. If `thresholds` is `"none"`
1084+
or `None`, a list for each class is returned with 1-D Arrays of shape
10781085
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
1079-
`(num_thresholds + 1, num_labels)` is returned.
1080-
thresholds : Array or List[Array]
1081-
The thresholds used for computing the precision and recall values, in
1082-
ascending order. If `thresholds` is `None`, a list for each label is
1083-
returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D
1084-
Array of shape `(num_thresholds,)` is returned.
1086+
`(num_thresholds + 1, num_classes)` is returned.
1087+
- `thresholds` used for computing the precision and recall values, in
1088+
ascending order. If `thresholds` is `"none"` or `None`, a list for each
1089+
class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise,
1090+
a 1-D Array of shape `(num_thresholds,)` is returned.
10851091
10861092
Raises
10871093
------
@@ -1193,9 +1199,10 @@ def multilabel_precision_recall_curve(
11931199
thresholds,
11941200
xp=xp,
11951201
)
1196-
return _multilabel_precision_recall_curve_compute(
1202+
precision, recall, thresholds_ = _multilabel_precision_recall_curve_compute(
11971203
state,
11981204
num_labels,
11991205
thresholds,
12001206
ignore_index,
12011207
)
1208+
return PRCurve(precision, recall, thresholds_)

cyclops/evaluate/metrics/experimental/functional/roc.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Functions for computing Receiver Operating Characteristic (ROC) curves."""
22
import warnings
3-
from typing import List, Literal, Optional, Tuple, Union
3+
from typing import List, Literal, NamedTuple, Optional, Tuple, Union
44

55
import array_api_compat as apc
66

@@ -28,6 +28,14 @@
2828
from cyclops.evaluate.metrics.experimental.utils.types import Array
2929

3030

31+
class ROCCurve(NamedTuple):
32+
"""Named tuple to store ROC curve (FPR, TPR and thresholds)."""
33+
34+
fpr: Union[Array, List[Array]]
35+
tpr: Union[Array, List[Array]]
36+
thresholds: Union[Array, List[Array]]
37+
38+
3139
def _binary_roc_compute(
3240
state: Union[Array, Tuple[Array, Array]],
3341
thresholds: Optional[Array],
@@ -91,7 +99,7 @@ def binary_roc(
9199
preds: Array,
92100
thresholds: Optional[Union[int, List[float], Array]] = None,
93101
ignore_index: Optional[int] = None,
94-
) -> Tuple[Array, Array, Array]:
102+
) -> ROCCurve:
95103
"""Compute the receiver operating characteristic (ROC) curve for binary tasks.
96104
97105
Parameters
@@ -120,15 +128,11 @@ def binary_roc(
120128
121129
Returns
122130
-------
123-
fpr : Array
124-
The false positive rates for all unique thresholds. The shape of the array is
125-
`(num_thresholds + 1,)`.
126-
tpr : Array
127-
The true positive rates for all unique thresholds. The shape of the array is
128-
`(num_thresholds + 1,)`.
129-
thresholds : Array
130-
The thresholds used for computing the ROC curve values, in descending order.
131-
The shape of the array is `(num_thresholds,)`.
131+
ROCCurve
132+
A named tuple containing the false positive rate (FPR), true positive rate
133+
(TPR) and thresholds. The FPR and TPR are arrays of of shape
134+
`(num_thresholds + 1,)` and the thresholds are an array of shape
135+
`(num_thresholds,)`.
132136
133137
Raises
134138
------
@@ -209,7 +213,8 @@ def binary_roc(
209213
xp=xp,
210214
)
211215
state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp)
212-
return _binary_roc_compute(state, thresholds)
216+
fpr, tpr, thresh = _binary_roc_compute(state, thresholds)
217+
return ROCCurve(fpr, tpr, thresh)
213218

214219

215220
def _multiclass_roc_compute(
@@ -277,7 +282,7 @@ def multiclass_roc(
277282
thresholds: Optional[Union[int, List[float], Array]] = None,
278283
average: Optional[Literal["macro", "micro", "none"]] = None,
279284
ignore_index: Optional[Union[int, Tuple[int]]] = None,
280-
) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]:
285+
) -> ROCCurve:
281286
"""Compute the receiver operating characteristic (ROC) curve for multiclass tasks.
282287
283288
Parameters
@@ -318,19 +323,13 @@ def multiclass_roc(
318323
319324
Returns
320325
-------
321-
fpr : Array or List[Array]
322-
The false positive rates for all unique thresholds. If `thresholds` is `"none"`
323-
or `None`, a list for each class is returned with 1-D Arrays of shape
324-
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
326+
ROCCurve
327+
A named tuple that contains the false positive rate, true positive rate,
328+
and the thresholds used for computing the ROC curve. If `thresholds` is `"none"`
329+
or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays
330+
of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
325331
`(num_thresholds + 1, num_classes)` is returned.
326-
tpr : Array or List[Array]
327-
The true positive rates for all unique thresholds. If `thresholds` is `"none"`
328-
or `None`, a list for each class is returned with 1-D Arrays of shape
329-
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
330-
`(num_thresholds + 1, num_classes)` is returned.
331-
thresholds : Array or List[Array]
332-
The thresholds used for computing the ROC curve values, in descending order.
333-
If `thresholds` is `"none"` or `None`, a list for each class is returned
332+
Similarly, a list of thresholds for each class is returned
334333
with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of
335334
shape `(num_thresholds,)` is returned.
336335
@@ -455,12 +454,13 @@ def multiclass_roc(
455454
average,
456455
xp=xp,
457456
)
458-
return _multiclass_roc_compute(
457+
fpr_, tpr_, thresholds_ = _multiclass_roc_compute(
459458
state,
460459
num_classes,
461460
thresholds=thresholds,
462461
average=average,
463462
)
463+
return ROCCurve(fpr_, tpr_, thresholds_)
464464

465465

466466
def _multilabel_roc_compute(
@@ -504,7 +504,7 @@ def multilabel_roc(
504504
num_labels: int,
505505
thresholds: Optional[Union[int, List[float], Array]] = None,
506506
ignore_index: Optional[int] = None,
507-
) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]:
507+
) -> ROCCurve:
508508
"""Compute the receiver operating characteristic (ROC) curve for multilabel tasks.
509509
510510
Parameters
@@ -535,21 +535,15 @@ def multilabel_roc(
535535
536536
Returns
537537
-------
538-
fpr : Array or List[Array]
539-
The false positive rates for all unique thresholds. If `thresholds` is `None`,
540-
a list for each label is returned with 1-D Arrays of shape
541-
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
542-
`(num_thresholds + 1, num_labels)` is returned.
543-
tpr : Array or List[Array]
544-
The true positive rates for all unique thresholds. If `thresholds` is `None`,
545-
a list for each label is returned with 1-D Arrays of shape
546-
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
547-
`(num_thresholds + 1, num_labels)` is returned.
548-
thresholds : Array or List[Array]
549-
The thresholds used for computing the ROC curve values, in
550-
descending order. If `thresholds` is `None`, a list for each label is
551-
returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D
552-
Array of shape `(num_thresholds,)` is returned.
538+
ROCCurve
539+
A named tuple that contains the false positive rate, true positive rate,
540+
and the thresholds used for computing the ROC curve. If `thresholds` is `"none"`
541+
or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays
542+
of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
543+
`(num_thresholds + 1, num_classes)` is returned.
544+
Similarly, a list of thresholds for each class is returned
545+
with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of
546+
shape `(num_thresholds,)` is returned.
553547
554548
Raises
555549
------
@@ -660,9 +654,10 @@ def multilabel_roc(
660654
thresholds,
661655
xp=xp,
662656
)
663-
return _multilabel_roc_compute(
657+
fpr_, tpr_, thresholds_ = _multilabel_roc_compute(
664658
state,
665659
num_labels,
666660
thresholds,
667661
ignore_index,
668662
)
663+
return ROCCurve(fpr_, tpr_, thresholds_)

cyclops/evaluate/metrics/experimental/precision_recall_curve.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import array_api_compat as apc
66

77
from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import (
8+
PRCurve,
89
_binary_precision_recall_curve_compute,
910
_binary_precision_recall_curve_format_arrays,
1011
_binary_precision_recall_curve_update,
@@ -140,14 +141,18 @@ def _update_state(self, target: Array, preds: Array) -> None:
140141
self.target.append(state[0]) # type: ignore[attr-defined]
141142
self.preds.append(state[1]) # type: ignore[attr-defined]
142143

143-
def _compute_metric(self) -> Tuple[Array, Array, Array]:
144+
def _compute_metric(self) -> PRCurve:
144145
"""Compute the metric."""
145146
state = (
146147
(dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined]
147148
if self.thresholds is None
148149
else self.confmat # type: ignore[attr-defined]
149150
)
150-
return _binary_precision_recall_curve_compute(state, self.thresholds) # type: ignore[arg-type]
151+
precision, recall, thresholds = _binary_precision_recall_curve_compute(
152+
state,
153+
self.thresholds, # type: ignore
154+
)
155+
return PRCurve(precision, recall, thresholds)
151156

152157

153158
class MulticlassPrecisionRecallCurve(

cyclops/evaluate/metrics/experimental/roc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Tuple, Union
33

44
from cyclops.evaluate.metrics.experimental.functional.roc import (
5+
ROCCurve,
56
_binary_roc_compute,
67
_multiclass_roc_compute,
78
_multilabel_roc_compute,
@@ -55,13 +56,14 @@ class BinaryROC(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve"):
5556

5657
name: str = "ROC Curve"
5758

58-
def _compute_metric(self) -> Tuple[Array, Array, Array]:
59+
def _compute_metric(self) -> ROCCurve: # type: ignore
5960
state = (
6061
(dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined]
6162
if self.thresholds is None
6263
else self.confmat # type: ignore[attr-defined]
6364
)
64-
return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type]
65+
fpr, tpr, thresholds = _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type]
66+
return ROCCurve(fpr, tpr, thresholds)
6567

6668

6769
class MulticlassROC(

0 commit comments

Comments
 (0)