Skip to content

Commit 838bc96

Browse files
amrit110fcogidi
andauthored
Use namedtuple to store curve results (ROC, PR) for non-experimental … (#574)
* Use namedtuple to store curve results (ROC, PR) for non-experimental metrics * Remove commented out doc example * handle named tuples in `_apply_function_recursively` and fix doctests --------- Co-authored-by: Franklin <41602287+fcogidi@users.noreply.github.com>
1 parent c056d71 commit 838bc96

File tree

7 files changed

+197
-250
lines changed

7 files changed

+197
-250
lines changed

cyclops/evaluate/metrics/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@
3737
recall,
3838
)
3939
from cyclops.evaluate.metrics.functional.precision_recall_curve import ( # noqa: F401
40+
PRCurve,
4041
binary_precision_recall_curve,
4142
multiclass_precision_recall_curve,
4243
multilabel_precision_recall_curve,
4344
precision_recall_curve,
4445
)
4546
from cyclops.evaluate.metrics.functional.roc import ( # noqa: F401
47+
ROCCurve,
4648
binary_roc_curve,
4749
multiclass_roc_curve,
4850
multilabel_roc_curve,

cyclops/evaluate/metrics/functional/precision_recall_curve.py

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Functions for computing the precision-recall curve for different input types."""
22

3-
from typing import Any, List, Literal, Optional, Tuple, Union
3+
from typing import Any, List, Literal, NamedTuple, Optional, Tuple, Union
44

55
import numpy as np
66
import numpy.typing as npt
@@ -15,6 +15,14 @@
1515
)
1616

1717

18+
class PRCurve(NamedTuple):
19+
"""Named tuple with Precision-Recall curve (Precision, Recall and thresholds)."""
20+
21+
precision: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]
22+
recall: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]
23+
thresholds: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]
24+
25+
1826
def _format_thresholds(
1927
thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None,
2028
) -> Optional[npt.NDArray[np.float_]]:
@@ -279,7 +287,7 @@ def binary_precision_recall_curve(
279287
preds: npt.ArrayLike,
280288
thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None,
281289
pos_label: int = 1,
282-
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
290+
) -> PRCurve:
283291
"""Compute precision-recall curve for binary input.
284292
285293
Parameters
@@ -301,13 +309,10 @@ def binary_precision_recall_curve(
301309
302310
Returns
303311
-------
304-
precision : numpy.ndarray
305-
Precision scores such that element i is the precision of predictions
306-
with score >= thresholds[i].
307-
recall : numpy.ndarray
308-
Recall scores in descending order.
309-
thresholds : numpy.ndarray
310-
Thresholds used for computing the precision and recall scores.
312+
PRCurve
313+
A named tuple containing the precision (element i is the precision of predictions
314+
with score >= thresholds[i]), recall (scores in descending order)
315+
and thresholds used to compute the precision-recall curve.
311316
312317
Examples
313318
--------
@@ -335,13 +340,14 @@ def binary_precision_recall_curve(
335340
thresholds = _format_thresholds(thresholds)
336341

337342
state = _binary_precision_recall_curve_update(target, preds, thresholds)
338-
339-
return _binary_precision_recall_curve_compute(
343+
precision_, recall_, thresholds_ = _binary_precision_recall_curve_compute(
340344
state,
341345
thresholds,
342346
pos_label=pos_label,
343347
)
344348

349+
return PRCurve(precision_, recall_, thresholds_)
350+
345351

346352
def _multiclass_precision_recall_curve_format(
347353
target: npt.ArrayLike,
@@ -572,14 +578,7 @@ def multiclass_precision_recall_curve(
572578
preds: npt.ArrayLike,
573579
num_classes: int,
574580
thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None,
575-
) -> Union[
576-
Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
577-
Tuple[
578-
List[npt.NDArray[np.float_]],
579-
List[npt.NDArray[np.float_]],
580-
List[npt.NDArray[np.float_]],
581-
],
582-
]:
581+
) -> PRCurve:
583582
"""Compute the precision-recall curve for multiclass problems.
584583
585584
Parameters
@@ -600,18 +599,13 @@ def multiclass_precision_recall_curve(
600599
601600
Returns
602601
-------
603-
precision : numpy.ndarray or list of numpy.ndarray
604-
Precision scores where element i is the precision score corresponding
605-
to the threshold i. If state is a tuple of the target and predicted
606-
probabilities, then precision is a list of arrays, where each array
607-
corresponds to the precision scores for a class.
608-
recall : numpy.ndarray or list of numpy.ndarray
609-
Recall scores where element i is the recall score corresponding to
610-
the threshold i. If state is a tuple of the target and predicted
611-
probabilities, then recall is a list of arrays, where each array
612-
corresponds to the recall scores for a class.
613-
thresholds : numpy.ndarray or list of numpy.ndarray
614-
Thresholds used for computing the precision and recall scores.
602+
PRcurve
603+
A named tuple containing the precision, recall, and thresholds.
604+
Precision and recall are arrays where element i is the precision and
605+
recall score corresponding to threshold i. If state is a tuple of the
606+
target and predicted probabilities, then precision and recall are lists
607+
of arrays, where each array corresponds to the precision and recall
608+
scores for a class.
615609
616610
Examples
617611
--------
@@ -652,11 +646,12 @@ def multiclass_precision_recall_curve(
652646
thresholds=thresholds,
653647
)
654648

655-
return _multiclass_precision_recall_curve_compute(
649+
precision_, recall_, thresholds_ = _multiclass_precision_recall_curve_compute(
656650
state,
657651
thresholds, # type: ignore
658652
num_classes,
659653
)
654+
return PRCurve(precision_, recall_, thresholds_)
660655

661656

662657
def _multilabel_precision_recall_curve_format(
@@ -868,14 +863,7 @@ def multilabel_precision_recall_curve(
868863
preds: npt.ArrayLike,
869864
num_labels: int,
870865
thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None,
871-
) -> Union[
872-
Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
873-
Tuple[
874-
List[npt.NDArray[np.float_]],
875-
List[npt.NDArray[np.float_]],
876-
List[npt.NDArray[np.float_]],
877-
],
878-
]:
866+
) -> PRCurve:
879867
"""Compute the precision-recall curve for multilabel input.
880868
881869
Parameters
@@ -897,16 +885,18 @@ def multilabel_precision_recall_curve(
897885
898886
Returns
899887
-------
900-
precision : numpy.ndarray or List[numpy.ndarray]
888+
PRCurve
889+
A named tuple with the following:
890+
- ``precision``: numpy.ndarray or List[numpy.ndarray].
901891
Precision values for each label. If ``thresholds`` is None, then
902892
precision is a list of arrays, one for each label. Otherwise,
903893
precision is a single array with shape
904894
(``num_labels``, len(``thresholds``)).
905-
recall : numpy.ndarray or List[numpy.ndarray]
895+
- ``recall``: numpy.ndarray or List[numpy.ndarray].
906896
Recall values for each label. If ``thresholds`` is None, then
907897
recall is a list of arrays, one for each label. Otherwise,
908898
recall is a single array with shape (``num_labels``, len(``thresholds``)).
909-
thresholds : numpy.ndarray or List[numpy.ndarray]
899+
- ``thresholds``: numpy.ndarray or List[numpy.ndarray].
910900
If ``thresholds`` is None, then thresholds is a list of arrays, one for
911901
each label. Otherwise, thresholds is a single array with shape
912902
(len(``thresholds``,).
@@ -950,11 +940,12 @@ def multilabel_precision_recall_curve(
950940
thresholds=thresholds,
951941
)
952942

953-
return _multilabel_precision_recall_curve_compute(
943+
precision_, recall_, thresholds_ = _multilabel_precision_recall_curve_compute(
954944
state,
955945
thresholds, # type: ignore
956946
num_labels,
957947
)
948+
return PRCurve(precision_, recall_, thresholds_)
958949

959950

960951
def precision_recall_curve(
@@ -965,14 +956,7 @@ def precision_recall_curve(
965956
pos_label: int = 1,
966957
num_classes: Optional[int] = None,
967958
num_labels: Optional[int] = None,
968-
) -> Union[
969-
Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
970-
Tuple[
971-
List[npt.NDArray[np.float_]],
972-
List[npt.NDArray[np.float_]],
973-
List[npt.NDArray[np.float_]],
974-
],
975-
]:
959+
) -> PRCurve:
976960
"""Compute the precision-recall curve for different tasks/input types.
977961
978962
Parameters
@@ -997,17 +981,19 @@ def precision_recall_curve(
997981
998982
Returns
999983
-------
1000-
precision : numpy.ndarray
984+
PRCurve
985+
A named tuple with the following:
986+
- ``precision``: numpy.ndarray or List[numpy.ndarray].
1001987
The precision scores where ``precision[i]`` is the precision score for
1002-
``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel',
988+
``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilabel',
1003989
then ``precision`` is a list of numpy arrays, where ``precision[i]`` is the
1004990
precision scores for class or label ``i``.
1005-
recall : numpy.ndarray
991+
- ``recall``: numpy.ndarray or List[numpy.ndarray].
1006992
The recall scores where ``recall[i]`` is the recall score for ``scores >=
1007993
thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel', then
1008994
``recall`` is a list of numpy arrays, where ``recall[i]`` is the recall
1009995
scores for class or label ``i``.
1010-
thresholds : numpy.ndarray
996+
- ``thresholds``: numpy.ndarray or List[numpy.ndarray].
1011997
Thresholds used for computing the precision and recall scores.
1012998
1013999
Raises

0 commit comments

Comments
 (0)