diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index 1a2e5902b..fb8e19054 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -52,11 +52,13 @@ multilabel_tpr, ) from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + PRCurve, binary_precision_recall_curve, multiclass_precision_recall_curve, multilabel_precision_recall_curve, ) from cyclops.evaluate.metrics.experimental.functional.roc import ( + ROCCurve, binary_roc, multiclass_roc, multilabel_roc, diff --git a/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py index d484a2b1b..7fe5dc968 100644 --- a/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py @@ -1,6 +1,6 @@ """Functions for computing the precision and recall for different unique thresholds.""" from types import ModuleType -from typing import Any, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union import array_api_compat as apc import numpy as np @@ -28,6 +28,14 @@ ) +class PRCurve(NamedTuple): + """Named tuple with Precision-Recall curve (Precision, Recall and thresholds).""" + + precision: Union[Array, List[Array]] + recall: Union[Array, List[Array]] + thresholds: Union[Array, List[Array]] + + def _validate_thresholds(thresholds: Optional[Union[int, List[float], Array]]) -> None: """Validate the `thresholds` argument.""" if thresholds is not None and not ( @@ -352,14 +360,13 @@ def binary_precision_recall_curve( Returns ------- - precision : Array - The precision values for all unique thresholds. The shape of the array is + PRCurve + A named tuple that contains the following elements: + - `precision` values for all unique thresholds. The shape of the array is `(num_thresholds + 1,)`. - recall : Array - The recall values for all unique thresholds. The shape of the array is + - `recall` values for all unique thresholds. The shape of the array is `(num_thresholds + 1,)`. - thresholds : Array - The thresholds used for computing the precision and recall values, in + - `thresholds` used for computing the precision and recall values, in ascending order. The shape of the array is `(num_thresholds,)`. Raises @@ -688,7 +695,7 @@ def multiclass_precision_recall_curve( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["macro", "micro", "none"]] = None, ignore_index: Optional[Union[int, Tuple[int]]] = None, -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: +) -> PRCurve: """Compute the precision and recall for all unique thresholds. Parameters @@ -730,18 +737,17 @@ def multiclass_precision_recall_curve( Returns ------- - precision : Array or List[Array] - The precision values for all unique thresholds. If `thresholds` is `"none"` + PRCurve + A named tuple that contains the following elements: + - `precision` values for all unique thresholds. If `thresholds` is `"none"` or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape `(num_thresholds + 1, num_classes)` is returned. - recall : Array or List[Array] - The recall values for all unique thresholds. If `thresholds` is `"none"` + - `recall` values for all unique thresholds. If `thresholds` is `"none"` or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape `(num_thresholds + 1, num_classes)` is returned. - thresholds : Array or List[Array] - The thresholds used for computing the precision and recall values, in + - `thresholds` used for computing the precision and recall values, in ascending order. If `thresholds` is `"none"` or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of shape `(num_thresholds,)` is returned. @@ -868,12 +874,13 @@ class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, average, xp=xp, ) - return _multiclass_precision_recall_curve_compute( + precision, recall, thresholds_ = _multiclass_precision_recall_curve_compute( state, num_classes, thresholds=thresholds, average=average, ) + return PRCurve(precision, recall, thresholds_) def _multilabel_precision_recall_curve_validate_args( @@ -1035,7 +1042,7 @@ def multilabel_precision_recall_curve( num_labels: int, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: +) -> PRCurve: """Compute the precision and recall for all unique thresholds. Parameters @@ -1067,21 +1074,20 @@ def multilabel_precision_recall_curve( Returns ------- - precision : Array or List[Array] - The precision values for all unique thresholds. If `thresholds` is `None`, - a list for each label is returned with 1-D Arrays of shape + PRCurve + A named tuple that contains the following elements: + - `precision` values for all unique thresholds. If `thresholds` is `"none"` + or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_labels)` is returned. - recall : Array or List[Array] - The recall values for all unique thresholds. If `thresholds` is `None`, - a list for each label is returned with 1-D Arrays of shape + `(num_thresholds + 1, num_classes)` is returned. + - `recall` values for all unique thresholds. If `thresholds` is `"none"` + or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_labels)` is returned. - thresholds : Array or List[Array] - The thresholds used for computing the precision and recall values, in - ascending order. If `thresholds` is `None`, a list for each label is - returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D - Array of shape `(num_thresholds,)` is returned. + `(num_thresholds + 1, num_classes)` is returned. + - `thresholds` used for computing the precision and recall values, in + ascending order. If `thresholds` is `"none"` or `None`, a list for each + class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, + a 1-D Array of shape `(num_thresholds,)` is returned. Raises ------ @@ -1193,9 +1199,10 @@ def multilabel_precision_recall_curve( thresholds, xp=xp, ) - return _multilabel_precision_recall_curve_compute( + precision, recall, thresholds_ = _multilabel_precision_recall_curve_compute( state, num_labels, thresholds, ignore_index, ) + return PRCurve(precision, recall, thresholds_) diff --git a/cyclops/evaluate/metrics/experimental/functional/roc.py b/cyclops/evaluate/metrics/experimental/functional/roc.py index 736696195..a9b7c9aac 100644 --- a/cyclops/evaluate/metrics/experimental/functional/roc.py +++ b/cyclops/evaluate/metrics/experimental/functional/roc.py @@ -1,6 +1,6 @@ """Functions for computing Receiver Operating Characteristic (ROC) curves.""" import warnings -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, NamedTuple, Optional, Tuple, Union import array_api_compat as apc @@ -28,6 +28,14 @@ from cyclops.evaluate.metrics.experimental.utils.types import Array +class ROCCurve(NamedTuple): + """Named tuple to store ROC curve (FPR, TPR and thresholds).""" + + fpr: Union[Array, List[Array]] + tpr: Union[Array, List[Array]] + thresholds: Union[Array, List[Array]] + + def _binary_roc_compute( state: Union[Array, Tuple[Array, Array]], thresholds: Optional[Array], @@ -91,7 +99,7 @@ def binary_roc( preds: Array, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, -) -> Tuple[Array, Array, Array]: +) -> ROCCurve: """Compute the receiver operating characteristic (ROC) curve for binary tasks. Parameters @@ -120,15 +128,11 @@ def binary_roc( Returns ------- - fpr : Array - The false positive rates for all unique thresholds. The shape of the array is - `(num_thresholds + 1,)`. - tpr : Array - The true positive rates for all unique thresholds. The shape of the array is - `(num_thresholds + 1,)`. - thresholds : Array - The thresholds used for computing the ROC curve values, in descending order. - The shape of the array is `(num_thresholds,)`. + ROCCurve + A named tuple containing the false positive rate (FPR), true positive rate + (TPR) and thresholds. The FPR and TPR are arrays of shape + `(num_thresholds + 1,)` and the thresholds are an array of shape + `(num_thresholds,)`. Raises ------ @@ -209,7 +213,8 @@ def binary_roc( xp=xp, ) state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp) - return _binary_roc_compute(state, thresholds) + fpr, tpr, thresh = _binary_roc_compute(state, thresholds) + return ROCCurve(fpr, tpr, thresh) def _multiclass_roc_compute( @@ -277,7 +282,7 @@ def multiclass_roc( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["macro", "micro", "none"]] = None, ignore_index: Optional[Union[int, Tuple[int]]] = None, -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: +) -> ROCCurve: """Compute the receiver operating characteristic (ROC) curve for multiclass tasks. Parameters @@ -318,19 +323,13 @@ def multiclass_roc( Returns ------- - fpr : Array or List[Array] - The false positive rates for all unique thresholds. If `thresholds` is `"none"` - or `None`, a list for each class is returned with 1-D Arrays of shape - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + ROCCurve + A named tuple that contains the false positive rate, true positive rate, + and the thresholds used for computing the ROC curve. If `thresholds` is `"none"` + or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays + of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape `(num_thresholds + 1, num_classes)` is returned. - tpr : Array or List[Array] - The true positive rates for all unique thresholds. If `thresholds` is `"none"` - or `None`, a list for each class is returned with 1-D Arrays of shape - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_classes)` is returned. - thresholds : Array or List[Array] - The thresholds used for computing the ROC curve values, in descending order. - If `thresholds` is `"none"` or `None`, a list for each class is returned + Similarly, a list of thresholds for each class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of shape `(num_thresholds,)` is returned. @@ -455,12 +454,13 @@ def multiclass_roc( average, xp=xp, ) - return _multiclass_roc_compute( + fpr_, tpr_, thresholds_ = _multiclass_roc_compute( state, num_classes, thresholds=thresholds, average=average, ) + return ROCCurve(fpr_, tpr_, thresholds_) def _multilabel_roc_compute( @@ -504,7 +504,7 @@ def multilabel_roc( num_labels: int, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: +) -> ROCCurve: """Compute the receiver operating characteristic (ROC) curve for multilabel tasks. Parameters @@ -535,21 +535,15 @@ def multilabel_roc( Returns ------- - fpr : Array or List[Array] - The false positive rates for all unique thresholds. If `thresholds` is `None`, - a list for each label is returned with 1-D Arrays of shape - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_labels)` is returned. - tpr : Array or List[Array] - The true positive rates for all unique thresholds. If `thresholds` is `None`, - a list for each label is returned with 1-D Arrays of shape - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_labels)` is returned. - thresholds : Array or List[Array] - The thresholds used for computing the ROC curve values, in - descending order. If `thresholds` is `None`, a list for each label is - returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D - Array of shape `(num_thresholds,)` is returned. + ROCCurve + A named tuple that contains the false positive rate, true positive rate, + and the thresholds used for computing the ROC curve. If `thresholds` is `"none"` + or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays + of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_classes)` is returned. + Similarly, a list of thresholds for each class is returned + with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of + shape `(num_thresholds,)` is returned. Raises ------ @@ -660,9 +654,10 @@ def multilabel_roc( thresholds, xp=xp, ) - return _multilabel_roc_compute( + fpr_, tpr_, thresholds_ = _multilabel_roc_compute( state, num_labels, thresholds, ignore_index, ) + return ROCCurve(fpr_, tpr_, thresholds_) diff --git a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py index 36181a0eb..6c486310d 100644 --- a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py @@ -5,6 +5,7 @@ import array_api_compat as apc from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + PRCurve, _binary_precision_recall_curve_compute, _binary_precision_recall_curve_format_arrays, _binary_precision_recall_curve_update, @@ -54,14 +55,14 @@ class BinaryPrecisionRecallCurve(Metric, registry_key="binary_precision_recall_c >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryPrecisionRecallCurve(thresholds=None) >>> metric(target, preds) - (Array([0.5 , 0.6 , 0.5 , 0.6666667, - 0.5 , 1. , 1. ], dtype=float32), Array([1. , 1. , 0.6666667 , 0.6666667 , - 0.33333334, 0.33333334, 0. ], dtype=float32), Array([0.11, 0.22, 0.33, 0.73, 0.84, 0.92], dtype=float64)) + PRCurve(precision=Array([0.5 , 0.6 , 0.5 , 0.6666667, + 0.5 , 1. , 1. ], dtype=float32), recall=Array([1. , 1. , 0.6666667 , 0.6666667 , + 0.33333334, 0.33333334, 0. ], dtype=float32), thresholds=Array([0.11, 0.22, 0.33, 0.73, 0.84, 0.92], dtype=float64)) >>> metric = BinaryPrecisionRecallCurve(thresholds=5) >>> metric(target, preds) - (Array([0.5 , 0.5 , 0.6666667, 0.5 , - 0. , 1. ], dtype=float32), Array([1. , 0.6666667 , 0.6666667 , 0.33333334, - 0. , 0. ], dtype=float32), Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32)) + PRCurve(precision=Array([0.5 , 0.5 , 0.6666667, 0.5 , + 0. , 1. ], dtype=float32), recall=Array([1. , 0.6666667 , 0.6666667 , 0.33333334, + 0. , 0. ], dtype=float32), thresholds=Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32)) """ # noqa: W505 @@ -140,14 +141,18 @@ def _update_state(self, target: Array, preds: Array) -> None: self.target.append(state[0]) # type: ignore[attr-defined] self.preds.append(state[1]) # type: ignore[attr-defined] - def _compute_metric(self) -> Tuple[Array, Array, Array]: + def _compute_metric(self) -> PRCurve: """Compute the metric.""" state = ( (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] if self.thresholds is None else self.confmat # type: ignore[attr-defined] ) - return _binary_precision_recall_curve_compute(state, self.thresholds) # type: ignore[arg-type] + precision, recall, thresholds = _binary_precision_recall_curve_compute( + state, + self.thresholds, # type: ignore + ) + return PRCurve(precision, recall, thresholds) class MulticlassPrecisionRecallCurve( diff --git a/cyclops/evaluate/metrics/experimental/roc.py b/cyclops/evaluate/metrics/experimental/roc.py index c3b6de9a5..766f1c9f5 100644 --- a/cyclops/evaluate/metrics/experimental/roc.py +++ b/cyclops/evaluate/metrics/experimental/roc.py @@ -2,6 +2,7 @@ from typing import List, Tuple, Union from cyclops.evaluate.metrics.experimental.functional.roc import ( + ROCCurve, _binary_roc_compute, _multiclass_roc_compute, _multilabel_roc_compute, @@ -42,26 +43,27 @@ class BinaryROC(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve"): >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryROC(thresholds=None) >>> metric(target, preds) - (Array([0. , 0. , 0.33333334, 0.33333334, - 0.6666667 , 0.6666667 , 1. ], dtype=float32), Array([0. , 0.33333334, 0.33333334, 0.6666667 , - 0.6666667 , 1. , 1. ], dtype=float32), Array([1. , 0.92, 0.84, 0.73, 0.33, 0.22, 0.11], dtype=float64)) + ROCCurve(fpr=Array([0. , 0. , 0.33333334, 0.33333334, + 0.6666667 , 0.6666667 , 1. ], dtype=float32), tpr=Array([0. , 0.33333334, 0.33333334, 0.6666667 , + 0.6666667 , 1. , 1. ], dtype=float32), thresholds=Array([1. , 0.92, 0.84, 0.73, 0.33, 0.22, 0.11], dtype=float64)) >>> metric = BinaryROC(thresholds=5) >>> metric(target, preds) - (Array([0. , 0.33333334, 0.33333334, 0.6666667 , - 1. ], dtype=float32), Array([0. , 0.33333334, 0.6666667 , 0.6666667 , - 1. ], dtype=float32), Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32)) + ROCCurve(fpr=Array([0. , 0.33333334, 0.33333334, 0.6666667 , + 1. ], dtype=float32), tpr=Array([0. , 0.33333334, 0.6666667 , 0.6666667 , + 1. ], dtype=float32), thresholds=Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32)) """ # noqa: W505 name: str = "ROC Curve" - def _compute_metric(self) -> Tuple[Array, Array, Array]: + def _compute_metric(self) -> ROCCurve: # type: ignore state = ( (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] if self.thresholds is None else self.confmat # type: ignore[attr-defined] ) - return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] + fpr, tpr, thresholds = _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] + return ROCCurve(fpr, tpr, thresholds) class MulticlassROC( diff --git a/cyclops/evaluate/metrics/functional/__init__.py b/cyclops/evaluate/metrics/functional/__init__.py index 14eee5b9e..ed64063d1 100644 --- a/cyclops/evaluate/metrics/functional/__init__.py +++ b/cyclops/evaluate/metrics/functional/__init__.py @@ -37,12 +37,14 @@ recall, ) from cyclops.evaluate.metrics.functional.precision_recall_curve import ( # noqa: F401 + PRCurve, binary_precision_recall_curve, multiclass_precision_recall_curve, multilabel_precision_recall_curve, precision_recall_curve, ) from cyclops.evaluate.metrics.functional.roc import ( # noqa: F401 + ROCCurve, binary_roc_curve, multiclass_roc_curve, multilabel_roc_curve, diff --git a/cyclops/evaluate/metrics/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/functional/precision_recall_curve.py index bbe0bcc78..a0f9b69e3 100644 --- a/cyclops/evaluate/metrics/functional/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/functional/precision_recall_curve.py @@ -1,6 +1,6 @@ """Functions for computing the precision-recall curve for different input types.""" -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, NamedTuple, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -15,6 +15,14 @@ ) +class PRCurve(NamedTuple): + """Named tuple with Precision-Recall curve (Precision, Recall and thresholds).""" + + precision: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + recall: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + thresholds: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + + def _format_thresholds( thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, ) -> Optional[npt.NDArray[np.float_]]: @@ -279,7 +287,7 @@ def binary_precision_recall_curve( preds: npt.ArrayLike, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, pos_label: int = 1, -) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: +) -> PRCurve: """Compute precision-recall curve for binary input. Parameters @@ -301,13 +309,10 @@ def binary_precision_recall_curve( Returns ------- - precision : numpy.ndarray - Precision scores such that element i is the precision of predictions - with score >= thresholds[i]. - recall : numpy.ndarray - Recall scores in descending order. - thresholds : numpy.ndarray - Thresholds used for computing the precision and recall scores. + PRCurve + A named tuple containing the precision (element i is the precision of predictions + with score >= thresholds[i]), recall (scores in descending order) + and thresholds used to compute the precision-recall curve. Examples -------- @@ -335,13 +340,14 @@ def binary_precision_recall_curve( thresholds = _format_thresholds(thresholds) state = _binary_precision_recall_curve_update(target, preds, thresholds) - - return _binary_precision_recall_curve_compute( + precision_, recall_, thresholds_ = _binary_precision_recall_curve_compute( state, thresholds, pos_label=pos_label, ) + return PRCurve(precision_, recall_, thresholds_) + def _multiclass_precision_recall_curve_format( target: npt.ArrayLike, @@ -572,14 +578,7 @@ def multiclass_precision_recall_curve( preds: npt.ArrayLike, num_classes: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> PRCurve: """Compute the precision-recall curve for multiclass problems. Parameters @@ -600,18 +599,13 @@ def multiclass_precision_recall_curve( Returns ------- - precision : numpy.ndarray or list of numpy.ndarray - Precision scores where element i is the precision score corresponding - to the threshold i. If state is a tuple of the target and predicted - probabilities, then precision is a list of arrays, where each array - corresponds to the precision scores for a class. - recall : numpy.ndarray or list of numpy.ndarray - Recall scores where element i is the recall score corresponding to - the threshold i. If state is a tuple of the target and predicted - probabilities, then recall is a list of arrays, where each array - corresponds to the recall scores for a class. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used for computing the precision and recall scores. + PRcurve + A named tuple containing the precision, recall, and thresholds. + Precision and recall are arrays where element i is the precision and + recall score corresponding to threshold i. If state is a tuple of the + target and predicted probabilities, then precision and recall are lists + of arrays, where each array corresponds to the precision and recall + scores for a class. Examples -------- @@ -652,11 +646,12 @@ def multiclass_precision_recall_curve( thresholds=thresholds, ) - return _multiclass_precision_recall_curve_compute( + precision_, recall_, thresholds_ = _multiclass_precision_recall_curve_compute( state, thresholds, # type: ignore num_classes, ) + return PRCurve(precision_, recall_, thresholds_) def _multilabel_precision_recall_curve_format( @@ -868,14 +863,7 @@ def multilabel_precision_recall_curve( preds: npt.ArrayLike, num_labels: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> PRCurve: """Compute the precision-recall curve for multilabel input. Parameters @@ -897,16 +885,18 @@ def multilabel_precision_recall_curve( Returns ------- - precision : numpy.ndarray or List[numpy.ndarray] + PRCurve + A named tuple with the following: + - ``precision``: numpy.ndarray or List[numpy.ndarray]. Precision values for each label. If ``thresholds`` is None, then precision is a list of arrays, one for each label. Otherwise, precision is a single array with shape (``num_labels``, len(``thresholds``)). - recall : numpy.ndarray or List[numpy.ndarray] + - ``recall``: numpy.ndarray or List[numpy.ndarray]. Recall values for each label. If ``thresholds`` is None, then recall is a list of arrays, one for each label. Otherwise, recall is a single array with shape (``num_labels``, len(``thresholds``)). - thresholds : numpy.ndarray or List[numpy.ndarray] + - ``thresholds``: numpy.ndarray or List[numpy.ndarray]. If ``thresholds`` is None, then thresholds is a list of arrays, one for each label. Otherwise, thresholds is a single array with shape (len(``thresholds``,). @@ -950,11 +940,12 @@ def multilabel_precision_recall_curve( thresholds=thresholds, ) - return _multilabel_precision_recall_curve_compute( + precision_, recall_, thresholds_ = _multilabel_precision_recall_curve_compute( state, thresholds, # type: ignore num_labels, ) + return PRCurve(precision_, recall_, thresholds_) def precision_recall_curve( @@ -965,14 +956,7 @@ def precision_recall_curve( pos_label: int = 1, num_classes: Optional[int] = None, num_labels: Optional[int] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> PRCurve: """Compute the precision-recall curve for different tasks/input types. Parameters @@ -997,17 +981,19 @@ def precision_recall_curve( Returns ------- - precision : numpy.ndarray + PRCurve + A named tuple with the following: + - ``precision``: numpy.ndarray or List[numpy.ndarray]. The precision scores where ``precision[i]`` is the precision score for - ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel', + ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilabel', then ``precision`` is a list of numpy arrays, where ``precision[i]`` is the precision scores for class or label ``i``. - recall : numpy.ndarray + - ``recall``: numpy.ndarray or List[numpy.ndarray]. The recall scores where ``recall[i]`` is the recall score for ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel', then ``recall`` is a list of numpy arrays, where ``recall[i]`` is the recall scores for class or label ``i``. - thresholds : numpy.ndarray + - ``thresholds``: numpy.ndarray or List[numpy.ndarray]. Thresholds used for computing the precision and recall scores. Raises diff --git a/cyclops/evaluate/metrics/functional/roc.py b/cyclops/evaluate/metrics/functional/roc.py index 24eb1dd16..b8935ded5 100644 --- a/cyclops/evaluate/metrics/functional/roc.py +++ b/cyclops/evaluate/metrics/functional/roc.py @@ -1,6 +1,6 @@ """Functions for computing the receiver operating characteristic (ROC) curve.""" import logging -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, NamedTuple, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -23,6 +23,14 @@ setup_logging(print_level="WARN", logger=LOGGER) +class ROCCurve(NamedTuple): + """Named tuple to store ROC curve (FPR, TPR and thresholds).""" + + fpr: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + tpr: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + thresholds: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + + def _roc_compute_from_confmat( confmat: npt.NDArray[Any], thresholds: npt.NDArray[np.float_], @@ -144,7 +152,7 @@ def binary_roc_curve( preds: npt.ArrayLike, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, pos_label: int = 1, -) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: +) -> ROCCurve: """Compute the ROC curve for binary classification tasks. Parameters @@ -166,12 +174,9 @@ def binary_roc_curve( Returns ------- - fpr : numpy.ndarray - False positive rate. - tpr : numpy.ndarray - True positive rate. - thresholds : numpy.ndarray - Thresholds used to compute fpr and tpr. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. Examples -------- @@ -197,8 +202,9 @@ def binary_roc_curve( thresholds = _format_thresholds(thresholds) state = _binary_precision_recall_curve_update(target, preds, thresholds) + fpr, tpr, thresholds = _binary_roc_compute(state, thresholds, pos_label) - return _binary_roc_compute(state, thresholds=thresholds, pos_label=pos_label) + return ROCCurve(fpr, tpr, thresholds) def _multiclass_roc_compute( @@ -272,14 +278,7 @@ def multiclass_roc_curve( preds: npt.ArrayLike, num_classes: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> ROCCurve: """Compute the ROC curve for multiclass classification tasks. Parameters @@ -301,17 +300,11 @@ def multiclass_roc_curve( Returns ------- - fpr : numpy.ndarray or list of numpy.ndarray - False positive rate. If ``threshold`` is not None, ``fpr`` is a 1d numpy - array. Otherwise, ``fpr`` is a list of 1d numpy arrays, one for each - class. - tpr : numpy.ndarray or list of numpy.ndarray - True positive rate. If ``threshold`` is not None, ``tpr`` is a 1d numpy - array. Otherwise, ``tpr`` is a list of 1d numpy arrays, one for each class. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used to compute fpr and tpr. ``threshold`` is not None, - thresholds is a 1d numpy array. Otherwise, thresholds is a list of - 1d numpy arrays, one for each class. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. If ``threshold`` is not None, + ``fpr``, ``tpr`` and ``thresholds`` are 1d numpy arrays, else they are lists + of 1d numpy arrays, one for each label. Examples -------- @@ -352,8 +345,9 @@ def multiclass_roc_curve( num_classes=num_classes, thresholds=thresholds, ) + fpr_, tpr_, thresholds_ = _multiclass_roc_compute(state, num_classes, thresholds) - return _multiclass_roc_compute(state, num_classes, thresholds) + return ROCCurve(fpr=fpr_, tpr=tpr_, thresholds=thresholds_) def _multilabel_roc_compute( @@ -427,14 +421,7 @@ def multilabel_roc_curve( preds: npt.ArrayLike, num_labels: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> ROCCurve: """Compute the ROC curve for multilabel classification tasks. Parameters @@ -456,17 +443,11 @@ def multilabel_roc_curve( Returns ------- - fpr : numpy.ndarray or list of numpy.ndarray - False positive rate. If ``threshold`` is not None, ``fpr`` is a 1d numpy - array. Otherwise, ``fpr`` is a list of 1d numpy arrays, one for each - label. - tpr : numpy.ndarray or list of numpy.ndarray - True positive rate. If ``threshold`` is not None, ``tpr`` is a 1d numpy - array. Otherwise, ``tpr`` is a list of 1d numpy arrays, one for each label. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used to compute fpr and tpr. ``threshold`` is not None, - thresholds is a 1d numpy array. Otherwise, thresholds is a list of - 1d numpy arrays, one for each label. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. If ``threshold`` is not None, + ``fpr``, ``tpr`` and ``thresholds`` are 1d numpy arrays, else they are lists + of 1d numpy arrays, one for each label. Examples -------- @@ -502,8 +483,9 @@ def multilabel_roc_curve( num_labels=num_labels, thresholds=thresholds, ) + fpr_, tpr_, thresholds_ = _multilabel_roc_compute(state, num_labels, thresholds) - return _multilabel_roc_compute(state, num_labels, thresholds) + return ROCCurve(fpr=fpr_, tpr=tpr_, thresholds=thresholds_) def roc_curve( @@ -514,14 +496,7 @@ def roc_curve( pos_label: int = 1, num_classes: Optional[int] = None, num_labels: Optional[int] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> ROCCurve: """Compute the ROC curve for different tasks/input types. Parameters @@ -558,22 +533,11 @@ def roc_curve( Returns ------- - fpr : numpy.ndarray or list of numpy.ndarray - False positive rate. If ``task`` is 'binary' or ``threshold`` is not None, - ``fpr`` is a 1d numpy array. If ``task`` is 'multiclass' or 'multilabel', - and ``threshold`` is None, then ``fpr`` is a list of 1d numpy - arrays, one for each class or label. - tpr : numpy.ndarray or list of numpy.ndarray - True positive rate. If ``task`` is 'binary' or ``threshold`` is not None, - ``tpr`` is a 1d numpy array. If ``task`` is 'multiclass' or 'multilabel', - and ``threshold`` is None, then ``tpr`` is a list of 1d numpy - arrays, one for each class or label. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used to compute fpr and tpr. If ``task`` is 'binary' or - ``threshold`` is not None, ``thresholds`` is a 1d numpy array. If - ``task`` is 'multiclass' or 'multilabel', and ``threshold`` is None, - then ``thresholds`` is a list of 1d numpy arrays, one for each class - or label. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. If ``threshold`` is not None, + ``fpr``, ``tpr`` and ``thresholds`` are 1d numpy arrays, else they are lists + of 1d numpy arrays, one for each label. Raises ------ diff --git a/cyclops/evaluate/metrics/precision_recall_curve.py b/cyclops/evaluate/metrics/precision_recall_curve.py index 64bf08833..9a5ce76b5 100644 --- a/cyclops/evaluate/metrics/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/precision_recall_curve.py @@ -1,11 +1,12 @@ """Classes for computing precision-recall curves.""" -from typing import Any, List, Literal, Optional, Tuple, Type, Union +from typing import Any, List, Literal, Optional, Type, Union import numpy as np import numpy.typing as npt from cyclops.evaluate.metrics.functional.precision_recall_curve import ( # type: ignore # noqa: E501 + PRCurve, _binary_precision_recall_curve_compute, _binary_precision_recall_curve_format, _binary_precision_recall_curve_update, @@ -42,14 +43,14 @@ class BinaryPrecisionRecallCurve(Metric, registry_key="binary_precision_recall_c >>> preds = [0.1, 0.4, 0.35, 0.8] >>> metric = BinaryPrecisionRecallCurve(thresholds=3) >>> metric(target, preds) - (array([0.5, 1. , 0. ]), array([1. , 0.5, 0. ]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([0.5, 1. , 0. ]), recall=array([1. , 0.5, 0. ]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[0, 1, 0, 1], [1, 1, 0, 0]] >>> preds = [[0.1, 0.4, 0.35, 0.8], [0.6, 0.3, 0.1, 0.7]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0.5 , 0.66666667, 0. ]), array([1. , 0.5, 0. ]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([0.5 , 0.66666667, 0. ]), recall=array([1. , 0.5, 0. ]), thresholds=array([0. , 0.5, 1. ])) """ @@ -101,7 +102,7 @@ def update_state(self, target: npt.ArrayLike, preds: npt.ArrayLike) -> None: def compute( self, - ) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: + ) -> PRCurve: """Compute the precision-recall curve from the state.""" if self.thresholds is None: state = ( @@ -111,11 +112,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _binary_precision_recall_curve_compute( + precision, recall, thresholds = _binary_precision_recall_curve_compute( state=state, thresholds=self.thresholds, pos_label=self.pos_label, ) + return PRCurve(precision, recall, thresholds) def __setattr__(self, name: str, value: Any) -> None: """Set the attribute ``name`` to ``value``. @@ -181,11 +183,11 @@ class MulticlassPrecisionRecallCurve( >>> preds = [[0.1, 0.6, 0.3], [0.05, 0.95, 0.0], [0.5, 0.3, 0.2], [0.2, 0.5, 0.3]] >>> metric = MulticlassPrecisionRecallCurve(num_classes=3, thresholds=3) >>> metric(target, preds) - (array([[0.5 , 0. , 0. , 1. ], + PRCurve(precision=array([[0.5 , 0. , 0. , 1. ], [0.25 , 0.33333333, 0. , 1. ], - [0.25 , 0. , 0. , 1. ]]), array([[1., 0., 0., 0.], + [0.25 , 0. , 0. , 1. ]]), recall=array([[1., 0., 0., 0.], [1., 1., 0., 0.], - [1., 0., 0., 0.]]), array([0. , 0.5, 1. ])) + [1., 0., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[0, 1, 2, 0], [1, 2, 0, 1]] >>> preds = [ @@ -195,11 +197,11 @@ class MulticlassPrecisionRecallCurve( >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.375, 0.5 , 0. , 1. ], + PRCurve(precision=array([[0.375, 0.5 , 0. , 1. ], [0.375, 0.4 , 0. , 1. ], - [0.25 , 0. , 0. , 1. ]]), array([[1. , 0.33333333, 0. , 0. ], + [0.25 , 0. , 0. , 1. ]]), recall=array([[1. , 0.33333333, 0. , 0. ], [1. , 0.66666667, 0. , 0. ], - [1. , 0. , 0. , 0. ]]), array([0. , 0.5, 1. ])) + [1. , 0. , 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) """ @@ -253,14 +255,7 @@ def update_state(self, target: npt.ArrayLike, preds: npt.ArrayLike) -> None: def compute( self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> PRCurve: """Compute the precision-recall curve from the state.""" if self.thresholds is None: state = ( @@ -270,11 +265,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _multiclass_precision_recall_curve_compute( + precision, recall, thresholds = _multiclass_precision_recall_curve_compute( state=state, thresholds=self.thresholds, # type: ignore[arg-type] num_classes=self.num_classes, ) + return PRCurve(precision, recall, thresholds) def __setattr__(self, name: str, value: Any) -> None: """Set the attribute ``name`` to ``value``. @@ -340,18 +336,18 @@ class MultilabelPrecisionRecallCurve( >>> preds = [[0.1, 0.9], [0.8, 0.2]] >>> metric = MultilabelPrecisionRecallCurve(num_labels=2, thresholds=3) >>> metric(target, preds) - (array([[0.5, 1. , 0. , 1. ], - [0.5, 1. , 0. , 1. ]]), array([[1., 1., 0., 0.], - [1., 1., 0., 0.]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 1. , 0. , 1. ], + [0.5, 1. , 0. , 1. ]]), recall=array([[1., 1., 0., 0.], + [1., 1., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[[0, 1], [1, 0]], [[1, 0], [0, 1]]] >>> preds = [[[0.1, 0.9], [0.8, 0.2]], [[0.2, 0.8], [0.7, 0.3]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.5, 0.5, 0. , 1. ], - [0.5, 0.5, 0. , 1. ]]), array([[1. , 0.5, 0. , 0. ], - [1. , 0.5, 0. , 0. ]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 0.5, 0. , 1. ], + [0.5, 0.5, 0. , 1. ]]), recall=array([[1. , 0.5, 0. , 0. ], + [1. , 0.5, 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) """ @@ -405,14 +401,7 @@ def update_state(self, target: npt.ArrayLike, preds: npt.ArrayLike) -> None: def compute( self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> PRCurve: """Compute the precision-recall curve from the state.""" if self.thresholds is None: state = ( @@ -422,11 +411,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _multilabel_precision_recall_curve_compute( + precision, recall, thresholds = _multilabel_precision_recall_curve_compute( state, thresholds=self.thresholds, # type: ignore[arg-type] num_labels=self.num_labels, ) + return PRCurve(precision, recall, thresholds) def __setattr__(self, name: str, value: Any) -> None: """Set the attribute ``name`` to ``value``. @@ -502,15 +492,15 @@ class PrecisionRecallCurve( >>> preds = [0.6, 0.2, 0.3, 0.8] >>> metric = PrecisionRecallCurve(task="binary", thresholds=None) >>> metric(target, preds) - (array([0.75 , 0.66666667, 0.5 , 0. , 1. ]), array([1. , 0.66666667, 0.33333333, 0. , 0. ]), array([0.2, 0.3, 0.6, 0.8])) + PRCurve(precision=array([0.75 , 0.66666667, 0.5 , 0. , 1. ]), recall=array([1. , 0.66666667, 0.33333333, 0. , 0. ]), thresholds=array([0.2, 0.3, 0.6, 0.8])) >>> metric.reset_state() >>> target = [[1, 0, 1, 1], [0, 0, 0, 1]] >>> preds = [[0.5, 0.4, 0.1, 0.3], [0.9, 0.6, 0.45, 0.8]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0.5 , 0.42857143, 0.33333333, 0.4 , 0.5 , - 0.33333333, 0.5 , 0. , 1. ]), array([1. , 0.75, 0.5 , 0.5 , 0.5 , 0.25, 0.25, 0. , 0. ]), array([0.1 , 0.3 , 0.4 , 0.45, 0.5 , 0.6 , 0.8 , 0.9 ])) + PRCurve(precision=array([0.5 , 0.42857143, 0.33333333, 0.4 , 0.5 , + 0.33333333, 0.5 , 0. , 1. ]), recall=array([1. , 0.75, 0.5 , 0.5 , 0.5 , 0.25, 0.25, 0. , 0. ]), thresholds=array([0.1 , 0.3 , 0.4 , 0.45, 0.5 , 0.6 , 0.8 , 0.9 ])) >>> # (multiclass) >>> from cyclops.evaluate.metrics import PrecisionRecallCurve @@ -518,11 +508,11 @@ class PrecisionRecallCurve( >>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6], [0.2, 0.2, 0.6]] >>> metric = PrecisionRecallCurve(task="multiclass", num_classes=3, thresholds=3) >>> metric(target, preds) - (array([[0.25, 0. , 0. , 1. ], + PRCurve(precision=array([[0.25, 0. , 0. , 1. ], [0.25, 0.5 , 0. , 1. ], - [0.5 , 1. , 0. , 1. ]]), array([[1., 0., 0., 0.], + [0.5 , 1. , 0. , 1. ]]), recall=array([[1., 0., 0., 0.], [1., 1., 0., 0.], - [1., 1., 0., 0.]]), array([0. , 0.5, 1. ])) + [1., 1., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[0, 1, 2, 2], [1, 2, 0, 1]] >>> preds = [ @@ -532,11 +522,11 @@ class PrecisionRecallCurve( >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.25 , 0. , 0. , 1. ], + PRCurve(precision=array([[0.25 , 0. , 0. , 1. ], [0.375, 0.5 , 0. , 1. ], - [0.375, 0.5 , 0. , 1. ]]), array([[1. , 0. , 0. , 0. ], + [0.375, 0.5 , 0. , 1. ]]), recall=array([[1. , 0. , 0. , 0. ], [1. , 0.66666667, 0. , 0. ], - [1. , 0.66666667, 0. , 0. ]]), array([0. , 0.5, 1. ])) + [1. , 0.66666667, 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) >>> # (multilabel) >>> from cyclops.evaluate.metrics import PrecisionRecallCurve @@ -544,18 +534,18 @@ class PrecisionRecallCurve( >>> preds = [[0.1, 0.9], [0.8, 0.2]] >>> metric = PrecisionRecallCurve(task="multilabel", num_labels=2, thresholds=3) >>> metric(target, preds) - (array([[0.5, 1. , 0. , 1. ], - [0.5, 1. , 0. , 1. ]]), array([[1., 1., 0., 0.], - [1., 1., 0., 0.]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 1. , 0. , 1. ], + [0.5, 1. , 0. , 1. ]]), recall=array([[1., 1., 0., 0.], + [1., 1., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[[0, 1], [1, 0]], [[1, 0], [0, 1]]] >>> preds = [[[0.1, 0.9], [0.8, 0.2]], [[0.1, 0.9], [0.8, 0.2]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.5, 0.5, 0. , 1. ], - [0.5, 0.5, 0. , 1. ]]), array([[1. , 0.5, 0. , 0. ], - [1. , 0.5, 0. , 0. ]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 0.5, 0. , 1. ], + [0.5, 0.5, 0. , 1. ]]), recall=array([[1. , 0.5, 0. , 0. ], + [1. , 0.5, 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) """ diff --git a/cyclops/evaluate/metrics/roc.py b/cyclops/evaluate/metrics/roc.py index 1b2774c8b..58587856c 100644 --- a/cyclops/evaluate/metrics/roc.py +++ b/cyclops/evaluate/metrics/roc.py @@ -1,10 +1,13 @@ """Classes for computing ROC metrics.""" -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Union import numpy as np import numpy.typing as npt +from cyclops.evaluate.metrics.functional.roc import ( + ROCCurve as ROCCurveData, +) from cyclops.evaluate.metrics.functional.roc import ( _binary_roc_compute, _multiclass_roc_compute, @@ -39,22 +42,22 @@ class BinaryROCCurve(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve" >>> preds = [0.1, 0.4, 0.35, 0.8] >>> metric = BinaryROCCurve() >>> metric(target, preds) - (array([0. , 0. , 0.5, 0.5, 1. ]), array([0. , 0.5, 0.5, 1. , 1. ]), array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) + ROCCurve(fpr=array([0. , 0. , 0.5, 0.5, 1. ]), tpr=array([0. , 0.5, 0.5, 1. , 1. ]), thresholds=array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) >>> metric.reset_state() >>> target = [[1, 1, 0, 0], [0, 0, 1, 1]] >>> preds = [[0.1, 0.2, 0.3, 0.4], [0.6, 0.5, 0.4, 0.3]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) + ROCCurve(fpr=array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), tpr=array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), thresholds=array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) """ # noqa: W505 name: str = "ROC Curve" - def compute( + def compute( # type: ignore self, - ) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: + ) -> ROCCurveData: """Compute the ROC curve from the state variables.""" if self.thresholds is None: state = ( @@ -63,13 +66,12 @@ def compute( ) else: state = self.confmat # type: ignore[attr-defined] - - return _binary_roc_compute( - state, - thresholds=self.thresholds, - pos_label=self.pos_label, + fpr_, tpr_, thresholds_ = _binary_roc_compute( + state, thresholds=self.thresholds, pos_label=self.pos_label ) + return ROCCurveData(fpr_, tpr_, thresholds_) + class MulticlassROCCurve( MulticlassPrecisionRecallCurve, @@ -101,11 +103,11 @@ class MulticlassROCCurve( >>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6], [0.9, 0.1, 0]] >>> metric = MulticlassROCCurve(num_classes=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0.33333333, 0.33333333, 1. ], - [0. , 0. , 0. , 1. ]]), array([[0. , 0.5, 0.5, 1. ], + [0. , 0. , 0. , 1. ]]), tpr=array([[0. , 0.5, 0.5, 1. ], [0. , 1. , 1. , 1. ], - [0. , 0. , 1. , 1. ]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0. , 0. , 1. , 1. ]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [[1, 1, 0, 0], [0, 0, 1, 1]] >>> preds = [ @@ -115,26 +117,19 @@ class MulticlassROCCurve( >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0. , 0.25, 0.5 , 1. ], + ROCCurve(fpr=array([[0. , 0.25, 0.5 , 1. ], [0. , 0. , 0.25, 1. ], - [0. , 0.25, 0.5 , 1. ]]), array([[0. , 0.25, 0.5 , 1. ], + [0. , 0.25, 0.5 , 1. ]]), tpr=array([[0. , 0.25, 0.5 , 1. ], [0. , 0. , 0.25, 1. ], - [0. , 0. , 0. , 0. ]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0. , 0. , 0. , 0. ]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) """ # noqa: W505 name: str = "ROC Curve" - def compute( + def compute( # type: ignore self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> ROCCurveData: """Compute the ROC curve from the state variables.""" if self.thresholds is None: state = ( @@ -143,13 +138,12 @@ def compute( ) else: state = self.confmat # type: ignore[attr-defined] - - return _multiclass_roc_compute( - state=state, - num_classes=self.num_classes, - thresholds=self.thresholds, + fpr_, tpr_, thresholds_ = _multiclass_roc_compute( + state, thresholds=self.thresholds, num_classes=self.num_classes ) + return ROCCurveData(fpr_, tpr_, thresholds_) + class MultilabelROCCurve( MultilabelPrecisionRecallCurve, @@ -175,37 +169,30 @@ class MultilabelROCCurve( >>> preds = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> metric = MultilabelROCCurve(num_labels=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [[[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]] >>> preds = [[[0.1, 0.9, 0.8], [0.05, 0.95, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) - """ + """ # noqa: W505 name: str = "ROC Curve" - def compute( + def compute( # type: ignore self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> ROCCurveData: """Compute the ROC curve from the state variables.""" if self.thresholds is None: state = ( @@ -215,11 +202,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _multilabel_roc_compute( + fpr_, tpr_, thresholds_ = _multilabel_roc_compute( state=state, num_labels=self.num_labels, thresholds=self.thresholds, ) + return ROCCurveData(fpr_, tpr_, thresholds_) class ROCCurve(Metric, registry_key="roc_curve", force_register=True): @@ -258,14 +246,14 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True): >>> preds = [0.1, 0.4, 0.35, 0.8] >>> metric = ROCCurve(task="binary", thresholds=None) >>> metric(target, preds) - (array([0. , 0. , 0.5, 0.5, 1. ]), array([0. , 0.5, 0.5, 1. , 1. ]), array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) + ROCCurve(fpr=array([0. , 0. , 0.5, 0.5, 1. ]), tpr=array([0. , 0.5, 0.5, 1. , 1. ]), thresholds=array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) >>> metric.reset_state() >>> target = [[1, 1, 0, 0], [0, 0, 1, 1]] >>> preds = [[0.1, 0.2, 0.3, 0.4], [0.6, 0.5, 0.4, 0.3]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) + ROCCurve(fpr=array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), tpr=array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), thresholds=array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) >>> # (multiclass) >>> from cyclops.evaluate.metrics import ROCCurve @@ -273,22 +261,22 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True): >>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]] >>> metric = ROCCurve(task="multiclass", num_classes=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0.5, 0.5, 1. ], - [0. , 0. , 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0. , 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 1.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 1.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [1, 2] >>> preds = [[[0.05, 0.75, 0.2]], [[0.1, 0.8, 0.1]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0., 0., 0., 1.], + ROCCurve(fpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 1.]]), array([[0., 0., 0., 0.], + [0., 0., 0., 1.]]), tpr=array([[0., 0., 0., 0.], [0., 1., 1., 1.], - [0., 0., 0., 1.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 1.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> # (multilabel) >>> from cyclops.evaluate.metrics import ROCCurve @@ -296,22 +284,22 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True): >>> preds = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> metric = ROCCurve(task="multilabel", num_labels=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [[[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]] >>> preds = [[[0.1, 0.9, 0.8], [0.05, 0.95, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) """ # noqa: W505 diff --git a/cyclops/evaluate/metrics/utils.py b/cyclops/evaluate/metrics/utils.py index 3bff5b85a..86f7fc9c3 100644 --- a/cyclops/evaluate/metrics/utils.py +++ b/cyclops/evaluate/metrics/utils.py @@ -1,6 +1,16 @@ """Utility functions for metrics.""" -from typing import Any, Callable, List, Literal, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import numpy.typing as npt @@ -295,10 +305,15 @@ def _apply_function_recursively( """ data_type = type(data) - if isinstance(data, (list, tuple, set)): - return data_type( - [_apply_function_recursively(el, func, *args, **kwargs) for el in data], - ) + is_namedtuple_ = ( + isinstance(data, tuple) + and hasattr(data, "_asdict") + and hasattr(data, "_fields") + ) + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + if is_namedtuple_ or is_sequence: + out = [_apply_function_recursively(el, func, *args, **kwargs) for el in data] + return data_type(*out) if is_namedtuple_ else data_type(out) if isinstance(data, Mapping): return data_type( { diff --git a/cyclops/report/plot/classification.py b/cyclops/report/plot/classification.py index e9ffaa0a8..255dee72a 100644 --- a/cyclops/report/plot/classification.py +++ b/cyclops/report/plot/classification.py @@ -1,13 +1,16 @@ """Classification plotter.""" from collections import defaultdict -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Union import numpy as np import numpy.typing as npt import plotly.graph_objs as go from plotly.subplots import make_subplots +from cyclops.evaluate.metrics.experimental.functional import PRCurve as PRCurveExp +from cyclops.evaluate.metrics.experimental.functional import ROCCurve as ROCCurveExp +from cyclops.evaluate.metrics.functional import PRCurve, ROCCurve from cyclops.report.plot.base import Plotter from cyclops.report.plot.utils import ( bar_plot, @@ -92,11 +95,7 @@ def _set_class_names(self, class_names: List[str]) -> None: def roc_curve( self, - roc_curve: Tuple[ - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - ], + roc_curve: Union[ROCCurve, ROCCurveExp], auroc: Optional[Union[float, List[float], npt.NDArray[np.float_]]] = None, title: Optional[str] = "ROC Curve", layout: Optional[go.Layout] = None, @@ -106,8 +105,8 @@ def roc_curve( Parameters ---------- - roc_curve : Tuple[np.ndarray, np.ndarray, np.ndarray] - Tuple of (fprs, tprs, thresholds) + roc_curve : ROCCurve + Named tuple of (fprs, tprs, thresholds) auroc : Union[float, list, np.ndarray], optional AUROCs, by default None title: str, optional @@ -123,8 +122,8 @@ def roc_curve( The figure object. """ - fprs = roc_curve[0] - tprs = roc_curve[1] + fprs = roc_curve.fpr + tprs = roc_curve.tpr trace = [] if self.task_type == "binary": @@ -191,7 +190,7 @@ def roc_curve( def roc_curve_comparison( self, - roc_curves: Dict[str, Tuple[npt.NDArray[np.float_], ...]], + roc_curves: Dict[str, Union[ROCCurve, ROCCurveExp]], aurocs: Optional[ Dict[str, Union[float, List[float], npt.NDArray[np.float_]]] ] = None, @@ -205,7 +204,7 @@ def roc_curve_comparison( ---------- roc_curves : Dict[str, Tuple] Dictionary of roc curves, with keys being the name of the subpopulation - or group and values being the roc curve tuple (fprs, tprs, thresholds) + or group and values being the roc curve namedtuples (fprs, tprs, thresholds) aurocs : Dict[str, Union[float, list, np.ndarray]], optional AUROCs for each subpopulation or group specified by name, by default None title: str, optional @@ -232,8 +231,8 @@ def roc_curve_comparison( name = f"{slice_name} (AUC = {aurocs[slice_name]:.2f})" else: name = slice_name - fprs = slice_curve[0] - tprs = slice_curve[1] + fprs = slice_curve.fpr + tprs = slice_curve.tpr trace.append( line_plot( x=fprs, @@ -296,11 +295,7 @@ def roc_curve_comparison( def precision_recall_curve( self, - precision_recall_curve: Tuple[ - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - ], + precision_recall_curve: Union[PRCurve, PRCurveExp], title: Optional[str] = "Precision-Recall Curve", layout: Optional[go.Layout] = None, **plot_kwargs: Any, @@ -309,8 +304,8 @@ def precision_recall_curve( Parameters ---------- - precision_recall_curve : Tuple[np.ndarray, np.ndarray, np.ndarray] - Tuple of (recalls, precisions, thresholds) + precision_recall_curve : PRcurve + Named tuple of (recalls, precisions, thresholds) title : str, optional Plot title, by default "Precision-Recall Curve" layout : go.Layout, optional @@ -324,8 +319,8 @@ def precision_recall_curve( The figure object. """ - recalls = precision_recall_curve[1] - precisions = precision_recall_curve[0] + recalls = precision_recall_curve.recall + precisions = precision_recall_curve.precision if self.task_type == "binary": trace = line_plot( @@ -364,7 +359,7 @@ def precision_recall_curve( def precision_recall_curve_comparison( self, - precision_recall_curves: Dict[str, Tuple[npt.NDArray[np.float_], ...]], + precision_recall_curves: Dict[str, Union[PRCurve, PRCurveExp]], auprcs: Optional[ Dict[str, Union[float, List[float], npt.NDArray[np.float_]]] ] = None, @@ -378,7 +373,7 @@ def precision_recall_curve_comparison( ---------- precision_recall_curves : Dict[str, Tuple] Dictionary of precision-recall curves, where the key is \ - the group or subpopulation name and the value is a tuple \ + the group or subpopulation name and the value is a namedtuple \ of (recalls, precisions, thresholds) auprcs : Dict[str, Union[float, list, np.ndarray]], optional AUPRCs for each subpopulation or group specified by name, by default None @@ -408,8 +403,8 @@ def precision_recall_curve_comparison( name = f"{slice_name}" trace.append( line_plot( - x=slice_curve[1], - y=slice_curve[0], + x=slice_curve.recall, + y=slice_curve.precision, trace_name=name, **plot_kwargs, ), @@ -417,7 +412,9 @@ def precision_recall_curve_comparison( else: for slice_name, slice_curve in precision_recall_curves.items(): assert ( - len(slice_curve[0]) == len(slice_curve[1]) == self.class_num + len(slice_curve.precision) + == len(slice_curve.recall) + == self.class_num ), f"Recalls and precisions must be of length class_num for \ multiclass/multilabel tasks in slice {slice_name}" for i in range(self.class_num): @@ -432,8 +429,8 @@ def precision_recall_curve_comparison( name = f"{slice_name}: {self.class_names[i]}" trace.append( line_plot( - x=slice_curve[1][i], - y=slice_curve[0][i], + x=slice_curve.recall[i], + y=slice_curve.precision[i], trace_name=name, **plot_kwargs, ),