2
2
Utility functions module
3
3
"""
4
4
5
+ import logging
5
6
import os
7
+ import pathlib
8
+ import random
9
+ import time
10
+ from collections .abc import Iterator
11
+ from contextlib import contextmanager
6
12
7
13
import cv2
14
+ import keras
8
15
import numpy as np
9
16
import numpy .typing as npt
10
-
11
- from fast_plate_ocr .config import MAX_PLATE_SLOTS , MODEL_ALPHABET , PAD_CHAR
17
+ from keras .src .activations import softmax
18
+
19
+ from fast_plate_ocr .config import (
20
+ DEFAULT_IMG_HEIGHT ,
21
+ DEFAULT_IMG_WIDTH ,
22
+ MAX_PLATE_SLOTS ,
23
+ MODEL_ALPHABET ,
24
+ PAD_CHAR ,
25
+ VOCABULARY_SIZE ,
26
+ )
27
+ from fast_plate_ocr .custom import cat_acc_metric , cce_loss , plate_acc_metric , top_3_k_metric
12
28
from fast_plate_ocr .custom_types import Framework
13
29
14
30
@@ -49,7 +65,9 @@ def set_keras_backend(framework: Framework) -> None:
49
65
os .environ ["KERAS_BACKEND" ] = framework
50
66
51
67
52
- def read_plate_image (image_path : str , img_height : int , img_width : int ) -> npt .NDArray :
68
+ def read_plate_image (
69
+ image_path : str , img_height : int = DEFAULT_IMG_HEIGHT , img_width : int = DEFAULT_IMG_WIDTH
70
+ ) -> npt .NDArray :
53
71
"""
54
72
Read and resize a license plate image.
55
73
@@ -62,3 +80,59 @@ def read_plate_image(image_path: str, img_height: int, img_width: int) -> npt.ND
62
80
img = cv2 .resize (img , (img_width , img_height ), interpolation = cv2 .INTER_LINEAR )
63
81
img = np .expand_dims (img , - 1 )
64
82
return img
83
+
84
+
85
+ def load_keras_model (
86
+ model_path : pathlib .Path ,
87
+ vocab_size : int = VOCABULARY_SIZE ,
88
+ max_plate_slots : int = MAX_PLATE_SLOTS ,
89
+ ) -> keras .Model :
90
+ """
91
+ Utility helper function to load the keras OCR model.
92
+ """
93
+ custom_objects = {
94
+ "cce" : cce_loss (vocabulary_size = vocab_size ),
95
+ "cat_acc" : cat_acc_metric (max_plate_slots = max_plate_slots , vocabulary_size = vocab_size ),
96
+ "plate_acc" : plate_acc_metric (max_plate_slots = max_plate_slots , vocabulary_size = vocab_size ),
97
+ "top_3_k" : top_3_k_metric (vocabulary_size = vocab_size ),
98
+ "softmax" : softmax ,
99
+ }
100
+ model = keras .models .load_model (model_path , custom_objects = custom_objects )
101
+ return model
102
+
103
+
104
+ IMG_EXTENSIONS : set [str ] = {".jpg" , ".jpeg" , ".png" , ".bmp" , ".gif" , ".tiff" , ".webp" }
105
+ """Valid image extensions for the scope of this script."""
106
+
107
+
108
+ def load_images_from_folder (
109
+ img_dir : pathlib .Path ,
110
+ width : int = DEFAULT_IMG_WIDTH ,
111
+ height : int = DEFAULT_IMG_HEIGHT ,
112
+ shuffle : bool = False ,
113
+ limit : int | None = None ,
114
+ ) -> list [npt .NDArray ]:
115
+ """
116
+ Return all images read from a directory. This uses the same read function used during training.
117
+ """
118
+ image_paths = sorted (
119
+ str (f .resolve ()) for f in img_dir .iterdir () if f .is_file () and f .suffix in IMG_EXTENSIONS
120
+ )
121
+ if limit :
122
+ image_paths = image_paths [:limit ]
123
+ if shuffle :
124
+ random .shuffle (image_paths )
125
+ images = [read_plate_image (i , img_height = height , img_width = width ) for i in image_paths ]
126
+ return images
127
+
128
+
129
+ @contextmanager
130
+ def log_time_taken (process_name : str ) -> Iterator [None ]:
131
+ """A concise context manager to time code snippets and log the result."""
132
+ time_start : float = time .perf_counter ()
133
+ try :
134
+ yield
135
+ finally :
136
+ time_end : float = time .perf_counter ()
137
+ time_elapsed : float = time_end - time_start
138
+ logging .info ("Computation time of '%s' = %.3fms" , process_name , 1000 * time_elapsed )
0 commit comments