Skip to content

Commit f098795

Browse files
anmolsjoshivfdev-5
authored andcommitted
Metric sugar (#484)
* Added indexing to Metric, with tests * Added confusion matrix to docs * Update requirements.txt * Update metrics.rst
1 parent 6b8b16b commit f098795

File tree

4 files changed

+81
-4
lines changed

4 files changed

+81
-4
lines changed

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
sphinx
1+
sphinx==1.8.5
22
-e git://github.com/snide/sphinx_rtd_theme.git#egg=sphinx_rtd_theme

docs/source/metrics.rst

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Metrics could be combined together to form new metrics. This could be done throu
4141
as ``metric1 + metric2``, use PyTorch operators, such as ``(metric1 + metric2).pow(2).mean()``,
4242
or use a lambda function, such as ``MetricsLambda(lambda a, b: torch.mean(a + b), metric1, metric2)``.
4343

44-
for example:
44+
For example:
4545

4646
.. code-block:: python
4747
@@ -54,6 +54,16 @@ for example:
5454
that `average=False`, i.e. to use the unaveraged precision and recall,
5555
otherwise we will not be computing F-beta metrics.
5656

57+
Metrics also support indexing operation (if metric's result is a vector/matrix/tensor). For example, this can be useful to compute mean metric (e.g. precision, recall or IoU) ignoring the background:
58+
59+
.. code-block:: python
60+
61+
cm = ConfusionMatrix(num_classes=10)
62+
iou_metric = IoU(cm)
63+
iou_no_bg_metric = iou_metric[:9] # We assume that the background index is 9
64+
mean_iou_no_bg_metric = iou_no_bg_metric.mean()
65+
# mean_iou_no_bg_metric.compute() -> tensor(0.12345)
66+
5767
5868
.. currentmodule:: ignite.metrics
5969

@@ -83,3 +93,9 @@ for example:
8393
.. autoclass:: RunningAverage
8494

8595
.. autoclass:: MetricsLambda
96+
97+
.. autoclass:: ConfusionMatrix
98+
99+
.. autofunction:: IoU
100+
101+
.. autofunction:: mIoU

ignite/metrics/metric.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,7 @@ def fn(x, *args, **kwargs):
142142
def wrapper(*args, **kwargs):
143143
return MetricsLambda(fn, self, *args, **kwargs)
144144
return wrapper
145+
146+
def __getitem__(self, index):
147+
from ignite.metrics import MetricsLambda
148+
return MetricsLambda(lambda x: x[index], self)

tests/ignite/metrics/test_metric.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import sys
2-
from ignite.metrics import Metric, Precision, Recall
2+
from ignite.metrics import Metric, Precision, Recall, ConfusionMatrix
33
from ignite.engine import Engine, State
44
import torch
55
from mock import MagicMock
66

77
from pytest import approx, raises
88
import numpy as np
9-
from sklearn.metrics import precision_score, recall_score, f1_score
9+
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
1010

1111

1212
def test_no_transform():
@@ -388,3 +388,60 @@ def compute_f1(y_pred, y):
388388
return f1
389389

390390
_test(f1, "f1", compute_true_value_fn=compute_f1)
391+
392+
393+
def test_indexing_metric():
394+
def _test(ignite_metric, sklearn_metic, sklearn_args, index, num_classes=5):
395+
y_pred = torch.rand(15, 10, num_classes).float()
396+
y = torch.randint(0, num_classes, size=(15, 10)).long()
397+
398+
def update_fn(engine, batch):
399+
y_pred, y = batch
400+
return y_pred, y
401+
402+
metrics = {'metric': ignite_metric[index],
403+
'metric_wo_index': ignite_metric}
404+
405+
validator = Engine(update_fn)
406+
407+
for name, metric in metrics.items():
408+
metric.attach(validator, name)
409+
410+
def data(y_pred, y):
411+
for i in range(y_pred.shape[0]):
412+
yield (y_pred[i], y[i])
413+
414+
d = data(y_pred, y)
415+
state = validator.run(d, max_epochs=1)
416+
417+
sklearn_output = sklearn_metic(y.view(-1).numpy(),
418+
y_pred.view(-1, num_classes).argmax(dim=1).numpy(),
419+
**sklearn_args)
420+
421+
assert (state.metrics['metric_wo_index'][index] == state.metrics['metric']).all()
422+
assert (np.allclose(state.metrics['metric'].numpy(), sklearn_output))
423+
424+
num_classes = 5
425+
426+
labels = list(range(0, num_classes, 2))
427+
_test(Precision(), precision_score, {'labels': labels, 'average': None}, index=labels)
428+
labels = list(range(num_classes - 1, 0, -2))
429+
_test(Precision(), precision_score, {'labels': labels, 'average': None}, index=labels)
430+
labels = [1]
431+
_test(Precision(), precision_score, {'labels': labels, 'average': None}, index=labels)
432+
433+
labels = list(range(0, num_classes, 2))
434+
_test(Recall(), recall_score, {'labels': labels, 'average': None}, index=labels)
435+
labels = list(range(num_classes - 1, 0, -2))
436+
_test(Recall(), recall_score, {'labels': labels, 'average': None}, index=labels)
437+
labels = [1]
438+
_test(Recall(), recall_score, {'labels': labels, 'average': None}, index=labels)
439+
440+
# np.ix_ is used to allow for a 2D slice of a matrix. This is required to get accurate result from
441+
# ConfusionMatrix. ConfusionMatrix must be sliced the same row-wise and column-wise.
442+
labels = list(range(0, num_classes, 2))
443+
_test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels))
444+
labels = list(range(num_classes - 1, 0, -2))
445+
_test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels))
446+
labels = [1]
447+
_test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels))

0 commit comments

Comments
 (0)