1
1
"""Functions for computing the precision-recall curve for different input types."""
2
2
3
- from typing import Any , List , Literal , Optional , Tuple , Union
3
+ from typing import Any , List , Literal , NamedTuple , Optional , Tuple , Union
4
4
5
5
import numpy as np
6
6
import numpy .typing as npt
15
15
)
16
16
17
17
18
+ class PRCurve (NamedTuple ):
19
+ """Named tuple with Precision-Recall curve (Precision, Recall and thresholds)."""
20
+
21
+ precision : Union [npt .NDArray [np .float_ ], List [npt .NDArray [np .float_ ]]]
22
+ recall : Union [npt .NDArray [np .float_ ], List [npt .NDArray [np .float_ ]]]
23
+ thresholds : Union [npt .NDArray [np .float_ ], List [npt .NDArray [np .float_ ]]]
24
+
25
+
18
26
def _format_thresholds (
19
27
thresholds : Optional [Union [int , List [float ], npt .NDArray [np .float_ ]]] = None ,
20
28
) -> Optional [npt .NDArray [np .float_ ]]:
@@ -279,7 +287,7 @@ def binary_precision_recall_curve(
279
287
preds : npt .ArrayLike ,
280
288
thresholds : Optional [Union [int , List [float ], npt .NDArray [np .float_ ]]] = None ,
281
289
pos_label : int = 1 ,
282
- ) -> Tuple [ npt . NDArray [ np . float_ ], npt . NDArray [ np . float_ ], npt . NDArray [ np . float_ ]] :
290
+ ) -> PRCurve :
283
291
"""Compute precision-recall curve for binary input.
284
292
285
293
Parameters
@@ -301,13 +309,10 @@ def binary_precision_recall_curve(
301
309
302
310
Returns
303
311
-------
304
- precision : numpy.ndarray
305
- Precision scores such that element i is the precision of predictions
306
- with score >= thresholds[i].
307
- recall : numpy.ndarray
308
- Recall scores in descending order.
309
- thresholds : numpy.ndarray
310
- Thresholds used for computing the precision and recall scores.
312
+ PRCurve
313
+ A named tuple containing the precision (element i is the precision of predictions
314
+ with score >= thresholds[i]), recall (scores in descending order)
315
+ and thresholds used to compute the precision-recall curve.
311
316
312
317
Examples
313
318
--------
@@ -335,13 +340,14 @@ def binary_precision_recall_curve(
335
340
thresholds = _format_thresholds (thresholds )
336
341
337
342
state = _binary_precision_recall_curve_update (target , preds , thresholds )
338
-
339
- return _binary_precision_recall_curve_compute (
343
+ precision_ , recall_ , thresholds_ = _binary_precision_recall_curve_compute (
340
344
state ,
341
345
thresholds ,
342
346
pos_label = pos_label ,
343
347
)
344
348
349
+ return PRCurve (precision_ , recall_ , thresholds_ )
350
+
345
351
346
352
def _multiclass_precision_recall_curve_format (
347
353
target : npt .ArrayLike ,
@@ -572,14 +578,7 @@ def multiclass_precision_recall_curve(
572
578
preds : npt .ArrayLike ,
573
579
num_classes : int ,
574
580
thresholds : Optional [Union [int , List [float ], npt .NDArray [np .float_ ]]] = None ,
575
- ) -> Union [
576
- Tuple [npt .NDArray [np .float_ ], npt .NDArray [np .float_ ], npt .NDArray [np .float_ ]],
577
- Tuple [
578
- List [npt .NDArray [np .float_ ]],
579
- List [npt .NDArray [np .float_ ]],
580
- List [npt .NDArray [np .float_ ]],
581
- ],
582
- ]:
581
+ ) -> PRCurve :
583
582
"""Compute the precision-recall curve for multiclass problems.
584
583
585
584
Parameters
@@ -600,18 +599,13 @@ def multiclass_precision_recall_curve(
600
599
601
600
Returns
602
601
-------
603
- precision : numpy.ndarray or list of numpy.ndarray
604
- Precision scores where element i is the precision score corresponding
605
- to the threshold i. If state is a tuple of the target and predicted
606
- probabilities, then precision is a list of arrays, where each array
607
- corresponds to the precision scores for a class.
608
- recall : numpy.ndarray or list of numpy.ndarray
609
- Recall scores where element i is the recall score corresponding to
610
- the threshold i. If state is a tuple of the target and predicted
611
- probabilities, then recall is a list of arrays, where each array
612
- corresponds to the recall scores for a class.
613
- thresholds : numpy.ndarray or list of numpy.ndarray
614
- Thresholds used for computing the precision and recall scores.
602
+ PRcurve
603
+ A named tuple containing the precision, recall, and thresholds.
604
+ Precision and recall are arrays where element i is the precision and
605
+ recall score corresponding to threshold i. If state is a tuple of the
606
+ target and predicted probabilities, then precision and recall are lists
607
+ of arrays, where each array corresponds to the precision and recall
608
+ scores for a class.
615
609
616
610
Examples
617
611
--------
@@ -652,11 +646,12 @@ def multiclass_precision_recall_curve(
652
646
thresholds = thresholds ,
653
647
)
654
648
655
- return _multiclass_precision_recall_curve_compute (
649
+ precision_ , recall_ , thresholds_ = _multiclass_precision_recall_curve_compute (
656
650
state ,
657
651
thresholds , # type: ignore
658
652
num_classes ,
659
653
)
654
+ return PRCurve (precision_ , recall_ , thresholds_ )
660
655
661
656
662
657
def _multilabel_precision_recall_curve_format (
@@ -868,14 +863,7 @@ def multilabel_precision_recall_curve(
868
863
preds : npt .ArrayLike ,
869
864
num_labels : int ,
870
865
thresholds : Optional [Union [int , List [float ], npt .NDArray [np .float_ ]]] = None ,
871
- ) -> Union [
872
- Tuple [npt .NDArray [np .float_ ], npt .NDArray [np .float_ ], npt .NDArray [np .float_ ]],
873
- Tuple [
874
- List [npt .NDArray [np .float_ ]],
875
- List [npt .NDArray [np .float_ ]],
876
- List [npt .NDArray [np .float_ ]],
877
- ],
878
- ]:
866
+ ) -> PRCurve :
879
867
"""Compute the precision-recall curve for multilabel input.
880
868
881
869
Parameters
@@ -897,16 +885,18 @@ def multilabel_precision_recall_curve(
897
885
898
886
Returns
899
887
-------
900
- precision : numpy.ndarray or List[numpy.ndarray]
888
+ PRCurve
889
+ A named tuple with the following:
890
+ - ``precision``: numpy.ndarray or List[numpy.ndarray].
901
891
Precision values for each label. If ``thresholds`` is None, then
902
892
precision is a list of arrays, one for each label. Otherwise,
903
893
precision is a single array with shape
904
894
(``num_labels``, len(``thresholds``)).
905
- recall : numpy.ndarray or List[numpy.ndarray]
895
+ - ``recall`` : numpy.ndarray or List[numpy.ndarray].
906
896
Recall values for each label. If ``thresholds`` is None, then
907
897
recall is a list of arrays, one for each label. Otherwise,
908
898
recall is a single array with shape (``num_labels``, len(``thresholds``)).
909
- thresholds : numpy.ndarray or List[numpy.ndarray]
899
+ - ``thresholds`` : numpy.ndarray or List[numpy.ndarray].
910
900
If ``thresholds`` is None, then thresholds is a list of arrays, one for
911
901
each label. Otherwise, thresholds is a single array with shape
912
902
(len(``thresholds``,).
@@ -950,11 +940,12 @@ def multilabel_precision_recall_curve(
950
940
thresholds = thresholds ,
951
941
)
952
942
953
- return _multilabel_precision_recall_curve_compute (
943
+ precision_ , recall_ , thresholds_ = _multilabel_precision_recall_curve_compute (
954
944
state ,
955
945
thresholds , # type: ignore
956
946
num_labels ,
957
947
)
948
+ return PRCurve (precision_ , recall_ , thresholds_ )
958
949
959
950
960
951
def precision_recall_curve (
@@ -965,14 +956,7 @@ def precision_recall_curve(
965
956
pos_label : int = 1 ,
966
957
num_classes : Optional [int ] = None ,
967
958
num_labels : Optional [int ] = None ,
968
- ) -> Union [
969
- Tuple [npt .NDArray [np .float_ ], npt .NDArray [np .float_ ], npt .NDArray [np .float_ ]],
970
- Tuple [
971
- List [npt .NDArray [np .float_ ]],
972
- List [npt .NDArray [np .float_ ]],
973
- List [npt .NDArray [np .float_ ]],
974
- ],
975
- ]:
959
+ ) -> PRCurve :
976
960
"""Compute the precision-recall curve for different tasks/input types.
977
961
978
962
Parameters
@@ -997,17 +981,19 @@ def precision_recall_curve(
997
981
998
982
Returns
999
983
-------
1000
- precision : numpy.ndarray
984
+ PRCurve
985
+ A named tuple with the following:
986
+ - ``precision``: numpy.ndarray or List[numpy.ndarray].
1001
987
The precision scores where ``precision[i]`` is the precision score for
1002
- ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel ',
988
+ ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilabel ',
1003
989
then ``precision`` is a list of numpy arrays, where ``precision[i]`` is the
1004
990
precision scores for class or label ``i``.
1005
- recall : numpy.ndarray
991
+ - ``recall`` : numpy.ndarray or List[numpy.ndarray].
1006
992
The recall scores where ``recall[i]`` is the recall score for ``scores >=
1007
993
thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel', then
1008
994
``recall`` is a list of numpy arrays, where ``recall[i]`` is the recall
1009
995
scores for class or label ``i``.
1010
- thresholds : numpy.ndarray
996
+ - ``thresholds`` : numpy.ndarray or List[numpy.ndarray].
1011
997
Thresholds used for computing the precision and recall scores.
1012
998
1013
999
Raises
0 commit comments