|
13 | 13 | import pytest
|
14 | 14 | import torch
|
15 | 15 |
|
16 |
| -from fast_plate_ocr.custom import cat_acc_metric |
| 16 | +from fast_plate_ocr.custom import cat_acc_metric, plate_acc_metric |
17 | 17 |
|
18 | 18 |
|
19 | 19 | @pytest.mark.parametrize(
|
|
22 | 22 | (torch.tensor([[[1, 0]] * 6]), torch.tensor([[[0.9, 0.1]] * 6]), 1.0),
|
23 | 23 | ],
|
24 | 24 | )
|
25 |
| -def test_cat_acc(y_true: torch.Tensor, y_pred: torch.Tensor, expected_accuracy: float): |
| 25 | +def test_cat_acc(y_true: torch.Tensor, y_pred: torch.Tensor, expected_accuracy: float) -> None: |
26 | 26 | actual_accuracy = cat_acc_metric(2, 1)(y_true, y_pred)
|
27 | 27 | assert actual_accuracy == expected_accuracy
|
| 28 | + |
| 29 | + |
| 30 | +@pytest.mark.parametrize( |
| 31 | + "y_true, y_pred, expected_accuracy", |
| 32 | + [ |
| 33 | + ( |
| 34 | + torch.tensor( |
| 35 | + [ |
| 36 | + [ |
| 37 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 38 | + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 39 | + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 40 | + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 41 | + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 42 | + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 43 | + ], |
| 44 | + [ |
| 45 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 46 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 47 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 48 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 49 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 50 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 51 | + ], |
| 52 | + ] |
| 53 | + ), |
| 54 | + torch.tensor( |
| 55 | + [ |
| 56 | + [ |
| 57 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 58 | + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 59 | + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 60 | + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 61 | + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 62 | + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 63 | + ], |
| 64 | + [ |
| 65 | + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 66 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 67 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 68 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 69 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 70 | + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 71 | + ], |
| 72 | + ] |
| 73 | + ), |
| 74 | + # First batch slice plate was recognized completely correct but second one wasn't |
| 75 | + # So 50% of plates were recognized correctly |
| 76 | + 0.5, |
| 77 | + ), |
| 78 | + ], |
| 79 | +) |
| 80 | +def test_plate_accuracy( |
| 81 | + y_true: torch.Tensor, y_pred: torch.Tensor, expected_accuracy: float |
| 82 | +) -> None: |
| 83 | + actual_accuracy = plate_acc_metric(y_true.shape[1], y_true.shape[2])(y_true, y_pred).item() |
| 84 | + assert actual_accuracy == expected_accuracy |
0 commit comments