Skip to content

Commit

Permalink
update implementation of multilabel confusion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Feb 16, 2024
1 parent 8c380c8 commit dfa89cf
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 66 deletions.
31 changes: 18 additions & 13 deletions cyclops/evaluate/metrics/experimental/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Confusion matrix."""

from types import ModuleType
from typing import Any, Optional, Tuple, Union

Expand Down Expand Up @@ -276,10 +277,7 @@ def _compute_metric(self) -> Array:
)


class MultilabelConfusionMatrix(
_AbstractConfusionMatrix,
registry_key="multilabel_confusion_matrix",
):
class MultilabelConfusionMatrix(Metric, registry_key="multilabel_confusion_matrix"):
"""Confusion matrix for multilabel classification tasks.
Parameters
Expand Down Expand Up @@ -329,6 +327,8 @@ class MultilabelConfusionMatrix(
"""

name: str = "Confusion Matrix"

def __init__(
self,
num_labels: int,
Expand All @@ -352,7 +352,11 @@ def __init__(
self.normalize = normalize
self.ignore_index = ignore_index

self._create_state(size=num_labels)
self.add_state_default_factory(
"confmat",
lambda xp: xp.zeros((num_labels, 2, 2), dtype=xp.int64, device=self.device), # type: ignore
dist_reduce_fn="sum",
)

def _update_state(self, target: Array, preds: Array) -> None:
"""Update the state variables."""
Expand All @@ -365,21 +369,22 @@ def _update_state(self, target: Array, preds: Array) -> None:
target, preds = _multilabel_confusion_matrix_format_arrays(
target,
preds,
self.num_labels,
threshold=self.threshold,
ignore_index=self.ignore_index,
xp=xp,
)
tn, fp, fn, tp = _multilabel_confusion_matrix_update_state(target, preds, xp=xp)
self._update_stat_scores(tn=tn, fp=fp, fn=fn, tp=tp)
confmat = _multilabel_confusion_matrix_update_state(
target,
preds,
self.num_labels,
xp=xp,
)
self.confmat += confmat # type: ignore

def _compute_metric(self) -> Array:
"""Compute the confusion matrix."""
tn, fp, fn, tp = self._final_state()
return _multilabel_confusion_matrix_compute(
tp=tp,
fp=fp,
tn=tn,
fn=fn,
num_labels=self.num_labels,
self.confmat, # type: ignore
normalize=self.normalize,
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions for computing the confusion matrix for classification tasks."""

# mypy: disable-error-code="no-any-return"
from types import ModuleType
from typing import Literal, Optional, Tuple, Union
Expand All @@ -9,6 +10,7 @@
bincount,
clone,
flatten,
moveaxis,
remove_ignore_index,
safe_divide,
sigmoid,
Expand Down Expand Up @@ -599,6 +601,7 @@ def _multilabel_confusion_matrix_validate_arrays(
def _multilabel_confusion_matrix_format_arrays(
target: Array,
preds: Array,
num_labels: int,
threshold: float = 0.5,
ignore_index: Optional[int] = None,
*,
Expand All @@ -613,48 +616,41 @@ def _multilabel_confusion_matrix_format_arrays(
preds = sigmoid(preds) # convert logits to probabilities
preds = to_int(preds > threshold)

preds = xp.reshape(preds, shape=(*preds.shape[:2], -1))
target = xp.reshape(target, shape=(*target.shape[:2], -1))
preds = xp.reshape(moveaxis(preds, 1, -1), shape=(-1, num_labels))
target = xp.reshape(moveaxis(target, 1, -1), shape=(-1, num_labels))

if ignore_index is not None:
idx = target == ignore_index
target = clone(target)
target[idx] = -1
preds = clone(preds)
idx = target == ignore_index
target[idx] = -4 * num_labels
preds[idx] = -4 * num_labels

return target, preds


def _multilabel_confusion_matrix_update_state(
target: Array,
preds: Array,
num_labels: int,
*,
xp: ModuleType,
) -> Tuple[Array, Array, Array, Array]:
) -> Array:
"""Compute the statistics for the given `target` and `preds` arrays."""
sum_axis = (0, -1)
tp = squeeze_all(xp.sum(to_int((target == preds) & (target == 1)), axis=sum_axis))
fn = squeeze_all(xp.sum(to_int((target != preds) & (target == 1)), axis=sum_axis))
fp = squeeze_all(xp.sum(to_int((target != preds) & (target == 0)), axis=sum_axis))
tn = squeeze_all(xp.sum(to_int((target == preds) & (target == 0)), axis=sum_axis))

return tn, fp, fn, tp
unique_mapping = (2 * target + preds) + 4 * flatten(
xp.arange(num_labels, device=apc.device(preds)),
)
unique_mapping = unique_mapping[unique_mapping >= 0]
bins = bincount(unique_mapping, minlength=4 * num_labels)
return xp.reshape(bins, shape=(num_labels, 2, 2))


def _multilabel_confusion_matrix_compute(
tn: Array,
fp: Array,
fn: Array,
tp: Array,
num_labels: int,
confmat: Array,
normalize: Optional[str] = None,
) -> Array:
"""Compute the confusion matrix from the given stat scores."""
xp = apc.array_namespace(tn, fp, fn, tp)

confmat = squeeze_all(
xp.reshape(xp.stack([tn, fp, fn, tp], axis=-1), shape=(-1, num_labels, 2, 2)),
)

xp = apc.array_namespace(confmat)
return _normalize_confusion_matrix(confmat, normalize=normalize, xp=xp)


Expand Down Expand Up @@ -768,17 +764,19 @@ class over the number of true samples for each class.
target, preds = _multilabel_confusion_matrix_format_arrays(
target,
preds,
num_labels,
threshold=threshold,
ignore_index=ignore_index,
xp=xp,
)
tn, fp, fn, tp = _multilabel_confusion_matrix_update_state(target, preds, xp=xp)
confmat = _multilabel_confusion_matrix_update_state(
target,
preds,
num_labels,
xp=xp,
)

return _multilabel_confusion_matrix_compute(
tn,
fp,
fn,
tp,
num_labels,
confmat,
normalize=normalize,
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functional API for the matthews correlation coefficient (MCC) metric."""

from typing import Optional, Tuple, Union

import array_api_compat as apc
Expand All @@ -13,7 +14,6 @@
_multiclass_confusion_matrix_update_state,
_multiclass_confusion_matrix_validate_args,
_multiclass_confusion_matrix_validate_arrays,
_multilabel_confusion_matrix_compute,
_multilabel_confusion_matrix_format_arrays,
_multilabel_confusion_matrix_update_state,
_multilabel_confusion_matrix_validate_args,
Expand All @@ -25,6 +25,7 @@
def _mcc_reduce(confmat: Array) -> Array:
"""Reduce an un-normalized confusion matrix into the matthews corrcoef."""
xp = apc.array_namespace(confmat)

# convert multilabel into binary
confmat = xp.sum(confmat, axis=0) if confmat.ndim == 3 else confmat

Expand All @@ -36,10 +37,10 @@ def _mcc_reduce(confmat: Array) -> Array:
if tp + tn == 0 and fp + fn != 0:
return xp.asarray(-1.0, dtype=xp.float32, device=apc.device(confmat)) # type: ignore[no-any-return]

tk = xp.sum(confmat, axis=-1, dtype=xp.float32)
pk = xp.sum(confmat, axis=-2, dtype=xp.float32)
c = xp.astype(xp.linalg.trace(confmat), xp.float32)
s = xp.sum(confmat, dtype=xp.float32)
tk = xp.sum(confmat, axis=-1, dtype=xp.float32) # tn + fp and tp + fn
pk = xp.sum(confmat, axis=-2, dtype=xp.float32) # tn + fn and tp + fp
c = xp.astype(xp.linalg.trace(confmat), xp.float32) # tn and tp
s = xp.sum(confmat, dtype=xp.float32) # tn + tp + fn + fp

cov_ytyp = c * s - sum(tk * pk)
cov_ypyp = s**2 - sum(pk * pk)
Expand Down Expand Up @@ -333,18 +334,16 @@ def multilabel_mcc(
target, preds = _multilabel_confusion_matrix_format_arrays(
target,
preds,
num_labels,
threshold=threshold,
ignore_index=ignore_index,
xp=xp,
)
tn, fp, fn, tp = _multilabel_confusion_matrix_update_state(target, preds, xp=xp)

confmat = _multilabel_confusion_matrix_compute(
tn,
fp,
fn,
tp,
confmat = _multilabel_confusion_matrix_update_state(
target,
preds,
num_labels,
normalize=None,
xp=xp,
)

return _mcc_reduce(confmat)
13 changes: 2 additions & 11 deletions cyclops/evaluate/metrics/experimental/matthews_corr_coef.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Matthews Correlation Coefficient (MCC) metric."""

from typing import Any, Optional, Tuple, Union

from cyclops.evaluate.metrics.experimental.confusion_matrix import (
Expand All @@ -8,7 +9,6 @@
)
from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import (
_binary_confusion_matrix_compute,
_multilabel_confusion_matrix_compute,
)
from cyclops.evaluate.metrics.experimental.functional.matthews_corr_coef import (
_mcc_reduce,
Expand Down Expand Up @@ -175,13 +175,4 @@ def __init__(

def _compute_metric(self) -> Array:
"""Compute the confusion matrix."""
tn, fp, fn, tp = self._final_state()
confmat = _multilabel_confusion_matrix_compute(
tp=tp,
fp=fp,
tn=tn,
fn=fn,
num_labels=self.num_labels,
normalize=self.normalize,
)
return _mcc_reduce(confmat)
return _mcc_reduce(self.confmat) # type: ignore

0 comments on commit dfa89cf

Please sign in to comment.