From ecec7a78f8cf2325df53077928bac7737eebac4b Mon Sep 17 00:00:00 2001 From: Amrit K Date: Mon, 4 Mar 2024 09:43:00 -0500 Subject: [PATCH 1/4] Use namedtuple to store curve results (ROC, PR) --- .../experimental/functional/__init__.py | 2 + .../functional/precision_recall_curve.py | 67 +-- .../metrics/experimental/functional/roc.py | 81 ++-- .../experimental/precision_recall_curve.py | 9 +- cyclops/evaluate/metrics/experimental/roc.py | 6 +- cyclops/models/calibrator.py | 394 ++++++++++++++++++ cyclops/report/plot/classification.py | 55 ++- 7 files changed, 507 insertions(+), 107 deletions(-) create mode 100644 cyclops/models/calibrator.py diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index 1a2e5902b..fb8e19054 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -52,11 +52,13 @@ multilabel_tpr, ) from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + PRCurve, binary_precision_recall_curve, multiclass_precision_recall_curve, multilabel_precision_recall_curve, ) from cyclops.evaluate.metrics.experimental.functional.roc import ( + ROCCurve, binary_roc, multiclass_roc, multilabel_roc, diff --git a/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py index d484a2b1b..7fe5dc968 100644 --- a/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py @@ -1,6 +1,6 @@ """Functions for computing the precision and recall for different unique thresholds.""" from types import ModuleType -from typing import Any, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union import array_api_compat as apc import numpy as np @@ -28,6 +28,14 @@ ) +class PRCurve(NamedTuple): + """Named tuple with Precision-Recall curve (Precision, Recall and thresholds).""" + + precision: Union[Array, List[Array]] + recall: Union[Array, List[Array]] + thresholds: Union[Array, List[Array]] + + def _validate_thresholds(thresholds: Optional[Union[int, List[float], Array]]) -> None: """Validate the `thresholds` argument.""" if thresholds is not None and not ( @@ -352,14 +360,13 @@ def binary_precision_recall_curve( Returns ------- - precision : Array - The precision values for all unique thresholds. The shape of the array is + PRCurve + A named tuple that contains the following elements: + - `precision` values for all unique thresholds. The shape of the array is `(num_thresholds + 1,)`. - recall : Array - The recall values for all unique thresholds. The shape of the array is + - `recall` values for all unique thresholds. The shape of the array is `(num_thresholds + 1,)`. - thresholds : Array - The thresholds used for computing the precision and recall values, in + - `thresholds` used for computing the precision and recall values, in ascending order. The shape of the array is `(num_thresholds,)`. Raises @@ -688,7 +695,7 @@ def multiclass_precision_recall_curve( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["macro", "micro", "none"]] = None, ignore_index: Optional[Union[int, Tuple[int]]] = None, -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: +) -> PRCurve: """Compute the precision and recall for all unique thresholds. Parameters @@ -730,18 +737,17 @@ def multiclass_precision_recall_curve( Returns ------- - precision : Array or List[Array] - The precision values for all unique thresholds. If `thresholds` is `"none"` + PRCurve + A named tuple that contains the following elements: + - `precision` values for all unique thresholds. If `thresholds` is `"none"` or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape `(num_thresholds + 1, num_classes)` is returned. - recall : Array or List[Array] - The recall values for all unique thresholds. If `thresholds` is `"none"` + - `recall` values for all unique thresholds. If `thresholds` is `"none"` or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape `(num_thresholds + 1, num_classes)` is returned. - thresholds : Array or List[Array] - The thresholds used for computing the precision and recall values, in + - `thresholds` used for computing the precision and recall values, in ascending order. If `thresholds` is `"none"` or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of shape `(num_thresholds,)` is returned. @@ -868,12 +874,13 @@ class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, average, xp=xp, ) - return _multiclass_precision_recall_curve_compute( + precision, recall, thresholds_ = _multiclass_precision_recall_curve_compute( state, num_classes, thresholds=thresholds, average=average, ) + return PRCurve(precision, recall, thresholds_) def _multilabel_precision_recall_curve_validate_args( @@ -1035,7 +1042,7 @@ def multilabel_precision_recall_curve( num_labels: int, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: +) -> PRCurve: """Compute the precision and recall for all unique thresholds. Parameters @@ -1067,21 +1074,20 @@ def multilabel_precision_recall_curve( Returns ------- - precision : Array or List[Array] - The precision values for all unique thresholds. If `thresholds` is `None`, - a list for each label is returned with 1-D Arrays of shape + PRCurve + A named tuple that contains the following elements: + - `precision` values for all unique thresholds. If `thresholds` is `"none"` + or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_labels)` is returned. - recall : Array or List[Array] - The recall values for all unique thresholds. If `thresholds` is `None`, - a list for each label is returned with 1-D Arrays of shape + `(num_thresholds + 1, num_classes)` is returned. + - `recall` values for all unique thresholds. If `thresholds` is `"none"` + or `None`, a list for each class is returned with 1-D Arrays of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_labels)` is returned. - thresholds : Array or List[Array] - The thresholds used for computing the precision and recall values, in - ascending order. If `thresholds` is `None`, a list for each label is - returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D - Array of shape `(num_thresholds,)` is returned. + `(num_thresholds + 1, num_classes)` is returned. + - `thresholds` used for computing the precision and recall values, in + ascending order. If `thresholds` is `"none"` or `None`, a list for each + class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, + a 1-D Array of shape `(num_thresholds,)` is returned. Raises ------ @@ -1193,9 +1199,10 @@ def multilabel_precision_recall_curve( thresholds, xp=xp, ) - return _multilabel_precision_recall_curve_compute( + precision, recall, thresholds_ = _multilabel_precision_recall_curve_compute( state, num_labels, thresholds, ignore_index, ) + return PRCurve(precision, recall, thresholds_) diff --git a/cyclops/evaluate/metrics/experimental/functional/roc.py b/cyclops/evaluate/metrics/experimental/functional/roc.py index 736696195..0034faaa8 100644 --- a/cyclops/evaluate/metrics/experimental/functional/roc.py +++ b/cyclops/evaluate/metrics/experimental/functional/roc.py @@ -1,6 +1,6 @@ """Functions for computing Receiver Operating Characteristic (ROC) curves.""" import warnings -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, NamedTuple, Optional, Tuple, Union import array_api_compat as apc @@ -28,6 +28,14 @@ from cyclops.evaluate.metrics.experimental.utils.types import Array +class ROCCurve(NamedTuple): + """Named tuple to store ROC curve (FPR, TPR and thresholds).""" + + fpr: Union[Array, List[Array]] + tpr: Union[Array, List[Array]] + thresholds: Union[Array, List[Array]] + + def _binary_roc_compute( state: Union[Array, Tuple[Array, Array]], thresholds: Optional[Array], @@ -91,7 +99,7 @@ def binary_roc( preds: Array, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, -) -> Tuple[Array, Array, Array]: +) -> ROCCurve: """Compute the receiver operating characteristic (ROC) curve for binary tasks. Parameters @@ -120,15 +128,11 @@ def binary_roc( Returns ------- - fpr : Array - The false positive rates for all unique thresholds. The shape of the array is - `(num_thresholds + 1,)`. - tpr : Array - The true positive rates for all unique thresholds. The shape of the array is - `(num_thresholds + 1,)`. - thresholds : Array - The thresholds used for computing the ROC curve values, in descending order. - The shape of the array is `(num_thresholds,)`. + ROCCurve + A named tuple containing the false positive rate (FPR), true positive rate + (TPR) and thresholds. The FPR and TPR are arrays of of shape + `(num_thresholds + 1,)` and the thresholds are an array of shape + `(num_thresholds,)`. Raises ------ @@ -209,7 +213,8 @@ def binary_roc( xp=xp, ) state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp) - return _binary_roc_compute(state, thresholds) + fpr, tpr, thresh = _binary_roc_compute(state, thresholds) + return ROCCurve(fpr, tpr, thresh) def _multiclass_roc_compute( @@ -277,7 +282,7 @@ def multiclass_roc( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["macro", "micro", "none"]] = None, ignore_index: Optional[Union[int, Tuple[int]]] = None, -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: +) -> ROCCurve: """Compute the receiver operating characteristic (ROC) curve for multiclass tasks. Parameters @@ -318,19 +323,13 @@ def multiclass_roc( Returns ------- - fpr : Array or List[Array] - The false positive rates for all unique thresholds. If `thresholds` is `"none"` - or `None`, a list for each class is returned with 1-D Arrays of shape - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + ROCCurve + A named tuple that contains the false positive rate, true positive rate, + and the thresholds used for computing the ROC curve. If `thresholds` is `"none"` + or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays + of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape `(num_thresholds + 1, num_classes)` is returned. - tpr : Array or List[Array] - The true positive rates for all unique thresholds. If `thresholds` is `"none"` - or `None`, a list for each class is returned with 1-D Arrays of shape - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_classes)` is returned. - thresholds : Array or List[Array] - The thresholds used for computing the ROC curve values, in descending order. - If `thresholds` is `"none"` or `None`, a list for each class is returned + Similarly, a list of thresholds for each class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of shape `(num_thresholds,)` is returned. @@ -455,12 +454,13 @@ def multiclass_roc( average, xp=xp, ) - return _multiclass_roc_compute( + fpr_, tpr_, thresholds_ = _multiclass_roc_compute( state, num_classes, thresholds=thresholds, average=average, ) + return ROCCurve(fpr_, tpr_, thresholds_) def _multilabel_roc_compute( @@ -504,7 +504,7 @@ def multilabel_roc( num_labels: int, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: +) -> ROCCurve: """Compute the receiver operating characteristic (ROC) curve for multilabel tasks. Parameters @@ -535,21 +535,15 @@ def multilabel_roc( Returns ------- - fpr : Array or List[Array] - The false positive rates for all unique thresholds. If `thresholds` is `None`, - a list for each label is returned with 1-D Arrays of shape - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_labels)` is returned. - tpr : Array or List[Array] - The true positive rates for all unique thresholds. If `thresholds` is `None`, - a list for each label is returned with 1-D Arrays of shape - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape - `(num_thresholds + 1, num_labels)` is returned. - thresholds : Array or List[Array] - The thresholds used for computing the ROC curve values, in - descending order. If `thresholds` is `None`, a list for each label is - returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D - Array of shape `(num_thresholds,)` is returned. + ROCCurve + A named tuple that contains the false positive rate, true positive rate, + and the thresholds used for computing the ROC curve. If `thresholds` is `"none"` + or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays + of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_classes)` is returned. + Similarly, a list of thresholds for each class is returned + with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of + shape `(num_thresholds,)` is returned. Raises ------ @@ -660,9 +654,10 @@ def multilabel_roc( thresholds, xp=xp, ) - return _multilabel_roc_compute( + fpr_, tpr_, thresholds_ = _multilabel_roc_compute( state, num_labels, thresholds, ignore_index, ) + return ROCCurve(fpr_, tpr_, thresholds_) diff --git a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py index 36181a0eb..b6a89ca50 100644 --- a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py @@ -5,6 +5,7 @@ import array_api_compat as apc from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + PRCurve, _binary_precision_recall_curve_compute, _binary_precision_recall_curve_format_arrays, _binary_precision_recall_curve_update, @@ -140,14 +141,18 @@ def _update_state(self, target: Array, preds: Array) -> None: self.target.append(state[0]) # type: ignore[attr-defined] self.preds.append(state[1]) # type: ignore[attr-defined] - def _compute_metric(self) -> Tuple[Array, Array, Array]: + def _compute_metric(self) -> PRCurve: """Compute the metric.""" state = ( (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] if self.thresholds is None else self.confmat # type: ignore[attr-defined] ) - return _binary_precision_recall_curve_compute(state, self.thresholds) # type: ignore[arg-type] + precision, recall, thresholds = _binary_precision_recall_curve_compute( + state, + self.thresholds, # type: ignore + ) + return PRCurve(precision, recall, thresholds) class MulticlassPrecisionRecallCurve( diff --git a/cyclops/evaluate/metrics/experimental/roc.py b/cyclops/evaluate/metrics/experimental/roc.py index c3b6de9a5..9a0c3f3a0 100644 --- a/cyclops/evaluate/metrics/experimental/roc.py +++ b/cyclops/evaluate/metrics/experimental/roc.py @@ -2,6 +2,7 @@ from typing import List, Tuple, Union from cyclops.evaluate.metrics.experimental.functional.roc import ( + ROCCurve, _binary_roc_compute, _multiclass_roc_compute, _multilabel_roc_compute, @@ -55,13 +56,14 @@ class BinaryROC(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve"): name: str = "ROC Curve" - def _compute_metric(self) -> Tuple[Array, Array, Array]: + def _compute_metric(self) -> ROCCurve: # type: ignore state = ( (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] if self.thresholds is None else self.confmat # type: ignore[attr-defined] ) - return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] + fpr, tpr, thresholds = _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] + return ROCCurve(fpr, tpr, thresholds) class MulticlassROC( diff --git a/cyclops/models/calibrator.py b/cyclops/models/calibrator.py new file mode 100644 index 000000000..e5e08b9bc --- /dev/null +++ b/cyclops/models/calibrator.py @@ -0,0 +1,394 @@ +"""Calibrate model scores into probabilities.""" + +import warnings +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Tuple, Type + +import numpy as np +import numpy.typing as npt +import pandas as pd +from sklearn.isotonic import IsotonicRegression +from sklearn.metrics import roc_auc_score +from sklearn.model_selection import StratifiedShuffleSplit + + +class Calibrator(ABC): + """Class that is able to calibrate ``y_pred_proba`` scores into probabilities.""" + + @abstractmethod + def fit( + self, + y_pred_proba: npt.NDArray[np.float_], + y_true: npt.NDArray[np.int_], + *args: Any, + **kwargs: Any, + ) -> "Calibrator": + """Fits the calibrator using a reference data set. + + Parameters + ---------- + y_pred_proba: numpy.ndarray + Vector of continuous reference scores/probabilities. + Has to be the same shape as y_true. + y_true : numpy.ndarray + Vector with reference binary targets - 0 or 1. Shape (n,). + + Returns + ------- + self + + """ + raise NotImplementedError + + @abstractmethod + def calibrate( + self, y_pred_proba: npt.NDArray[np.float_], *args: Any, **kwargs: Any + ) -> npt.NDArray[np.float_]: + """Perform calibration of prediction scores. + + Parameters + ---------- + y_pred_proba: numpy.ndarray + Vector of continuous scores/probabilities. + Has to be the same shape as y_true. + + """ + raise NotImplementedError + + +class CalibratorFactory: + """Factory class to aid in construction of Calibrators.""" + + _registry: Dict[str, Type[Calibrator]] = {} + + @classmethod + def register_calibrator(cls, key: str, calibrator: Type[Calibrator]) -> None: + """Register a new calibrator to the index. + + This index associates a certain key with a function that can be used to + construct a new Calibrator instance. + + Parameters + ---------- + key: str + The key used to retrieve a Calibrator. When providing a key that is already + in the index, the value will be overwritten. + calibrator: Type[Calibrator] + A function that - given a ``**kwargs`` argument - create a new instance + of a Calibrator subclass. + + Examples + -------- + >>> CalibratorFactory.register_calibrator("isotonic", IsotonicCalibrator) + + """ + cls._registry[key] = calibrator + + @classmethod + def register(cls, key: str) -> Callable[[Type[Calibrator]], Type[Calibrator]]: + """Register a new calibrator to the index.""" + + def inner_wrapper(wrapped_class: Type[Calibrator]) -> Type[Calibrator]: + if key in cls._registry: + warnings.warn( + f"re-registering calibrator with key '{key}'", stacklevel=2 + ) + + cls._registry[key] = wrapped_class + return wrapped_class + + return inner_wrapper + + @classmethod + def create(cls, key: str = "isotonic", **kwargs: Any) -> Calibrator: + """Create a new Calibrator given a key value and optional keyword args. + + If the provided key equals ``None``, then a new instance of the default + Calibrator (IsotonicCalibrator) will be returned. + + If a non-existent key is provided an ``InvalidArgumentsException`` is raised. + + Parameters + ---------- + key : str, default='isotonic' + The key used to retrieve a Calibrator. When providing a key that is + already in the index, the value will be overwritten. + kwargs : dict + Optional keyword arguments that will be passed along to the function + associated with the key. It can then use these arguments during the + creation of a new Calibrator instance. + + Returns + ------- + calibrator: Calibrator + A new instance of a specific Calibrator subclass. + + Examples + -------- + >>> calibrator = CalibratorFactory.create("isotonic", **{"foo": "bar"}) + + """ + if key not in cls._registry: + raise ValueError( + f"calibrator '{key}' unknown. " + f"Please provide one of the following: {cls._registry.keys()}" + ) + + calibrator_class = cls._registry.get(key) + assert calibrator_class + + return calibrator_class(**kwargs) + + +@CalibratorFactory.register("isotonic") +class IsotonicCalibrator(Calibrator): + """Calibrates using IsotonicRegression model.""" + + def __init__(self) -> None: + """Create a new IsotonicCalibrator.""" + regressor = IsotonicRegression(out_of_bounds="clip", increasing=True) + self._regressor = regressor + + def fit( + self, + y_pred_proba: npt.NDArray[np.float_], + y_true: npt.NDArray[np.int_], + *args: Any, + **kwargs: Any, + ) -> Any: + """Fits the calibrator using a reference data set. + + Parameters + ---------- + y_pred_proba: numpy.ndarray + Vector of continuous reference scores/probabilities. Has to be the same + shape as y_true. + y_true : numpy.ndarray + Vector with reference binary targets - 0 or 1. Shape (n,). + + Returns + ------- + self: IsotonicCalibrator + The instance itself. + + """ + return self._regressor.fit(y_pred_proba, y_true) + + def calibrate( + self, y_pred_proba: npt.NDArray[np.float_], *args: Any, **kwargs: Any + ) -> Any: + """Perform calibration of prediction scores. + + Parameters + ---------- + y_pred_proba: numpy.ndarray + Vector of continuous scores/probabilities. + Has to be the same shape as ``y_true``. + + Returns + ------- + calibrated_scores: numpy.ndarray + Vector of calibrated scores/probabilities. + + """ + return self._regressor.predict(y_pred_proba) + + +class NoopCalibrator(Calibrator): + """A Calibrator subclass that simply returns the inputs unaltered.""" + + def fit( + self, + y_pred_proba: npt.NDArray[np.float_], + y_true: npt.NDArray[np.int_], + *args: Any, + **kwargs: Any, + ) -> Any: + """Fit nothing and just return the calibrator.""" + return self + + def calibrate( + self, + y_pred_proba: npt.NDArray[np.float_], + *args: Any, + **kwargs: Any, + ) -> npt.NDArray[np.float_]: + """Calibrate nothing and just return the original ``y_pred_proba`` inputs.""" + return np.asarray(y_pred_proba) + + +def _get_bin_index_edges(vector_length: int, bin_count: int) -> List[Tuple[int, int]]: + """Generate edges of bins for specified vector length and number of bins required. + + Parameters + ---------- + vector_length : int + The length of the vector that will be binned using bins. + bin_count : int + Number of bins and bin edges that will be generated. + + Returns + ------- + bin_index_edges : list of tuples with bin edges (indexes) + See the example below for best intuition. + + Examples + -------- + >>> _get_bin_edge_indexes(20, 4) + [(0, 5), (5, 10), (10, 15), (15, 20)] + + """ + if vector_length <= 2 * bin_count: + bin_count = vector_length // 2 + if bin_count < 2: + raise ValueError( + "cannot split into minimum of 2 bins. Current sample size " + f"is {vector_length}, please increase sample size. " + ) + + bin_width = vector_length // bin_count + bin_edges = np.asarray(range(0, vector_length + 1, bin_width)) + bin_edges[-1] = vector_length + bin_index_left = bin_edges[:-1] + bin_index_right = bin_edges[1:] + bin_index_edges = [(x, y) for x, y in zip(bin_index_left, bin_index_right)] # noqa: C416 + + return bin_index_edges # noqa: RET504 + + +def _calculate_expected_calibration_error( + y_true: npt.NDArray[np.int_], + y_pred_proba: npt.NDArray[np.float_], + bin_index_edges: List[Tuple[int, int]], +) -> float: + terms = [] + + y_pred_proba, y_true = np.asarray(y_pred_proba), np.asarray(y_true) + + # sort both y_pred_proba and y_true, just to make sure + sort_index = y_pred_proba.argsort() + y_pred_proba = y_pred_proba[sort_index] + y_true = y_true[sort_index] + + for left_edge, right_edge in bin_index_edges: + bin_proba = y_pred_proba[left_edge:right_edge] + bin_true = y_true[left_edge:right_edge] + mean_bin_proba = np.mean(bin_proba) + mean_bin_true = np.mean(bin_true) + weight = len(bin_proba) / len(y_pred_proba) + terms.append(weight * abs(mean_bin_proba - mean_bin_true)) + + expected_calibration_error = float(np.sum(terms)) + + return expected_calibration_error # noqa: RET504 + + +def needs_calibration( + y_true: npt.NDArray[np.int_], + y_pred_proba: npt.NDArray[np.float_], + calibrator: Calibrator, + bin_count: int = 10, + split_count: int = 10, +) -> bool: + """Return whether prediction scores benefits from additional calibration or not. + + Performs probability calibration in cross validation loop. For each fold a + difference of Expected Calibration Error (ECE) between non calibrated and calibrated + probabilites is calculated. If in any of the folds the difference is lower than zero + (i.e. ECE of calibrated probability is larger than that of non-calibrated) + returns ``False``. Otherwise - returns ``True``. + + Parameters + ---------- + calibrator : Calibrator + The Calibrator to use during testing. + y_true : np.array + Series with reference binary targets - ``0`` or ``1``. Shape ``(n,)``. + y_pred_proba : np.array + Series or DataFrame of continuous reference scores/probabilities. + Has to be the same shape as ``y_true``. + bin_count : int + Desired amount of bins to calculate ECE on. + split_count : int + Desired number of splits to make, i.e. number of times to evaluate calibration. + + Returns + ------- + needs_calibration: bool + ``True`` when the scores benefit from calibration, ``False`` otherwise. + + Examples + -------- + >>> import numpy as np + >>> from cyclops.estimate.calibrator import IsotonicCalibrator, needs_calibration + >>> np.random.seed(1) + >>> y_true = np.random.binomial(1, 0.5, 10) + >>> y_pred_proba = np.linspace(0, 1, 10) + >>> calibrator = IsotonicCalibrator() + >>> needs_calibration(y_true, y_pred_proba, calibrator, bin_count=2, split_count=3) + True + + """ + if y_true.dtype == "object": + if pd.isnull(y_true).any(): + raise ValueError( + "target values contain NaN. " + "Please ensure reference targets do not contain NaN values." + ) + elif np.isnan(y_true).any(): + raise ValueError( + "target values contain NaN. " + "Please ensure reference targets do not contain NaN values." + ) + + if np.isnan(y_pred_proba).any(): + raise ValueError( + "predicted probabilities contain NaN. " + "Please ensure reference predicted probabilities do not contain NaN values." + ) + + # Check if we have a single class in y_true. This would crash the AUROC check below. + # If we do only have a single class in y_true, no calibration will be required. + if len(np.unique(y_true)) == 1: + return False + + if roc_auc_score(y_true, y_pred_proba, multi_class="ovr") > 0.999: + return False + + sss = StratifiedShuffleSplit(n_splits=split_count, test_size=0.1, random_state=42) + + list_y_true_test = [] + list_y_pred_proba_test = [] + list_calibrated_y_pred_proba_test = [] + + for train, test in sss.split(y_pred_proba, y_true): + if isinstance(y_pred_proba, pd.DataFrame): + y_pred_proba_train, y_true_train = ( + y_pred_proba.iloc[train, :], + y_true[train], + ) + y_pred_proba_test, y_true_test = y_pred_proba.iloc[test, :], y_true[test] + else: + y_pred_proba_train, y_true_train = y_pred_proba[train], y_true[train] + y_pred_proba_test, y_true_test = y_pred_proba[test], y_true[test] + + calibrator.fit(y_pred_proba_train, y_true_train) + calibrated_y_pred_proba_test = calibrator.calibrate(y_pred_proba_test) + + list_y_true_test.append(y_true_test) + list_y_pred_proba_test.append(y_pred_proba_test) + list_calibrated_y_pred_proba_test.append(calibrated_y_pred_proba_test) + + vec_y_true_test = np.concatenate(list_y_true_test) + vec_y_pred_proba_test = np.concatenate(list_y_pred_proba_test) + vec_calibrated_y_pred_proba_test = np.concatenate(list_calibrated_y_pred_proba_test) + + bin_index_edges = _get_bin_index_edges(len(vec_y_pred_proba_test), bin_count) + ece_before_calibration = _calculate_expected_calibration_error( + vec_y_true_test, vec_y_pred_proba_test, bin_index_edges + ) + ece_after_calibration = _calculate_expected_calibration_error( + vec_y_true_test, vec_calibrated_y_pred_proba_test, bin_index_edges + ) + + return ece_before_calibration > ece_after_calibration diff --git a/cyclops/report/plot/classification.py b/cyclops/report/plot/classification.py index e9ffaa0a8..69aa612d8 100644 --- a/cyclops/report/plot/classification.py +++ b/cyclops/report/plot/classification.py @@ -1,13 +1,14 @@ """Classification plotter.""" from collections import defaultdict -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Union import numpy as np import numpy.typing as npt import plotly.graph_objs as go from plotly.subplots import make_subplots +from cyclops.evaluate.metrics.experimental.functional import PRCurve, ROCCurve from cyclops.report.plot.base import Plotter from cyclops.report.plot.utils import ( bar_plot, @@ -92,11 +93,7 @@ def _set_class_names(self, class_names: List[str]) -> None: def roc_curve( self, - roc_curve: Tuple[ - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - ], + roc_curve: ROCCurve, auroc: Optional[Union[float, List[float], npt.NDArray[np.float_]]] = None, title: Optional[str] = "ROC Curve", layout: Optional[go.Layout] = None, @@ -106,8 +103,8 @@ def roc_curve( Parameters ---------- - roc_curve : Tuple[np.ndarray, np.ndarray, np.ndarray] - Tuple of (fprs, tprs, thresholds) + roc_curve : ROCCurve + Named tuple of (fprs, tprs, thresholds) auroc : Union[float, list, np.ndarray], optional AUROCs, by default None title: str, optional @@ -123,8 +120,8 @@ def roc_curve( The figure object. """ - fprs = roc_curve[0] - tprs = roc_curve[1] + fprs = roc_curve.fpr + tprs = roc_curve.tpr trace = [] if self.task_type == "binary": @@ -191,7 +188,7 @@ def roc_curve( def roc_curve_comparison( self, - roc_curves: Dict[str, Tuple[npt.NDArray[np.float_], ...]], + roc_curves: Dict[str, ROCCurve], aurocs: Optional[ Dict[str, Union[float, List[float], npt.NDArray[np.float_]]] ] = None, @@ -205,7 +202,7 @@ def roc_curve_comparison( ---------- roc_curves : Dict[str, Tuple] Dictionary of roc curves, with keys being the name of the subpopulation - or group and values being the roc curve tuple (fprs, tprs, thresholds) + or group and values being the roc curve namedtuples (fprs, tprs, thresholds) aurocs : Dict[str, Union[float, list, np.ndarray]], optional AUROCs for each subpopulation or group specified by name, by default None title: str, optional @@ -232,8 +229,8 @@ def roc_curve_comparison( name = f"{slice_name} (AUC = {aurocs[slice_name]:.2f})" else: name = slice_name - fprs = slice_curve[0] - tprs = slice_curve[1] + fprs = slice_curve.fpr + tprs = slice_curve.tpr trace.append( line_plot( x=fprs, @@ -296,11 +293,7 @@ def roc_curve_comparison( def precision_recall_curve( self, - precision_recall_curve: Tuple[ - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - ], + precision_recall_curve: PRCurve, title: Optional[str] = "Precision-Recall Curve", layout: Optional[go.Layout] = None, **plot_kwargs: Any, @@ -309,8 +302,8 @@ def precision_recall_curve( Parameters ---------- - precision_recall_curve : Tuple[np.ndarray, np.ndarray, np.ndarray] - Tuple of (recalls, precisions, thresholds) + precision_recall_curve : PRcurve + Named tuple of (recalls, precisions, thresholds) title : str, optional Plot title, by default "Precision-Recall Curve" layout : go.Layout, optional @@ -324,8 +317,8 @@ def precision_recall_curve( The figure object. """ - recalls = precision_recall_curve[1] - precisions = precision_recall_curve[0] + recalls = precision_recall_curve.recall + precisions = precision_recall_curve.precision if self.task_type == "binary": trace = line_plot( @@ -364,7 +357,7 @@ def precision_recall_curve( def precision_recall_curve_comparison( self, - precision_recall_curves: Dict[str, Tuple[npt.NDArray[np.float_], ...]], + precision_recall_curves: Dict[str, PRCurve], auprcs: Optional[ Dict[str, Union[float, List[float], npt.NDArray[np.float_]]] ] = None, @@ -378,7 +371,7 @@ def precision_recall_curve_comparison( ---------- precision_recall_curves : Dict[str, Tuple] Dictionary of precision-recall curves, where the key is \ - the group or subpopulation name and the value is a tuple \ + the group or subpopulation name and the value is a namedtuple \ of (recalls, precisions, thresholds) auprcs : Dict[str, Union[float, list, np.ndarray]], optional AUPRCs for each subpopulation or group specified by name, by default None @@ -408,8 +401,8 @@ def precision_recall_curve_comparison( name = f"{slice_name}" trace.append( line_plot( - x=slice_curve[1], - y=slice_curve[0], + x=slice_curve.recall, + y=slice_curve.precision, trace_name=name, **plot_kwargs, ), @@ -417,7 +410,9 @@ def precision_recall_curve_comparison( else: for slice_name, slice_curve in precision_recall_curves.items(): assert ( - len(slice_curve[0]) == len(slice_curve[1]) == self.class_num + len(slice_curve.precision) + == len(slice_curve.recall) + == self.class_num ), f"Recalls and precisions must be of length class_num for \ multiclass/multilabel tasks in slice {slice_name}" for i in range(self.class_num): @@ -432,8 +427,8 @@ def precision_recall_curve_comparison( name = f"{slice_name}: {self.class_names[i]}" trace.append( line_plot( - x=slice_curve[1][i], - y=slice_curve[0][i], + x=slice_curve.recall[i], + y=slice_curve.precision[i], trace_name=name, **plot_kwargs, ), From 6a09df0781b9c91013c60a22c77b05ebb62eebc2 Mon Sep 17 00:00:00 2001 From: Amrit K Date: Mon, 4 Mar 2024 09:48:08 -0500 Subject: [PATCH 2/4] Remove calibrator which was added by mistake --- cyclops/models/calibrator.py | 394 ----------------------------------- 1 file changed, 394 deletions(-) delete mode 100644 cyclops/models/calibrator.py diff --git a/cyclops/models/calibrator.py b/cyclops/models/calibrator.py deleted file mode 100644 index e5e08b9bc..000000000 --- a/cyclops/models/calibrator.py +++ /dev/null @@ -1,394 +0,0 @@ -"""Calibrate model scores into probabilities.""" - -import warnings -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Tuple, Type - -import numpy as np -import numpy.typing as npt -import pandas as pd -from sklearn.isotonic import IsotonicRegression -from sklearn.metrics import roc_auc_score -from sklearn.model_selection import StratifiedShuffleSplit - - -class Calibrator(ABC): - """Class that is able to calibrate ``y_pred_proba`` scores into probabilities.""" - - @abstractmethod - def fit( - self, - y_pred_proba: npt.NDArray[np.float_], - y_true: npt.NDArray[np.int_], - *args: Any, - **kwargs: Any, - ) -> "Calibrator": - """Fits the calibrator using a reference data set. - - Parameters - ---------- - y_pred_proba: numpy.ndarray - Vector of continuous reference scores/probabilities. - Has to be the same shape as y_true. - y_true : numpy.ndarray - Vector with reference binary targets - 0 or 1. Shape (n,). - - Returns - ------- - self - - """ - raise NotImplementedError - - @abstractmethod - def calibrate( - self, y_pred_proba: npt.NDArray[np.float_], *args: Any, **kwargs: Any - ) -> npt.NDArray[np.float_]: - """Perform calibration of prediction scores. - - Parameters - ---------- - y_pred_proba: numpy.ndarray - Vector of continuous scores/probabilities. - Has to be the same shape as y_true. - - """ - raise NotImplementedError - - -class CalibratorFactory: - """Factory class to aid in construction of Calibrators.""" - - _registry: Dict[str, Type[Calibrator]] = {} - - @classmethod - def register_calibrator(cls, key: str, calibrator: Type[Calibrator]) -> None: - """Register a new calibrator to the index. - - This index associates a certain key with a function that can be used to - construct a new Calibrator instance. - - Parameters - ---------- - key: str - The key used to retrieve a Calibrator. When providing a key that is already - in the index, the value will be overwritten. - calibrator: Type[Calibrator] - A function that - given a ``**kwargs`` argument - create a new instance - of a Calibrator subclass. - - Examples - -------- - >>> CalibratorFactory.register_calibrator("isotonic", IsotonicCalibrator) - - """ - cls._registry[key] = calibrator - - @classmethod - def register(cls, key: str) -> Callable[[Type[Calibrator]], Type[Calibrator]]: - """Register a new calibrator to the index.""" - - def inner_wrapper(wrapped_class: Type[Calibrator]) -> Type[Calibrator]: - if key in cls._registry: - warnings.warn( - f"re-registering calibrator with key '{key}'", stacklevel=2 - ) - - cls._registry[key] = wrapped_class - return wrapped_class - - return inner_wrapper - - @classmethod - def create(cls, key: str = "isotonic", **kwargs: Any) -> Calibrator: - """Create a new Calibrator given a key value and optional keyword args. - - If the provided key equals ``None``, then a new instance of the default - Calibrator (IsotonicCalibrator) will be returned. - - If a non-existent key is provided an ``InvalidArgumentsException`` is raised. - - Parameters - ---------- - key : str, default='isotonic' - The key used to retrieve a Calibrator. When providing a key that is - already in the index, the value will be overwritten. - kwargs : dict - Optional keyword arguments that will be passed along to the function - associated with the key. It can then use these arguments during the - creation of a new Calibrator instance. - - Returns - ------- - calibrator: Calibrator - A new instance of a specific Calibrator subclass. - - Examples - -------- - >>> calibrator = CalibratorFactory.create("isotonic", **{"foo": "bar"}) - - """ - if key not in cls._registry: - raise ValueError( - f"calibrator '{key}' unknown. " - f"Please provide one of the following: {cls._registry.keys()}" - ) - - calibrator_class = cls._registry.get(key) - assert calibrator_class - - return calibrator_class(**kwargs) - - -@CalibratorFactory.register("isotonic") -class IsotonicCalibrator(Calibrator): - """Calibrates using IsotonicRegression model.""" - - def __init__(self) -> None: - """Create a new IsotonicCalibrator.""" - regressor = IsotonicRegression(out_of_bounds="clip", increasing=True) - self._regressor = regressor - - def fit( - self, - y_pred_proba: npt.NDArray[np.float_], - y_true: npt.NDArray[np.int_], - *args: Any, - **kwargs: Any, - ) -> Any: - """Fits the calibrator using a reference data set. - - Parameters - ---------- - y_pred_proba: numpy.ndarray - Vector of continuous reference scores/probabilities. Has to be the same - shape as y_true. - y_true : numpy.ndarray - Vector with reference binary targets - 0 or 1. Shape (n,). - - Returns - ------- - self: IsotonicCalibrator - The instance itself. - - """ - return self._regressor.fit(y_pred_proba, y_true) - - def calibrate( - self, y_pred_proba: npt.NDArray[np.float_], *args: Any, **kwargs: Any - ) -> Any: - """Perform calibration of prediction scores. - - Parameters - ---------- - y_pred_proba: numpy.ndarray - Vector of continuous scores/probabilities. - Has to be the same shape as ``y_true``. - - Returns - ------- - calibrated_scores: numpy.ndarray - Vector of calibrated scores/probabilities. - - """ - return self._regressor.predict(y_pred_proba) - - -class NoopCalibrator(Calibrator): - """A Calibrator subclass that simply returns the inputs unaltered.""" - - def fit( - self, - y_pred_proba: npt.NDArray[np.float_], - y_true: npt.NDArray[np.int_], - *args: Any, - **kwargs: Any, - ) -> Any: - """Fit nothing and just return the calibrator.""" - return self - - def calibrate( - self, - y_pred_proba: npt.NDArray[np.float_], - *args: Any, - **kwargs: Any, - ) -> npt.NDArray[np.float_]: - """Calibrate nothing and just return the original ``y_pred_proba`` inputs.""" - return np.asarray(y_pred_proba) - - -def _get_bin_index_edges(vector_length: int, bin_count: int) -> List[Tuple[int, int]]: - """Generate edges of bins for specified vector length and number of bins required. - - Parameters - ---------- - vector_length : int - The length of the vector that will be binned using bins. - bin_count : int - Number of bins and bin edges that will be generated. - - Returns - ------- - bin_index_edges : list of tuples with bin edges (indexes) - See the example below for best intuition. - - Examples - -------- - >>> _get_bin_edge_indexes(20, 4) - [(0, 5), (5, 10), (10, 15), (15, 20)] - - """ - if vector_length <= 2 * bin_count: - bin_count = vector_length // 2 - if bin_count < 2: - raise ValueError( - "cannot split into minimum of 2 bins. Current sample size " - f"is {vector_length}, please increase sample size. " - ) - - bin_width = vector_length // bin_count - bin_edges = np.asarray(range(0, vector_length + 1, bin_width)) - bin_edges[-1] = vector_length - bin_index_left = bin_edges[:-1] - bin_index_right = bin_edges[1:] - bin_index_edges = [(x, y) for x, y in zip(bin_index_left, bin_index_right)] # noqa: C416 - - return bin_index_edges # noqa: RET504 - - -def _calculate_expected_calibration_error( - y_true: npt.NDArray[np.int_], - y_pred_proba: npt.NDArray[np.float_], - bin_index_edges: List[Tuple[int, int]], -) -> float: - terms = [] - - y_pred_proba, y_true = np.asarray(y_pred_proba), np.asarray(y_true) - - # sort both y_pred_proba and y_true, just to make sure - sort_index = y_pred_proba.argsort() - y_pred_proba = y_pred_proba[sort_index] - y_true = y_true[sort_index] - - for left_edge, right_edge in bin_index_edges: - bin_proba = y_pred_proba[left_edge:right_edge] - bin_true = y_true[left_edge:right_edge] - mean_bin_proba = np.mean(bin_proba) - mean_bin_true = np.mean(bin_true) - weight = len(bin_proba) / len(y_pred_proba) - terms.append(weight * abs(mean_bin_proba - mean_bin_true)) - - expected_calibration_error = float(np.sum(terms)) - - return expected_calibration_error # noqa: RET504 - - -def needs_calibration( - y_true: npt.NDArray[np.int_], - y_pred_proba: npt.NDArray[np.float_], - calibrator: Calibrator, - bin_count: int = 10, - split_count: int = 10, -) -> bool: - """Return whether prediction scores benefits from additional calibration or not. - - Performs probability calibration in cross validation loop. For each fold a - difference of Expected Calibration Error (ECE) between non calibrated and calibrated - probabilites is calculated. If in any of the folds the difference is lower than zero - (i.e. ECE of calibrated probability is larger than that of non-calibrated) - returns ``False``. Otherwise - returns ``True``. - - Parameters - ---------- - calibrator : Calibrator - The Calibrator to use during testing. - y_true : np.array - Series with reference binary targets - ``0`` or ``1``. Shape ``(n,)``. - y_pred_proba : np.array - Series or DataFrame of continuous reference scores/probabilities. - Has to be the same shape as ``y_true``. - bin_count : int - Desired amount of bins to calculate ECE on. - split_count : int - Desired number of splits to make, i.e. number of times to evaluate calibration. - - Returns - ------- - needs_calibration: bool - ``True`` when the scores benefit from calibration, ``False`` otherwise. - - Examples - -------- - >>> import numpy as np - >>> from cyclops.estimate.calibrator import IsotonicCalibrator, needs_calibration - >>> np.random.seed(1) - >>> y_true = np.random.binomial(1, 0.5, 10) - >>> y_pred_proba = np.linspace(0, 1, 10) - >>> calibrator = IsotonicCalibrator() - >>> needs_calibration(y_true, y_pred_proba, calibrator, bin_count=2, split_count=3) - True - - """ - if y_true.dtype == "object": - if pd.isnull(y_true).any(): - raise ValueError( - "target values contain NaN. " - "Please ensure reference targets do not contain NaN values." - ) - elif np.isnan(y_true).any(): - raise ValueError( - "target values contain NaN. " - "Please ensure reference targets do not contain NaN values." - ) - - if np.isnan(y_pred_proba).any(): - raise ValueError( - "predicted probabilities contain NaN. " - "Please ensure reference predicted probabilities do not contain NaN values." - ) - - # Check if we have a single class in y_true. This would crash the AUROC check below. - # If we do only have a single class in y_true, no calibration will be required. - if len(np.unique(y_true)) == 1: - return False - - if roc_auc_score(y_true, y_pred_proba, multi_class="ovr") > 0.999: - return False - - sss = StratifiedShuffleSplit(n_splits=split_count, test_size=0.1, random_state=42) - - list_y_true_test = [] - list_y_pred_proba_test = [] - list_calibrated_y_pred_proba_test = [] - - for train, test in sss.split(y_pred_proba, y_true): - if isinstance(y_pred_proba, pd.DataFrame): - y_pred_proba_train, y_true_train = ( - y_pred_proba.iloc[train, :], - y_true[train], - ) - y_pred_proba_test, y_true_test = y_pred_proba.iloc[test, :], y_true[test] - else: - y_pred_proba_train, y_true_train = y_pred_proba[train], y_true[train] - y_pred_proba_test, y_true_test = y_pred_proba[test], y_true[test] - - calibrator.fit(y_pred_proba_train, y_true_train) - calibrated_y_pred_proba_test = calibrator.calibrate(y_pred_proba_test) - - list_y_true_test.append(y_true_test) - list_y_pred_proba_test.append(y_pred_proba_test) - list_calibrated_y_pred_proba_test.append(calibrated_y_pred_proba_test) - - vec_y_true_test = np.concatenate(list_y_true_test) - vec_y_pred_proba_test = np.concatenate(list_y_pred_proba_test) - vec_calibrated_y_pred_proba_test = np.concatenate(list_calibrated_y_pred_proba_test) - - bin_index_edges = _get_bin_index_edges(len(vec_y_pred_proba_test), bin_count) - ece_before_calibration = _calculate_expected_calibration_error( - vec_y_true_test, vec_y_pred_proba_test, bin_index_edges - ) - ece_after_calibration = _calculate_expected_calibration_error( - vec_y_true_test, vec_calibrated_y_pred_proba_test, bin_index_edges - ) - - return ece_before_calibration > ece_after_calibration From ba5f70fcb9338d173680b386620b676d65a38ef9 Mon Sep 17 00:00:00 2001 From: Amrit K Date: Mon, 4 Mar 2024 10:59:12 -0500 Subject: [PATCH 3/4] Fix docstring examples, typo --- .../evaluate/metrics/experimental/functional/roc.py | 2 +- .../metrics/experimental/precision_recall_curve.py | 12 ++++++------ cyclops/evaluate/metrics/experimental/roc.py | 12 ++++++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/cyclops/evaluate/metrics/experimental/functional/roc.py b/cyclops/evaluate/metrics/experimental/functional/roc.py index 0034faaa8..a9b7c9aac 100644 --- a/cyclops/evaluate/metrics/experimental/functional/roc.py +++ b/cyclops/evaluate/metrics/experimental/functional/roc.py @@ -130,7 +130,7 @@ def binary_roc( ------- ROCCurve A named tuple containing the false positive rate (FPR), true positive rate - (TPR) and thresholds. The FPR and TPR are arrays of of shape + (TPR) and thresholds. The FPR and TPR are arrays of shape `(num_thresholds + 1,)` and the thresholds are an array of shape `(num_thresholds,)`. diff --git a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py index b6a89ca50..6c486310d 100644 --- a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py @@ -55,14 +55,14 @@ class BinaryPrecisionRecallCurve(Metric, registry_key="binary_precision_recall_c >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryPrecisionRecallCurve(thresholds=None) >>> metric(target, preds) - (Array([0.5 , 0.6 , 0.5 , 0.6666667, - 0.5 , 1. , 1. ], dtype=float32), Array([1. , 1. , 0.6666667 , 0.6666667 , - 0.33333334, 0.33333334, 0. ], dtype=float32), Array([0.11, 0.22, 0.33, 0.73, 0.84, 0.92], dtype=float64)) + PRCurve(precision=Array([0.5 , 0.6 , 0.5 , 0.6666667, + 0.5 , 1. , 1. ], dtype=float32), recall=Array([1. , 1. , 0.6666667 , 0.6666667 , + 0.33333334, 0.33333334, 0. ], dtype=float32), thresholds=Array([0.11, 0.22, 0.33, 0.73, 0.84, 0.92], dtype=float64)) >>> metric = BinaryPrecisionRecallCurve(thresholds=5) >>> metric(target, preds) - (Array([0.5 , 0.5 , 0.6666667, 0.5 , - 0. , 1. ], dtype=float32), Array([1. , 0.6666667 , 0.6666667 , 0.33333334, - 0. , 0. ], dtype=float32), Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32)) + PRCurve(precision=Array([0.5 , 0.5 , 0.6666667, 0.5 , + 0. , 1. ], dtype=float32), recall=Array([1. , 0.6666667 , 0.6666667 , 0.33333334, + 0. , 0. ], dtype=float32), thresholds=Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32)) """ # noqa: W505 diff --git a/cyclops/evaluate/metrics/experimental/roc.py b/cyclops/evaluate/metrics/experimental/roc.py index 9a0c3f3a0..766f1c9f5 100644 --- a/cyclops/evaluate/metrics/experimental/roc.py +++ b/cyclops/evaluate/metrics/experimental/roc.py @@ -43,14 +43,14 @@ class BinaryROC(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve"): >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryROC(thresholds=None) >>> metric(target, preds) - (Array([0. , 0. , 0.33333334, 0.33333334, - 0.6666667 , 0.6666667 , 1. ], dtype=float32), Array([0. , 0.33333334, 0.33333334, 0.6666667 , - 0.6666667 , 1. , 1. ], dtype=float32), Array([1. , 0.92, 0.84, 0.73, 0.33, 0.22, 0.11], dtype=float64)) + ROCCurve(fpr=Array([0. , 0. , 0.33333334, 0.33333334, + 0.6666667 , 0.6666667 , 1. ], dtype=float32), tpr=Array([0. , 0.33333334, 0.33333334, 0.6666667 , + 0.6666667 , 1. , 1. ], dtype=float32), thresholds=Array([1. , 0.92, 0.84, 0.73, 0.33, 0.22, 0.11], dtype=float64)) >>> metric = BinaryROC(thresholds=5) >>> metric(target, preds) - (Array([0. , 0.33333334, 0.33333334, 0.6666667 , - 1. ], dtype=float32), Array([0. , 0.33333334, 0.6666667 , 0.6666667 , - 1. ], dtype=float32), Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32)) + ROCCurve(fpr=Array([0. , 0.33333334, 0.33333334, 0.6666667 , + 1. ], dtype=float32), tpr=Array([0. , 0.33333334, 0.6666667 , 0.6666667 , + 1. ], dtype=float32), thresholds=Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32)) """ # noqa: W505 From 838bc96b4d0a70fe7ada13cbf8692c0bb65cb71c Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Tue, 5 Mar 2024 13:40:26 -0500 Subject: [PATCH 4/4] =?UTF-8?q?Use=20namedtuple=20to=20store=20curve=20res?= =?UTF-8?q?ults=20(ROC,=20PR)=20for=20non-experimental=20=E2=80=A6=20(#574?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Use namedtuple to store curve results (ROC, PR) for non-experimental metrics * Remove commented out doc example * handle named tuples in `_apply_function_recursively` and fix doctests --------- Co-authored-by: Franklin <41602287+fcogidi@users.noreply.github.com> --- .../evaluate/metrics/functional/__init__.py | 2 + .../functional/precision_recall_curve.py | 98 +++++++--------- cyclops/evaluate/metrics/functional/roc.py | 110 ++++++------------ .../metrics/precision_recall_curve.py | 90 +++++++------- cyclops/evaluate/metrics/roc.py | 110 ++++++++---------- cyclops/evaluate/metrics/utils.py | 25 +++- cyclops/report/plot/classification.py | 12 +- 7 files changed, 197 insertions(+), 250 deletions(-) diff --git a/cyclops/evaluate/metrics/functional/__init__.py b/cyclops/evaluate/metrics/functional/__init__.py index 14eee5b9e..ed64063d1 100644 --- a/cyclops/evaluate/metrics/functional/__init__.py +++ b/cyclops/evaluate/metrics/functional/__init__.py @@ -37,12 +37,14 @@ recall, ) from cyclops.evaluate.metrics.functional.precision_recall_curve import ( # noqa: F401 + PRCurve, binary_precision_recall_curve, multiclass_precision_recall_curve, multilabel_precision_recall_curve, precision_recall_curve, ) from cyclops.evaluate.metrics.functional.roc import ( # noqa: F401 + ROCCurve, binary_roc_curve, multiclass_roc_curve, multilabel_roc_curve, diff --git a/cyclops/evaluate/metrics/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/functional/precision_recall_curve.py index bbe0bcc78..a0f9b69e3 100644 --- a/cyclops/evaluate/metrics/functional/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/functional/precision_recall_curve.py @@ -1,6 +1,6 @@ """Functions for computing the precision-recall curve for different input types.""" -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, NamedTuple, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -15,6 +15,14 @@ ) +class PRCurve(NamedTuple): + """Named tuple with Precision-Recall curve (Precision, Recall and thresholds).""" + + precision: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + recall: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + thresholds: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + + def _format_thresholds( thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, ) -> Optional[npt.NDArray[np.float_]]: @@ -279,7 +287,7 @@ def binary_precision_recall_curve( preds: npt.ArrayLike, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, pos_label: int = 1, -) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: +) -> PRCurve: """Compute precision-recall curve for binary input. Parameters @@ -301,13 +309,10 @@ def binary_precision_recall_curve( Returns ------- - precision : numpy.ndarray - Precision scores such that element i is the precision of predictions - with score >= thresholds[i]. - recall : numpy.ndarray - Recall scores in descending order. - thresholds : numpy.ndarray - Thresholds used for computing the precision and recall scores. + PRCurve + A named tuple containing the precision (element i is the precision of predictions + with score >= thresholds[i]), recall (scores in descending order) + and thresholds used to compute the precision-recall curve. Examples -------- @@ -335,13 +340,14 @@ def binary_precision_recall_curve( thresholds = _format_thresholds(thresholds) state = _binary_precision_recall_curve_update(target, preds, thresholds) - - return _binary_precision_recall_curve_compute( + precision_, recall_, thresholds_ = _binary_precision_recall_curve_compute( state, thresholds, pos_label=pos_label, ) + return PRCurve(precision_, recall_, thresholds_) + def _multiclass_precision_recall_curve_format( target: npt.ArrayLike, @@ -572,14 +578,7 @@ def multiclass_precision_recall_curve( preds: npt.ArrayLike, num_classes: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> PRCurve: """Compute the precision-recall curve for multiclass problems. Parameters @@ -600,18 +599,13 @@ def multiclass_precision_recall_curve( Returns ------- - precision : numpy.ndarray or list of numpy.ndarray - Precision scores where element i is the precision score corresponding - to the threshold i. If state is a tuple of the target and predicted - probabilities, then precision is a list of arrays, where each array - corresponds to the precision scores for a class. - recall : numpy.ndarray or list of numpy.ndarray - Recall scores where element i is the recall score corresponding to - the threshold i. If state is a tuple of the target and predicted - probabilities, then recall is a list of arrays, where each array - corresponds to the recall scores for a class. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used for computing the precision and recall scores. + PRcurve + A named tuple containing the precision, recall, and thresholds. + Precision and recall are arrays where element i is the precision and + recall score corresponding to threshold i. If state is a tuple of the + target and predicted probabilities, then precision and recall are lists + of arrays, where each array corresponds to the precision and recall + scores for a class. Examples -------- @@ -652,11 +646,12 @@ def multiclass_precision_recall_curve( thresholds=thresholds, ) - return _multiclass_precision_recall_curve_compute( + precision_, recall_, thresholds_ = _multiclass_precision_recall_curve_compute( state, thresholds, # type: ignore num_classes, ) + return PRCurve(precision_, recall_, thresholds_) def _multilabel_precision_recall_curve_format( @@ -868,14 +863,7 @@ def multilabel_precision_recall_curve( preds: npt.ArrayLike, num_labels: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> PRCurve: """Compute the precision-recall curve for multilabel input. Parameters @@ -897,16 +885,18 @@ def multilabel_precision_recall_curve( Returns ------- - precision : numpy.ndarray or List[numpy.ndarray] + PRCurve + A named tuple with the following: + - ``precision``: numpy.ndarray or List[numpy.ndarray]. Precision values for each label. If ``thresholds`` is None, then precision is a list of arrays, one for each label. Otherwise, precision is a single array with shape (``num_labels``, len(``thresholds``)). - recall : numpy.ndarray or List[numpy.ndarray] + - ``recall``: numpy.ndarray or List[numpy.ndarray]. Recall values for each label. If ``thresholds`` is None, then recall is a list of arrays, one for each label. Otherwise, recall is a single array with shape (``num_labels``, len(``thresholds``)). - thresholds : numpy.ndarray or List[numpy.ndarray] + - ``thresholds``: numpy.ndarray or List[numpy.ndarray]. If ``thresholds`` is None, then thresholds is a list of arrays, one for each label. Otherwise, thresholds is a single array with shape (len(``thresholds``,). @@ -950,11 +940,12 @@ def multilabel_precision_recall_curve( thresholds=thresholds, ) - return _multilabel_precision_recall_curve_compute( + precision_, recall_, thresholds_ = _multilabel_precision_recall_curve_compute( state, thresholds, # type: ignore num_labels, ) + return PRCurve(precision_, recall_, thresholds_) def precision_recall_curve( @@ -965,14 +956,7 @@ def precision_recall_curve( pos_label: int = 1, num_classes: Optional[int] = None, num_labels: Optional[int] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> PRCurve: """Compute the precision-recall curve for different tasks/input types. Parameters @@ -997,17 +981,19 @@ def precision_recall_curve( Returns ------- - precision : numpy.ndarray + PRCurve + A named tuple with the following: + - ``precision``: numpy.ndarray or List[numpy.ndarray]. The precision scores where ``precision[i]`` is the precision score for - ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel', + ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilabel', then ``precision`` is a list of numpy arrays, where ``precision[i]`` is the precision scores for class or label ``i``. - recall : numpy.ndarray + - ``recall``: numpy.ndarray or List[numpy.ndarray]. The recall scores where ``recall[i]`` is the recall score for ``scores >= thresholds[i]``. If ``task`` is 'multiclass' or 'multilaabel', then ``recall`` is a list of numpy arrays, where ``recall[i]`` is the recall scores for class or label ``i``. - thresholds : numpy.ndarray + - ``thresholds``: numpy.ndarray or List[numpy.ndarray]. Thresholds used for computing the precision and recall scores. Raises diff --git a/cyclops/evaluate/metrics/functional/roc.py b/cyclops/evaluate/metrics/functional/roc.py index 24eb1dd16..b8935ded5 100644 --- a/cyclops/evaluate/metrics/functional/roc.py +++ b/cyclops/evaluate/metrics/functional/roc.py @@ -1,6 +1,6 @@ """Functions for computing the receiver operating characteristic (ROC) curve.""" import logging -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, NamedTuple, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -23,6 +23,14 @@ setup_logging(print_level="WARN", logger=LOGGER) +class ROCCurve(NamedTuple): + """Named tuple to store ROC curve (FPR, TPR and thresholds).""" + + fpr: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + tpr: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + thresholds: Union[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]] + + def _roc_compute_from_confmat( confmat: npt.NDArray[Any], thresholds: npt.NDArray[np.float_], @@ -144,7 +152,7 @@ def binary_roc_curve( preds: npt.ArrayLike, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, pos_label: int = 1, -) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: +) -> ROCCurve: """Compute the ROC curve for binary classification tasks. Parameters @@ -166,12 +174,9 @@ def binary_roc_curve( Returns ------- - fpr : numpy.ndarray - False positive rate. - tpr : numpy.ndarray - True positive rate. - thresholds : numpy.ndarray - Thresholds used to compute fpr and tpr. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. Examples -------- @@ -197,8 +202,9 @@ def binary_roc_curve( thresholds = _format_thresholds(thresholds) state = _binary_precision_recall_curve_update(target, preds, thresholds) + fpr, tpr, thresholds = _binary_roc_compute(state, thresholds, pos_label) - return _binary_roc_compute(state, thresholds=thresholds, pos_label=pos_label) + return ROCCurve(fpr, tpr, thresholds) def _multiclass_roc_compute( @@ -272,14 +278,7 @@ def multiclass_roc_curve( preds: npt.ArrayLike, num_classes: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> ROCCurve: """Compute the ROC curve for multiclass classification tasks. Parameters @@ -301,17 +300,11 @@ def multiclass_roc_curve( Returns ------- - fpr : numpy.ndarray or list of numpy.ndarray - False positive rate. If ``threshold`` is not None, ``fpr`` is a 1d numpy - array. Otherwise, ``fpr`` is a list of 1d numpy arrays, one for each - class. - tpr : numpy.ndarray or list of numpy.ndarray - True positive rate. If ``threshold`` is not None, ``tpr`` is a 1d numpy - array. Otherwise, ``tpr`` is a list of 1d numpy arrays, one for each class. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used to compute fpr and tpr. ``threshold`` is not None, - thresholds is a 1d numpy array. Otherwise, thresholds is a list of - 1d numpy arrays, one for each class. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. If ``threshold`` is not None, + ``fpr``, ``tpr`` and ``thresholds`` are 1d numpy arrays, else they are lists + of 1d numpy arrays, one for each label. Examples -------- @@ -352,8 +345,9 @@ def multiclass_roc_curve( num_classes=num_classes, thresholds=thresholds, ) + fpr_, tpr_, thresholds_ = _multiclass_roc_compute(state, num_classes, thresholds) - return _multiclass_roc_compute(state, num_classes, thresholds) + return ROCCurve(fpr=fpr_, tpr=tpr_, thresholds=thresholds_) def _multilabel_roc_compute( @@ -427,14 +421,7 @@ def multilabel_roc_curve( preds: npt.ArrayLike, num_labels: int, thresholds: Optional[Union[int, List[float], npt.NDArray[np.float_]]] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> ROCCurve: """Compute the ROC curve for multilabel classification tasks. Parameters @@ -456,17 +443,11 @@ def multilabel_roc_curve( Returns ------- - fpr : numpy.ndarray or list of numpy.ndarray - False positive rate. If ``threshold`` is not None, ``fpr`` is a 1d numpy - array. Otherwise, ``fpr`` is a list of 1d numpy arrays, one for each - label. - tpr : numpy.ndarray or list of numpy.ndarray - True positive rate. If ``threshold`` is not None, ``tpr`` is a 1d numpy - array. Otherwise, ``tpr`` is a list of 1d numpy arrays, one for each label. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used to compute fpr and tpr. ``threshold`` is not None, - thresholds is a 1d numpy array. Otherwise, thresholds is a list of - 1d numpy arrays, one for each label. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. If ``threshold`` is not None, + ``fpr``, ``tpr`` and ``thresholds`` are 1d numpy arrays, else they are lists + of 1d numpy arrays, one for each label. Examples -------- @@ -502,8 +483,9 @@ def multilabel_roc_curve( num_labels=num_labels, thresholds=thresholds, ) + fpr_, tpr_, thresholds_ = _multilabel_roc_compute(state, num_labels, thresholds) - return _multilabel_roc_compute(state, num_labels, thresholds) + return ROCCurve(fpr=fpr_, tpr=tpr_, thresholds=thresholds_) def roc_curve( @@ -514,14 +496,7 @@ def roc_curve( pos_label: int = 1, num_classes: Optional[int] = None, num_labels: Optional[int] = None, -) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], -]: +) -> ROCCurve: """Compute the ROC curve for different tasks/input types. Parameters @@ -558,22 +533,11 @@ def roc_curve( Returns ------- - fpr : numpy.ndarray or list of numpy.ndarray - False positive rate. If ``task`` is 'binary' or ``threshold`` is not None, - ``fpr`` is a 1d numpy array. If ``task`` is 'multiclass' or 'multilabel', - and ``threshold`` is None, then ``fpr`` is a list of 1d numpy - arrays, one for each class or label. - tpr : numpy.ndarray or list of numpy.ndarray - True positive rate. If ``task`` is 'binary' or ``threshold`` is not None, - ``tpr`` is a 1d numpy array. If ``task`` is 'multiclass' or 'multilabel', - and ``threshold`` is None, then ``tpr`` is a list of 1d numpy - arrays, one for each class or label. - thresholds : numpy.ndarray or list of numpy.ndarray - Thresholds used to compute fpr and tpr. If ``task`` is 'binary' or - ``threshold`` is not None, ``thresholds`` is a 1d numpy array. If - ``task`` is 'multiclass' or 'multilabel', and ``threshold`` is None, - then ``thresholds`` is a list of 1d numpy arrays, one for each class - or label. + ROCCurve + A named tuple containing the false positive rate, true positive rate, + and thresholds used to compute the ROC curve. If ``threshold`` is not None, + ``fpr``, ``tpr`` and ``thresholds`` are 1d numpy arrays, else they are lists + of 1d numpy arrays, one for each label. Raises ------ diff --git a/cyclops/evaluate/metrics/precision_recall_curve.py b/cyclops/evaluate/metrics/precision_recall_curve.py index 64bf08833..9a5ce76b5 100644 --- a/cyclops/evaluate/metrics/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/precision_recall_curve.py @@ -1,11 +1,12 @@ """Classes for computing precision-recall curves.""" -from typing import Any, List, Literal, Optional, Tuple, Type, Union +from typing import Any, List, Literal, Optional, Type, Union import numpy as np import numpy.typing as npt from cyclops.evaluate.metrics.functional.precision_recall_curve import ( # type: ignore # noqa: E501 + PRCurve, _binary_precision_recall_curve_compute, _binary_precision_recall_curve_format, _binary_precision_recall_curve_update, @@ -42,14 +43,14 @@ class BinaryPrecisionRecallCurve(Metric, registry_key="binary_precision_recall_c >>> preds = [0.1, 0.4, 0.35, 0.8] >>> metric = BinaryPrecisionRecallCurve(thresholds=3) >>> metric(target, preds) - (array([0.5, 1. , 0. ]), array([1. , 0.5, 0. ]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([0.5, 1. , 0. ]), recall=array([1. , 0.5, 0. ]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[0, 1, 0, 1], [1, 1, 0, 0]] >>> preds = [[0.1, 0.4, 0.35, 0.8], [0.6, 0.3, 0.1, 0.7]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0.5 , 0.66666667, 0. ]), array([1. , 0.5, 0. ]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([0.5 , 0.66666667, 0. ]), recall=array([1. , 0.5, 0. ]), thresholds=array([0. , 0.5, 1. ])) """ @@ -101,7 +102,7 @@ def update_state(self, target: npt.ArrayLike, preds: npt.ArrayLike) -> None: def compute( self, - ) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: + ) -> PRCurve: """Compute the precision-recall curve from the state.""" if self.thresholds is None: state = ( @@ -111,11 +112,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _binary_precision_recall_curve_compute( + precision, recall, thresholds = _binary_precision_recall_curve_compute( state=state, thresholds=self.thresholds, pos_label=self.pos_label, ) + return PRCurve(precision, recall, thresholds) def __setattr__(self, name: str, value: Any) -> None: """Set the attribute ``name`` to ``value``. @@ -181,11 +183,11 @@ class MulticlassPrecisionRecallCurve( >>> preds = [[0.1, 0.6, 0.3], [0.05, 0.95, 0.0], [0.5, 0.3, 0.2], [0.2, 0.5, 0.3]] >>> metric = MulticlassPrecisionRecallCurve(num_classes=3, thresholds=3) >>> metric(target, preds) - (array([[0.5 , 0. , 0. , 1. ], + PRCurve(precision=array([[0.5 , 0. , 0. , 1. ], [0.25 , 0.33333333, 0. , 1. ], - [0.25 , 0. , 0. , 1. ]]), array([[1., 0., 0., 0.], + [0.25 , 0. , 0. , 1. ]]), recall=array([[1., 0., 0., 0.], [1., 1., 0., 0.], - [1., 0., 0., 0.]]), array([0. , 0.5, 1. ])) + [1., 0., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[0, 1, 2, 0], [1, 2, 0, 1]] >>> preds = [ @@ -195,11 +197,11 @@ class MulticlassPrecisionRecallCurve( >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.375, 0.5 , 0. , 1. ], + PRCurve(precision=array([[0.375, 0.5 , 0. , 1. ], [0.375, 0.4 , 0. , 1. ], - [0.25 , 0. , 0. , 1. ]]), array([[1. , 0.33333333, 0. , 0. ], + [0.25 , 0. , 0. , 1. ]]), recall=array([[1. , 0.33333333, 0. , 0. ], [1. , 0.66666667, 0. , 0. ], - [1. , 0. , 0. , 0. ]]), array([0. , 0.5, 1. ])) + [1. , 0. , 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) """ @@ -253,14 +255,7 @@ def update_state(self, target: npt.ArrayLike, preds: npt.ArrayLike) -> None: def compute( self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> PRCurve: """Compute the precision-recall curve from the state.""" if self.thresholds is None: state = ( @@ -270,11 +265,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _multiclass_precision_recall_curve_compute( + precision, recall, thresholds = _multiclass_precision_recall_curve_compute( state=state, thresholds=self.thresholds, # type: ignore[arg-type] num_classes=self.num_classes, ) + return PRCurve(precision, recall, thresholds) def __setattr__(self, name: str, value: Any) -> None: """Set the attribute ``name`` to ``value``. @@ -340,18 +336,18 @@ class MultilabelPrecisionRecallCurve( >>> preds = [[0.1, 0.9], [0.8, 0.2]] >>> metric = MultilabelPrecisionRecallCurve(num_labels=2, thresholds=3) >>> metric(target, preds) - (array([[0.5, 1. , 0. , 1. ], - [0.5, 1. , 0. , 1. ]]), array([[1., 1., 0., 0.], - [1., 1., 0., 0.]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 1. , 0. , 1. ], + [0.5, 1. , 0. , 1. ]]), recall=array([[1., 1., 0., 0.], + [1., 1., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[[0, 1], [1, 0]], [[1, 0], [0, 1]]] >>> preds = [[[0.1, 0.9], [0.8, 0.2]], [[0.2, 0.8], [0.7, 0.3]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.5, 0.5, 0. , 1. ], - [0.5, 0.5, 0. , 1. ]]), array([[1. , 0.5, 0. , 0. ], - [1. , 0.5, 0. , 0. ]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 0.5, 0. , 1. ], + [0.5, 0.5, 0. , 1. ]]), recall=array([[1. , 0.5, 0. , 0. ], + [1. , 0.5, 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) """ @@ -405,14 +401,7 @@ def update_state(self, target: npt.ArrayLike, preds: npt.ArrayLike) -> None: def compute( self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> PRCurve: """Compute the precision-recall curve from the state.""" if self.thresholds is None: state = ( @@ -422,11 +411,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _multilabel_precision_recall_curve_compute( + precision, recall, thresholds = _multilabel_precision_recall_curve_compute( state, thresholds=self.thresholds, # type: ignore[arg-type] num_labels=self.num_labels, ) + return PRCurve(precision, recall, thresholds) def __setattr__(self, name: str, value: Any) -> None: """Set the attribute ``name`` to ``value``. @@ -502,15 +492,15 @@ class PrecisionRecallCurve( >>> preds = [0.6, 0.2, 0.3, 0.8] >>> metric = PrecisionRecallCurve(task="binary", thresholds=None) >>> metric(target, preds) - (array([0.75 , 0.66666667, 0.5 , 0. , 1. ]), array([1. , 0.66666667, 0.33333333, 0. , 0. ]), array([0.2, 0.3, 0.6, 0.8])) + PRCurve(precision=array([0.75 , 0.66666667, 0.5 , 0. , 1. ]), recall=array([1. , 0.66666667, 0.33333333, 0. , 0. ]), thresholds=array([0.2, 0.3, 0.6, 0.8])) >>> metric.reset_state() >>> target = [[1, 0, 1, 1], [0, 0, 0, 1]] >>> preds = [[0.5, 0.4, 0.1, 0.3], [0.9, 0.6, 0.45, 0.8]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0.5 , 0.42857143, 0.33333333, 0.4 , 0.5 , - 0.33333333, 0.5 , 0. , 1. ]), array([1. , 0.75, 0.5 , 0.5 , 0.5 , 0.25, 0.25, 0. , 0. ]), array([0.1 , 0.3 , 0.4 , 0.45, 0.5 , 0.6 , 0.8 , 0.9 ])) + PRCurve(precision=array([0.5 , 0.42857143, 0.33333333, 0.4 , 0.5 , + 0.33333333, 0.5 , 0. , 1. ]), recall=array([1. , 0.75, 0.5 , 0.5 , 0.5 , 0.25, 0.25, 0. , 0. ]), thresholds=array([0.1 , 0.3 , 0.4 , 0.45, 0.5 , 0.6 , 0.8 , 0.9 ])) >>> # (multiclass) >>> from cyclops.evaluate.metrics import PrecisionRecallCurve @@ -518,11 +508,11 @@ class PrecisionRecallCurve( >>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6], [0.2, 0.2, 0.6]] >>> metric = PrecisionRecallCurve(task="multiclass", num_classes=3, thresholds=3) >>> metric(target, preds) - (array([[0.25, 0. , 0. , 1. ], + PRCurve(precision=array([[0.25, 0. , 0. , 1. ], [0.25, 0.5 , 0. , 1. ], - [0.5 , 1. , 0. , 1. ]]), array([[1., 0., 0., 0.], + [0.5 , 1. , 0. , 1. ]]), recall=array([[1., 0., 0., 0.], [1., 1., 0., 0.], - [1., 1., 0., 0.]]), array([0. , 0.5, 1. ])) + [1., 1., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[0, 1, 2, 2], [1, 2, 0, 1]] >>> preds = [ @@ -532,11 +522,11 @@ class PrecisionRecallCurve( >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.25 , 0. , 0. , 1. ], + PRCurve(precision=array([[0.25 , 0. , 0. , 1. ], [0.375, 0.5 , 0. , 1. ], - [0.375, 0.5 , 0. , 1. ]]), array([[1. , 0. , 0. , 0. ], + [0.375, 0.5 , 0. , 1. ]]), recall=array([[1. , 0. , 0. , 0. ], [1. , 0.66666667, 0. , 0. ], - [1. , 0.66666667, 0. , 0. ]]), array([0. , 0.5, 1. ])) + [1. , 0.66666667, 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) >>> # (multilabel) >>> from cyclops.evaluate.metrics import PrecisionRecallCurve @@ -544,18 +534,18 @@ class PrecisionRecallCurve( >>> preds = [[0.1, 0.9], [0.8, 0.2]] >>> metric = PrecisionRecallCurve(task="multilabel", num_labels=2, thresholds=3) >>> metric(target, preds) - (array([[0.5, 1. , 0. , 1. ], - [0.5, 1. , 0. , 1. ]]), array([[1., 1., 0., 0.], - [1., 1., 0., 0.]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 1. , 0. , 1. ], + [0.5, 1. , 0. , 1. ]]), recall=array([[1., 1., 0., 0.], + [1., 1., 0., 0.]]), thresholds=array([0. , 0.5, 1. ])) >>> metric.reset_state() >>> target = [[[0, 1], [1, 0]], [[1, 0], [0, 1]]] >>> preds = [[[0.1, 0.9], [0.8, 0.2]], [[0.1, 0.9], [0.8, 0.2]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0.5, 0.5, 0. , 1. ], - [0.5, 0.5, 0. , 1. ]]), array([[1. , 0.5, 0. , 0. ], - [1. , 0.5, 0. , 0. ]]), array([0. , 0.5, 1. ])) + PRCurve(precision=array([[0.5, 0.5, 0. , 1. ], + [0.5, 0.5, 0. , 1. ]]), recall=array([[1. , 0.5, 0. , 0. ], + [1. , 0.5, 0. , 0. ]]), thresholds=array([0. , 0.5, 1. ])) """ diff --git a/cyclops/evaluate/metrics/roc.py b/cyclops/evaluate/metrics/roc.py index 1b2774c8b..58587856c 100644 --- a/cyclops/evaluate/metrics/roc.py +++ b/cyclops/evaluate/metrics/roc.py @@ -1,10 +1,13 @@ """Classes for computing ROC metrics.""" -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Union import numpy as np import numpy.typing as npt +from cyclops.evaluate.metrics.functional.roc import ( + ROCCurve as ROCCurveData, +) from cyclops.evaluate.metrics.functional.roc import ( _binary_roc_compute, _multiclass_roc_compute, @@ -39,22 +42,22 @@ class BinaryROCCurve(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve" >>> preds = [0.1, 0.4, 0.35, 0.8] >>> metric = BinaryROCCurve() >>> metric(target, preds) - (array([0. , 0. , 0.5, 0.5, 1. ]), array([0. , 0.5, 0.5, 1. , 1. ]), array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) + ROCCurve(fpr=array([0. , 0. , 0.5, 0.5, 1. ]), tpr=array([0. , 0.5, 0.5, 1. , 1. ]), thresholds=array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) >>> metric.reset_state() >>> target = [[1, 1, 0, 0], [0, 0, 1, 1]] >>> preds = [[0.1, 0.2, 0.3, 0.4], [0.6, 0.5, 0.4, 0.3]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) + ROCCurve(fpr=array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), tpr=array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), thresholds=array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) """ # noqa: W505 name: str = "ROC Curve" - def compute( + def compute( # type: ignore self, - ) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]: + ) -> ROCCurveData: """Compute the ROC curve from the state variables.""" if self.thresholds is None: state = ( @@ -63,13 +66,12 @@ def compute( ) else: state = self.confmat # type: ignore[attr-defined] - - return _binary_roc_compute( - state, - thresholds=self.thresholds, - pos_label=self.pos_label, + fpr_, tpr_, thresholds_ = _binary_roc_compute( + state, thresholds=self.thresholds, pos_label=self.pos_label ) + return ROCCurveData(fpr_, tpr_, thresholds_) + class MulticlassROCCurve( MulticlassPrecisionRecallCurve, @@ -101,11 +103,11 @@ class MulticlassROCCurve( >>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6], [0.9, 0.1, 0]] >>> metric = MulticlassROCCurve(num_classes=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0.33333333, 0.33333333, 1. ], - [0. , 0. , 0. , 1. ]]), array([[0. , 0.5, 0.5, 1. ], + [0. , 0. , 0. , 1. ]]), tpr=array([[0. , 0.5, 0.5, 1. ], [0. , 1. , 1. , 1. ], - [0. , 0. , 1. , 1. ]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0. , 0. , 1. , 1. ]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [[1, 1, 0, 0], [0, 0, 1, 1]] >>> preds = [ @@ -115,26 +117,19 @@ class MulticlassROCCurve( >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0. , 0.25, 0.5 , 1. ], + ROCCurve(fpr=array([[0. , 0.25, 0.5 , 1. ], [0. , 0. , 0.25, 1. ], - [0. , 0.25, 0.5 , 1. ]]), array([[0. , 0.25, 0.5 , 1. ], + [0. , 0.25, 0.5 , 1. ]]), tpr=array([[0. , 0.25, 0.5 , 1. ], [0. , 0. , 0.25, 1. ], - [0. , 0. , 0. , 0. ]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0. , 0. , 0. , 0. ]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) """ # noqa: W505 name: str = "ROC Curve" - def compute( + def compute( # type: ignore self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> ROCCurveData: """Compute the ROC curve from the state variables.""" if self.thresholds is None: state = ( @@ -143,13 +138,12 @@ def compute( ) else: state = self.confmat # type: ignore[attr-defined] - - return _multiclass_roc_compute( - state=state, - num_classes=self.num_classes, - thresholds=self.thresholds, + fpr_, tpr_, thresholds_ = _multiclass_roc_compute( + state, thresholds=self.thresholds, num_classes=self.num_classes ) + return ROCCurveData(fpr_, tpr_, thresholds_) + class MultilabelROCCurve( MultilabelPrecisionRecallCurve, @@ -175,37 +169,30 @@ class MultilabelROCCurve( >>> preds = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> metric = MultilabelROCCurve(num_labels=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [[[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]] >>> preds = [[[0.1, 0.9, 0.8], [0.05, 0.95, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) - """ + """ # noqa: W505 name: str = "ROC Curve" - def compute( + def compute( # type: ignore self, - ) -> Union[ - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]], - Tuple[ - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - List[npt.NDArray[np.float_]], - ], - ]: + ) -> ROCCurveData: """Compute the ROC curve from the state variables.""" if self.thresholds is None: state = ( @@ -215,11 +202,12 @@ def compute( else: state = self.confmat # type: ignore[attr-defined] - return _multilabel_roc_compute( + fpr_, tpr_, thresholds_ = _multilabel_roc_compute( state=state, num_labels=self.num_labels, thresholds=self.thresholds, ) + return ROCCurveData(fpr_, tpr_, thresholds_) class ROCCurve(Metric, registry_key="roc_curve", force_register=True): @@ -258,14 +246,14 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True): >>> preds = [0.1, 0.4, 0.35, 0.8] >>> metric = ROCCurve(task="binary", thresholds=None) >>> metric(target, preds) - (array([0. , 0. , 0.5, 0.5, 1. ]), array([0. , 0.5, 0.5, 1. , 1. ]), array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) + ROCCurve(fpr=array([0. , 0. , 0.5, 0.5, 1. ]), tpr=array([0. , 0.5, 0.5, 1. , 1. ]), thresholds=array([1. , 0.8 , 0.4 , 0.35, 0.1 ])) >>> metric.reset_state() >>> target = [[1, 1, 0, 0], [0, 0, 1, 1]] >>> preds = [[0.1, 0.2, 0.3, 0.4], [0.6, 0.5, 0.4, 0.3]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) + ROCCurve(fpr=array([0. , 0.25, 0.5 , 0.75, 1. , 1. , 1. ]), tpr=array([0. , 0. , 0. , 0.25, 0.5 , 0.75, 1. ]), thresholds=array([1. , 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])) >>> # (multiclass) >>> from cyclops.evaluate.metrics import ROCCurve @@ -273,22 +261,22 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True): >>> preds = [[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]] >>> metric = ROCCurve(task="multiclass", num_classes=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0.5, 0.5, 1. ], - [0. , 0. , 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0. , 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 1.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 1.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [1, 2] >>> preds = [[[0.05, 0.75, 0.2]], [[0.1, 0.8, 0.1]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0., 0., 0., 1.], + ROCCurve(fpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 1.]]), array([[0., 0., 0., 0.], + [0., 0., 0., 1.]]), tpr=array([[0., 0., 0., 0.], [0., 1., 1., 1.], - [0., 0., 0., 1.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 1.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> # (multilabel) >>> from cyclops.evaluate.metrics import ROCCurve @@ -296,22 +284,22 @@ class ROCCurve(Metric, registry_key="roc_curve", force_register=True): >>> preds = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] >>> metric = ROCCurve(task="multilabel", num_labels=3, thresholds=4) >>> metric(target, preds) - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) >>> metric.reset_state() >>> target = [[[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]] >>> preds = [[[0.1, 0.9, 0.8], [0.05, 0.95, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]] >>> for t, p in zip(target, preds): ... metric.update_state(t, p) >>> metric.compute() - (array([[0. , 0. , 0. , 1. ], + ROCCurve(fpr=array([[0. , 0. , 0. , 1. ], [0. , 0. , 0. , 0. ], - [0. , 0.5, 0.5, 1. ]]), array([[0., 0., 0., 1.], + [0. , 0.5, 0.5, 1. ]]), tpr=array([[0., 0., 0., 1.], [0., 1., 1., 1.], - [0., 0., 0., 0.]]), array([1. , 0.66666667, 0.33333333, 0. ])) + [0., 0., 0., 0.]]), thresholds=array([1. , 0.66666667, 0.33333333, 0. ])) """ # noqa: W505 diff --git a/cyclops/evaluate/metrics/utils.py b/cyclops/evaluate/metrics/utils.py index 3bff5b85a..86f7fc9c3 100644 --- a/cyclops/evaluate/metrics/utils.py +++ b/cyclops/evaluate/metrics/utils.py @@ -1,6 +1,16 @@ """Utility functions for metrics.""" -from typing import Any, Callable, List, Literal, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import numpy.typing as npt @@ -295,10 +305,15 @@ def _apply_function_recursively( """ data_type = type(data) - if isinstance(data, (list, tuple, set)): - return data_type( - [_apply_function_recursively(el, func, *args, **kwargs) for el in data], - ) + is_namedtuple_ = ( + isinstance(data, tuple) + and hasattr(data, "_asdict") + and hasattr(data, "_fields") + ) + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + if is_namedtuple_ or is_sequence: + out = [_apply_function_recursively(el, func, *args, **kwargs) for el in data] + return data_type(*out) if is_namedtuple_ else data_type(out) if isinstance(data, Mapping): return data_type( { diff --git a/cyclops/report/plot/classification.py b/cyclops/report/plot/classification.py index 69aa612d8..255dee72a 100644 --- a/cyclops/report/plot/classification.py +++ b/cyclops/report/plot/classification.py @@ -8,7 +8,9 @@ import plotly.graph_objs as go from plotly.subplots import make_subplots -from cyclops.evaluate.metrics.experimental.functional import PRCurve, ROCCurve +from cyclops.evaluate.metrics.experimental.functional import PRCurve as PRCurveExp +from cyclops.evaluate.metrics.experimental.functional import ROCCurve as ROCCurveExp +from cyclops.evaluate.metrics.functional import PRCurve, ROCCurve from cyclops.report.plot.base import Plotter from cyclops.report.plot.utils import ( bar_plot, @@ -93,7 +95,7 @@ def _set_class_names(self, class_names: List[str]) -> None: def roc_curve( self, - roc_curve: ROCCurve, + roc_curve: Union[ROCCurve, ROCCurveExp], auroc: Optional[Union[float, List[float], npt.NDArray[np.float_]]] = None, title: Optional[str] = "ROC Curve", layout: Optional[go.Layout] = None, @@ -188,7 +190,7 @@ def roc_curve( def roc_curve_comparison( self, - roc_curves: Dict[str, ROCCurve], + roc_curves: Dict[str, Union[ROCCurve, ROCCurveExp]], aurocs: Optional[ Dict[str, Union[float, List[float], npt.NDArray[np.float_]]] ] = None, @@ -293,7 +295,7 @@ def roc_curve_comparison( def precision_recall_curve( self, - precision_recall_curve: PRCurve, + precision_recall_curve: Union[PRCurve, PRCurveExp], title: Optional[str] = "Precision-Recall Curve", layout: Optional[go.Layout] = None, **plot_kwargs: Any, @@ -357,7 +359,7 @@ def precision_recall_curve( def precision_recall_curve_comparison( self, - precision_recall_curves: Dict[str, PRCurve], + precision_recall_curves: Dict[str, Union[PRCurve, PRCurveExp]], auprcs: Optional[ Dict[str, Union[float, List[float], npt.NDArray[np.float_]]] ] = None,