From b6179b2e6aae377530417e4d1407c95fe10ea17e Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Tue, 19 Dec 2023 11:24:09 -0500 Subject: [PATCH] Add specificity metrics to experimental module --- .../evaluate/metrics/experimental/__init__.py | 5 + .../experimental/functional/__init__.py | 5 + .../experimental/functional/specificity.py | 443 ++++++++++++++++ .../metrics/experimental/specificity.py | 186 +++++++ .../metrics/experimental/test_specificity.py | 482 ++++++++++++++++++ 5 files changed, 1121 insertions(+) create mode 100644 cyclops/evaluate/metrics/experimental/functional/specificity.py create mode 100644 cyclops/evaluate/metrics/experimental/specificity.py create mode 100644 tests/cyclops/evaluate/metrics/experimental/test_specificity.py diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py index 2ffa7d6f9..c4b00ba99 100644 --- a/cyclops/evaluate/metrics/experimental/__init__.py +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -26,3 +26,8 @@ MultilabelPrecision, MultilabelRecall, ) +from cyclops.evaluate.metrics.experimental.specificity import ( + BinarySpecificity, + MulticlassSpecificity, + MultilabelSpecificity, +) diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index 54d6ab902..7a4962ea0 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -25,3 +25,8 @@ multilabel_precision, multilabel_recall, ) +from cyclops.evaluate.metrics.experimental.functional.specificity import ( + binary_specificity, + multiclass_specificity, + multilabel_specificity, +) diff --git a/cyclops/evaluate/metrics/experimental/functional/specificity.py b/cyclops/evaluate/metrics/experimental/functional/specificity.py new file mode 100644 index 000000000..4b5de6074 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/specificity.py @@ -0,0 +1,443 @@ +"""Methods for computing specificity scores for classification tasks.""" +from typing import Literal, Optional, Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.functional._stat_scores import ( + _binary_stat_scores_format_arrays, + _binary_stat_scores_update_state, + _binary_stat_scores_validate_args, + _binary_stat_scores_validate_arrays, + _multiclass_stat_scores_format_arrays, + _multiclass_stat_scores_update_state, + _multiclass_stat_scores_validate_args, + _multiclass_stat_scores_validate_arrays, + _multilabel_stat_scores_format_arrays, + _multilabel_stat_scores_update_state, + _multilabel_stat_scores_validate_arrays, +) +from cyclops.evaluate.metrics.experimental.utils.ops import ( + _adjust_weight_apply_average, + safe_divide, + squeeze_all, +) +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +def _specificity_compute( + average: Literal["micro", "macro", "weighted", "none"], + is_multilabel: bool, + *, + tp: Array, + fp: Array, + tn: Array, + fn: Array, +) -> Array: + xp = apc.array_namespace(tp, fp, tn) + if average == "micro": + tn = xp.sum(tn, axis=0) + fp = xp.sum(fp, axis=0) + return safe_divide(tn, tn + fp) + + score = safe_divide(tn, tn + fp) + return _adjust_weight_apply_average( + score, + average, + is_multilabel=is_multilabel, + tp=tp, + fp=fp, + fn=fn, + xp=xp, + ) + + +def _binary_specificity_compute(*, fp: Array, tn: Array) -> Array: + return squeeze_all(safe_divide(tn, tn + fp)) + + +def binary_specificity( + target: Array, + preds: Array, + threshold: float = 0.5, + ignore_index: Optional[int] = None, +) -> Array: + """Measure how well a binary classifier identifies negative samples. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels. The expected shape of the array + is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the predictions of a binary classifier. The expected shape of the + array is `(N, ...)` where `N` is the number of samples. If `preds` contains + floating point values that are not in the range `[0, 1]`, a sigmoid function + will be applied to each value before thresholding. + ignore_index : int, optional, default=None + Specifies a target class that is ignored when computing the specificity score. + Ignoring a target class means that the corresponding predictions do not + contribute to the specificity score. + + Returns + ------- + Array + An array API compatible object containing the specificity score. + + Raises + ------ + ValueError + If the arrays `target` and `preds` are not compatible with the Python + array API standard. + ValueError + If `target` or `preds` are empty. + ValueError + If `target` or `preds` are not numeric arrays. + ValueError + If `target` and `preds` have different shapes. + RuntimeError + If `target` contains values that are not in {0, 1}. + RuntimeError + If `preds` contains integer values that are not in {0, 1}. + ValueError + If `threshold` is not a float in the range [0, 1]. + ValueError + If `ignore_index` is not `None` or an integer. + + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import binary_specificity + >>> import numpy.array_api as anp + >>> target = anp.asarray([1, 1, 0, 1, 0, 1]) + >>> preds = anp.asarray([1, 0, 1, 1, 0, 1]) + >>> binary_specificity(target, preds) + Array(0.5, dtype=float32) + >>> binary_specificity(target, preds, ignore_index=0) + Array(0., dtype=float32) + >>> target = anp.asarray([1, 1, 0, 1, 0, 1]) + >>> preds = anp.asarray([0.61, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_specificity(target, preds) + Array(0.5, dtype=float32) + >>> binary_specificity(target, preds, threshold=0.8) + Array(0.5, dtype=float32) + + """ + _binary_stat_scores_validate_args( + threshold=threshold, + ignore_index=ignore_index, + ) + xp = _binary_stat_scores_validate_arrays( + target, + preds, + ignore_index=ignore_index, + ) + target, preds = _binary_stat_scores_format_arrays( + target, + preds, + threshold=threshold, + ignore_index=ignore_index, + xp=xp, + ) + tn, fp, _, _ = _binary_stat_scores_update_state(target, preds, xp=xp) + return _binary_specificity_compute(fp=fp, tn=tn) + + +def multiclass_specificity( + target: Array, + preds: Array, + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> Array: + """Measure how well a classifier identifies negative samples. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels. The expected shape of the array + is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the predictions of a classifier. If `preds` contains integer values + the expected shape of the array is `(N, ...)`, where `N` is the number of + samples. If `preds` contains floating point values the expected shape of the + array is `(N, C, ...)` where `N` is the number of samples and `C` is the + number of classes. + num_classes : int + The number of classes in the classification task. + top_k : int, default=1 + The number of highest probability or logit score predictions to consider + when computing the specificity score. By default, only the top prediction is + considered. This parameter is ignored if `preds` contains integer values. + average : {'micro', 'macro', 'weighted', 'none'}, optional, default='micro' + Specifies the type of averaging to apply to the specificity scores. Should + be one of the following: + - `'micro'`: Compute the specificity score globally by considering all + predictions and all targets. + - `'macro'`: Compute the specificity score for each class individually and + then take the unweighted mean of the specificity scores. + - `'weighted'`: Compute the specificity score for each class individually + and then take the mean of the specificity scores weighted by the support + (the number of true positives + the number of false negatives) for each + class. + - `'none'` or `None`: Compute the specificity score for each class individually + and return the scores as an array. + ignore_index : int or tuple of int, optional, default=None + Specifies a target class that is ignored when computing the specificity score. + Ignoring a target class means that the corresponding predictions do not + contribute to the specificity score. + + + Returns + ------- + Array + An array API compatible object containing the specificity score(s). + + Raises + ------ + ValueError + If the arrays `target` and `preds` are not compatible with the Python + array API standard. + ValueError + If `target` or `preds` are empty. + ValueError + If `target` or `preds` are not numeric arrays. + ValueError + If `preds` has one more dimension than `target` but `preds` does not + contain floating point values. + ValueError + If `preds` has one more dimension than `target` and the second dimension + (first dimension, if `preds` is a scalar) of `preds` is not equal to + `num_classes`. In the multidimensional case (i.e., `preds` has more than + two dimensions), the rest of the dimensions must be the same for `target` + and `preds`. + ValueError + If `preds` and `target` have the same number of dimensions but not the + same shape. + RuntimeError + If `target` or `preds` contain values that are not in + {0, 1, ..., num_classes-1} or `target` contains more values than specified + in `ignore_index`. + ValueError + If `num_classes` is not a positive integer greater than two. + ValueError + If `top_k` is not a positive integer. + ValueError + If `top_k` is greater than the number of classes. + ValueError + If `average` is not one of {`'micro'`, `'macro'`, `'weighted'`, `'none'`, + `None`}. + ValueError + If `ignore_index` is not `None`, an integer, or a tuple of integers. + + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... multiclass_specificity + ... ) + >>> import numpy.array_api as anp + >>> target = anp.asarray([2, 1, 0, 0]) + >>> preds = anp.asarray([2, 1, 0, 1]) + >>> multiclass_specificity(target, preds, num_classes=3) + Array(0.875, dtype=float32) + >>> target = anp.asarray([2, 1, 0, 0]) + >>> preds = anp.asarray( + ... [[0.1, 0.1, 0.8], [0.2, 0.7, 0.1], [0.9, 0.1, 0.0], [0.4, 0.6, 0.0]], + ... ) + >>> multiclass_specificity(target, preds, num_classes=3) + Array(0.875, dtype=float32) + >>> multiclass_specificity(target, preds, num_classes=3, top_k=2) + Array(0.5, dtype=float32) + >>> multiclass_specificity(target, preds, num_classes=3, average=None) + Array([1. , 0.6666667, 1. ], dtype=float32) + >>> multiclass_specificity(target, preds, num_classes=3, average="macro") + Array(0.88888896, dtype=float32) + >>> multiclass_specificity(target, preds, num_classes=3, average="weighted") + Array(0.9166667, dtype=float32) + >>> multiclass_specificity(target, preds, num_classes=3, ignore_index=0) + Array(1., dtype=float32) + >>> multiclass_specificity( + ... target, preds, num_classes=3, average=None, ignore_index=(1, 2), + ... ) + Array([0. , 0.5, 1. ], dtype=float32) + + """ + _multiclass_stat_scores_validate_args( + num_classes, + top_k=top_k, + average=average, + ignore_index=ignore_index, + ) + xp = _multiclass_stat_scores_validate_arrays( + target, + preds, + num_classes, + top_k=top_k, + ignore_index=ignore_index, + ) + + target, preds = _multiclass_stat_scores_format_arrays( + target, + preds, + top_k=top_k, + xp=xp, + ) + tn, fp, fn, tp = _multiclass_stat_scores_update_state( + target, + preds, + num_classes, + top_k=top_k, + average=average, + ignore_index=ignore_index, + xp=xp, + ) + return _specificity_compute( + average, # type: ignore[arg-type] + is_multilabel=False, + tp=tp, + fp=fp, + tn=tn, + fn=fn, + ) + + +def multilabel_specificity( + target: Array, + preds: Array, + num_labels: int, + threshold: float = 0.5, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> Array: + """Measure the proportion of positive predictions that are true positive. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels. The expected shape of the array + is `(N, L, ...)`, where `N` is the number of samples and `L` is the + number of labels. + preds : Array + An array object that is compatible with the Python array API standard and + contains the predictions of a classifier. The expected shape of the array + is `(N, L, ...)`, where `N` is the number of samples and `L` is the + number of labels. If `preds` contains floating point values that are not + in the range `[0, 1]`, a sigmoid function will be applied to each value + before thresholding. + num_labels : int + The number of labels in the classification task. + threshold : float, optional, default=0.5 + The threshold used to convert probabilities to binary values. + top_k : int, optional, default=1 + The number of highest probability predictions to assign the value `1` + (all other predictions are assigned the value `0`). By default, only the + highest probability prediction is considered. This parameter is ignored + if `preds` does not contain floating point values. + average : {'micro', 'macro', 'weighted', 'none'}, optional, default='macro' + Specifies the type of averaging to apply to the specificity scores. Should + be one of the following: + - `'micro'`: Compute the specificity score globally by considering all + predictions and all targets. + - `'macro'`: Compute the specificity score for each label individually and then + take the unweighted mean of the specificity scores. + - `'weighted'`: Compute the specificity score for each label individually + and then take the mean of the specificity scores weighted by the support + (the number of true positives + the number of false negatives) for each + label. + - `'none'` or `None`: Compute the specificity score for each label individually + and return the scores as an array. + ignore_index : int, optional, default=None + Specifies value in `target` that is ignored when computing the specificity + score. + + Raises + ------ + ValueError + If the arrays `target` and `preds` are not compatible with the Python + array API standard. + ValueError + If `target` or `preds` are empty. + ValueError + If `target` or `preds` are not numeric arrays. + ValueError + If `target` and `preds` have different shapes. + ValueError + If the second dimension of `target` and `preds` is not equal to `num_labels`. + RuntimeError + If `target` contains values that are not in {0, 1} or not in `ignore_index`. + RuntimeError + If `preds` contains integer values that are not in {0, 1}. + ValueError + If `num_labels` is not a positive integer greater than two. + ValueError + If `threshold` is not a float in the range [0, 1]. + ValueError + If `top_k` is not a positive integer. + ValueError + If `top_k` is greater than the number of labels. + ValueError + If `average` is not one of {`'micro'`, `'macro'`, `'weighted'`, `'none'`, + `None`}. + ValueError + If `ignore_index` is not `None` or an integer. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... multilabel_specificity + ... ) + >>> import numpy.array_api as anp + >>> target = anp.asarray([[0, 1, 0], [1, 0, 1]]) + >>> preds = anp.asarray([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_specificity(target, preds, num_labels=3) + Array(0.6666667, dtype=float32) + >>> target = anp.asarray([[1, 0, 1, 0], [1, 1, 0, 1]]) + >>> preds = anp.asarray([[0.11, 0.58, 0.22, 0.84], [0.73, 0.47, 0.33, 0.92]]) + >>> multilabel_specificity(target, preds, num_labels=4) + Array(0.25, dtype=float32) + >>> multilabel_specificity(target, preds, num_labels=4, top_k=2) + Array(0.25, dtype=float32) + >>> multilabel_specificity(target, preds, num_labels=4, threshold=0.7) + Array(0.5, dtype=float32) + >>> multilabel_specificity(target, preds, num_labels=4, average=None) + Array([0., 0., 1., 0.], dtype=float32) + >>> multilabel_specificity(target, preds, num_labels=4, average="micro") + Array(0.33333334, dtype=float32) + >>> multilabel_specificity(target, preds, num_labels=4, average="weighted") + Array(0.2, dtype=float32) + >>> multilabel_specificity( + ... target, preds, num_labels=4, average=None, ignore_index=1, + ... ) + Array([0., 0., 1., 0.], dtype=float32) + + """ + xp = _multilabel_stat_scores_validate_arrays( + target, + preds, + num_labels, + ignore_index=ignore_index, + ) + target, preds = _multilabel_stat_scores_format_arrays( + target, + preds, + top_k=top_k, + threshold=threshold, + ignore_index=ignore_index, + xp=xp, + ) + tn, fp, fn, tp = _multilabel_stat_scores_update_state(target, preds, xp=xp) + return _specificity_compute( + average, # type: ignore[arg-type] + is_multilabel=True, + tp=tp, + fp=fp, + tn=tn, + fn=fn, + ) diff --git a/cyclops/evaluate/metrics/experimental/specificity.py b/cyclops/evaluate/metrics/experimental/specificity.py new file mode 100644 index 000000000..e4b433094 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/specificity.py @@ -0,0 +1,186 @@ +"""Classes for computing specificity scores for classification tasks.""" +from cyclops.evaluate.metrics.experimental._stat_scores import ( + _AbstractBinaryStatScores, + _AbstractMulticlassStatScores, + _AbstractMultilabelStatScores, +) +from cyclops.evaluate.metrics.experimental.functional.specificity import ( + _binary_specificity_compute, + _specificity_compute, +) +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class BinarySpecificity(_AbstractBinaryStatScores, registry_key="binary_specificity"): + """The proportion of actual negatives that are correctly identified. + + Parameters + ---------- + threshold : float, default=0.5 + Threshold for converting probabilities into binary values. + ignore_index : int, optional + Values in the target array to ignore when computing the metric. + **kwargs + Additional keyword arguments common to all metrics. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental import BinarySpecificity + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 0, 1]) + >>> preds = anp.asarray([0, 1, 1, 1]) + >>> metric = BinarySpecificity() + >>> metric(target, preds) + Array(0.5, dtype=float32) + >>> metric.reset() + >>> target = [[0, 1, 0, 1], [1, 0, 1, 0]] + >>> preds = [[0, 1, 1, 1], [1, 0, 1, 0]] + >>> for t, p in zip(target, preds): + ... metric.update(anp.asarray(t), anp.asarray(p)) + >>> metric.compute() + Array(0.75, dtype=float32) + + """ + + name: str = "Specificity Score" + + def _compute_metric(self) -> Array: + """Compute the specificity score.""" + tn, fp, _, _ = self._final_state() + return _binary_specificity_compute(fp=fp, tn=tn) + + +class MulticlassSpecificity( + _AbstractMulticlassStatScores, + registry_key="multiclass_specificity", +): + """The proportion of actual negatives that are correctly identified. + + Parameters + ---------- + num_classes : int + The number of classes in the classification task. + top_k : int, default=1 + The number of highest probability or logit score predictions to consider + when computing the specificity score. By default, only the top prediction is + considered. This parameter is ignored if `preds` contains integer values. + average : {'micro', 'macro', 'weighted', 'none'}, optional, default='micro' + Specifies the type of averaging to apply to the specificity scores. Should + be one of the following: + - `'micro'`: Compute the specificity score globally by considering all + predictions and all targets. + - `'macro'`: Compute the specificity score for each class individually and + then take the unweighted mean of the specificity scores. + - `'weighted'`: Compute the specificity score for each class individually + and then take the mean of the specificity scores weighted by the support + (the number of true positives + the number of false negatives) for + each class. + - `'none'` or `None`: Compute the specificity score for each class individually + and return the scores as an array. + ignore_index : int or tuple of int, optional, default=None + Specifies a target class that is ignored when computing the specificity score. + Ignoring a target class means that the corresponding predictions do not + contribute to the specificity score. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental import MulticlassSpecificity + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 2, 2, 2]) + >>> preds = anp.asarray([0, 0, 2, 2, 1]) + >>> metric = MulticlassSpecificity(num_classes=3) + >>> metric(target, preds) + Array(0.8, dtype=float32) + >>> metric.reset() + >>> target = [[0, 1, 2], [2, 1, 0]] + >>> preds = [[[0.05, 0.95, 0], [0.1, 0.8, 0.1], [0.2, 0.6, 0.2]], + ... [[0.1, 0.8, 0.1], [0.05, 0.95, 0], [0.2, 0.6, 0.2]]] + >>> for t, p in zip(target, preds): + ... metric.update(anp.asarray(t), anp.asarray(p)) + >>> metric.compute() + Array(0.6666667, dtype=float32) + + """ + + name: str = "Specificity Score" + + def _compute_metric(self) -> Array: + """Compute the specificity score(s).""" + tn, fp, fn, tp = self._final_state() + return _specificity_compute( + self.average, # type: ignore[arg-type] + is_multilabel=False, + tp=tp, + fp=fp, + tn=tn, + fn=fn, + ) + + +class MultilabelSpecificity( + _AbstractMultilabelStatScores, + registry_key="multilabel_specificity", +): + """The proportion of actual negatives that are correctly identified. + + Parameters + ---------- + num_labels : int + The number of labels in the classification task. + threshold : float, optional, default=0.5 + The threshold used to convert probabilities to binary values. + top_k : int, optional, default=1 + The number of highest probability predictions to assign the value `1` + (all other predictions are assigned the value `0`). By default, only the + highest probability prediction is considered. This parameter is ignored + if `preds` does not contain floating point values. + average : {'micro', 'macro', 'weighted', 'none'}, optional, default='macro' + Specifies the type of averaging to apply to the specificity scores. Should + be one of the following: + - `'micro'`: Compute the specificity score globally by considering all + predictions and all targets. + - `'macro'`: Compute the specificity score for each label individually and + then take the unweighted mean of the specificity scores. + - `'weighted'`: Compute the specificity score for each label individually + and then take the mean of the specificity scores weighted by the support + (the number of true positives + the number of false negatives) for each + label. + - `'none'` or `None`: Compute the specificity score for each label individually + and return the scores as an array. + ignore_index : int, optional, default=None + Specifies a value in the target array(s) that is ignored when computing + the specificity score. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental import MultilabelSpecificity + >>> import numpy.array_api as anp + >>> target = anp.asarray([[0, 1, 1], [1, 0, 0]]) + >>> preds = anp.asarray([[0, 1, 0], [1, 0, 1]]) + >>> metric = MultilabelSpecificity(num_labels=3) + >>> metric(target, preds) + Array(0.6666667, dtype=float32) + >>> metric.reset() + >>> target = [[[0, 1, 1], [1, 0, 0]], [[1, 0, 0], [0, 1, 1]]] + >>> preds = [[[0.05, 0.95, 0], [0.1, 0.8, 0.1]], + ... [[0.1, 0.8, 0.1], [0.05, 0.95, 0]]] + >>> for t, p in zip(target, preds): + ... metric.update(anp.asarray(t), anp.asarray(p)) + >>> metric.compute() + Array(0.6666667, dtype=float32) + + """ + + name: str = "Specificity Score" + + def _compute_metric(self) -> Array: + """Compute the specificity score(s).""" + tn, fp, fn, tp = self._final_state() + return _specificity_compute( + self.average, # type: ignore[arg-type] + is_multilabel=True, + tp=tp, + fp=fp, + tn=tn, + fn=fn, + ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_specificity.py b/tests/cyclops/evaluate/metrics/experimental/test_specificity.py new file mode 100644 index 000000000..b905a15d2 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_specificity.py @@ -0,0 +1,482 @@ +"""Test specificity recall metrics.""" +from functools import partial +from typing import Literal, Optional + +import array_api_compat as apc +import array_api_compat.torch +import numpy as np +import numpy.array_api as anp +import pytest +import torch.utils.dlpack +from torchmetrics.functional.classification.specificity import ( + binary_specificity as tm_binary_specificity, +) +from torchmetrics.functional.classification.specificity import ( + multiclass_specificity as tm_multiclass_specificity, +) +from torchmetrics.functional.classification.specificity import ( + multilabel_specificity as tm_multilabel_specificity, +) + +from cyclops.evaluate.metrics.experimental.functional.specificity import ( + binary_specificity, + multiclass_specificity, + multilabel_specificity, +) +from cyclops.evaluate.metrics.experimental.specificity import ( + BinarySpecificity, + MulticlassSpecificity, + MultilabelSpecificity, +) +from cyclops.evaluate.metrics.experimental.utils.ops import to_int +from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point + +from ..conftest import NUM_CLASSES, NUM_LABELS, THRESHOLD +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from .testers import MetricTester, _inject_ignore_index + + +def _binary_specificity_reference( + target, + preds, + threshold, + ignore_index, +) -> torch.Tensor: + """Compute binary specificity using torchmetrics.""" + return tm_binary_specificity( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + threshold=threshold, + ignore_index=ignore_index, + ) + + +class TestBinarySpecificity(MetricTester): + """Test binary specificity metric class and function.""" + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_specificity_function_with_numpy_array_api_arrays( + self, + inputs, + ignore_index, + ) -> None: + """Test function for binary specificity using `numpy.array_api` arrays.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=binary_specificity, + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index}, + reference_metric=partial( + _binary_specificity_reference, + threshold=THRESHOLD, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_specificity_class_with_numpy_array_api_arrays( + self, + inputs, + ignore_index, + ) -> None: + """Test class for binary specificity using `numpy.array_api` arrays.""" + target, preds = inputs + + if ( + preds.ndim == 1 + and is_floating_point(preds) + and not anp.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinarySpecificity, + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index}, + reference_metric=partial( + _binary_specificity_reference, + threshold=THRESHOLD, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_specificity_class_with_torch_tensors( + self, + inputs, + ignore_index, + ) -> None: + """Test binary specificity class with torch tensors.""" + target, preds = inputs + + if ( + preds.ndim == 1 + and is_floating_point(preds) + and not torch.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinarySpecificity, + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index}, + reference_metric=partial( + _binary_specificity_reference, + threshold=THRESHOLD, + ignore_index=ignore_index, + ), + device=device, + use_device_for_ref=True, + ) + + +def _multiclass_specificity_reference( + target, + preds, + num_classes=NUM_CLASSES, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted"]] = "micro", + ignore_index=None, +) -> torch.Tensor: + """Compute multiclass specificity using torchmetrics.""" + if preds.ndim == 1 and is_floating_point(preds): + xp = apc.array_namespace(preds) + preds = xp.argmax(preds, axis=0) + + return tm_multiclass_specificity( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_classes=num_classes, + top_k=top_k, + average=average, + ignore_index=ignore_index, + ) + + +class TestMulticlassSpecificity(MetricTester): + """Test multiclass specificity metric class and function.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)) + @pytest.mark.parametrize("top_k", [1, 2]) + @pytest.mark.parametrize("average", [None, "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_specificity_function_with_numpy_array_api_arrays( + self, + inputs, + top_k, + average, + ignore_index, + ) -> None: + """Test function for multiclass specificity using `numpy.array_api` arrays.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + if top_k > 1 and not is_floating_point(preds): + with pytest.raises(ValueError): + multiclass_specificity( + target, + preds, + num_classes=NUM_CLASSES, + top_k=top_k, + average=average, + ignore_index=ignore_index, + ) + else: + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multiclass_specificity, + metric_args={ + "num_classes": NUM_CLASSES, + "top_k": top_k, + "average": average, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _multiclass_specificity_reference, + num_classes=NUM_CLASSES, + top_k=top_k, + average=average, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)) + @pytest.mark.parametrize("top_k", [1, 2]) + @pytest.mark.parametrize("average", [None, "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_specificity_class_with_numpy_array_api_arrays( + self, + inputs, + top_k, + average, + ignore_index, + ) -> None: + """Test class for multiclass specificity using `numpy.array_api` arrays.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + if top_k > 1 and not is_floating_point(preds): + with pytest.raises(ValueError): + metric = MulticlassSpecificity( + num_classes=NUM_CLASSES, + top_k=top_k, + average=average, + ignore_index=ignore_index, + ) + metric(target, preds) + else: + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassSpecificity, + metric_args={ + "num_classes": NUM_CLASSES, + "top_k": top_k, + "average": average, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _multiclass_specificity_reference, + num_classes=NUM_CLASSES, + top_k=top_k, + average=average, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)) + @pytest.mark.parametrize("top_k", [1, 2]) + @pytest.mark.parametrize("average", [None, "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_specificity_class_with_torch_tensors( + self, + inputs, + top_k, + average, + ignore_index, + ) -> None: + """Test multiclass specificity class with torch tensors.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + if top_k > 1 and not is_floating_point(preds): + with pytest.raises(ValueError): + metric = MulticlassSpecificity( + num_classes=NUM_CLASSES, + top_k=top_k, + average=average, + ignore_index=ignore_index, + ) + metric(target, preds) + else: + device = "cuda" if torch.cuda.is_available() else "cpu" + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassSpecificity, + reference_metric=partial( + _multiclass_specificity_reference, + num_classes=NUM_CLASSES, + top_k=top_k, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "top_k": top_k, + "average": average, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) + + +def _multilabel_specificity_reference( + target, + preds, + threshold, + num_labels=NUM_LABELS, + average: Optional[Literal["micro", "macro", "weighted"]] = "macro", + ignore_index=None, +) -> torch.Tensor: + """Compute multilabel specificity using torchmetrics.""" + return tm_multilabel_specificity( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_labels=num_labels, + threshold=threshold, + average=average, + ignore_index=ignore_index, + ) + + +class TestMultilabelSpecificity(MetricTester): + """Test multilabel specificity function and class.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)) + @pytest.mark.parametrize("average", [None, "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_specificity_with_numpy_array_api_arrays( + self, + inputs, + average, + ignore_index, + ) -> None: + """Test function for multilabel specificity with `numpy.array_api` arrays.""" + target, preds = inputs + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multilabel_specificity, + reference_metric=partial( + _multilabel_specificity_reference, + num_labels=NUM_LABELS, + threshold=THRESHOLD, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "threshold": THRESHOLD, + "num_labels": NUM_LABELS, + "average": average, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)) + @pytest.mark.parametrize("average", [None, "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_specificity_class_with_numpy_array_api_arrays( + self, + inputs, + average, + ignore_index, + ) -> None: + """Test class for multilabel specificity with `numpy.array_api` arrays.""" + target, preds = inputs + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelSpecificity, + reference_metric=partial( + _multilabel_specificity_reference, + num_labels=NUM_LABELS, + threshold=THRESHOLD, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "threshold": THRESHOLD, + "num_labels": NUM_LABELS, + "average": average, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)) + @pytest.mark.parametrize("average", [None, "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_specificity_class_with_torch_tensors( + self, + inputs, + average, + ignore_index, + ) -> None: + """Test class for multilabel specificity with torch tensors.""" + target, preds = inputs + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelSpecificity, + reference_metric=partial( + _multilabel_specificity_reference, + num_labels=NUM_LABELS, + threshold=THRESHOLD, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "threshold": THRESHOLD, + "num_labels": NUM_LABELS, + "average": average, + "ignore_index": ignore_index, + }, + ) + + +def test_top_k_multilabel_specificity(): + """Test top-k multilabel specificity.""" + target = anp.asarray([[0, 1, 1, 0], [1, 0, 1, 0]]) + preds = anp.asarray([[0.1, 0.9, 0.8, 0.3], [0.9, 0.1, 0.8, 0.3]]) + expected_result = anp.asarray([1.0, 1.0, 0.0, 1.0], dtype=anp.float32) + + result = multilabel_specificity(target, preds, num_labels=4, average=None, top_k=2) + assert np.allclose(result, expected_result) + + metric = MultilabelSpecificity(num_labels=4, average=None, top_k=2) + metric(target, preds) + class_result = metric.compute() + assert np.allclose(class_result, expected_result) + metric.reset() + + preds = anp.asarray( + [ + [[0.57, 0.63], [0.33, 0.55], [0.73, 0.55], [0.36, 0.66]], + [[0.78, 0.94], [0.47, 0.31], [0.14, 0.28], [0.35, 0.81]], + ], + ) + target = anp.asarray( + [[[0, 0], [1, 1], [0, 1], [0, 0]], [[0, 1], [0, 1], [1, 0], [0, 0]]], + ) + expected_result = anp.asarray([0.0, 0.0, 0.5, 0.5], dtype=anp.float32) + + result = multilabel_specificity(target, preds, num_labels=4, average=None, top_k=2) + assert np.allclose(result, expected_result) + + class_result = metric(target, preds) + assert np.allclose(class_result, expected_result)