From f3b67344dea8b50efa1d1a79e2e659ab21093d72 Mon Sep 17 00:00:00 2001 From: Ankit Jha <45779665+PyExtreme@users.noreply.github.com> Date: Wed, 22 Jan 2020 02:44:40 +0530 Subject: [PATCH] Add Dice Coefficient (#680) * Add Dice Coefficient * Add Dice Coefficient * add test: dice_coefficient * fix minor issues * fix minor issues * add ignore_index: dice_coefficient * fix minor issues: dice_coefficient * add logic: ignore_index * add logic: ignore_index * Updated DiceCoefficient code * Updated docs Co-authored-by: vfdev --- docs/source/metrics.rst | 3 ++ ignite/metrics/__init__.py | 2 +- ignite/metrics/confusion_matrix.py | 46 ++++++++++++++++--- tests/ignite/metrics/test_confusion_matrix.py | 40 ++++++++++++++-- 4 files changed, 81 insertions(+), 10 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index c8572f33367..6f2a7718a5b 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -187,6 +187,7 @@ Complete list of metrics - :class:`~ignite.metrics.Accuracy` - :class:`~ignite.metrics.Average` - :class:`~ignite.metrics.ConfusionMatrix` + - :meth:`~ignite.metrics.DiceCoefficient` - :class:`~ignite.metrics.EpochMetric` - :meth:`~ignite.metrics.Fbeta` - :class:`~ignite.metrics.GeometricAverage` @@ -214,6 +215,8 @@ Complete list of metrics .. autoclass:: ConfusionMatrix +.. autofunction:: DiceCoefficient + .. autoclass:: EpochMetric .. autofunction:: Fbeta diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 88ff8a0304a..c1ec42e0680 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -11,6 +11,6 @@ from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy from ignite.metrics.running_average import RunningAverage from ignite.metrics.metrics_lambda import MetricsLambda -from ignite.metrics.confusion_matrix import ConfusionMatrix, IoU, mIoU +from ignite.metrics.confusion_matrix import ConfusionMatrix, IoU, mIoU, DiceCoefficient from ignite.metrics.accumulation import VariableAccumulation, Average, GeometricAverage from ignite.metrics.fbeta import Fbeta diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index 504aff53fd5..72630768d71 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -116,7 +116,7 @@ def compute(self): def IoU(cm, ignore_index=None): - """Calculates Intersection over Union + """Calculates Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric. Args: cm (ConfusionMatrix): instance of confusion matrix metric @@ -164,7 +164,7 @@ def ignore_index_fn(iou_vector): def mIoU(cm, ignore_index=None): - """Calculates mean Intersection over Union + """Calculates mean Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric. Args: cm (ConfusionMatrix): instance of confusion matrix metric @@ -191,8 +191,8 @@ def mIoU(cm, ignore_index=None): def cmAccuracy(cm): - """ - Calculates accuracy using :class:`~ignite.metrics.ConfusionMatrix` metric. + """Calculates accuracy using :class:`~ignite.metrics.ConfusionMatrix` metric. + Args: cm (ConfusionMatrix): instance of confusion matrix metric @@ -205,8 +205,8 @@ def cmAccuracy(cm): def cmPrecision(cm, average=True): - """ - Calculates precision using :class:`~ignite.metrics.ConfusionMatrix` metric. + """Calculates precision using :class:`~ignite.metrics.ConfusionMatrix` metric. + Args: cm (ConfusionMatrix): instance of confusion matrix metric average (bool, optional): if True metric value is averaged over all classes @@ -238,3 +238,37 @@ def cmRecall(cm, average=True): if average: return recall.mean() return recall + + +def DiceCoefficient(cm, ignore_index=None): + """Calculates Dice Coefficient for a given :class:`~ignite.metrics.ConfusionMatrix` metric. + + Args: + cm (ConfusionMatrix): instance of confusion matrix metric + ignore_index (int, optional): index to ignore, e.g. background index + """ + + if not isinstance(cm, ConfusionMatrix): + raise TypeError("Argument cm should be instance of ConfusionMatrix, but given {}".format(type(cm))) + + if ignore_index is not None: + if not (isinstance(ignore_index, numbers.Integral) and 0 <= ignore_index < cm.num_classes): + raise ValueError("ignore_index should be non-negative integer, but given {}".format(ignore_index)) + + # Increase floating point precision and pass to CPU + cm = cm.type(torch.DoubleTensor) + dice = 2.0 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-15) + + if ignore_index is not None: + + def ignore_index_fn(dice_vector): + if ignore_index >= len(dice_vector): + raise ValueError("ignore_index {} is larger than the length of Dice vector {}" + .format(ignore_index, len(dice_vector))) + indices = list(range(len(dice_vector))) + indices.remove(ignore_index) + return dice_vector[indices] + + return MetricsLambda(ignore_index_fn, dice) + else: + return dice diff --git a/tests/ignite/metrics/test_confusion_matrix.py b/tests/ignite/metrics/test_confusion_matrix.py index ab014ed29fe..38521a98266 100644 --- a/tests/ignite/metrics/test_confusion_matrix.py +++ b/tests/ignite/metrics/test_confusion_matrix.py @@ -1,4 +1,3 @@ -from __future__ import division import os import torch @@ -7,8 +6,7 @@ from ignite.exceptions import NotComputableError from ignite.metrics import ConfusionMatrix, IoU, mIoU -from ignite.metrics.confusion_matrix import cmAccuracy, cmPrecision, cmRecall - +from ignite.metrics.confusion_matrix import cmAccuracy, cmPrecision, cmRecall, DiceCoefficient import pytest @@ -477,6 +475,42 @@ def test_cm_with_average(): np.testing.assert_almost_equal(true_pr, res) +def test_dice_coefficient(): + + y_true, y_pred = get_y_true_y_pred() + th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred) + + true_res = [0, 0, 0] + for index in range(3): + bin_y_true = y_true == index + bin_y_pred = y_pred == index + # dice coefficient: 2*intersection(x, y) / (|x| + |y|) + # union(x, y) = |x| + |y| - intersection(x, y) + intersection = bin_y_true & bin_y_pred + union = bin_y_true | bin_y_pred + true_res[index] = 2.0 * intersection.sum() / (union.sum() + intersection.sum()) + + cm = ConfusionMatrix(num_classes=3) + dice_metric = DiceCoefficient(cm) + + # Update metric + output = (th_y_logits, th_y_true) + cm.update(output) + + res = dice_metric.compute().numpy() + np.testing.assert_allclose(res, true_res) + + for ignore_index in range(3): + cm = ConfusionMatrix(num_classes=3) + dice_metric = DiceCoefficient(cm, ignore_index=ignore_index) + # Update metric + output = (th_y_logits, th_y_true) + cm.update(output) + res = dice_metric.compute().numpy() + true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1:] + assert np.all(res == true_res_), "{}: {} vs {}".format(ignore_index, res, true_res_) + + def _test_distrib_multiclass_images(device): import torch.distributed as dist