Skip to content

Commit 4f5b7fc

Browse files
authored
Fixes #689 (#704)
* Fixes #689 - handles (y_pred, y) or {'y_pred': y_pred, 'y': y, ...} as output argument for update function * Update documentation
1 parent 8c8c3c2 commit 4f5b7fc

33 files changed

+106
-52
lines changed

docs/source/metrics.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ value is then computed using the output of the engine's `process_function`:
1717
metric = Accuracy()
1818
metric.attach(engine, "accuracy")
1919
20-
If the engine's output is not in the format `y_pred, y`, the user can
20+
If the engine's output is not in the format `(y_pred, y)` or `{'y_pred': y_pred, 'y': y, ...}`, the user can
2121
use the `output_transform` argument to transform it:
2222

2323
.. code-block:: python

ignite/contrib/metrics/average_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class AveragePrecision(EpochMetric):
3030
3131
def activated_output_transform(output):
3232
y_pred, y = output
33-
y_pred = torch.softmax(y_pred)
33+
y_pred = torch.softmax(y_pred, dim=1)
3434
return y_pred, y
3535
3636
avg_precision = AveragePrecision(activated_output_transform)

ignite/contrib/metrics/regression/canberra_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class CanberraMetric(_BaseRegression):
1515
1616
More details can be found in `Botchkarev 2018`__.
1717
18-
- `update` must receive output of the form `(y_pred, y)`.
18+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1919
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
2020
2121
__ https://arxiv.org/abs/1809.03006

ignite/contrib/metrics/regression/fractional_absolute_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class FractionalAbsoluteError(_BaseRegression):
1616
1717
More details can be found in `Botchkarev 2018`__.
1818
19-
- `update` must receive output of the form `(y_pred, y)`.
19+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2020
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
2121
2222
__ https://arxiv.org/abs/1809.03006

ignite/contrib/metrics/regression/fractional_bias.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class FractionalBias(_BaseRegression):
1616
1717
More details can be found in `Botchkarev 2018`__.
1818
19-
- `update` must receive output of the form `(y_pred, y)`.
19+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2020
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
2121
2222
__ https://arxiv.org/abs/1809.03006

ignite/contrib/metrics/regression/geometric_mean_absolute_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class GeometricMeanAbsoluteError(_BaseRegression):
1616
1717
More details can be found in `Botchkarev 2018`__.
1818
19-
- `update` must receive output of the form `(y_pred, y)`.
19+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2020
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
2121
2222
__ https://arxiv.org/abs/1809.03006

ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class GeometricMeanRelativeAbsoluteError(_BaseRegression):
1515
1616
More details can be found in `Botchkarev 2018`__.
1717
18-
- `update` must receive output of the form `(y_pred, y)`.
18+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1919
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
2020
2121

ignite/contrib/metrics/regression/manhattan_distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class ManhattanDistance(_BaseRegression):
1515
1616
More details can be found in `Botchkarev 2018`__.
1717
18-
- `update` must receive output of the form `(y_pred, y)`.
18+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1919
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
2020
2121
__ https://arxiv.org/abs/1809.03006

ignite/contrib/metrics/regression/maximum_absolute_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class MaximumAbsoluteError(_BaseRegression):
1414
1515
More details can be found in `Botchkarev 2018`__.
1616
17-
- `update` must receive output of the form `(y_pred, y)`.
17+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1818
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
1919
2020
__ https://arxiv.org/abs/1809.03006

ignite/contrib/metrics/regression/mean_absolute_relative_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class MeanAbsoluteRelativeError(_BaseRegression):
1616
1717
More details can be found in the reference `Botchkarev 2018`__.
1818
19-
- `update` must receive output of the form `(y_pred, y)`.
19+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2020
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
2121
2222
__ https://arxiv.org/ftp/arxiv/papers/1809/1809.03006.pdf

ignite/contrib/metrics/regression/mean_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class MeanError(_BaseRegression):
1616
1717
More details can be found in the reference `Botchkarev 2018`__.
1818
19-
- `update` must receive output of the form `(y_pred, y)`.
19+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2020
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
2121
2222
__ https://arxiv.org/abs/1809.03006

ignite/contrib/metrics/regression/mean_normalized_bias.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class MeanNormalizedBias(_BaseRegression):
1616
1717
More details can be found in the reference `Botchkarev 2018`__.
1818
19-
- `update` must receive output of the form `(y_pred, y)`.
19+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2020
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
2121
2222
__ https://arxiv.org/abs/1809.03006

ignite/contrib/metrics/regression/median_absolute_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class MedianAbsoluteError(_BaseRegressionEpoch):
1818
1919
More details can be found in `Botchkarev 2018`__.
2020
21-
- `update` must receive output of the form `(y_pred, y)`.
21+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2222
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
2323
2424
.. warning::

ignite/contrib/metrics/regression/median_absolute_percentage_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class MedianAbsolutePercentageError(_BaseRegressionEpoch):
2020
2121
More details can be found in `Botchkarev 2018`__.
2222
23-
- `update` must receive output of the form `(y_pred, y)`.
23+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2424
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
2525
2626
.. warning::

ignite/contrib/metrics/regression/median_relative_absolute_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class MedianRelativeAbsoluteError(_BaseRegressionEpoch):
2020
2121
More details can be found in `Botchkarev 2018`__.
2222
23-
- `update` must receive output of the form `(y_pred, y)`.
23+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2424
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
2525
2626
.. warning::

ignite/contrib/metrics/regression/r2_score.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class R2Score(_BaseRegression):
1616
where :math:`A_j` is the ground truth, :math:`P_j` is the predicted value and
1717
:math:`\bar{A}` is the mean of the ground truth.
1818
19-
- `update` must receive output of the form `(y_pred, y)`.
19+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2020
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
2121
"""
2222
def reset(self):

ignite/contrib/metrics/regression/wave_hedges_distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class WaveHedgesDistance(_BaseRegression):
1414
1515
More details can be found in `Botchkarev 2018`__.
1616
17-
- `update` must receive output of the form `(y_pred, y)`.
17+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1818
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
1919
2020
__ https://arxiv.org/abs/1809.03006

ignite/engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import time
44
from collections import defaultdict, OrderedDict
5-
from collections import Mapping
5+
from collections.abc import Mapping
66
from enum import Enum
77
import weakref
88
import numbers

ignite/metrics/accumulation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class VariableAccumulation(Metric):
3535
initialized and available, device is set to `cuda`.
3636
3737
"""
38+
_required_output_keys = None
3839

3940
def __init__(self, op, output_transform=lambda x: x, device=None):
4041
if not callable(op):

ignite/metrics/accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Accuracy(_BaseClassification):
8585
"""
8686
Calculates the accuracy for binary, multiclass and multilabel data.
8787
88-
- `update` must receive output of the form `(y_pred, y)`.
88+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
8989
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).
9090
- `y` must be in the following shape (batch_size, ...).
9191
- `y` and `y_pred` must be in the following shape of (batch_size, num_categories, ...) for multilabel cases.

ignite/metrics/confusion_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class ConfusionMatrix(Metric):
1111
"""Calculates confusion matrix for multi-class data.
1212
13-
- `update` must receive output of the form `(y_pred, y)`.
13+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1414
- `y_pred` must contain logits and has the following shape (batch_size, num_categories, ...)
1515
- `y` should have the following shape (batch_size, ...) and contains ground-truth class indices
1616
with or without the background class. During the computation, argmax of `y_pred` is taken to determine

ignite/metrics/epoch_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class EpochMetric(Metric):
2020
Current implementation does not work with distributed computations. Results are not gather across all devices
2121
and computed results are valid for a single device only.
2222
23-
- `update` must receive output of the form `(y_pred, y)`.
23+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2424
2525
If target shape is `(batch_size, n_classes)` and `n_classes > 1` than it should be binary: e.g. `[[0, 1, 0, 1], ]`.
2626

ignite/metrics/loss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ class Loss(Metric):
1818
form expected by the metric.
1919
This can be useful if, for example, you have a multi-output model and
2020
you want to compute the metric with respect to one of the outputs.
21-
The output is is expected to be a tuple (prediction, target) or
21+
The output is expected to be a tuple `(prediction, target)` or
2222
(prediction, target, kwargs) where kwargs is a dictionary of extra
23-
keywords arguments.
23+
keywords arguments. If extra keywords arguments are provided they are passed to `loss_fn`.
2424
batch_size (callable): a callable taking a target tensor that returns the
2525
first dimension size (usually the batch size).
2626
device (str of torch.device, optional): device specification in case of distributed computation usage.
@@ -29,6 +29,7 @@ class Loss(Metric):
2929
initialized and available, device is set to `cuda`.
3030
3131
"""
32+
_required_output_keys = None
3233

3334
def __init__(self, loss_fn, output_transform=lambda x: x,
3435
batch_size=lambda x: len(x), device=None):

ignite/metrics/mean_absolute_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class MeanAbsoluteError(Metric):
1111
"""
1212
Calculates the mean absolute error.
1313
14-
- `update` must receive output of the form `(y_pred, y)`.
14+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1515
"""
1616
@reinit__is_reduced
1717
def reset(self):

ignite/metrics/mean_pairwise_distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class MeanPairwiseDistance(Metric):
1212
"""
1313
Calculates the mean pairwise distance: average of pairwise distances computed on provided batches.
1414
15-
- `update` must receive output of the form `(y_pred, y)`.
15+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1616
"""
1717
def __init__(self, p=2, eps=1e-6, output_transform=lambda x: x, device=None):
1818
super(MeanPairwiseDistance, self).__init__(output_transform, device=device)

ignite/metrics/mean_squared_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class MeanSquaredError(Metric):
1111
"""
1212
Calculates the mean squared error.
1313
14-
- `update` must receive output of the form `(y_pred, y)`.
14+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1515
"""
1616
@reinit__is_reduced
1717
def reset(self):

ignite/metrics/metric.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numbers
22
from abc import ABCMeta, abstractmethod
33
from functools import wraps
4+
from collections.abc import Mapping
45
import warnings
56

67
import torch
@@ -18,12 +19,14 @@ class Metric(metaclass=ABCMeta):
1819
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
1920
form expected by the metric. This can be useful if, for example, you have a multi-output model and
2021
you want to compute the metric with respect to one of the outputs.
22+
By default, metrics require the output as `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
2123
device (str of torch.device, optional): device specification in case of distributed computation usage.
2224
In most of the cases, it can be defined as "cuda:local_rank" or "cuda"
2325
if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is
2426
initialized and available, device is set to `cuda`.
2527
2628
"""
29+
_required_output_keys = ("y_pred", "y")
2730

2831
def __init__(self, output_transform=lambda x: x, device=None):
2932
self._output_transform = output_transform
@@ -110,6 +113,15 @@ def started(self, engine):
110113
@torch.no_grad()
111114
def iteration_completed(self, engine):
112115
output = self._output_transform(engine.state.output)
116+
if isinstance(output, Mapping):
117+
if self._required_output_keys is None:
118+
raise TypeError("Transformed engine output for {} metric should be a tuple/list, but given {}"
119+
.format(self.__class__.__name__, type(output)))
120+
if not all([k in output for k in self._required_output_keys]):
121+
raise ValueError("When transformed engine's output is a mapping, "
122+
"it should contain {} keys, but given {}".format(self._required_output_keys,
123+
list(output.keys())))
124+
output = tuple(output[k] for k in self._required_output_keys)
113125
self.update(output)
114126

115127
def completed(self, engine, name):

ignite/metrics/precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class Precision(_BasePrecisionRecall):
5757
"""
5858
Calculates precision for binary and multiclass data.
5959
60-
- `update` must receive output of the form `(y_pred, y)`.
60+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
6161
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).
6262
- `y` must be in the following shape (batch_size, ...).
6363

ignite/metrics/recall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class Recall(_BasePrecisionRecall):
1111
"""
1212
Calculates recall for binary and multiclass data.
1313
14-
- `update` must receive output of the form `(y_pred, y)`.
14+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1515
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).
1616
- `y` must be in the following shape (batch_size, ...).
1717

ignite/metrics/root_mean_squared_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class RootMeanSquaredError(MeanSquaredError):
88
"""
99
Calculates the root mean squared error.
1010
11-
- `update` must receive output of the form (y_pred, y).
11+
- `update` must receive output of the form (y_pred, y) or `{'y_pred': y_pred, 'y': y}`.
1212
"""
1313
def compute(self):
1414
mse = super(RootMeanSquaredError, self).compute()

ignite/metrics/running_average.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def log_running_avg_metrics(engine):
3737
print("running avg loss:", engine.state.metrics['running_avg_loss'])
3838
3939
"""
40+
_required_output_keys = None
4041

4142
def __init__(self, src=None, alpha=0.98, output_transform=None, epoch_bound=True, device=None):
4243
if not (isinstance(src, Metric) or src is None):

ignite/metrics/top_k_categorical_accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class TopKCategoricalAccuracy(Metric):
1111
"""
1212
Calculates the top-k categorical accuracy.
1313
14-
- `update` must receive output of the form `(y_pred, y)`.
14+
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
1515
"""
1616
def __init__(self, k=5, output_transform=lambda x: x, device=None):
1717
super(TopKCategoricalAccuracy, self).__init__(output_transform, device=device)

0 commit comments

Comments
 (0)