Skip to content

Commit

Permalink
Add Dice Coefficient (#680)
Browse files Browse the repository at this point in the history
* Add Dice Coefficient

* Add Dice Coefficient

* add test: dice_coefficient

* fix minor issues

* fix minor issues

* add ignore_index: dice_coefficient

* fix minor issues: dice_coefficient

* add  logic: ignore_index

* add  logic: ignore_index

* Updated DiceCoefficient code

* Updated docs

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
PyExtreme and vfdev-5 committed Jan 21, 2020
1 parent 5b1bcd3 commit f3b6734
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Complete list of metrics
- :class:`~ignite.metrics.Accuracy`
- :class:`~ignite.metrics.Average`
- :class:`~ignite.metrics.ConfusionMatrix`
- :meth:`~ignite.metrics.DiceCoefficient`
- :class:`~ignite.metrics.EpochMetric`
- :meth:`~ignite.metrics.Fbeta`
- :class:`~ignite.metrics.GeometricAverage`
Expand Down Expand Up @@ -214,6 +215,8 @@ Complete list of metrics

.. autoclass:: ConfusionMatrix

.. autofunction:: DiceCoefficient

.. autoclass:: EpochMetric

.. autofunction:: Fbeta
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.confusion_matrix import ConfusionMatrix, IoU, mIoU
from ignite.metrics.confusion_matrix import ConfusionMatrix, IoU, mIoU, DiceCoefficient
from ignite.metrics.accumulation import VariableAccumulation, Average, GeometricAverage
from ignite.metrics.fbeta import Fbeta
46 changes: 40 additions & 6 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def compute(self):


def IoU(cm, ignore_index=None):
"""Calculates Intersection over Union
"""Calculates Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric.
Args:
cm (ConfusionMatrix): instance of confusion matrix metric
Expand Down Expand Up @@ -164,7 +164,7 @@ def ignore_index_fn(iou_vector):


def mIoU(cm, ignore_index=None):
"""Calculates mean Intersection over Union
"""Calculates mean Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric.
Args:
cm (ConfusionMatrix): instance of confusion matrix metric
Expand All @@ -191,8 +191,8 @@ def mIoU(cm, ignore_index=None):


def cmAccuracy(cm):
"""
Calculates accuracy using :class:`~ignite.metrics.ConfusionMatrix` metric.
"""Calculates accuracy using :class:`~ignite.metrics.ConfusionMatrix` metric.
Args:
cm (ConfusionMatrix): instance of confusion matrix metric
Expand All @@ -205,8 +205,8 @@ def cmAccuracy(cm):


def cmPrecision(cm, average=True):
"""
Calculates precision using :class:`~ignite.metrics.ConfusionMatrix` metric.
"""Calculates precision using :class:`~ignite.metrics.ConfusionMatrix` metric.
Args:
cm (ConfusionMatrix): instance of confusion matrix metric
average (bool, optional): if True metric value is averaged over all classes
Expand Down Expand Up @@ -238,3 +238,37 @@ def cmRecall(cm, average=True):
if average:
return recall.mean()
return recall


def DiceCoefficient(cm, ignore_index=None):
"""Calculates Dice Coefficient for a given :class:`~ignite.metrics.ConfusionMatrix` metric.
Args:
cm (ConfusionMatrix): instance of confusion matrix metric
ignore_index (int, optional): index to ignore, e.g. background index
"""

if not isinstance(cm, ConfusionMatrix):
raise TypeError("Argument cm should be instance of ConfusionMatrix, but given {}".format(type(cm)))

if ignore_index is not None:
if not (isinstance(ignore_index, numbers.Integral) and 0 <= ignore_index < cm.num_classes):
raise ValueError("ignore_index should be non-negative integer, but given {}".format(ignore_index))

# Increase floating point precision and pass to CPU
cm = cm.type(torch.DoubleTensor)
dice = 2.0 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-15)

if ignore_index is not None:

def ignore_index_fn(dice_vector):
if ignore_index >= len(dice_vector):
raise ValueError("ignore_index {} is larger than the length of Dice vector {}"
.format(ignore_index, len(dice_vector)))
indices = list(range(len(dice_vector)))
indices.remove(ignore_index)
return dice_vector[indices]

return MetricsLambda(ignore_index_fn, dice)
else:
return dice
40 changes: 37 additions & 3 deletions tests/ignite/metrics/test_confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from __future__ import division
import os
import torch

Expand All @@ -7,8 +6,7 @@

from ignite.exceptions import NotComputableError
from ignite.metrics import ConfusionMatrix, IoU, mIoU
from ignite.metrics.confusion_matrix import cmAccuracy, cmPrecision, cmRecall

from ignite.metrics.confusion_matrix import cmAccuracy, cmPrecision, cmRecall, DiceCoefficient
import pytest


Expand Down Expand Up @@ -477,6 +475,42 @@ def test_cm_with_average():
np.testing.assert_almost_equal(true_pr, res)


def test_dice_coefficient():

y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

true_res = [0, 0, 0]
for index in range(3):
bin_y_true = y_true == index
bin_y_pred = y_pred == index
# dice coefficient: 2*intersection(x, y) / (|x| + |y|)
# union(x, y) = |x| + |y| - intersection(x, y)
intersection = bin_y_true & bin_y_pred
union = bin_y_true | bin_y_pred
true_res[index] = 2.0 * intersection.sum() / (union.sum() + intersection.sum())

cm = ConfusionMatrix(num_classes=3)
dice_metric = DiceCoefficient(cm)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = dice_metric.compute().numpy()
np.testing.assert_allclose(res, true_res)

for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
dice_metric = DiceCoefficient(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
res = dice_metric.compute().numpy()
true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1:]
assert np.all(res == true_res_), "{}: {} vs {}".format(ignore_index, res, true_res_)


def _test_distrib_multiclass_images(device):

import torch.distributed as dist
Expand Down

0 comments on commit f3b6734

Please sign in to comment.