|
1 | 1 | import sys
|
2 |
| -from ignite.metrics import Metric, Precision, Recall |
| 2 | +from ignite.metrics import Metric, Precision, Recall, ConfusionMatrix |
3 | 3 | from ignite.engine import Engine, State
|
4 | 4 | import torch
|
5 | 5 | from mock import MagicMock
|
6 | 6 |
|
7 | 7 | from pytest import approx, raises
|
8 | 8 | import numpy as np
|
9 |
| -from sklearn.metrics import precision_score, recall_score, f1_score |
| 9 | +from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix |
10 | 10 |
|
11 | 11 |
|
12 | 12 | def test_no_transform():
|
@@ -388,3 +388,60 @@ def compute_f1(y_pred, y):
|
388 | 388 | return f1
|
389 | 389 |
|
390 | 390 | _test(f1, "f1", compute_true_value_fn=compute_f1)
|
| 391 | + |
| 392 | + |
| 393 | +def test_indexing_metric(): |
| 394 | + def _test(ignite_metric, sklearn_metic, sklearn_args, index, num_classes=5): |
| 395 | + y_pred = torch.rand(15, 10, num_classes).float() |
| 396 | + y = torch.randint(0, num_classes, size=(15, 10)).long() |
| 397 | + |
| 398 | + def update_fn(engine, batch): |
| 399 | + y_pred, y = batch |
| 400 | + return y_pred, y |
| 401 | + |
| 402 | + metrics = {'metric': ignite_metric[index], |
| 403 | + 'metric_wo_index': ignite_metric} |
| 404 | + |
| 405 | + validator = Engine(update_fn) |
| 406 | + |
| 407 | + for name, metric in metrics.items(): |
| 408 | + metric.attach(validator, name) |
| 409 | + |
| 410 | + def data(y_pred, y): |
| 411 | + for i in range(y_pred.shape[0]): |
| 412 | + yield (y_pred[i], y[i]) |
| 413 | + |
| 414 | + d = data(y_pred, y) |
| 415 | + state = validator.run(d, max_epochs=1) |
| 416 | + |
| 417 | + sklearn_output = sklearn_metic(y.view(-1).numpy(), |
| 418 | + y_pred.view(-1, num_classes).argmax(dim=1).numpy(), |
| 419 | + **sklearn_args) |
| 420 | + |
| 421 | + assert (state.metrics['metric_wo_index'][index] == state.metrics['metric']).all() |
| 422 | + assert (np.allclose(state.metrics['metric'].numpy(), sklearn_output)) |
| 423 | + |
| 424 | + num_classes = 5 |
| 425 | + |
| 426 | + labels = list(range(0, num_classes, 2)) |
| 427 | + _test(Precision(), precision_score, {'labels': labels, 'average': None}, index=labels) |
| 428 | + labels = list(range(num_classes - 1, 0, -2)) |
| 429 | + _test(Precision(), precision_score, {'labels': labels, 'average': None}, index=labels) |
| 430 | + labels = [1] |
| 431 | + _test(Precision(), precision_score, {'labels': labels, 'average': None}, index=labels) |
| 432 | + |
| 433 | + labels = list(range(0, num_classes, 2)) |
| 434 | + _test(Recall(), recall_score, {'labels': labels, 'average': None}, index=labels) |
| 435 | + labels = list(range(num_classes - 1, 0, -2)) |
| 436 | + _test(Recall(), recall_score, {'labels': labels, 'average': None}, index=labels) |
| 437 | + labels = [1] |
| 438 | + _test(Recall(), recall_score, {'labels': labels, 'average': None}, index=labels) |
| 439 | + |
| 440 | + # np.ix_ is used to allow for a 2D slice of a matrix. This is required to get accurate result from |
| 441 | + # ConfusionMatrix. ConfusionMatrix must be sliced the same row-wise and column-wise. |
| 442 | + labels = list(range(0, num_classes, 2)) |
| 443 | + _test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels)) |
| 444 | + labels = list(range(num_classes - 1, 0, -2)) |
| 445 | + _test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels)) |
| 446 | + labels = [1] |
| 447 | + _test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels)) |
0 commit comments