Skip to content

Commit 76df538

Browse files
committed
Fix bug in plate accuracy metric and add test
1 parent 0b2bb89 commit 76df538

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

fast_plate_ocr/custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def plate_acc(y_true, y_pred):
4444
"""
4545
y_true = ops.reshape(y_true, newshape=(-1, max_plate_slots, vocabulary_size))
4646
y_pred = ops.reshape(y_pred, newshape=(-1, max_plate_slots, vocabulary_size))
47-
et = ops.equal(ops.argmax(y_true), ops.argmax(y_pred))
47+
et = ops.equal(ops.argmax(y_true, axis=-1), ops.argmax(y_pred, axis=-1))
4848
return ops.mean(ops.cast(ops.all(et, axis=-1, keepdims=False), dtype="float32"))
4949

5050
return plate_acc

test/fast_lp_ocr/test_custom.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pytest
1414
import torch
1515

16-
from fast_plate_ocr.custom import cat_acc_metric
16+
from fast_plate_ocr.custom import cat_acc_metric, plate_acc_metric
1717

1818

1919
@pytest.mark.parametrize(
@@ -22,6 +22,63 @@
2222
(torch.tensor([[[1, 0]] * 6]), torch.tensor([[[0.9, 0.1]] * 6]), 1.0),
2323
],
2424
)
25-
def test_cat_acc(y_true: torch.Tensor, y_pred: torch.Tensor, expected_accuracy: float):
25+
def test_cat_acc(y_true: torch.Tensor, y_pred: torch.Tensor, expected_accuracy: float) -> None:
2626
actual_accuracy = cat_acc_metric(2, 1)(y_true, y_pred)
2727
assert actual_accuracy == expected_accuracy
28+
29+
30+
@pytest.mark.parametrize(
31+
"y_true, y_pred, expected_accuracy",
32+
[
33+
(
34+
torch.tensor(
35+
[
36+
[
37+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
38+
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
39+
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
40+
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
41+
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
42+
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
43+
],
44+
[
45+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
46+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
47+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
48+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
49+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
50+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
51+
],
52+
]
53+
),
54+
torch.tensor(
55+
[
56+
[
57+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
58+
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
59+
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
60+
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
61+
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
62+
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
63+
],
64+
[
65+
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
66+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
67+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
68+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
69+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
70+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
71+
],
72+
]
73+
),
74+
# First batch slice plate was recognized completely correct but second one wasn't
75+
# So 50% of plates were recognized correctly
76+
0.5,
77+
),
78+
],
79+
)
80+
def test_plate_accuracy(
81+
y_true: torch.Tensor, y_pred: torch.Tensor, expected_accuracy: float
82+
) -> None:
83+
actual_accuracy = plate_acc_metric(y_true.shape[1], y_true.shape[2])(y_true, y_pred).item()
84+
assert actual_accuracy == expected_accuracy

0 commit comments

Comments
 (0)