Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Matthews Correlation Coefficient (MCC) metric #550

Merged
merged 18 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cyclops/evaluate/metrics/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
)
from cyclops.evaluate.metrics.experimental.mae import MeanAbsoluteError
from cyclops.evaluate.metrics.experimental.mape import MeanAbsolutePercentageError
from cyclops.evaluate.metrics.experimental.matthews_corr_coef import (
BinaryMCC,
MulticlassMCC,
MultilabelMCC,
)
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
from cyclops.evaluate.metrics.experimental.mse import MeanSquaredError
from cyclops.evaluate.metrics.experimental.negative_predictive_value import (
Expand Down
1 change: 1 addition & 0 deletions cyclops/evaluate/metrics/experimental/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Confusion matrix."""

from types import ModuleType
from typing import Any, Optional, Tuple, Union

Expand Down
5 changes: 5 additions & 0 deletions cyclops/evaluate/metrics/experimental/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from cyclops.evaluate.metrics.experimental.functional.mape import (
mean_absolute_percentage_error,
)
from cyclops.evaluate.metrics.experimental.functional.matthews_corr_coef import (
binary_mcc,
multiclass_mcc,
multilabel_mcc,
)
from cyclops.evaluate.metrics.experimental.functional.mse import mean_squared_error
from cyclops.evaluate.metrics.experimental.functional.negative_predictive_value import (
binary_npv,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions for computing the confusion matrix for classification tasks."""

# mypy: disable-error-code="no-any-return"
from types import ModuleType
from typing import Literal, Optional, Tuple, Union
Expand Down
355 changes: 355 additions & 0 deletions cyclops/evaluate/metrics/experimental/functional/matthews_corr_coef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
"""Functional API for the matthews correlation coefficient (MCC) metric."""

from typing import Optional, Tuple, Union

import array_api_compat as apc

from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import (
_binary_confusion_matrix_compute,
_binary_confusion_matrix_format_arrays,
_binary_confusion_matrix_update_state,
_binary_confusion_matrix_validate_args,
_binary_confusion_matrix_validate_arrays,
_multiclass_confusion_matrix_format_arrays,
_multiclass_confusion_matrix_update_state,
_multiclass_confusion_matrix_validate_args,
_multiclass_confusion_matrix_validate_arrays,
_multilabel_confusion_matrix_compute,
_multilabel_confusion_matrix_format_arrays,
_multilabel_confusion_matrix_update_state,
_multilabel_confusion_matrix_validate_args,
_multilabel_confusion_matrix_validate_arrays,
)
from cyclops.evaluate.metrics.experimental.utils.types import Array


def _mcc_reduce(confmat: Array) -> Array:
"""Reduce an un-normalized confusion matrix into the matthews corrcoef."""
xp = apc.array_namespace(confmat)
# convert multilabel into binary
confmat = xp.sum(confmat, axis=0) if confmat.ndim == 3 else confmat

if int(apc.size(confmat) or 0) == 4: # binary case
tn, fp, fn, tp = xp.reshape(xp.astype(confmat, xp.float64), (-1,))
if tp + tn != 0 and fp + fn == 0:
return xp.asarray(1.0, dtype=xp.float32, device=apc.device(confmat)) # type: ignore[no-any-return]

if tp + tn == 0 and fp + fn != 0:
return xp.asarray(-1.0, dtype=xp.float32, device=apc.device(confmat)) # type: ignore[no-any-return]

tk = xp.sum(confmat, axis=-1, dtype=xp.float64) # tn + fp and tp + fn
pk = xp.sum(confmat, axis=-2, dtype=xp.float64) # tn + fn and tp + fp
c = xp.astype(xp.linalg.trace(confmat), xp.float64) # tn and tp
s = xp.sum(confmat, dtype=xp.float64) # tn + tp + fn + fp

cov_ytyp = c * s - sum(tk * pk)
cov_ypyp = s**2 - sum(pk * pk)
cov_ytyt = s**2 - sum(tk * tk)

numerator = cov_ytyp
denom = cov_ypyp * cov_ytyt

if denom == 0 and int(apc.size(confmat) or 0) == 4:
if tp == 0 or tn == 0:
a = tp + tn

if fp == 0 or fn == 0:
b = fp + fn

eps = xp.asarray(
xp.finfo(xp.float32).eps,
dtype=xp.float32,
device=apc.device(confmat),
)
numerator = xp.sqrt(eps) * (a - b)
denom = (tp + fp + eps) * (tp + fn + eps) * (tn + fp + eps) * (tn + fn + eps)
elif denom == 0:
return xp.asarray(0.0, dtype=xp.float32, device=apc.device(confmat)) # type: ignore[no-any-return]
return xp.astype(numerator / xp.sqrt(denom), xp.float32) # type: ignore[no-any-return]


def binary_mcc(
target: Array,
preds: Array,
threshold: float = 0.5,
ignore_index: Optional[int] = None,
) -> Array:
"""Compute the matthews correlation coefficient for binary classification.

Parameters
----------
target : Array
An array object that is compatible with the Python array API standard
and contains the ground truth labels. The expected shape of the array
is `(N, ...)`, where `N` is the number of samples.
preds : Array
An array object that is compatible with the Python array API standard and
contains the predictions of a binary classifier. the expected shape of the
array is `(N, ...)` where `N` is the number of samples. If `preds` contains
floating point values that are not in the range `[0, 1]`, a sigmoid function
will be applied to each value before thresholding.
threshold : float, default=0.5
The threshold to use when converting probabilities to binary predictions.
ignore_index : int, optional, default=None
Specifies a target value that is ignored and does not contribute to the
metric. If `None`, ignore nothing.

Returns
-------
Array
The matthews correlation coefficient.

Raises
------
ValueError
If `target` and `preds` have different shapes.
ValueError
If `target` and `preds` are not array-API-compatible.
ValueError
If `target` or `preds` are empty.
ValueError
If `target` or `preds` are not numeric arrays.
ValueError
If `threshold` is not a float in the [0,1] range.
ValueError
If `normalize` is not one of `'pred'`, `'true'`, `'all'`, `'none'`, or `None`.
ValueError
If `ignore_index` is not `None` or an integer.

Examples
--------
>>> import numpy.array_api as anp
>>> from cyclops.evaluate.metrics.experimental.functional import binary_mcc
>>> target = anp.asarray([0, 1, 0, 1, 0, 1])
>>> preds = anp.asarray([0, 0, 1, 1, 0, 1])
>>> binary_mcc(target, preds)
Array(0.33333334, dtype=float32)
>>> target = anp.asarray([0, 1, 0, 1, 0, 1])
>>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_mcc(target, preds)
Array(0.33333334, dtype=float32)

"""
_binary_confusion_matrix_validate_args(
threshold=threshold,
normalize=None,
ignore_index=ignore_index,
)
xp = _binary_confusion_matrix_validate_arrays(target, preds, ignore_index)

target, preds = _binary_confusion_matrix_format_arrays(
target,
preds,
threshold,
ignore_index,
xp=xp,
)
tn, fp, fn, tp = _binary_confusion_matrix_update_state(target, preds, xp=xp)

confmat = _binary_confusion_matrix_compute(tn, fp, fn, tp, normalize=None)
return _mcc_reduce(confmat)


def multiclass_mcc(
target: Array,
preds: Array,
num_classes: int,
ignore_index: Optional[Union[int, Tuple[int]]] = None,
) -> Array:
"""Compute the matthews correlation coefficient for multiclass classification.

Parameters
----------
target : Array
The target array of shape `(N, ...)`, where `N` is the number of samples.
preds : Array
The prediction array with shape `(N, ...)`, for integer inputs, or
`(N, C, ...)`, for float inputs, where `N` is the number of samples and
`C` is the number of classes.
num_classes : int
The number of classes.
ignore_index : int, Tuple[int], optional, default=None
Specifies a target value(s) that is ignored and does not contribute to the
metric. If `None`, ignore nothing.

Returns
-------
Array
The matthews correlation coefficient.

Raises
------
ValueError
If `target` and `preds` are not array-API-compatible.
ValueError
If `target` or `preds` are empty.
ValueError
If `target` or `preds` are not numeric arrays.
ValueError
If `num_classes` is not an integer larger than 1.
ValueError
If `normalize` is not one of `'pred'`, `'true'`, `'all'`, `'none'`, or `None`.
ValueError
If `ignore_index` is not `None`, an integer or a tuple of integers.
ValueError
If `preds` contains floats but `target` does not have one dimension less than
`preds`.
ValueError
If the second dimension of `preds` is not equal to `num_classes`.
ValueError
If when `target` has one dimension less than `preds`, the shape of `preds` is
not `(N, C, ...)` while the shape of `target` is `(N, ...)`.
ValueError
If when `target` and `preds` have the same number of dimensions, they
do not have the same shape.
RuntimeError
If `target` contains values that are not in the range [0, `num_classes`).

Examples
--------
>>> import numpy.array_api as anp
>>> from cyclops.evaluate.metrics.experimental.functional import multiclass_mcc
>>> target = anp.asarray([2, 1, 0, 0])
>>> preds = anp.asarray([2, 1, 0, 1])
>>> multiclass_mcc(target, preds, num_classes=3)
Array(0.7, dtype=float32)
>>> target = anp.asarray([2, 1, 0, 0])
>>> preds = anp.asarray(
... [
... [0.16, 0.26, 0.58],
... [0.22, 0.61, 0.17],
... [0.71, 0.09, 0.20],
... [0.05, 0.82, 0.13],
... ]
... )
>>> multiclass_mcc(target, preds, num_classes=3)
Array(0.7, dtype=float32)

"""
_multiclass_confusion_matrix_validate_args(
num_classes,
normalize=None,
ignore_index=ignore_index,
)
xp = _multiclass_confusion_matrix_validate_arrays(
target,
preds,
num_classes,
ignore_index=ignore_index,
)

target, preds = _multiclass_confusion_matrix_format_arrays(
target,
preds,
ignore_index=ignore_index,
xp=xp,
)
confmat = _multiclass_confusion_matrix_update_state(
target,
preds,
num_classes,
xp=xp,
)
return _mcc_reduce(confmat)


def multilabel_mcc(
target: Array,
preds: Array,
num_labels: int,
threshold: float = 0.5,
ignore_index: Optional[int] = None,
) -> Array:
"""Compute the matthews correlation coefficient for multilabel classification.

Parameters
----------
target : Array
The target array of shape `(N, L, ...)`, where `N` is the number of samples
and `L` is the number of labels.
preds : Array
The prediction array of shape `(N, L, ...)`, where `N` is the number of
samples and `L` is the number of labels. If `preds` contains floats that
are not in the range [0,1], they will be converted to probabilities using
the sigmoid function.
num_labels : int
The number of labels.
threshold : float, default=0.5
The threshold to use for binarizing the predictions.
ignore_index : int, optional, default=None
Specifies a target value that is ignored and does not contribute to the
metric. If `None`, ignore nothing.

Returns
-------
Array
The matthews correlation coefficient.

Raises
------
ValueError
If `target` and `preds` are not array-API-compatible.
ValueError
If `target` or `preds` are empty.
ValueError
If `target` or `preds` are not numeric arrays.
ValueError
If `threshold` is not a float in the [0,1] range.
ValueError
If `normalize` is not one of `'pred'`, `'true'`, `'all'`, `'none'`, or `None`.
ValueError
If `ignore_index` is not `None` or a non-negative integer.
ValueError
If `num_labels` is not an integer larger than 1.
ValueError
If `target` and `preds` do not have the same shape.
ValueError
If the second dimension of `preds` is not equal to `num_labels`.
RuntimeError
If `target` contains values that are not in the range [0, 1].

Examples
--------
>>> import numpy.array_api as anp
>>> from cyclops.evaluate.metrics.experimental.functional import multilabel_mcc
>>> target = anp.asarray([[0, 1, 0], [1, 0, 1]])
>>> preds = anp.asarray([[0, 0, 1], [1, 0, 1]])
>>> multilabel_mcc(target, preds, num_labels=3)
Array(0.33333334, dtype=float32)
>>> target = anp.asarray([[0, 1, 0], [1, 0, 1]])
>>> preds = anp.asarray([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_mcc(target, preds, num_labels=3)
Array(0.33333334, dtype=float32)

"""
_multilabel_confusion_matrix_validate_args(
num_labels,
threshold=threshold,
normalize=None,
ignore_index=ignore_index,
)
xp = _multilabel_confusion_matrix_validate_arrays(
target,
preds,
num_labels,
ignore_index=ignore_index,
)

target, preds = _multilabel_confusion_matrix_format_arrays(
target,
preds,
threshold=threshold,
ignore_index=ignore_index,
xp=xp,
)
tn, fp, fn, tp = _multilabel_confusion_matrix_update_state(target, preds, xp=xp)

confmat = _multilabel_confusion_matrix_compute(
tn,
fp,
fn,
tp,
num_labels,
normalize=None,
)
return _mcc_reduce(confmat)
Loading
Loading