4
4
5
5
import logging
6
6
import pathlib
7
- from contextlib import nullcontext
8
7
9
8
import click
10
9
import cv2
11
10
import keras
12
11
import numpy as np
13
12
14
- import fast_plate_ocr .common .utils
15
13
from fast_plate_ocr .train .model .config import load_config_from_yaml
16
14
from fast_plate_ocr .train .utilities import utils
15
+ from fast_plate_ocr .train .utilities .utils import postprocess_model_output
17
16
18
17
logging .basicConfig (
19
18
level = logging .INFO , format = "%(asctime)s - %(message)s" , datefmt = "%Y-%m-%d %H:%M:%S"
42
41
type = click .Path (exists = True , dir_okay = True , file_okay = False , path_type = pathlib .Path ),
43
42
help = "Directory containing the images to make predictions from." ,
44
43
)
45
- @click .option (
46
- "-t" ,
47
- "--time" ,
48
- default = True ,
49
- is_flag = True ,
50
- help = "Log time taken to run predictions." ,
51
- )
52
44
@click .option (
53
45
"-l" ,
54
46
"--low-conf-thresh" ,
55
47
type = float ,
56
- default = 0.2 ,
48
+ default = 0.35 ,
57
49
show_default = True ,
58
50
help = "Threshold for displaying low confidence characters." ,
59
51
)
52
+ @click .option (
53
+ "-l" ,
54
+ "--filter-conf" ,
55
+ type = float ,
56
+ help = "Display plates that any of the plate characters are below this number." ,
57
+ )
60
58
def visualize_predictions (
61
59
model_path : pathlib .Path ,
62
60
config_file : pathlib .Path ,
63
61
img_dir : pathlib .Path ,
64
62
low_conf_thresh : float ,
65
- time : bool ,
63
+ filter_conf : float | None ,
66
64
):
67
65
"""
68
66
Visualize OCR model predictions on unlabeled data.
@@ -75,20 +73,19 @@ def visualize_predictions(
75
73
img_dir , width = config .img_width , height = config .img_height
76
74
)
77
75
for image in images :
78
- with (
79
- fast_plate_ocr .common .utils .log_time_taken ("Prediction time" ) if time else nullcontext ()
80
- ):
81
- x = np .expand_dims (image , 0 )
82
- prediction = model (x , training = False )
83
- prediction = keras .ops .stop_gradient (prediction ).numpy ()
84
- utils .display_predictions (
85
- image = image ,
76
+ x = np .expand_dims (image , 0 )
77
+ prediction = model (x , training = False )
78
+ prediction = keras .ops .stop_gradient (prediction ).numpy ()
79
+ plate , probs = postprocess_model_output (
86
80
prediction = prediction ,
87
81
alphabet = config .alphabet ,
88
- plate_slots = config .max_plate_slots ,
82
+ max_plate_slots = config .max_plate_slots ,
89
83
vocab_size = config .vocabulary_size ,
90
- low_conf_thresh = low_conf_thresh ,
91
84
)
85
+ if not filter_conf or (filter_conf and np .any (probs < filter_conf )):
86
+ utils .display_predictions (
87
+ image = image , plate = plate , probs = probs , low_conf_thresh = low_conf_thresh
88
+ )
92
89
cv2 .destroyAllWindows ()
93
90
94
91
0 commit comments