From 3160750e6d2b9507580a48154504b3aa668b7616 Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Mon, 29 Jan 2024 19:11:51 -0500 Subject: [PATCH] Integrate experimental metrics with other modules (#549) * integrate experimental metrics with other modules * add average precision metric to experimental metrics package * fix tutorials * Add type hints and keyword arguments to metrics classes * Update nbsphinx version to 0.9.3 * Update nbconvert version to 7.14.2 * Fix type annotations and formatting issues * Update kernel display name in mortality_prediction.ipynb * Add guard clause to prevent module execution on import * Update `torch_distributed.py` with type hints * Add multiclass and multilabel average precision metrics * Change jupyter kernel * Fix type annotations for metric values in ClassificationPlotter --------- Co-authored-by: Amrit K --- cyclops/evaluate/evaluator.py | 33 +- cyclops/evaluate/fairness/config.py | 5 +- cyclops/evaluate/fairness/evaluator.py | 89 ++- .../evaluate/metrics/experimental/__init__.py | 5 + .../evaluate/metrics/experimental/auroc.py | 25 +- .../metrics/experimental/average_precision.py | 272 +++++++ .../distributed_backends/torch_distributed.py | 10 +- .../evaluate/metrics/experimental/f_score.py | 12 +- .../experimental/functional/__init__.py | 5 + .../metrics/experimental/functional/auroc.py | 6 +- .../functional/average_precision.py | 677 ++++++++++++++++++ cyclops/evaluate/metrics/experimental/mae.py | 11 +- cyclops/evaluate/metrics/experimental/mape.py | 8 +- cyclops/evaluate/metrics/experimental/mse.py | 13 +- .../experimental/negative_predictive_value.py | 6 +- .../metrics/experimental/precision_recall.py | 30 +- .../experimental/precision_recall_curve.py | 17 +- cyclops/evaluate/metrics/experimental/roc.py | 18 +- .../evaluate/metrics/experimental/smape.py | 8 +- .../metrics/experimental/specificity.py | 12 +- .../evaluate/metrics/experimental/wmape.py | 7 +- cyclops/evaluate/metrics/factory.py | 27 +- cyclops/report/plot/classification.py | 34 +- cyclops/tasks/classification.py | 38 +- .../kaggle/heart_failure_prediction.ipynb | 66 +- .../mimiciv/mortality_prediction.ipynb | 55 +- .../tutorials/nihcxr/cxr_classification.ipynb | 90 +-- .../nihcxr/generate_nihcxr_report.py | 96 +-- .../tutorials/synthea/los_prediction.ipynb | 62 +- poetry.lock | 35 +- pyproject.toml | 3 +- .../experimental/test_average_precision.py | 503 +++++++++++++ .../experimental/test_precision_recall.py | 2 + 33 files changed, 1894 insertions(+), 386 deletions(-) create mode 100644 cyclops/evaluate/metrics/experimental/average_precision.py create mode 100644 cyclops/evaluate/metrics/experimental/functional/average_precision.py create mode 100644 tests/cyclops/evaluate/metrics/experimental/test_average_precision.py diff --git a/cyclops/evaluate/evaluator.py b/cyclops/evaluate/evaluator.py index e7af9a2c0..3763cf37e 100644 --- a/cyclops/evaluate/evaluator.py +++ b/cyclops/evaluate/evaluator.py @@ -1,5 +1,4 @@ """Evaluate one or more models on a dataset.""" - import logging import warnings from dataclasses import asdict @@ -16,7 +15,9 @@ ) from cyclops.evaluate.fairness.config import FairnessConfig from cyclops.evaluate.fairness.evaluator import evaluate_fairness -from cyclops.evaluate.metrics.metric import Metric, MetricCollection +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict +from cyclops.evaluate.metrics.experimental.utils.types import Array from cyclops.evaluate.utils import _format_column_names, choose_split from cyclops.utils.log import setup_logging @@ -27,7 +28,7 @@ def evaluate( dataset: Union[str, Dataset, DatasetDict], - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection], + metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict], target_columns: Union[str, List[str]], prediction_columns: Union[str, List[str]], ignore_columns: Optional[Union[str, List[str]]] = None, @@ -47,7 +48,7 @@ def evaluate( The dataset to evaluate on. If a string, the dataset will be loaded using `datasets.load_dataset`. If `DatasetDict`, the `split` argument must be specified. - metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection] + metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict] The metrics to compute. target_columns : Union[str, List[str]] The name of the column(s) containing the target values. A string value @@ -202,28 +203,28 @@ def _load_data( def _prepare_metrics( - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection], -) -> MetricCollection: + metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict], +) -> MetricDict: """Prepare metrics for evaluation.""" - # TODO: wrap in BootstrappedMetric if computing confidence intervals + # TODO [fcogidi]: wrap in BootstrappedMetric if computing confidence intervals if isinstance(metrics, (Metric, Sequence, Dict)) and not isinstance( metrics, - MetricCollection, + MetricDict, ): - return MetricCollection(metrics) - if isinstance(metrics, MetricCollection): + return MetricDict(metrics) # type: ignore[arg-type] + if isinstance(metrics, MetricDict): return metrics raise TypeError( f"Invalid type for `metrics`: {type(metrics)}. " "Expected one of: Metric, Sequence[Metric], Dict[str, Metric], " - "MetricCollection.", + "MetricDict.", ) def _compute_metrics( dataset: Dataset, - metrics: MetricCollection, + metrics: MetricDict, slice_spec: SliceSpec, target_columns: Union[str, List[str]], prediction_columns: Union[str, List[str]], @@ -266,8 +267,8 @@ def _compute_metrics( RuntimeWarning, stacklevel=1, ) - metric_output = { - metric_name: float("NaN") for metric_name in metrics + metric_output: Dict[str, Array] = { + metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined,misc] } elif ( batch_size is None or batch_size < 0 @@ -293,10 +294,10 @@ def _compute_metrics( ) # update the metric state - metrics.update_state(targets, predictions) + metrics.update(targets, predictions) metric_output = metrics.compute() - metrics.reset_state() + metrics.reset() model_name: str = "model_for_%s" % prediction_column results.setdefault(model_name, {}) diff --git a/cyclops/evaluate/fairness/config.py b/cyclops/evaluate/fairness/config.py index 3f220f4b4..f6e2aaebe 100644 --- a/cyclops/evaluate/fairness/config.py +++ b/cyclops/evaluate/fairness/config.py @@ -5,14 +5,15 @@ from datasets import Dataset, config -from cyclops.evaluate.metrics.metric import Metric, MetricCollection +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict @dataclass class FairnessConfig: """Configuration for fairness metrics.""" - metrics: Union[str, Callable[..., Any], Metric, MetricCollection] + metrics: Union[str, Callable[..., Any], Metric, MetricDict] dataset: Dataset groups: Union[str, List[str]] target_columns: Union[str, List[str]] diff --git a/cyclops/evaluate/fairness/evaluator.py b/cyclops/evaluate/fairness/evaluator.py index 1296f0e56..cd44a07f1 100644 --- a/cyclops/evaluate/fairness/evaluator.py +++ b/cyclops/evaluate/fairness/evaluator.py @@ -1,5 +1,4 @@ """Fairness evaluator.""" - import inspect import itertools import logging @@ -7,8 +6,8 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union +import array_api_compat.numpy import numpy as np -import numpy.typing as npt import pandas as pd from datasets import Dataset, config from datasets.features import Features @@ -21,15 +20,14 @@ get_columns_as_numpy_array, set_decode, ) -from cyclops.evaluate.metrics.factory import create_metric -from cyclops.evaluate.metrics.functional.precision_recall_curve import ( +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( _format_thresholds, + _validate_thresholds, ) -from cyclops.evaluate.metrics.metric import Metric, MetricCollection, OperatorMetric -from cyclops.evaluate.metrics.utils import ( - _check_thresholds, - _get_value_if_singleton_array, -) +from cyclops.evaluate.metrics.experimental.metric import Metric, OperatorMetric +from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict +from cyclops.evaluate.metrics.experimental.utils.types import Array +from cyclops.evaluate.metrics.factory import create_metric from cyclops.evaluate.utils import _format_column_names from cyclops.utils.log import setup_logging @@ -39,7 +37,7 @@ def evaluate_fairness( - metrics: Union[str, Callable[..., Any], Metric, MetricCollection], + metrics: Union[str, Callable[..., Any], Metric, MetricDict], dataset: Dataset, groups: Union[str, List[str]], target_columns: Union[str, List[str]], @@ -62,7 +60,7 @@ def evaluate_fairness( Parameters ---------- - metrics : Union[str, Callable[..., Any], Metric, MetricCollection] + metrics : Union[str, Callable[..., Any], Metric, MetricDict] The metric or metrics to compute. If a string, it should be the name of a metric provided by CyclOps. If a callable, it should be a function that takes target, prediction, and optionally threshold/thresholds as arguments @@ -147,18 +145,14 @@ def evaluate_fairness( raise TypeError( "Expected `dataset` to be of type `Dataset`, but got " f"{type(dataset)}.", ) + _validate_thresholds(thresholds) - _check_thresholds(thresholds) - fmt_thresholds: npt.NDArray[np.float_] = _format_thresholds( # type: ignore - thresholds, - ) - - metrics_: Union[Callable[..., Any], MetricCollection] = _format_metrics( + metrics_: Union[Callable[..., Any], MetricDict] = _format_metrics( metrics, metric_name, **(metric_kwargs or {}), ) - + fmt_thresholds = _format_thresholds(thresholds, xp=array_api_compat.numpy) fmt_groups: List[str] = _format_column_names(groups) fmt_target_columns: List[str] = _format_column_names(target_columns) fmt_prediction_columns: List[str] = _format_column_names(prediction_columns) @@ -361,15 +355,15 @@ def warn_too_many_unique_values( def _format_metrics( - metrics: Union[str, Callable[..., Any], Metric, MetricCollection], + metrics: Union[str, Callable[..., Any], Metric, MetricDict], metric_name: Optional[str] = None, **metric_kwargs: Any, -) -> Union[Callable[..., Any], Metric, MetricCollection]: +) -> Union[Callable[..., Any], Metric, MetricDict]: """Format the metrics argument. Parameters ---------- - metrics : Union[str, Callable[..., Any], Metric, MetricCollection] + metrics : Union[str, Callable[..., Any], Metric, MetricDict] The metrics to use for computing the metric results. metric_name : str, optional, default=None The name of the metric. This is only used if `metrics` is a callable. @@ -379,23 +373,23 @@ def _format_metrics( Returns ------- - Union[Callable[..., Any], Metric, MetricCollection] + Union[Callable[..., Any], Metric, MetricDict] The formatted metrics. Raises ------ TypeError - If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricCollection`. + If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricDict`. """ if isinstance(metrics, str): - metrics = create_metric(metric_name=metrics, **metric_kwargs) + metrics = create_metric(metric_name=metrics, experimental=True, **metric_kwargs) if isinstance(metrics, Metric): if metric_name is not None and isinstance(metrics, OperatorMetric): # single metric created from arithmetic operation, with given name - return MetricCollection({metric_name: metrics}) - return MetricCollection(metrics) - if isinstance(metrics, MetricCollection): + return MetricDict({metric_name: metrics}) + return MetricDict(metrics) + if isinstance(metrics, MetricDict): return metrics if callable(metrics): if metric_name is None: @@ -407,7 +401,7 @@ def _format_metrics( return metrics raise TypeError( - f"Expected `metrics` to be of type `str`, `Metric`, `MetricCollection`, or " + f"Expected `metrics` to be of type `str`, `Metric`, `MetricDict`, or " f"`Callable`, but got {type(metrics)}.", ) @@ -701,7 +695,7 @@ def _get_slice_spec( def _compute_metrics( # noqa: C901, PLR0912 - metrics: Union[Callable[..., Any], MetricCollection], + metrics: Union[Callable[..., Any], MetricDict], dataset: Dataset, target_columns: List[str], prediction_column: str, @@ -713,7 +707,7 @@ def _compute_metrics( # noqa: C901, PLR0912 Parameters ---------- - metrics : Union[Callable, MetricCollection] + metrics : Union[Callable, MetricDict] The metrics to compute. dataset : Dataset The dataset to compute the metrics on. @@ -738,12 +732,19 @@ def _compute_metrics( # noqa: C901, PLR0912 "Encountered empty dataset while computing metrics. " "The metric values will be set to `None`." ) - if isinstance(metrics, MetricCollection): + if isinstance(metrics, MetricDict): if threshold is not None: # set the threshold for each metric in the collection for name, metric in metrics.items(): - if hasattr(metric, "threshold"): + if isinstance(metric, Metric) and hasattr(metric, "threshold"): metric.threshold = threshold + elif isinstance(metric, OperatorMetric): + if hasattr(metric.metric_a, "threshold") and hasattr( + metric.metric_b, + "threshold", + ): + metric.metric_a.threshold = threshold + metric.metric_b.threshold = threshold # type: ignore[union-attr] else: LOGGER.warning( "Metric %s does not have a threshold attribute. " @@ -754,7 +755,7 @@ def _compute_metrics( # noqa: C901, PLR0912 if len(dataset) == 0: warnings.warn(empty_dataset_msg, RuntimeWarning, stacklevel=1) results: Dict[str, Any] = { - metric_name: float("NaN") for metric_name in metrics + metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined] } elif ( batch_size is None or batch_size <= 0 @@ -779,11 +780,11 @@ def _compute_metrics( # noqa: C901, PLR0912 columns=prediction_column, ) - metrics.update_state(targets, predictions) + metrics.update(targets, predictions) results = metrics.compute() - metrics.reset_state() + metrics.reset() return results if callable(metrics): @@ -817,26 +818,26 @@ def _compute_metrics( # noqa: C901, PLR0912 return {metric_name.title(): output} raise TypeError( - "The `metrics` argument must be a string, a Metric, a MetricCollection, " + "The `metrics` argument must be a string, a Metric, a MetricDict, " f"or a callable. Got {type(metrics)}.", ) def _get_metric_results_for_prediction_and_slice( - metrics: Union[Callable[..., Any], MetricCollection], + metrics: Union[Callable[..., Any], MetricDict], dataset: Dataset, target_columns: List[str], prediction_column: str, slice_name: str, batch_size: Optional[int] = config.DEFAULT_MAX_BATCH_SIZE, metric_name: Optional[str] = None, - thresholds: Optional[npt.NDArray[np.float_]] = None, + thresholds: Optional[Array] = None, ) -> Dict[str, Dict[str, Any]]: """Compute metrics for a slice of a dataset. Parameters ---------- - metrics : Union[Callable, MetricCollection] + metrics : Union[Callable, MetricDict] The metrics to compute. dataset : Dataset The dataset to compute the metrics on. @@ -850,7 +851,7 @@ def _get_metric_results_for_prediction_and_slice( The batch size to use for the computation. metric_name : Optional[str] The name of the metric to compute. - thresholds : Optional[List[float]] + thresholds : Optional[Array] The thresholds to use for the metrics. Returns @@ -873,7 +874,7 @@ def _get_metric_results_for_prediction_and_slice( return {slice_name: metric_output} results: Dict[str, Dict[str, Any]] = {} - for threshold in thresholds: + for threshold in thresholds: # type: ignore[attr-defined] metric_output = _compute_metrics( metrics=metrics, dataset=dataset, @@ -969,11 +970,7 @@ def _compute_parity_metrics( ) parity_results[key].setdefault(slice_name, {}).update( - { - parity_metric_name: _get_value_if_singleton_array( - parity_metric_value, - ), - }, + {parity_metric_name: parity_metric_value}, ) return parity_results diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py index ec6c72609..3a5b9974a 100644 --- a/cyclops/evaluate/metrics/experimental/__init__.py +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -9,6 +9,11 @@ MulticlassAUROC, MultilabelAUROC, ) +from cyclops.evaluate.metrics.experimental.average_precision import ( + BinaryAveragePrecision, + MulticlassAveragePrecision, + MultilabelAveragePrecision, +) from cyclops.evaluate.metrics.experimental.confusion_matrix import ( BinaryConfusionMatrix, MulticlassConfusionMatrix, diff --git a/cyclops/evaluate/metrics/experimental/auroc.py b/cyclops/evaluate/metrics/experimental/auroc.py index 17c6af31f..bd139cb77 100644 --- a/cyclops/evaluate/metrics/experimental/auroc.py +++ b/cyclops/evaluate/metrics/experimental/auroc.py @@ -1,5 +1,5 @@ """Classes for computing the area under the ROC curve.""" -from typing import List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union from cyclops.evaluate.metrics.experimental.functional.auroc import ( _binary_auroc_compute, @@ -18,7 +18,7 @@ from cyclops.evaluate.metrics.experimental.utils.types import Array -class BinaryAUROC(BinaryPrecisionRecallCurve): +class BinaryAUROC(BinaryPrecisionRecallCurve, registry_key="binary_auroc"): """Area under the Receiver Operating Characteristic (ROC) curve. Parameters @@ -37,6 +37,8 @@ class BinaryAUROC(BinaryPrecisionRecallCurve): ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the AUROC. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments that are common to all metrics. Examples -------- @@ -59,9 +61,10 @@ def __init__( max_fpr: Optional[float] = None, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, + **kwargs: Any, ) -> None: """Initialize the BinaryAUROC metric.""" - super().__init__(thresholds=thresholds, ignore_index=ignore_index) + super().__init__(thresholds=thresholds, ignore_index=ignore_index, **kwargs) _binary_auroc_validate_args( max_fpr=max_fpr, thresholds=thresholds, @@ -70,7 +73,7 @@ def __init__( self.max_fpr = max_fpr def _compute_metric(self) -> Array: # type: ignore[override] - """Compute the AUROC.""" "" + """Compute the AUROC.""" state = ( (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] if self.thresholds is None @@ -79,7 +82,7 @@ def _compute_metric(self) -> Array: # type: ignore[override] return _binary_auroc_compute(state, thresholds=self.thresholds, max_fpr=self.max_fpr) # type: ignore -class MulticlassAUROC(MulticlassPrecisionRecallCurve): +class MulticlassAUROC(MulticlassPrecisionRecallCurve, registry_key="multiclass_auroc"): """Area under the Receiver Operating Characteristic (ROC) curve. Parameters @@ -105,6 +108,8 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve): ignore_index : int or Tuple[int], optional, default=None The value(s) in `target` that should be ignored when computing the AUROC. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments that are common to all metrics. Examples -------- @@ -140,12 +145,14 @@ def __init__( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", ignore_index: Optional[Union[int, Tuple[int]]] = None, + **kwargs: Any, ) -> None: """Initialize the MulticlassAUROC metric.""" super().__init__( num_classes, thresholds=thresholds, ignore_index=ignore_index, + **kwargs, ) _multiclass_auroc_validate_args( num_classes=num_classes, @@ -170,9 +177,11 @@ def _compute_metric(self) -> Array: # type: ignore[override] ) -class MultilabelAUROC(MultilabelPrecisionRecallCurve): +class MultilabelAUROC(MultilabelPrecisionRecallCurve, registry_key="multilabel_auroc"): """Area under the Receiver Operating Characteristic (ROC) curve. + Parameters + ---------- num_labels : int The number of labels in the multilabel classification problem. thresholds : Union[int, List[float], Array], optional, default=None @@ -195,6 +204,8 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve): ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the AUROC. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments that are common to all metrics. Examples -------- @@ -227,12 +238,14 @@ def __init__( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, + **kwargs: Any, ) -> None: """Initialize the MultilabelAUROC metric.""" super().__init__( num_labels, thresholds=thresholds, ignore_index=ignore_index, + **kwargs, ) _multilabel_auroc_validate_args( num_labels=num_labels, diff --git a/cyclops/evaluate/metrics/experimental/average_precision.py b/cyclops/evaluate/metrics/experimental/average_precision.py new file mode 100644 index 000000000..f8f8692ac --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/average_precision.py @@ -0,0 +1,272 @@ +"""Classes for computing area under the Average Precision (AUPRC).""" + +from typing import Any, List, Literal, Optional, Tuple, Union + +from cyclops.evaluate.metrics.experimental.functional.average_precision import ( + _binary_average_precision_compute, + _multiclass_average_precision_compute, + _multiclass_average_precision_validate_args, + _multilabel_average_precision_compute, + _multilabel_average_precision_validate_args, +) +from cyclops.evaluate.metrics.experimental.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class BinaryAveragePrecision( + BinaryPrecisionRecallCurve, + registry_key="binary_average_precision", +): + """A summary of the precision-recall curve via a weighted mean of the points. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the precision and recall. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the average + precision. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import BinaryAveragePrecision + >>> target = anp.asarray([0, 1, 0, 1]) + >>> preds = anp.asarray([0.1, 0.4, 0.35, 0.8]) + >>> metric = BinaryAveragePrecision(thresholds=3) + >>> metric(target, preds) + Array(0.75, dtype=float32) + >>> metric.reset() + >>> target = [[0, 1, 0, 1], [1, 1, 0, 0]] + >>> preds = [[0.1, 0.4, 0.35, 0.8], [0.6, 0.3, 0.1, 0.7]] + >>> for t, p in zip(target, preds): + ... metric.update(anp.asarray(t), anp.asarray(p)) + >>> metric.compute() + Array(0.5833334, dtype=float32) + + """ + + name: str = "Average Precision Score" + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the metric.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + + return _binary_average_precision_compute( + state, + self.thresholds, # type: ignore + pos_label=1, + ) + + +class MulticlassAveragePrecision( + MulticlassPrecisionRecallCurve, + registry_key="multiclass_average_precision", +): + """A summary of the precision-recall curve via a weighted mean of the points. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision score. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"macro"`: compute the average precision score for each class and average + over the classes. + - `"weighted"`: computes the average of the precision for each class and + average over the classwise scores using the support of each class as + weights. + - `"none"`: do not average over the classwise scores. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the average + precision score. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MulticlassAveragePrecision + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> metric = MulticlassAveragePrecision( + ... num_classes=3, thresholds=None, average=None, + ... ) + >>> metric(target, preds) + Array([0.33333334, 0.5 , 0.5 ], dtype=float32) + + """ + + name: str = "Average Precision Score" + + def __init__( + self, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, + **kwargs: Any, + ) -> None: + """Initialize a `MulticlassAveragePrecision` instance.""" + super().__init__(num_classes, thresholds, ignore_index=ignore_index, **kwargs) + _multiclass_average_precision_validate_args( + num_classes, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + self.average = average # type: ignore[assignment] + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the metric.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + + return _multiclass_average_precision_compute( + state, + self.num_classes, + thresholds=self.thresholds, # type: ignore[arg-type] + average=self.average, # type: ignore[arg-type] + ) + + +class MultilabelAveragePrecision( + MultilabelPrecisionRecallCurve, + registry_key="multilabel_average_precision", +): + """A summary of the precision-recall curve via a weighted mean of the points. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision score. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"micro", "macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"micro"`: computes the average precision score globally by summing over + the average precision scores for each label. + - `"macro"`: compute the average precision score for each label and average + over the labels. + - `"weighted"`: computes the average of the precision for each label and + average over the labelwise scores using the support of each label as + weights. + - `"none"`: do not average over the labelwise scores. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the average + precision score. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MultilabelAveragePrecision + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> metric = MultilabelAveragePrecision( + ... num_labels=3, thresholds=None, average=None, + ... ) + >>> metric(target, preds) + Array([1. , 0.5833334, 0.5 ], dtype=float32) + """ + + name: str = "Average Precision Score" + + def __init__( + self, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Initialize a `MultilabelAveragePrecision` instance.""" + super().__init__( + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + **kwargs, + ) + _multilabel_average_precision_validate_args( + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + self.average = average + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the metric.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + + return _multilabel_average_precision_compute( + state, + self.num_labels, + thresholds=self.thresholds, # type: ignore[arg-type] + average=self.average, + ignore_index=self.ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py b/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py index 851b74f8a..25b0ea1cd 100644 --- a/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py +++ b/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py @@ -10,8 +10,10 @@ if TYPE_CHECKING: import torch import torch.distributed as torch_dist + from torch import Tensor else: torch = import_optional_module("torch", error="warn") + Tensor = import_optional_module("torch", attribute="Tensor", error="warn") torch_dist = import_optional_module("torch.distributed", error="warn") @@ -47,13 +49,13 @@ def world_size(self) -> int: """Return the world size of the current process group.""" return torch_dist.get_world_size() - def _simple_all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: + def _simple_all_gather(self, data: Tensor) -> List[Tensor]: """Gather tensors of the same shape from all processes.""" gathered_data = [torch.zeros_like(data) for _ in range(self.world_size)] torch_dist.all_gather(gathered_data, data) # type: ignore[no-untyped-call] return gathered_data - def all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: # type: ignore[override] + def all_gather(self, data: Tensor) -> List[Tensor]: # type: ignore[override] """Gather Arrays from current proccess and return as a list. Parameters @@ -95,3 +97,7 @@ def all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: # type: ignore[ slice_param = [slice(dim_size) for dim_size in item_size] gathered_data[idx] = gathered_data[idx][slice_param] return gathered_data + + +if __name__ == "__main__": # prevent execution of module on import + pass diff --git a/cyclops/evaluate/metrics/experimental/f_score.py b/cyclops/evaluate/metrics/experimental/f_score.py index 1092e499c..7e9bc7a20 100644 --- a/cyclops/evaluate/metrics/experimental/f_score.py +++ b/cyclops/evaluate/metrics/experimental/f_score.py @@ -28,7 +28,7 @@ class BinaryFBetaScore(_AbstractBinaryStatScores, registry_key="binary_fbeta_sco Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -106,6 +106,8 @@ class MulticlassFBetaScore( Specifies a target class that is ignored when computing the F-beta score. Ignoring a target class means that the corresponding predictions do not contribute to the F-beta score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -203,6 +205,8 @@ class MultilabelFBetaScore( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the F-beta score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -277,7 +281,7 @@ class BinaryF1Score(BinaryFBetaScore, registry_key="binary_f1_score"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -341,6 +345,8 @@ class MulticlassF1Score(MulticlassFBetaScore, registry_key="multiclass_f1_score" Specifies a target class that is ignored when computing the F1 score. Ignoring a target class means that the corresponding predictions do not contribute to the F1 score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -413,6 +419,8 @@ class MultilabelF1Score( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the F1 score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index e24543e64..1a2e5902b 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -9,6 +9,11 @@ multiclass_auroc, multilabel_auroc, ) +from cyclops.evaluate.metrics.experimental.functional.average_precision import ( + binary_average_precision, + multiclass_average_precision, + multilabel_average_precision, +) from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import ( binary_confusion_matrix, multiclass_confusion_matrix, diff --git a/cyclops/evaluate/metrics/experimental/functional/auroc.py b/cyclops/evaluate/metrics/experimental/functional/auroc.py index c6e7c83c5..7abe73990 100644 --- a/cyclops/evaluate/metrics/experimental/functional/auroc.py +++ b/cyclops/evaluate/metrics/experimental/functional/auroc.py @@ -1,5 +1,6 @@ """Functions for computing the area under the ROC curve (AUROC).""" import warnings +from types import ModuleType from typing import List, Literal, Optional, Tuple, Union import array_api_compat as apc @@ -194,6 +195,8 @@ def _reduce_auroc( tpr: Union[Array, List[Array]], average: Optional[Literal["macro", "weighted", "none"]] = None, weights: Optional[Array] = None, + *, + xp: ModuleType, ) -> Array: """Compute the area under the ROC curve and apply `average` method. @@ -225,7 +228,6 @@ def _reduce_auroc( If the AUROC for one or more classes is `nan` and ``average`` is not ``none``. """ - xp = apc.array_namespace((fpr[0], tpr[0]) if isinstance(fpr, list) else (fpr, tpr)) if apc.is_array_api_obj(fpr) and apc.is_array_api_obj(tpr): res = _auc_compute(fpr, tpr, 1.0, axis=1) # type: ignore else: @@ -288,6 +290,7 @@ def _multiclass_auroc_compute( weights=xp.astype(bincount(state[0], minlength=num_classes), xp.float32) if thresholds is None else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore[call-overload] + xp=xp, ) @@ -492,6 +495,7 @@ def _multilabel_auroc_compute( ) if thresholds is None else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore[call-overload] + xp=xp, ) diff --git a/cyclops/evaluate/metrics/experimental/functional/average_precision.py b/cyclops/evaluate/metrics/experimental/functional/average_precision.py new file mode 100644 index 000000000..257fd119c --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/average_precision.py @@ -0,0 +1,677 @@ +"""Functions for computing average precision (AUPRC) for classification tasks.""" +import warnings +from types import ModuleType +from typing import List, Literal, Optional, Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + _binary_precision_recall_curve_compute, + _binary_precision_recall_curve_format_arrays, + _binary_precision_recall_curve_update, + _binary_precision_recall_curve_validate_args, + _binary_precision_recall_curve_validate_arrays, + _multiclass_precision_recall_curve_compute, + _multiclass_precision_recall_curve_format_arrays, + _multiclass_precision_recall_curve_update, + _multiclass_precision_recall_curve_validate_args, + _multiclass_precision_recall_curve_validate_arrays, + _multilabel_precision_recall_curve_compute, + _multilabel_precision_recall_curve_format_arrays, + _multilabel_precision_recall_curve_update, + _multilabel_precision_recall_curve_validate_args, + _multilabel_precision_recall_curve_validate_arrays, +) +from cyclops.evaluate.metrics.experimental.utils.ops import ( + _diff, + bincount, + flatten, + remove_ignore_index, + safe_divide, +) +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +def _binary_average_precision_compute( + state: Union[Tuple[Array, Array], Array], + thresholds: Optional[Array], + pos_label: int = 1, +) -> Array: + """Compute average precision for binary classification task. + + Parameters + ---------- + state : Array or Tuple[Array, Array] + State from which the precision-recall curve can be computed. Can be + either a tuple of (target, preds) or a multi-threshold confusion matrix. + thresholds : Array, optional + Thresholds used for computing the precision and recall scores. If not None, + must be a 1D numpy array of floats in the [0, 1] range and monotonically + increasing. + pos_label : int, optional, default=1 + The label of the positive class. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + ValueError + If ``thresholds`` is None. + + """ + precision, recall, _ = _binary_precision_recall_curve_compute( + state, + thresholds, + pos_label, + ) + xp = apc.array_namespace(precision, recall) + return -xp.sum(_diff(recall) * precision[:-1], dtype=xp.float32) # type: ignore + + +def binary_average_precision( + target: Array, + preds: Array, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> Array: + """Compute average precision score for binary classification task. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, 1]. The expected + shape of the array is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for the positive class. The expected + shape of the array is `(N, ...)` where `N` is the number of samples. If + `preds` contains floating point values that are not in the range `[0, 1]`, + a sigmoid function will be applied to each value before thresholding. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the average + precision. If `None`, all values in `target` are used. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... binary_average_precision + ... ) + >>> target = anp.asarray([0, 1, 1, 0]) + >>> preds = anp.asarray([0, 0.5, 0.7, 0.8]) + >>> binary_average_precision(target, preds, thresholds=None) + Array(0.5833334, dtype=float32) + + """ + _binary_precision_recall_curve_validate_args(thresholds, ignore_index) + xp = _binary_precision_recall_curve_validate_arrays( + target, + preds, + thresholds, + ignore_index, + ) + target, preds, thresholds = _binary_precision_recall_curve_format_arrays( + target, + preds, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _binary_precision_recall_curve_update( + target, + preds, + thresholds=thresholds, + xp=xp, + ) + return _binary_average_precision_compute(state, thresholds, pos_label=1) + + +def _reduce_average_precision( + precision: Union[Array, List[Array]], + recall: Union[Array, List[Array]], + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + weights: Optional[Array] = None, + *, + xp: ModuleType, +) -> Array: + """Reduce the precision-recall curve to a single average precision score. + + Applies the specified `average` after computing the average precision score + for each class/label. + + Parameters + ---------- + precision : Array or List[Array] + The precision values for each class/label, computed at different thresholds. + recall : Array or List[Array] + The recall values for each class/label, computed at different thresholds. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"macro"`: computes the average precision score for each class/label and + average over the scores. + - `"weighted"`: computes the average of the precision score for each + class/label and average over the classwise/labelwise scores using + `weights` as weights. + - `"none"`: do not average over the classwise/labelwise scores. + weights : Array, optional, default=None + The weights to use for computing the weighted average precision score. + xp : ModuleType + The array API module to use for computations. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + ValueError + If `average` is not `"macro"`, `"weighted"` or `"none"` or `None` or + average is `"weighted"` and `weights` is `None`. + """ + if apc.is_array_api_obj(precision) and apc.is_array_api_obj(recall): + avg_prec = -xp.sum( + (recall[:, 1:] - recall[:, :-1]) * precision[:, :-1], # type: ignore + axis=1, + dtype=xp.float32, + ) + else: + avg_prec = xp.stack( + [ + -xp.sum((rec[1:] - rec[:-1]) * prec[:-1], dtype=xp.float32) + for prec, rec in zip(precision, recall) # type: ignore[arg-type] + ], + ) + if average is None or average == "none": + return avg_prec # type: ignore[no-any-return] + if xp.any(xp.isnan(avg_prec)): + warnings.warn( + f"Average precision score for one or more classes was `nan`. Ignoring these classes in {average}-average", + UserWarning, + stacklevel=1, + ) + idx = ~xp.isnan(avg_prec) + if average == "macro": + return xp.mean(avg_prec[idx]) # type: ignore[no-any-return] + if average == "weighted" and weights is not None: + weights = safe_divide(weights[idx], xp.sum(weights[idx])) + return xp.sum(avg_prec[idx] * weights, dtype=xp.float32) # type: ignore[no-any-return] + raise ValueError( + "Received an incompatible combinations of inputs to make reduction.", + ) + + +def _multiclass_average_precision_validate_args( + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> None: + """Validate the arguments for the `multiclass_average_precision` function.""" + _multiclass_precision_recall_curve_validate_args( + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + ) + allowed_averages = ["macro", "weighted", "none"] + if average is not None and average not in allowed_averages: + raise ValueError( + f"Expected `average` to be one of {allowed_averages}, got {average}.", + ) + + +def _multiclass_average_precision_compute( + state: Union[Tuple[Array, Array], Array], + num_classes: int, + thresholds: Optional[Array], + average: Optional[Literal["macro", "weighted", "none"]] = "macro", +) -> Array: + """Compute the average precision score for multiclass classification task.""" + precision, recall, _ = _multiclass_precision_recall_curve_compute( + state, + num_classes, + thresholds=thresholds, + average=None, + ) + xp = apc.array_namespace(state) + return _reduce_average_precision( + precision, + recall, + average=average, + weights=xp.astype(bincount(state[0], minlength=num_classes), xp.float32) + if thresholds is None + else xp.sum(state[0, ...][:, 1, :], axis=-1, dtype=xp.float32), # type: ignore + xp=xp, + ) + + +def multiclass_average_precision( + target: Array, + preds: Array, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> Array: + """Compute the average precision score for multiclass classification task. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, `num_classes`] + (except if `ignore_index` is specified). The expected shape of the array + is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for each sample. The expected shape + of the array is `(N, C, ...)` where `N` is the number of samples and `C` + is the number of classes. If `preds` contains floating point values that + are not in the range `[0, 1]`, a softmax function will be applied to each + value before thresholding. + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision score. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"macro"`: compute the average precision score for each class and average + over the classes. + - `"weighted"`: computes the average of the precision for each class and + average over the classwise scores using the support of each class as + weights. + - `"none"`: do not average over the classwise scores. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the average + precision score. If `None`, all values in `target` are used. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `num_classes` is not an integer larger than 1. + ValueError + If `ignore_index` is not `None`, an integer or a tuple of integers. + ValueError + If `average` is not `"macro"`, `"weighted"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `preds` does not have one more dimension than `target`. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + ValueError + If the second dimension of `preds` is not equal to `num_classes`. + ValueError + If the first dimension of `preds` is not equal to the first dimension of + `target` or the third dimension of `preds` is not equal to the second + dimension of `target`. + RuntimeError + If `target` contains more unique values than `num_classes` or `num_classes` + plus the number of values in `ignore_index` if `ignore_index` is not `None`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... multiclass_average_precision, + ... ) + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> multiclass_average_precision( + ... target, preds, num_classes=3, thresholds=None, average=None, + ... ) + Array([0.33333334, 0.5 , 0.5 ], dtype=float32) + >>> multiclass_average_precision( + ... target, preds, num_classes=3, thresholds=None, average="macro", + ... ) + Array(0.44444445, dtype=float32) + >>> multiclass_average_precision( + ... target, preds, num_classes=3, thresholds=None, average="weighted", + ... ) + Array(0.44444448, dtype=float32) + """ + _multiclass_average_precision_validate_args( + num_classes, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + xp = _multiclass_precision_recall_curve_validate_arrays( + target, + preds, + num_classes, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multiclass_precision_recall_curve_format_arrays( + target, + preds, + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _multiclass_precision_recall_curve_update( + target, + preds, + num_classes, + thresholds=thresholds, + xp=xp, + ) + return _multiclass_average_precision_compute( + state, + num_classes, + thresholds=thresholds, + average=average, + ) + + +def _multilabel_average_precision_validate_args( + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> None: + """Validate the arguments for the `multilabel_average_precision` function.""" + _multilabel_precision_recall_curve_validate_args( + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + allowed_averages = ["micro", "macro", "weighted", "none"] + if average is not None and average not in allowed_averages: + raise ValueError( + f"Expected `average` to be one of {allowed_averages}, got {average}.", + ) + + +def _multilabel_average_precision_compute( + state: Union[Tuple[Array, Array], Array], + num_labels: int, + thresholds: Optional[Array], + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> Array: + """Compute the average precision score for multilabel classification task.""" + xp = apc.array_namespace(state) + if average == "micro": + if apc.is_array_api_obj(state) and thresholds is not None: + state = xp.sum(state, axis=1) + else: + target, preds = flatten(state[0]), flatten(state[1]) + target, preds = remove_ignore_index(target, preds, ignore_index) + state = (target, preds) + return _binary_average_precision_compute(state, thresholds) + + precision, recall, _ = _multilabel_precision_recall_curve_compute( + state, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + return _reduce_average_precision( + precision, + recall, + average=average, + weights=xp.sum(xp.astype(state[0] == 1, xp.int32), axis=0, dtype=xp.float32) + if thresholds is None + else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore + xp=xp, + ) + + +def multilabel_average_precision( + target: Array, + preds: Array, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> Array: + """Compute the average precision score for multilabel classification task. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + target : Array + The target array of shape `(N, L, ...)` containing the ground truth labels + in the range [0, 1], where `N` is the number of samples and `L` is the + number of labels. + preds : Array + The prediction array of shape `(N, L, ...)` containing the probability/logit + scores for each sample, where `N` is the number of samples and `L` is the + number of labels. If `preds` contains floating point values that are not + in the range [0,1], they will be converted to probabilities using the + sigmoid function. + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision score. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"micro", "macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"micro"`: computes the average precision score globally by summing over + the average precision scores for each label. + - `"macro"`: compute the average precision score for each label and average + over the labels. + - `"weighted"`: computes the average of the precision for each label and + average over the labelwise scores using the support of each label as + weights. + - `"none"`: do not average over the labelwise scores. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the average + precision score. If `None`, all values in `target` are used. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + ValueError + If `num_labels` is not an integer larger than 1. + ValueError + If `average` is not `"micro"`, `"macro"`, `"weighted"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + ValueError + If the second dimension of `preds` is not equal to `num_labels`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... multilabel_average_precision, + ... ) + >>> import numpy.array_api as anp + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> multilabel_average_precision( + ... target, preds, num_labels=3, thresholds=None, average=None, + ... ) + Array([1. , 0.5833334, 0.5 ], dtype=float32) + >>> multilabel_average_precision( + ... target, preds, num_labels=3, thresholds=None, average="micro", + ... ) + Array(0.58452386, dtype=float32) + >>> multilabel_average_precision( + ... target, preds, num_labels=3, thresholds=None, average="macro", + ... ) + Array(0.6944445, dtype=float32) + >>> multilabel_average_precision( + ... target, preds, num_labels=3, thresholds=None, average="weighted", + ... ) + Array(0.6666667, dtype=float32) + """ + _multilabel_average_precision_validate_args( + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + xp = _multilabel_precision_recall_curve_validate_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multilabel_precision_recall_curve_format_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _multilabel_precision_recall_curve_update( + target, + preds, + num_labels, + thresholds=thresholds, + xp=xp, + ) + return _multilabel_average_precision_compute( + state, + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/mae.py b/cyclops/evaluate/metrics/experimental/mae.py index 3221d3340..dab2f5a5d 100644 --- a/cyclops/evaluate/metrics/experimental/mae.py +++ b/cyclops/evaluate/metrics/experimental/mae.py @@ -1,4 +1,6 @@ """Mean Absolute Error metric.""" +from typing import Any + from cyclops.evaluate.metrics.experimental.functional.mae import ( _mean_absolute_error_compute, _mean_absolute_error_update, @@ -10,6 +12,11 @@ class MeanAbsoluteError(Metric): """Mean Absolute Error. + Parameters + ---------- + **kwargs : Any + Keyword arguments to pass to the `Metric` base class. + Examples -------- >>> import numpy.array_api as anp @@ -24,8 +31,8 @@ class MeanAbsoluteError(Metric): name: str = "Mean Absolute Error" - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self.add_state_default_factory( "sum_abs_error", lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore diff --git a/cyclops/evaluate/metrics/experimental/mape.py b/cyclops/evaluate/metrics/experimental/mape.py index dede691f1..6d9d4afbf 100644 --- a/cyclops/evaluate/metrics/experimental/mape.py +++ b/cyclops/evaluate/metrics/experimental/mape.py @@ -1,4 +1,6 @@ """Mean Absolute Percentage Error (MAPE) metric.""" +from typing import Any + from cyclops.evaluate.metrics.experimental.functional.mape import ( _mean_absolute_percentage_error_compute, _mean_absolute_percentage_error_update, @@ -15,6 +17,8 @@ class MeanAbsolutePercentageError(Metric): epsilon : float, optional, default=1.17e-06 Specifies the lower bound for target values. Any target value below epsilon is set to epsilon (avoids division by zero errors). + **kwargs : Any + Keyword arguments to pass to the `Metric` base class. Examples -------- @@ -30,8 +34,8 @@ class MeanAbsolutePercentageError(Metric): name: str = "Mean Absolute Percentage Error" - def __init__(self, epsilon: float = 1.17e-6) -> None: - super().__init__() + def __init__(self, epsilon: float = 1.17e-6, **kwargs: Any) -> None: + super().__init__(**kwargs) if not isinstance(epsilon, float): raise TypeError(f"Expected `epsilon` to be a float. Got {type(epsilon)}") self.epsilon = epsilon diff --git a/cyclops/evaluate/metrics/experimental/mse.py b/cyclops/evaluate/metrics/experimental/mse.py index b8ef4b435..6210055a2 100644 --- a/cyclops/evaluate/metrics/experimental/mse.py +++ b/cyclops/evaluate/metrics/experimental/mse.py @@ -1,4 +1,6 @@ """Mean Squared Error metric.""" +from typing import Any + from cyclops.evaluate.metrics.experimental.functional.mse import ( _mean_squared_error_compute, _mean_squared_error_update, @@ -17,6 +19,8 @@ class MeanSquaredError(Metric): to `False`, returns the root mean squared error. num_outputs : int, optional, default=1 Number of outputs in multioutput setting. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. Examples -------- @@ -43,8 +47,13 @@ class MeanSquaredError(Metric): name: str = "Mean Squared Error" - def __init__(self, squared: bool = True, num_outputs: int = 1) -> None: - super().__init__() + def __init__( + self, + squared: bool = True, + num_outputs: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) if not isinstance(squared, bool): raise TypeError(f"Expected `squared` to be a boolean. Got {type(squared)}") if not isinstance(num_outputs, int) and num_outputs > 0: diff --git a/cyclops/evaluate/metrics/experimental/negative_predictive_value.py b/cyclops/evaluate/metrics/experimental/negative_predictive_value.py index 7a5f1e5ee..99602555d 100644 --- a/cyclops/evaluate/metrics/experimental/negative_predictive_value.py +++ b/cyclops/evaluate/metrics/experimental/negative_predictive_value.py @@ -20,7 +20,7 @@ class BinaryNPV(_AbstractBinaryStatScores, registry_key="binary_npv"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -82,6 +82,8 @@ class MulticlassNPV( Specifies a target class that is ignored when computing the negative predictive value. Ignoring a target class means that the corresponding predictions do not contribute to the negative predictive value. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -151,6 +153,8 @@ class MultilabelNPV( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the negative predictive value. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/precision_recall.py b/cyclops/evaluate/metrics/experimental/precision_recall.py index d704aff89..57253a1d6 100644 --- a/cyclops/evaluate/metrics/experimental/precision_recall.py +++ b/cyclops/evaluate/metrics/experimental/precision_recall.py @@ -20,7 +20,7 @@ class BinaryPrecision(_AbstractBinaryStatScores, registry_key="binary_precision" Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -59,7 +59,7 @@ class BinaryPPV(BinaryPrecision, registry_key="binary_ppv"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -114,6 +114,8 @@ class MulticlassPrecision( Specifies a target class that is ignored when computing the precision score. Ignoring a target class means that the corresponding predictions do not contribute to the precision score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -179,6 +181,8 @@ class MulticlassPPV(MulticlassPrecision, registry_key="multiclass_ppv"): Specifies a target class that is ignored when computing the positive predictive value. Ignoring a target class means that the corresponding predictions do not contribute to the positive predictive value. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -235,6 +239,8 @@ class MultilabelPrecision( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the precision score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -301,6 +307,8 @@ class MultilabelPPV(MultilabelPrecision, registry_key="multilabel_ppv"): ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the positive predictive value. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -334,7 +342,7 @@ class BinaryRecall(_AbstractBinaryStatScores, registry_key="binary_recall"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -373,7 +381,7 @@ class BinarySensitivity(BinaryRecall, registry_key="binary_sensitivity"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -407,7 +415,7 @@ class BinaryTPR(BinaryRecall, registry_key="binary_tpr"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -459,6 +467,8 @@ class MulticlassRecall(_AbstractMulticlassStatScores, registry_key="multiclass_r Specifies a target class that is ignored when computing the recall score. Ignoring a target class means that the corresponding predictions do not contribute to the recall score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -523,6 +533,8 @@ class MulticlassSensitivity(MulticlassRecall, registry_key="multiclass_sensitivi Specifies a target class that is ignored when computing the sensitivity score. Ignoring a target class means that the corresponding predictions do not contribute to the sensitivity score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -575,6 +587,8 @@ class MulticlassTPR(MulticlassRecall, registry_key="multiclass_tpr"): Specifies a target class that is ignored when computing the true positive rate. Ignoring a target class means that the corresponding predictions do not contribute to the true positive rate. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -628,6 +642,8 @@ class MultilabelRecall(_AbstractMultilabelStatScores, registry_key="multilabel_r ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the recall score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -694,6 +710,8 @@ class MultilabelSensitivity(MultilabelRecall, registry_key="multilabel_sensitivi ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the sensitivity score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -748,6 +766,8 @@ class MultilabelTPR(MultilabelRecall, registry_key="multilabel_tpr"): ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the true positive rate. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py index 46bfba20e..6567e407f 100644 --- a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py @@ -1,6 +1,6 @@ """Classes for computing the precision-recall curve.""" from types import ModuleType -from typing import List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union import array_api_compat as apc @@ -43,6 +43,8 @@ class BinaryPrecisionRecallCurve(Metric, registry_key="binary_precision_recall_c ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the precision and recall. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. Examples -------- @@ -69,9 +71,10 @@ def __init__( self, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, + **kwargs: Any, ) -> None: """Initialize a `BinaryPrecisionRecallCurve` instance.""" - super().__init__() + super().__init__(**kwargs) _binary_precision_recall_curve_validate_args(thresholds, ignore_index) self.ignore_index = ignore_index self.thresholds = thresholds @@ -173,6 +176,8 @@ class MulticlassPrecisionRecallCurve( ignore_index : int or Tuple[int], optional, default=None The value(s) in `target` that should be ignored when computing the precision and recall. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. Examples -------- @@ -219,9 +224,10 @@ def __init__( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["macro", "micro", "none"]] = None, ignore_index: Optional[Union[int, Tuple[int]]] = None, + **kwargs: Any, ) -> None: """Initialize a `MulticlassPrecisionRecallCurve` instance.""" - super().__init__() + super().__init__(**kwargs) _multiclass_precision_recall_curve_validate_args( num_classes, thresholds=thresholds, @@ -345,6 +351,8 @@ class MultilabelPrecisionRecallCurve( ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the precision and recall. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. Examples -------- @@ -385,9 +393,10 @@ def __init__( num_labels: int, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, + **kwargs: Any, ) -> None: """Initialize a `MultilabelPrecisionRecallCurve` instance.""" - super().__init__() + super().__init__(**kwargs) _multilabel_precision_recall_curve_validate_args( num_labels, thresholds=thresholds, diff --git a/cyclops/evaluate/metrics/experimental/roc.py b/cyclops/evaluate/metrics/experimental/roc.py index 942cc4e89..6c6fbecb5 100644 --- a/cyclops/evaluate/metrics/experimental/roc.py +++ b/cyclops/evaluate/metrics/experimental/roc.py @@ -15,7 +15,7 @@ from cyclops.evaluate.metrics.experimental.utils.types import Array -class BinaryROC(BinaryPrecisionRecallCurve): +class BinaryROC(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve"): """The receiver operating characteristic (ROC) curve. Parameters @@ -31,6 +31,8 @@ class BinaryROC(BinaryPrecisionRecallCurve): ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the ROC curve. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -62,7 +64,10 @@ def _compute_metric(self) -> Tuple[Array, Array, Array]: return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] -class MulticlassROC(MulticlassPrecisionRecallCurve): +class MulticlassROC( + MulticlassPrecisionRecallCurve, + registry_key="multiclass_roc_curve", +): """The reciever operator characteristics (ROC) curve. Parameters @@ -89,6 +94,8 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): ignore_index : int or Tuple[int], optional, default=None The value(s) in `target` that should be ignored when computing the ROC curve. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -145,7 +152,10 @@ def _compute_metric( ) -class MultilabelROC(MultilabelPrecisionRecallCurve): +class MultilabelROC( + MultilabelPrecisionRecallCurve, + registry_key="multilabel_roc_curve", +): """The reciever operator characteristics (ROC) curve. Parameters @@ -163,6 +173,8 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the ROC Curve. If `None`, all values in `target` are used. + **kwargs + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/smape.py b/cyclops/evaluate/metrics/experimental/smape.py index df2392ce4..a7e61c027 100644 --- a/cyclops/evaluate/metrics/experimental/smape.py +++ b/cyclops/evaluate/metrics/experimental/smape.py @@ -1,4 +1,6 @@ """Symmetric Mean Absolute Percentage Error metric.""" +from typing import Any + from cyclops.evaluate.metrics.experimental.functional.smape import ( _symmetric_mean_absolute_percentage_error_compute, _symmetric_mean_absolute_percentage_error_update, @@ -15,6 +17,8 @@ class SymmetricMeanAbsolutePercentageError(Metric): epsilon : float, optional, default=1.17e-6 Specifies the lower bound for target values. Any target value below epsilon is set to epsilon (avoids division by zero errors). + **kwargs : Any + Keyword arguments to pass to the `Metric` base class. Examples -------- @@ -32,8 +36,8 @@ class SymmetricMeanAbsolutePercentageError(Metric): name: str = "Symmetric Mean Absolute Percentage Error" - def __init__(self, epsilon: float = 1.17e-6) -> None: - super().__init__() + def __init__(self, epsilon: float = 1.17e-6, **kwargs: Any) -> None: + super().__init__(**kwargs) if not isinstance(epsilon, float): raise TypeError(f"Expected `epsilon` to be a float. Got {type(epsilon)}") self.epsilon = epsilon diff --git a/cyclops/evaluate/metrics/experimental/specificity.py b/cyclops/evaluate/metrics/experimental/specificity.py index b289b046d..768b8939e 100644 --- a/cyclops/evaluate/metrics/experimental/specificity.py +++ b/cyclops/evaluate/metrics/experimental/specificity.py @@ -20,7 +20,7 @@ class BinarySpecificity(_AbstractBinaryStatScores, registry_key="binary_specific Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -81,6 +81,8 @@ class MulticlassSpecificity( Specifies a target class that is ignored when computing the specificity score. Ignoring a target class means that the corresponding predictions do not contribute to the specificity score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -150,6 +152,8 @@ class MultilabelSpecificity( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the specificity score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -196,7 +200,7 @@ class BinaryTNR(BinarySpecificity, registry_key="binary_tnr"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -249,6 +253,8 @@ class MulticlassTNR(MulticlassSpecificity, registry_key="multiclass_tnr"): Specifies a target class that is ignored when computing the true negative rate. Ignoring a target class means that the corresponding predictions do not contribute to the true negative rate. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -303,6 +309,8 @@ class MultilabelTNR(MultilabelSpecificity, registry_key="multilabel_tnr"): ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the true negative rate. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/wmape.py b/cyclops/evaluate/metrics/experimental/wmape.py index a24e8eac5..fc37cf6a6 100644 --- a/cyclops/evaluate/metrics/experimental/wmape.py +++ b/cyclops/evaluate/metrics/experimental/wmape.py @@ -1,5 +1,6 @@ """Weighted Mean Absolute Percentage Error metric.""" from types import ModuleType +from typing import Any from cyclops.evaluate.metrics.experimental.functional.wmape import ( _weighted_mean_absolute_percentage_error_compute, @@ -17,6 +18,8 @@ class WeightedMeanAbsolutePercentageError(Metric): epsilon : float, optional, default=1.17e-6 Specifies the lower bound for target values. Any target value below epsilon is set to epsilon (avoids division by zero errors). + **kwargs : Any + Keyword arguments to pass to the `Metric` base class. Examples -------- @@ -34,8 +37,8 @@ class WeightedMeanAbsolutePercentageError(Metric): name: str = "Weighted Mean Absolute Percentage Error" - def __init__(self, epsilon: float = 1.17e-6) -> None: - super().__init__() + def __init__(self, epsilon: float = 1.17e-6, **kwargs: Any) -> None: + super().__init__(**kwargs) if not isinstance(epsilon, float): raise TypeError(f"Expected `epsilon` to be a float. Got {type(epsilon)}") self.epsilon = epsilon diff --git a/cyclops/evaluate/metrics/factory.py b/cyclops/evaluate/metrics/factory.py index e83d76b55..bdcd6048f 100644 --- a/cyclops/evaluate/metrics/factory.py +++ b/cyclops/evaluate/metrics/factory.py @@ -1,18 +1,28 @@ """Factory for creating metrics.""" from difflib import get_close_matches -from typing import Any, List +from typing import Any, List, Union +from cyclops.evaluate.metrics.experimental.metric import ( + _METRIC_REGISTRY as _EXPERIMENTAL_METRIC_REGISTRY, +) +from cyclops.evaluate.metrics.experimental.metric import Metric as ExperimentalMetric from cyclops.evaluate.metrics.metric import _METRIC_REGISTRY, Metric -def create_metric(metric_name: str, **kwargs: Any) -> Metric: +def create_metric( + metric_name: str, + experimental: bool = False, + **kwargs: Any, +) -> Union[Metric, ExperimentalMetric]: """Create a metric instance from a name. Parameters ---------- metric_name : str The name of the metric. + experimental : bool + Whether to use metrics from `cyclops.evaluate.metrics.experimental`. **kwargs : Any The keyword arguments to pass to the metric constructor. @@ -22,11 +32,20 @@ def create_metric(metric_name: str, **kwargs: Any) -> Metric: The metric instance. """ - metric_class = _METRIC_REGISTRY.get(metric_name, None) + metric_class = ( + _METRIC_REGISTRY.get(metric_name, None) + if not experimental + else _EXPERIMENTAL_METRIC_REGISTRY.get(metric_name, None) + ) if metric_class is None: + registry_keys: List[str] = ( + list(_METRIC_REGISTRY.keys()) + if not experimental + else list(_EXPERIMENTAL_METRIC_REGISTRY.keys()) # type: ignore[arg-type] + ) similar_keys_list: List[str] = get_close_matches( metric_name, - _METRIC_REGISTRY.keys(), + registry_keys, n=5, ) similar_keys: str = ", ".join(similar_keys_list) diff --git a/cyclops/report/plot/classification.py b/cyclops/report/plot/classification.py index e176ec386..04e0130f6 100644 --- a/cyclops/report/plot/classification.py +++ b/cyclops/report/plot/classification.py @@ -131,7 +131,7 @@ def roc_curve( if auroc is not None: assert isinstance( auroc, - float, + (float, np.floating), ), "AUROCs must be a float for binary tasks" name = f"Model (AUC = {auroc:.2f})" else: @@ -227,7 +227,7 @@ def roc_curve_comparison( if aurocs and slice_name in aurocs: assert isinstance( aurocs[slice_name], - float, + (float, np.floating), ), "AUROCs must be a float for binary tasks" name = f"{slice_name} (AUC = {aurocs[slice_name]:.2f})" else: @@ -401,7 +401,7 @@ def precision_recall_curve_comparison( if auprcs and slice_name in auprcs: assert isinstance( auprcs[slice_name], - float, + (float, np.floating), ), "AUPRCs must be a float for binary tasks" name = f"{slice_name} (AUC = {auprcs[slice_name]:.2f})" else: @@ -483,8 +483,10 @@ def metrics_value( """ if self.task_type == "binary": assert all( - not isinstance(value, (list, np.ndarray)) for value in metrics.values() - ), ("Metrics must not be of type list or np.ndarray for" "binary tasks") + not isinstance(value, list) + and not (isinstance(value, np.ndarray) and value.ndim > 0) + for value in metrics.values() + ), "Metrics must not be of type list or np.ndarray for binary tasks" trace = bar_plot( x=list(metrics.keys()), # type: ignore[arg-type] y=list(metrics.values()), # type: ignore[arg-type] @@ -705,7 +707,8 @@ def metrics_comparison_radar( for slice_name, metrics in slice_metrics.items(): metric_names = list(metrics.keys()) assert all( - not isinstance(value, (list, np.ndarray)) + not isinstance(value, list) + and not (isinstance(value, np.ndarray) and value.ndim > 0) for value in metrics.values() ), ( "Generic metrics must not be of type list or np.ndarray for" @@ -725,7 +728,9 @@ def metrics_comparison_radar( radial_data: List[float] = [] theta_data: List[float] = [] for metric_name, metric_values in metrics.items(): - if isinstance(metric_values, (list, np.ndarray)): + if isinstance(metric_values, list) or ( + isinstance(metric_values, np.ndarray) and metric_values.ndim > 0 + ): assert ( len(metric_values) == self.class_num ), "Metric values must be of length class_num for \ @@ -736,7 +741,7 @@ def metrics_comparison_radar( for i in range(self.class_num) ] theta_data.extend(theta) # type: ignore[arg-type] - elif isinstance(metric_values, float): + elif isinstance(metric_values, (float, np.floating)): radial_data.append(metric_values) theta_data.append(metric_name) # type: ignore[arg-type] else: @@ -807,10 +812,11 @@ def metrics_comparison_bar( metric_names = list(metrics.keys()) metric_values = list(metrics.values()) assert all( - not isinstance(value, (list, np.ndarray)) + not isinstance(value, list) + and not (isinstance(value, np.ndarray) and value.ndim > 0) for value in metrics.values() ), ( - "Generic metrics must not be of type list or np.ndarray for" + "Generic metrics must not be of type list or np.ndarray for " "binary tasks" ) trace.append( @@ -856,7 +862,10 @@ def metrics_comparison_bar( metric_names = list(metrics.keys()) for num in range(self.class_num): for metric_name in metric_names: - if isinstance(metrics[metric_name], (list, np.ndarray)): + if isinstance(metrics[metric_name], list) or ( + isinstance(metrics[metric_name], np.ndarray) + and metrics[metric_name].ndim > 0 + ): metric_values = metrics[metric_name][num] # type: ignore else: metric_values = metrics[metric_name] # type: ignore @@ -926,7 +935,8 @@ def metrics_comparison_scatter( metric_names = list(metrics.keys()) metric_values = list(metrics.values()) assert all( - not isinstance(value, (list, np.ndarray)) + not isinstance(value, list) + and not (isinstance(value, np.ndarray) and value.ndim > 0) for value in metrics.values() ), ( "Generic metrics must not be of type list or np.ndarray for" diff --git a/cyclops/tasks/classification.py b/cyclops/tasks/classification.py index 0772f13d3..d5344bb82 100644 --- a/cyclops/tasks/classification.py +++ b/cyclops/tasks/classification.py @@ -14,8 +14,8 @@ from cyclops.data.slicer import SliceSpec from cyclops.evaluate.evaluator import evaluate from cyclops.evaluate.fairness.config import FairnessConfig +from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict from cyclops.evaluate.metrics.factory import create_metric -from cyclops.evaluate.metrics.metric import MetricCollection from cyclops.models.catalog import ( _img_model_keys, _model_names_mapping, @@ -261,7 +261,7 @@ def predict( def evaluate( self, dataset: Union[Dataset, DatasetDict], - metrics: Union[List[str], MetricCollection], + metrics: Union[List[str], MetricDict], model_names: Optional[Union[str, List[str]]] = None, transforms: Optional[ColumnTransformer] = None, prediction_column_prefix: str = "predictions", @@ -278,7 +278,7 @@ def evaluate( ---------- dataset : Union[Dataset, DatasetDict] HuggingFace dataset. - metrics : Union[List[str], MetricCollection] + metrics : Union[List[str], MetricDict] Metrics to be evaluated. model_names : Union[str, List[str]], optional Model names to be evaluated, if not specified all fitted models \ @@ -315,9 +315,9 @@ def evaluate( if splits_mapping is None: splits_mapping = {"test": "test"} if isinstance(metrics, list) and len(metrics): - metrics_collection = MetricCollection( + metrics_collection = MetricDict( [ - create_metric( + create_metric( # type: ignore[misc] m, task=self.task_type, num_labels=len(self.task_features), @@ -325,7 +325,7 @@ def evaluate( for m in metrics ], ) - elif isinstance(metrics, MetricCollection): + elif isinstance(metrics, MetricDict): metrics_collection = metrics if isinstance(model_names, str): model_names = [model_names] @@ -345,6 +345,22 @@ def evaluate( only_predictions=False, splits_mapping=splits_mapping, ) + + # select the probability scores of the positive class since metrics + # expect a single column of probabilities + dataset = dataset.map( # type: ignore[union-attr] + lambda examples: { + f"{prediction_column_prefix}.{model_name}": np.array( # noqa: B023 + examples, + )[ + :, + 1, + ].tolist(), + }, + batched=True, + batch_size=batch_size, + input_columns=f"{prediction_column_prefix}.{model_name}", + ) results = evaluate( dataset=dataset, metrics=metrics_collection, @@ -448,7 +464,7 @@ def predict( def evaluate( self, dataset: Union[Dataset, DatasetDict], - metrics: Union[List[str], MetricCollection], + metrics: Union[List[str], MetricDict], model_names: Optional[Union[str, List[str]]] = None, transforms: Optional[Compose] = None, prediction_column_prefix: str = "predictions", @@ -465,7 +481,7 @@ def evaluate( ---------- dataset : Union[Dataset, DatasetDict] HuggingFace dataset. - metrics : Union[List[str], MetricCollection] + metrics : Union[List[str], MetricDict] Metrics to be evaluated. model_names : Union[str, List[str]], optional Model names to be evaluated, required if more than one model exists, \ @@ -515,9 +531,9 @@ def add_missing_labels(examples: Dict[str, Any]) -> Dict[str, Any]: dataset = dataset.map(add_missing_labels) if isinstance(metrics, list) and len(metrics): - metrics_collection = MetricCollection( + metrics_collection = MetricDict( [ - create_metric( + create_metric( # type: ignore[misc] m, task=self.task_type, num_labels=len(self.task_target), @@ -525,7 +541,7 @@ def add_missing_labels(examples: Dict[str, Any]) -> Dict[str, Any]: for m in metrics ], ) - elif isinstance(metrics, MetricCollection): + elif isinstance(metrics, MetricDict): metrics_collection = metrics if isinstance(model_names, str): model_names = [model_names] diff --git a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb index 69d0642e3..44acea15b 100644 --- a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb +++ b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb @@ -44,7 +44,8 @@ "from cyclops.data.df.feature import TabularFeatures\n", "from cyclops.data.slicer import SliceSpec\n", "from cyclops.evaluate.fairness import FairnessConfig # noqa: E402\n", - "from cyclops.evaluate.metrics import MetricCollection, create_metric\n", + "from cyclops.evaluate.metrics import create_metric\n", + "from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n", "from cyclops.models.catalog import create_model\n", "from cyclops.report import ModelCardReport\n", "from cyclops.report.plot.classification import ClassificationPlotter\n", @@ -697,7 +698,7 @@ "\n", "Evaluation is done using various evaluation metrics that provide different perspectives on the model's predictive abilities i.e. standard performance metrics and fairness metrics.\n", "\n", - "The standard performance metrics can be created using the `MetricCollection` object." + "The standard performance metrics can be created using the `MetricDict` object." ] }, { @@ -709,17 +710,19 @@ "outputs": [], "source": [ "metric_names = [\n", - " \"accuracy\",\n", - " \"precision\",\n", - " \"recall\",\n", - " \"f1_score\",\n", - " \"auroc\",\n", - " \"average_precision\",\n", - " \"roc_curve\",\n", - " \"precision_recall_curve\",\n", + " \"binary_accuracy\",\n", + " \"binary_precision\",\n", + " \"binary_recall\",\n", + " \"binary_f1_score\",\n", + " \"binary_auroc\",\n", + " \"binary_average_precision\",\n", + " \"binary_roc_curve\",\n", + " \"binary_precision_recall_curve\",\n", "]\n", - "metrics = [create_metric(metric_name, task=\"binary\") for metric_name in metric_names]\n", - "metric_collection = MetricCollection(metrics)" + "metrics = [\n", + " create_metric(metric_name, experimental=True) for metric_name in metric_names\n", + "]\n", + "metric_collection = MetricDict(metrics)" ] }, { @@ -762,7 +765,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "A `MetricCollection` can also be defined for the fairness metrics." + "A `MetricDict` can also be defined for the fairness metrics." ] }, { @@ -773,21 +776,15 @@ }, "outputs": [], "source": [ - "specificity = create_metric(\n", - " metric_name=\"specificity\",\n", - " task=\"binary\",\n", - ")\n", - "sensitivity = create_metric(\n", - " metric_name=\"sensitivity\",\n", - " task=\"binary\",\n", - ")\n", + "specificity = create_metric(metric_name=\"binary_specificity\", experimental=True)\n", + "sensitivity = create_metric(metric_name=\"binary_sensitivity\", experimental=True)\n", "\n", - "fpr = 1 - specificity\n", - "fnr = 1 - sensitivity\n", + "fpr = -specificity + 1\n", + "fnr = -sensitivity + 1\n", "\n", "ber = (fpr + fnr) / 2\n", "\n", - "fairness_metric_collection = MetricCollection(\n", + "fairness_metric_collection = MetricDict(\n", " {\n", " \"Sensitivity\": sensitivity,\n", " \"Specificity\": specificity,\n", @@ -858,8 +855,13 @@ "source": [ "results_female, _ = heart_failure_prediction_task.evaluate(\n", " dataset=dataset[\"test\"],\n", - " metrics=MetricCollection(\n", - " {\"BinaryAccuracy\": create_metric(metric_name=\"accuracy\", task=\"binary\")},\n", + " metrics=MetricDict(\n", + " {\n", + " \"BinaryAccuracy\": create_metric(\n", + " metric_name=\"binary_accuracy\",\n", + " experimental=True,\n", + " ),\n", + " },\n", " ),\n", " model_names=model_name,\n", " transforms=preprocessor,\n", @@ -889,7 +891,7 @@ "model_name = f\"model_for_preds.{model_name}\"\n", "results_flat = flatten_results_dict(\n", " results=results,\n", - " remove_metrics=[\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"],\n", + " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", " model_name=model_name,\n", ")\n", "results_female_flat = flatten_results_dict(\n", @@ -910,7 +912,7 @@ " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist(),\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", @@ -930,7 +932,7 @@ " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist(),\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", @@ -963,7 +965,7 @@ "source": [ "# extracting the ROC curves and AUROC results for all the slices\n", "roc_curves = {\n", - " slice_name: slice_results[\"BinaryROCCurve\"]\n", + " slice_name: slice_results[\"BinaryROC\"]\n", " for slice_name, slice_results in results[model_name].items()\n", "}\n", "aurocs = {\n", @@ -1036,7 +1038,7 @@ "overall_performance = {\n", " metric_name: metric_value\n", " for metric_name, metric_value in results[model_name][\"overall\"].items()\n", - " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + " if metric_name not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\"]\n", "}" ] }, @@ -1070,7 +1072,7 @@ " slice_name: {\n", " metric_name: metric_value\n", " for metric_name, metric_value in slice_results.items()\n", - " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + " if metric_name not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\"]\n", " }\n", " for slice_name, slice_results in results[model_name].items()\n", "}" diff --git a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb index 09fa6649f..b480f2dfd 100644 --- a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb +++ b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb @@ -48,7 +48,8 @@ "from cyclops.data.df.feature import TabularFeatures\n", "from cyclops.data.slicer import SliceSpec\n", "from cyclops.evaluate.fairness import FairnessConfig # noqa: E402\n", - "from cyclops.evaluate.metrics import MetricCollection, create_metric\n", + "from cyclops.evaluate.metrics import create_metric\n", + "from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n", "from cyclops.models.catalog import create_model\n", "from cyclops.report import ModelCardReport\n", "from cyclops.report.plot.classification import ClassificationPlotter\n", @@ -767,7 +768,7 @@ "\n", "Evaluation is done using various evaluation metrics that provide different perspectives on the model's predictive abilities i.e. standard performance metrics and fairness metrics.\n", "\n", - "The standard performance metrics can be created using the `MetricCollection` object." + "The standard performance metrics can be created using the `MetricDict` object." ] }, { @@ -777,17 +778,19 @@ "outputs": [], "source": [ "metric_names = [\n", - " \"accuracy\",\n", - " \"precision\",\n", - " \"recall\",\n", - " \"f1_score\",\n", - " \"auroc\",\n", - " \"average_precision\",\n", - " \"roc_curve\",\n", - " \"precision_recall_curve\",\n", + " \"binary_accuracy\",\n", + " \"binary_precision\",\n", + " \"binary_recall\",\n", + " \"binary_f1_score\",\n", + " \"binary_auroc\",\n", + " \"binary_average_precision\",\n", + " \"binary_roc_curve\",\n", + " \"binary_precision_recall_curve\",\n", "]\n", - "metrics = [create_metric(metric_name, task=\"binary\") for metric_name in metric_names]\n", - "metric_collection = MetricCollection(metrics)" + "metrics = [\n", + " create_metric(metric_name, experimental=True) for metric_name in metric_names\n", + "]\n", + "metric_collection = MetricDict(metrics)" ] }, { @@ -830,7 +833,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "A `MetricCollection` can also be defined for the fairness metrics." + "A `MetricDict` can also be defined for the fairness metrics." ] }, { @@ -839,18 +842,12 @@ "metadata": {}, "outputs": [], "source": [ - "specificity = create_metric(\n", - " metric_name=\"specificity\",\n", - " task=\"binary\",\n", - ")\n", - "sensitivity = create_metric(\n", - " metric_name=\"sensitivity\",\n", - " task=\"binary\",\n", - ")\n", - "fpr = 1 - specificity\n", - "fnr = 1 - sensitivity\n", + "specificity = create_metric(metric_name=\"binary_specificity\", experimental=True)\n", + "sensitivity = create_metric(metric_name=\"binary_sensitivity\", experimental=True)\n", + "fpr = -specificity + 1 # __rsub__ is not implemented for metrics\n", + "fnr = -sensitivity + 1\n", "ber = (fpr + fnr) / 2\n", - "fairness_metric_collection = MetricCollection(\n", + "fairness_metric_collection = MetricDict(\n", " {\n", " \"Sensitivity\": sensitivity,\n", " \"Specificity\": specificity,\n", @@ -929,7 +926,7 @@ "model_name = f\"model_for_preds.{model_name}\"\n", "results_flat = flatten_results_dict(\n", " results=results,\n", - " remove_metrics=[\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"],\n", + " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", " model_name=model_name,\n", ")" ] @@ -954,7 +951,7 @@ " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist(),\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", @@ -987,7 +984,7 @@ "source": [ "# extracting the ROC curves and AUROC results for all the slices\n", "roc_curves = {\n", - " slice_name: slice_results[\"BinaryROCCurve\"]\n", + " slice_name: slice_results[\"BinaryROC\"]\n", " for slice_name, slice_results in results[model_name].items()\n", "}\n", "aurocs = {\n", @@ -1060,7 +1057,7 @@ "overall_performance = {\n", " metric_name: metric_value\n", " for metric_name, metric_value in results[model_name][\"overall\"].items()\n", - " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + " if metric_name not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\"]\n", "}" ] }, @@ -1094,7 +1091,7 @@ " slice_name: {\n", " metric_name: metric_value\n", " for metric_name, metric_value in slice_results.items()\n", - " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + " if metric_name not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\"]\n", " }\n", " for slice_name, slice_results in results[model_name].items()\n", "}" diff --git a/docs/source/tutorials/nihcxr/cxr_classification.ipynb b/docs/source/tutorials/nihcxr/cxr_classification.ipynb index 32974c7e9..c4e34479f 100644 --- a/docs/source/tutorials/nihcxr/cxr_classification.ipynb +++ b/docs/source/tutorials/nihcxr/cxr_classification.ipynb @@ -29,7 +29,6 @@ "\n", "import shutil\n", "from functools import partial\n", - "from typing import Optional\n", "\n", "import numpy as np\n", "import plotly.express as px\n", @@ -45,7 +44,6 @@ "from cyclops.data.utils import apply_transforms\n", "from cyclops.evaluate import evaluator\n", "from cyclops.evaluate.metrics.factory import create_metric\n", - "from cyclops.evaluate.metrics.stat_scores import MultilabelStatScores\n", "from cyclops.models.wrappers import PTModel\n", "from cyclops.report import ModelCardReport" ] @@ -217,77 +215,35 @@ "]\n", "\n", "\n", - "class MultilabelPositivePredictiveValue(\n", - " MultilabelStatScores,\n", - " registry_key=\"positive_predictive_value\",\n", - "):\n", - " \"\"\"Compute the recall score for multilabel classification tasks.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " num_labels: int,\n", - " threshold: float = 0.5,\n", - " top_k: Optional[int] = None,\n", - " ) -> None:\n", - " \"\"\"Initialize the metric.\"\"\"\n", - " super().__init__(\n", - " num_labels=num_labels,\n", - " threshold=threshold,\n", - " top_k=top_k,\n", - " labelwise=True,\n", - " )\n", - "\n", - " def compute(self): # type: ignore[override]\n", - " \"\"\"Compute the recall score from the state.\"\"\"\n", - " tp, fp, tn, fn = self._final_state()\n", - " return tp / (tp + fp)\n", - "\n", - "\n", - "class MultilabelNegativePredictiveValue(\n", - " MultilabelStatScores,\n", - " registry_key=\"negative_predictive_value\",\n", - "):\n", - " \"\"\"Compute the recall score for multilabel classification tasks.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " num_labels: int,\n", - " threshold: float = 0.5,\n", - " top_k: Optional[int] = None,\n", - " ) -> None:\n", - " \"\"\"Initialize the metric.\"\"\"\n", - " super().__init__(\n", - " num_labels=num_labels,\n", - " threshold=threshold,\n", - " top_k=top_k,\n", - " labelwise=True,\n", - " )\n", - "\n", - " def compute(self): # type: ignore[override]\n", - " \"\"\"Compute the recall score from the state.\"\"\"\n", - " tp, fp, tn, fn = self._final_state()\n", - " return tn / (tn + fn)\n", - "\n", - "\n", - "ppv = MultilabelPositivePredictiveValue(\n", - " num_labels=len(pathologies),\n", + "num_labels = len(pathologies)\n", + "ppv = create_metric(\n", + " metric_name=\"multilabel_ppv\",\n", + " experimental=True,\n", + " num_labels=num_labels,\n", + " average=None,\n", ")\n", "\n", - "npv = MultilabelNegativePredictiveValue(\n", - " num_labels=len(pathologies),\n", + "npv = create_metric(\n", + " metric_name=\"multilabel_npv\",\n", + " experimental=True,\n", + " num_labels=num_labels,\n", + " average=None,\n", ")\n", "\n", "specificity = create_metric(\n", - " metric_name=\"specificity\",\n", - " task=\"multilabel\",\n", - " num_labels=len(pathologies),\n", + " metric_name=\"multilabel_specificity\",\n", + " experimental=True,\n", + " num_labels=num_labels,\n", + " average=None,\n", ")\n", "\n", "sensitivity = create_metric(\n", - " metric_name=\"sensitivity\",\n", - " task=\"multilabel\",\n", - " num_labels=len(pathologies),\n", + " metric_name=\"multilabel_sensitivity\",\n", + " experimental=True,\n", + " num_labels=num_labels,\n", + " average=None,\n", ")\n", + "\n", "# create the slice functions\n", "slice_spec = SliceSpec(spec_list=slices)\n", "\n", @@ -479,15 +435,15 @@ "for name, metric in results_flat.items():\n", " split, name = name.split(\"/\") # noqa: PLW2901\n", " descriptions = {\n", - " \"MultilabelPositivePredictiveValue\": \"The proportion of correctly predicted positive instances among all instances predicted as positive. Also known as precision.\",\n", - " \"MultilabelNegativePredictiveValue\": \"The proportion of correctly predicted negative instances among all instances predicted as negative.\",\n", + " \"MultilabelPPV\": \"The proportion of correctly predicted positive instances among all instances predicted as positive. Also known as precision.\",\n", + " \"MultilabelNPV\": \"The proportion of correctly predicted negative instances among all instances predicted as negative.\",\n", " \"MultilabelSensitivity\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", " \"MultilabelSpecificity\": \"The proportion of actual negative instances that are correctly predicted.\",\n", " }\n", " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist() if isinstance(metric, np.generic) else metric,\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", diff --git a/docs/source/tutorials/nihcxr/generate_nihcxr_report.py b/docs/source/tutorials/nihcxr/generate_nihcxr_report.py index 516b276c3..584bf8be8 100644 --- a/docs/source/tutorials/nihcxr/generate_nihcxr_report.py +++ b/docs/source/tutorials/nihcxr/generate_nihcxr_report.py @@ -3,10 +3,9 @@ # get args from command line import argparse from functools import partial -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import numpy as np -import numpy.typing as npt import plotly.express as px from torchvision.transforms import Compose from torchxrayvision.models import DenseNet @@ -20,7 +19,6 @@ from cyclops.data.utils import apply_transforms from cyclops.evaluate import evaluator from cyclops.evaluate.metrics.factory import create_metric -from cyclops.evaluate.metrics.stat_scores import MultilabelStatScores from cyclops.models.wrappers import PTModel # type: ignore[attr-defined] from cyclops.report import ModelCardReport # type: ignore[attr-defined] @@ -92,80 +90,40 @@ {"Patient Gender": {"value": "F"}}, ] +num_labels = len(pathologies) +ppv = create_metric( + metric_name="multilabel_ppv", + experimental=True, + num_labels=num_labels, + average=None, +) -class MultilabelPositivePredictiveValue( - MultilabelStatScores, - registry_key="positive_predictive_value", -): - """Compute the recall score for multilabel classification tasks.""" - - def __init__( - self, - num_labels: int, - threshold: float = 0.5, - top_k: Optional[int] = None, - ) -> None: - """Initialize the metric.""" - super().__init__( - num_labels=num_labels, - threshold=threshold, - top_k=top_k, - labelwise=True, - ) - - def compute(self) -> npt.NDArray[np.int_]: - """Compute the recall score from the state.""" - tp, fp, tn, fn = self._final_state() - return tp / (tp + fp) # type: ignore[return-value] - - -class MultilabelNegativePredictiveValue( - MultilabelStatScores, - registry_key="negative_predictive_value", -): - """Compute the recall score for multilabel classification tasks.""" - - def __init__( - self, - num_labels: int, - threshold: float = 0.5, - top_k: Optional[int] = None, - ) -> None: - """Initialize the metric.""" - super().__init__( - num_labels=num_labels, - threshold=threshold, - top_k=top_k, - labelwise=True, - ) - - def compute(self) -> npt.NDArray[np.int_]: - """Compute the recall score from the state.""" - tp, fp, tn, fn = self._final_state() - return tn / (tn + fn) # type: ignore[return-value] - - -ppv = MultilabelPositivePredictiveValue(num_labels=len(pathologies)) - -npv = MultilabelNegativePredictiveValue(num_labels=len(pathologies)) +npv = create_metric( + metric_name="multilabel_npv", + experimental=True, + num_labels=num_labels, + average=None, +) specificity = create_metric( - metric_name="specificity", - task="multilabel", - num_labels=len(pathologies), + metric_name="multilabel_specificity", + experimental=True, + num_labels=num_labels, + average=None, ) sensitivity = create_metric( - metric_name="sensitivity", - task="multilabel", - num_labels=len(pathologies), + metric_name="multilabel_sensitivity", + experimental=True, + num_labels=num_labels, + average=None, ) # create the slice functions slice_spec = SliceSpec(spec_list=slices_sex) nih_eval_results_gender = evaluator.evaluate( dataset=nih_ds, - metrics=[ppv, npv, sensitivity, specificity], + metrics=[ppv, npv, sensitivity, specificity], # type: ignore[list-item] target_columns=pathologies, prediction_columns="predictions.densenet", ignore_columns="image", @@ -208,7 +166,7 @@ def compute(self) -> npt.NDArray[np.int_]: nih_eval_results_age = evaluator.evaluate( dataset=nih_ds, - metrics=[ppv, npv, sensitivity, specificity], + metrics=[ppv, npv, sensitivity, specificity], # type: ignore[list-item] target_columns=pathologies, prediction_columns="predictions.densenet", ignore_columns="image", @@ -286,15 +244,15 @@ def compute(self) -> npt.NDArray[np.int_]: for name, metric in results_flat.items(): split, name = name.split("/") # noqa: PLW2901 descriptions = { - "MultilabelPositivePredictiveValue": "The proportion of correctly predicted positive instances among all instances predicted as positive. Also known as precision.", - "MultilabelNegativePredictiveValue": "The proportion of correctly predicted negative instances among all instances predicted as negative.", + "MultilabelPPV": "The proportion of correctly predicted positive instances among all instances predicted as positive. Also known as precision.", + "MultilabelNPV": "The proportion of correctly predicted negative instances among all instances predicted as negative.", "MultilabelSensitivity": "The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.", "MultilabelSpecificity": "The proportion of actual negative instances that are correctly predicted.", } report.log_quantitative_analysis( "performance", name=name, - value=metric, + value=metric.tolist() if isinstance(metric, np.generic) else metric, description=descriptions[name], metric_slice=split, pass_fail_thresholds=0.7, diff --git a/docs/source/tutorials/synthea/los_prediction.ipynb b/docs/source/tutorials/synthea/los_prediction.ipynb index 1c0bf4166..9f32ddaf7 100644 --- a/docs/source/tutorials/synthea/los_prediction.ipynb +++ b/docs/source/tutorials/synthea/los_prediction.ipynb @@ -55,7 +55,8 @@ "from cyclops.data.df.feature import TabularFeatures\n", "from cyclops.data.slicer import SliceSpec\n", "from cyclops.evaluate.fairness import FairnessConfig # noqa: E402\n", - "from cyclops.evaluate.metrics import MetricCollection, create_metric\n", + "from cyclops.evaluate.metrics import create_metric\n", + "from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n", "from cyclops.models.catalog import create_model\n", "from cyclops.report import ModelCardReport\n", "from cyclops.report.plot.classification import ClassificationPlotter\n", @@ -908,7 +909,7 @@ "\n", "Evaluation is done using various evaluation metrics that provide different perspectives on the model's predictive abilities i.e. standard performance metrics and fairness metrics.\n", "\n", - "The standard performance metrics can be created using the `MetricCollection` object." + "The standard performance metrics can be created using the `MetricDict` object." ] }, { @@ -921,17 +922,19 @@ "outputs": [], "source": [ "metric_names = [\n", - " \"accuracy\",\n", - " \"precision\",\n", - " \"recall\",\n", - " \"f1_score\",\n", - " \"auroc\",\n", - " \"roc_curve\",\n", - " \"precision_recall_curve\",\n", - " \"stat_scores\",\n", + " \"binary_accuracy\",\n", + " \"binary_precision\",\n", + " \"binary_recall\",\n", + " \"binary_f1_score\",\n", + " \"binary_auroc\",\n", + " \"binary_roc_curve\",\n", + " \"binary_precision_recall_curve\",\n", + " \"binary_confusion_matrix\",\n", "]\n", - "metrics = [create_metric(metric_name, task=\"binary\") for metric_name in metric_names]\n", - "metric_collection = MetricCollection(metrics)" + "metrics = [\n", + " create_metric(metric_name, experimental=True) for metric_name in metric_names\n", + "]\n", + "metric_collection = MetricDict(metrics)" ] }, { @@ -979,7 +982,7 @@ "id": "67bd7806-c480-4c47-8e33-6612c2ede93e", "metadata": {}, "source": [ - "A `MetricCollection` can also be defined for the fairness metrics." + "A `MetricDict` can also be defined for the fairness metrics." ] }, { @@ -991,18 +994,14 @@ }, "outputs": [], "source": [ - "specificity = create_metric(\n", - " metric_name=\"specificity\",\n", - " task=\"binary\",\n", - ")\n", - "sensitivity = create_metric(\n", - " metric_name=\"sensitivity\",\n", - " task=\"binary\",\n", - ")\n", - "fpr = 1 - specificity\n", - "fnr = 1 - sensitivity\n", + "specificity = create_metric(metric_name=\"binary_specificity\", experimental=True)\n", + "sensitivity = create_metric(metric_name=\"binary_sensitivity\", experimental=True)\n", + "fpr = (\n", + " -specificity + 1\n", + ") # rsub is not supported due to limitations in the array API standard\n", + "fnr = -sensitivity + 1\n", "ber = (fpr + fnr) / 2\n", - "fairness_metric_collection = MetricCollection(\n", + "fairness_metric_collection = MetricDict(\n", " {\n", " \"Sensitivity\": sensitivity,\n", " \"Specificity\": specificity,\n", @@ -1095,7 +1094,7 @@ "model_name = f\"model_for_preds.{model_name}\"\n", "results_flat = flatten_results_dict(\n", " results=results,\n", - " remove_metrics=[\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"],\n", + " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", " model_name=model_name,\n", ")" ] @@ -1111,7 +1110,7 @@ "source": [ "for name, metric in results_flat.items():\n", " split, name = name.split(\"/\") # noqa: PLW2901\n", - " if name == \"BinaryStatScores\":\n", + " if name == \"BinaryConfusionMatrix\":\n", " continue\n", " descriptions = {\n", " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", @@ -1123,7 +1122,7 @@ " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist(),\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", @@ -1163,7 +1162,7 @@ "source": [ "# extracting the ROC curves and AUROC results for all the slices\n", "roc_curves = {\n", - " slice_name: slice_results[\"BinaryROCCurve\"]\n", + " slice_name: slice_results[\"BinaryROC\"]\n", " for slice_name, slice_results in results[model_name].items()\n", "}\n", "aurocs = {\n", @@ -1181,8 +1180,7 @@ "outputs": [], "source": [ "# Plot confusion matrix\n", - "tp, fp, tn, fn, _ = results[model_name][\"overall\"][\"BinaryStatScores\"]\n", - "confusion_matrix = np.array([[tn, fp], [fn, tp]])\n", + "confusion_matrix = results[model_name][\"overall\"][\"BinaryConfusionMatrix\"]\n", "conf_plot = plotter.plot_confusion_matrix(\n", " confusion_matrix,\n", ")\n", @@ -1225,7 +1223,7 @@ " metric_name: metric_value\n", " for metric_name, metric_value in results[model_name][\"overall\"].items()\n", " if metric_name\n", - " not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\", \"BinaryStatScores\"]\n", + " not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\", \"BinaryConfusionMatrix\"]\n", "}" ] }, @@ -1262,7 +1260,7 @@ " metric_name: metric_value\n", " for metric_name, metric_value in slice_results.items()\n", " if metric_name\n", - " not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\", \"BinaryStatScores\"]\n", + " not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\", \"BinaryConfusionMatrix\"]\n", " }\n", " for slice_name, slice_results in results[model_name].items()\n", "}" diff --git a/poetry.lock b/poetry.lock index a114b6e21..7d7b1f444 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2422,7 +2422,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, - {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -3170,16 +3169,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3692,13 +3681,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= [[package]] name = "nbconvert" -version = "7.13.0" +version = "7.14.2" description = "Converting Jupyter Notebooks" optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.13.0-py3-none-any.whl", hash = "sha256:22521cfcc10ba5755e44acb6a70d2bd8a891ce7aed6746481e10cd548b169e19"}, - {file = "nbconvert-7.13.0.tar.gz", hash = "sha256:c6f61c86fca5b28bd17f4f9a308248e59fa2b54919e1589f6cc3575c5dfec2bd"}, + {file = "nbconvert-7.14.2-py3-none-any.whl", hash = "sha256:db28590cef90f7faf2ebbc71acd402cbecf13d29176df728c0a9025a49345ea1"}, + {file = "nbconvert-7.14.2.tar.gz", hash = "sha256:a7f8808fd4e082431673ac538400218dd45efd076fbeb07cc6e5aa5a3a4e949e"}, ] [package.dependencies] @@ -3780,13 +3769,13 @@ toolchain = ["black", "blacken-docs", "flake8", "isort", "jupytext", "mypy", "py [[package]] name = "nbsphinx" -version = "0.8.12" +version = "0.9.3" description = "Jupyter Notebook Tools for Sphinx" optional = false python-versions = ">=3.6" files = [ - {file = "nbsphinx-0.8.12-py3-none-any.whl", hash = "sha256:c15b681c7fce287000856f91fe1edac50d29f7b0c15bbc746fbe55c8eb84750b"}, - {file = "nbsphinx-0.8.12.tar.gz", hash = "sha256:76570416cdecbeb21dbf5c3d6aa204ced6c1dd7ebef4077b5c21b8c6ece9533f"}, + {file = "nbsphinx-0.9.3-py3-none-any.whl", hash = "sha256:6e805e9627f4a358bd5720d5cbf8bf48853989c79af557afd91a5f22e163029f"}, + {file = "nbsphinx-0.9.3.tar.gz", hash = "sha256:ec339c8691b688f8676104a367a4b8cf3ea01fd089dc28d24dec22d563b11562"}, ] [package.dependencies] @@ -4515,8 +4504,6 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, - {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, - {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -5021,7 +5008,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -5029,15 +5015,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -5054,7 +5033,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -5062,7 +5040,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, diff --git a/pyproject.toml b/pyproject.toml index bc2801268..e893cbd6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ alibi = { version = "^0.9.4", optional = true, extras = ["shap"] } alibi-detect = { version = "^0.11.0", optional = true, extras = ["torch"] } llvmlite = { version = "^0.40.0", optional = true } sphinx-book-theme = "^1.1.0" +nbsphinx = "^0.9.3" [tool.poetry.group.xgboost] optional = true @@ -117,7 +118,7 @@ sphinx-autodoc-typehints = "^1.24.0" myst-parser = "^2.0.0" sphinx-copybutton = "^0.5.0" sphinx-autoapi = "^2.0.0" -nbsphinx = "^0.8.11" +nbsphinx = "^0.9.3" ipython = "^8.8.0" ipykernel = "^6.23.0" kaggle = "^1.5.13" diff --git a/tests/cyclops/evaluate/metrics/experimental/test_average_precision.py b/tests/cyclops/evaluate/metrics/experimental/test_average_precision.py new file mode 100644 index 000000000..5d3d7704e --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_average_precision.py @@ -0,0 +1,503 @@ +"""Test average precision metric.""" +from functools import partial + +import array_api_compat as apc +import array_api_compat.torch +import numpy.array_api as anp +import pytest +import torch.utils.dlpack +from torchmetrics.functional.classification import ( + binary_average_precision as tm_binary_average_precision, +) +from torchmetrics.functional.classification import ( + multiclass_average_precision as tm_multiclass_average_precision, +) +from torchmetrics.functional.classification import ( + multilabel_average_precision as tm_multilabel_average_precision, +) + +from cyclops.evaluate.metrics.experimental.average_precision import ( + BinaryAveragePrecision, + MulticlassAveragePrecision, + MultilabelAveragePrecision, +) +from cyclops.evaluate.metrics.experimental.functional.average_precision import ( + binary_average_precision, + multiclass_average_precision, + multilabel_average_precision, +) +from cyclops.evaluate.metrics.experimental.utils.ops import to_int +from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point + +from ..conftest import NUM_CLASSES, NUM_LABELS +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds +from .testers import MetricTester, _inject_ignore_index + + +def _binary_average_precision_reference( + target, + preds, + thresholds, + ignore_index, +) -> torch.Tensor: + """Return the reference binary average precision.""" + return tm_binary_average_precision( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestBinaryAveragePrecision(MetricTester): + """Test binary average precision function and class.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_average_precision_function_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test function for binary average precision using array_api arrays.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=binary_average_precision, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_average_precision_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_average_precision_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test class for binary average precision using array_api arrays.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not anp.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryAveragePrecision, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_average_precision_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_average_precision_with_torch_tensors( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test binary average precision class with torch tensors.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not torch.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryAveragePrecision, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_average_precision_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + device=device, + use_device_for_ref=True, + ) + + +def _multiclass_average_precision_reference( + target, + preds, + num_classes=NUM_CLASSES, + thresholds=None, + average="macro", + ignore_index=None, +) -> torch.Tensor: + """Return the reference multiclass average precision.""" + if preds.ndim == 1 and is_floating_point(preds): + xp = apc.array_namespace(preds) + preds = xp.argmax(preds, axis=0) + + return tm_multiclass_average_precision( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_classes, + average=average, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestMulticlassAveragePrecision(MetricTester): + """Test multiclass average precision function and class.""" + + atol = 3e-8 + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multiclass_average_precision_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multiclass average precision.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multiclass_average_precision, + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _multiclass_average_precision_reference, + average=average, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_average_precision_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass average precision.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassAveragePrecision, + reference_metric=partial( + _multiclass_average_precision_reference, + average=average, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "average": average, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_average_precision_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass average precision.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassAveragePrecision, + reference_metric=partial( + _multiclass_average_precision_reference, + average=average, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) + + +def _multilabel_average_precision_reference( + preds, + target, + num_labels=NUM_LABELS, + thresholds=None, + average="macro", + ignore_index=None, +) -> torch.Tensor: + """Return the reference multilabel average precision.""" + return tm_multilabel_average_precision( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_labels, + average=average, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestMultilabelAveragePrecision(MetricTester): + """Test multilabel average precision function and class.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_average_precision_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multilabel average precision.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multilabel_average_precision, + reference_metric=partial( + _multilabel_average_precision_reference, + average=average, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "average": average, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_average_precision_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multilabel average precision.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelAveragePrecision, + reference_metric=partial( + _multilabel_average_precision_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "average": average, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_average_precision_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multilabel average precision.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelAveragePrecision, + reference_metric=partial( + _multilabel_average_precision_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "average": average, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall.py b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall.py index 8eb2a8e84..14c3c3a96 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall.py @@ -328,6 +328,8 @@ def _multiclass_precision_recall_reference( class TestMulticlassPrecision(MetricTester): """Test multiclass precision metric class and function.""" + atol = 6e-8 + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)) @pytest.mark.parametrize("top_k", [1, 2]) @pytest.mark.parametrize("average", [None, "micro", "macro", "weighted"])