Skip to content

Commit 3207fd5

Browse files
committed
Add option to visualize only predictions with low char prob
1 parent 2d194e2 commit 3207fd5

File tree

2 files changed

+20
-31
lines changed

2 files changed

+20
-31
lines changed

fast_plate_ocr/cli/visualize_predictions.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44

55
import logging
66
import pathlib
7-
from contextlib import nullcontext
87

98
import click
109
import cv2
1110
import keras
1211
import numpy as np
1312

14-
import fast_plate_ocr.common.utils
1513
from fast_plate_ocr.train.model.config import load_config_from_yaml
1614
from fast_plate_ocr.train.utilities import utils
15+
from fast_plate_ocr.train.utilities.utils import postprocess_model_output
1716

1817
logging.basicConfig(
1918
level=logging.INFO, format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
@@ -42,27 +41,26 @@
4241
type=click.Path(exists=True, dir_okay=True, file_okay=False, path_type=pathlib.Path),
4342
help="Directory containing the images to make predictions from.",
4443
)
45-
@click.option(
46-
"-t",
47-
"--time",
48-
default=True,
49-
is_flag=True,
50-
help="Log time taken to run predictions.",
51-
)
5244
@click.option(
5345
"-l",
5446
"--low-conf-thresh",
5547
type=float,
56-
default=0.2,
48+
default=0.35,
5749
show_default=True,
5850
help="Threshold for displaying low confidence characters.",
5951
)
52+
@click.option(
53+
"-l",
54+
"--filter-conf",
55+
type=float,
56+
help="Display plates that any of the plate characters are below this number.",
57+
)
6058
def visualize_predictions(
6159
model_path: pathlib.Path,
6260
config_file: pathlib.Path,
6361
img_dir: pathlib.Path,
6462
low_conf_thresh: float,
65-
time: bool,
63+
filter_conf: float | None,
6664
):
6765
"""
6866
Visualize OCR model predictions on unlabeled data.
@@ -75,20 +73,19 @@ def visualize_predictions(
7573
img_dir, width=config.img_width, height=config.img_height
7674
)
7775
for image in images:
78-
with (
79-
fast_plate_ocr.common.utils.log_time_taken("Prediction time") if time else nullcontext()
80-
):
81-
x = np.expand_dims(image, 0)
82-
prediction = model(x, training=False)
83-
prediction = keras.ops.stop_gradient(prediction).numpy()
84-
utils.display_predictions(
85-
image=image,
76+
x = np.expand_dims(image, 0)
77+
prediction = model(x, training=False)
78+
prediction = keras.ops.stop_gradient(prediction).numpy()
79+
plate, probs = postprocess_model_output(
8680
prediction=prediction,
8781
alphabet=config.alphabet,
88-
plate_slots=config.max_plate_slots,
82+
max_plate_slots=config.max_plate_slots,
8983
vocab_size=config.vocabulary_size,
90-
low_conf_thresh=low_conf_thresh,
9184
)
85+
if not filter_conf or (filter_conf and np.any(probs < filter_conf)):
86+
utils.display_predictions(
87+
image=image, plate=plate, probs=probs, low_conf_thresh=low_conf_thresh
88+
)
9289
cv2.destroyAllWindows()
9390

9491

fast_plate_ocr/train/utilities/utils.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,21 +117,13 @@ def low_confidence_positions(probs, thresh=0.3) -> npt.NDArray:
117117

118118
def display_predictions(
119119
image: npt.NDArray,
120-
prediction: npt.NDArray,
121-
alphabet: str,
122-
plate_slots: int,
123-
vocab_size: int,
120+
plate: str,
121+
probs: npt.NDArray,
124122
low_conf_thresh: float,
125123
) -> None:
126124
"""
127125
Display plate and corresponding prediction.
128126
"""
129-
plate, probs = postprocess_model_output(
130-
prediction=prediction,
131-
alphabet=alphabet,
132-
max_plate_slots=plate_slots,
133-
vocab_size=vocab_size,
134-
)
135127
plate_str = "".join(plate)
136128
logging.info("Plate: %s", plate_str)
137129
logging.info("Confidence: %s", probs)

0 commit comments

Comments
 (0)