Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Matthews Correlation Coefficient (MCC) metric #550

Merged
merged 18 commits into from
Feb 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
revert implementation update & add print statements for debugging
  • Loading branch information
fcogidi committed Feb 16, 2024

Verified

This commit was signed with the committer’s verified signature.
commit 2612934d59705d3bd27f23652c6248db405173f7
30 changes: 13 additions & 17 deletions cyclops/evaluate/metrics/experimental/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -277,7 +277,10 @@ def _compute_metric(self) -> Array:
)


class MultilabelConfusionMatrix(Metric, registry_key="multilabel_confusion_matrix"):
class MultilabelConfusionMatrix(
_AbstractConfusionMatrix,
registry_key="multilabel_confusion_matrix",
):
"""Confusion matrix for multilabel classification tasks.

Parameters
@@ -327,8 +330,6 @@ class MultilabelConfusionMatrix(Metric, registry_key="multilabel_confusion_matri

"""

name: str = "Confusion Matrix"

def __init__(
self,
num_labels: int,
@@ -352,11 +353,7 @@ def __init__(
self.normalize = normalize
self.ignore_index = ignore_index

self.add_state_default_factory(
"confmat",
lambda xp: xp.zeros((num_labels, 2, 2), dtype=xp.int64, device=self.device), # type: ignore
dist_reduce_fn="sum",
)
self._create_state(size=num_labels)

def _update_state(self, target: Array, preds: Array) -> None:
"""Update the state variables."""
@@ -369,22 +366,21 @@ def _update_state(self, target: Array, preds: Array) -> None:
target, preds = _multilabel_confusion_matrix_format_arrays(
target,
preds,
self.num_labels,
threshold=self.threshold,
ignore_index=self.ignore_index,
xp=xp,
)
confmat = _multilabel_confusion_matrix_update_state(
target,
preds,
self.num_labels,
xp=xp,
)
self.confmat += confmat # type: ignore
tn, fp, fn, tp = _multilabel_confusion_matrix_update_state(target, preds, xp=xp)
self._update_stat_scores(tn=tn, fp=fp, fn=fn, tp=tp)

def _compute_metric(self) -> Array:
"""Compute the confusion matrix."""
tn, fp, fn, tp = self._final_state()
return _multilabel_confusion_matrix_compute(
self.confmat, # type: ignore
tp=tp,
fp=fp,
tn=tn,
fn=fn,
num_labels=self.num_labels,
normalize=self.normalize,
)
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@
bincount,
clone,
flatten,
moveaxis,
remove_ignore_index,
safe_divide,
sigmoid,
@@ -601,7 +600,6 @@ def _multilabel_confusion_matrix_validate_arrays(
def _multilabel_confusion_matrix_format_arrays(
target: Array,
preds: Array,
num_labels: int,
threshold: float = 0.5,
ignore_index: Optional[int] = None,
*,
@@ -616,41 +614,48 @@ def _multilabel_confusion_matrix_format_arrays(
preds = sigmoid(preds) # convert logits to probabilities
preds = to_int(preds > threshold)

preds = xp.reshape(moveaxis(preds, 1, -1), shape=(-1, num_labels))
target = xp.reshape(moveaxis(target, 1, -1), shape=(-1, num_labels))
preds = xp.reshape(preds, shape=(*preds.shape[:2], -1))
target = xp.reshape(target, shape=(*target.shape[:2], -1))

if ignore_index is not None:
target = clone(target)
preds = clone(preds)
idx = target == ignore_index
target[idx] = -4 * num_labels
preds[idx] = -4 * num_labels
target = clone(target)
target[idx] = -1

return target, preds


def _multilabel_confusion_matrix_update_state(
target: Array,
preds: Array,
num_labels: int,
*,
xp: ModuleType,
) -> Array:
) -> Tuple[Array, Array, Array, Array]:
"""Compute the statistics for the given `target` and `preds` arrays."""
unique_mapping = (2 * target + preds) + 4 * flatten(
xp.arange(num_labels, device=apc.device(preds)),
)
unique_mapping = unique_mapping[unique_mapping >= 0]
bins = bincount(unique_mapping, minlength=4 * num_labels)
return xp.reshape(bins, shape=(num_labels, 2, 2))
sum_axis = (0, -1)
tp = squeeze_all(xp.sum(to_int((target == preds) & (target == 1)), axis=sum_axis))
fn = squeeze_all(xp.sum(to_int((target != preds) & (target == 1)), axis=sum_axis))
fp = squeeze_all(xp.sum(to_int((target != preds) & (target == 0)), axis=sum_axis))
tn = squeeze_all(xp.sum(to_int((target == preds) & (target == 0)), axis=sum_axis))

return tn, fp, fn, tp


def _multilabel_confusion_matrix_compute(
confmat: Array,
tn: Array,
fp: Array,
fn: Array,
tp: Array,
num_labels: int,
normalize: Optional[str] = None,
) -> Array:
"""Compute the confusion matrix from the given stat scores."""
xp = apc.array_namespace(confmat)
xp = apc.array_namespace(tn, fp, fn, tp)

confmat = squeeze_all(
xp.reshape(xp.stack([tn, fp, fn, tp], axis=-1), shape=(-1, num_labels, 2, 2)),
)

return _normalize_confusion_matrix(confmat, normalize=normalize, xp=xp)


@@ -764,19 +769,17 @@ class over the number of true samples for each class.
target, preds = _multilabel_confusion_matrix_format_arrays(
target,
preds,
num_labels,
threshold=threshold,
ignore_index=ignore_index,
xp=xp,
)
confmat = _multilabel_confusion_matrix_update_state(
target,
preds,
num_labels,
xp=xp,
)
tn, fp, fn, tp = _multilabel_confusion_matrix_update_state(target, preds, xp=xp)

return _multilabel_confusion_matrix_compute(
confmat,
tn,
fp,
fn,
tp,
num_labels,
normalize=normalize,
)
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
_multiclass_confusion_matrix_update_state,
_multiclass_confusion_matrix_validate_args,
_multiclass_confusion_matrix_validate_arrays,
_multilabel_confusion_matrix_compute,
_multilabel_confusion_matrix_format_arrays,
_multilabel_confusion_matrix_update_state,
_multilabel_confusion_matrix_validate_args,
@@ -25,9 +26,10 @@
def _mcc_reduce(confmat: Array) -> Array:
"""Reduce an un-normalized confusion matrix into the matthews corrcoef."""
xp = apc.array_namespace(confmat)

# convert multilabel into binary
confmat = xp.sum(confmat, axis=0) if confmat.ndim == 3 else confmat
print("confmat: ", confmat)
print("numel: ", apc.size(confmat))

if int(apc.size(confmat) or 0) == 4: # binary case
tn, fp, fn, tp = xp.reshape(xp.astype(confmat, xp.float32), (-1,))
@@ -38,16 +40,25 @@ def _mcc_reduce(confmat: Array) -> Array:
return xp.asarray(-1.0, dtype=xp.float32, device=apc.device(confmat)) # type: ignore[no-any-return]

tk = xp.sum(confmat, axis=-1, dtype=xp.float32) # tn + fp and tp + fn
print("tk: ", tk)
pk = xp.sum(confmat, axis=-2, dtype=xp.float32) # tn + fn and tp + fp
print("pk: ", pk)
c = xp.astype(xp.linalg.trace(confmat), xp.float32) # tn and tp
print("c: ", c)
s = xp.sum(confmat, dtype=xp.float32) # tn + tp + fn + fp
print("s: ", s)

cov_ytyp = c * s - sum(tk * pk)
print("cov_ytyp: ", cov_ytyp)
cov_ypyp = s**2 - sum(pk * pk)
print("cov_ypyp: ", cov_ypyp)
cov_ytyt = s**2 - sum(tk * tk)
print("cov_ytyt: ", cov_ytyt)

numerator = cov_ytyp
print("numerator: ", numerator)
denom = cov_ypyp * cov_ytyt
print("denom: ", denom)

if denom == 0 and int(apc.size(confmat) or 0) == 4:
if tp == 0 or tn == 0:
@@ -61,8 +72,11 @@ def _mcc_reduce(confmat: Array) -> Array:
dtype=xp.float32,
device=apc.device(confmat),
)
print("eps: ", eps)
numerator = xp.sqrt(eps) * (a - b)
print("numerator: ", numerator)
denom = (tp + fp + eps) * (tp + fn + eps) * (tn + fp + eps) * (tn + fn + eps)
print("denom: ", denom)
elif denom == 0:
return xp.asarray(0.0, dtype=xp.float32, device=apc.device(confmat)) # type: ignore[no-any-return]
return numerator / xp.sqrt(denom) # type: ignore[no-any-return]
@@ -334,16 +348,18 @@ def multilabel_mcc(
target, preds = _multilabel_confusion_matrix_format_arrays(
target,
preds,
num_labels,
threshold=threshold,
ignore_index=ignore_index,
xp=xp,
)
confmat = _multilabel_confusion_matrix_update_state(
target,
preds,
tn, fp, fn, tp = _multilabel_confusion_matrix_update_state(target, preds, xp=xp)

confmat = _multilabel_confusion_matrix_compute(
tn,
fp,
fn,
tp,
num_labels,
xp=xp,
normalize=None,
)

return _mcc_reduce(confmat)
12 changes: 11 additions & 1 deletion cyclops/evaluate/metrics/experimental/matthews_corr_coef.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
)
from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import (
_binary_confusion_matrix_compute,
_multilabel_confusion_matrix_compute,
)
from cyclops.evaluate.metrics.experimental.functional.matthews_corr_coef import (
_mcc_reduce,
@@ -175,4 +176,13 @@ def __init__(

def _compute_metric(self) -> Array:
"""Compute the confusion matrix."""
return _mcc_reduce(self.confmat) # type: ignore
tn, fp, fn, tp = self._final_state()
confmat = _multilabel_confusion_matrix_compute(
tp=tp,
fp=fp,
tn=tn,
fn=fn,
num_labels=self.num_labels,
normalize=self.normalize,
)
return _mcc_reduce(confmat)
Loading