Skip to content

Commit 0b2bb89

Browse files
committed
Add and use more utilities
1 parent 12cde10 commit 0b2bb89

File tree

3 files changed

+87
-30
lines changed

3 files changed

+87
-30
lines changed

fast_plate_ocr/utils.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,29 @@
22
Utility functions module
33
"""
44

5+
import logging
56
import os
7+
import pathlib
8+
import random
9+
import time
10+
from collections.abc import Iterator
11+
from contextlib import contextmanager
612

713
import cv2
14+
import keras
815
import numpy as np
916
import numpy.typing as npt
10-
11-
from fast_plate_ocr.config import MAX_PLATE_SLOTS, MODEL_ALPHABET, PAD_CHAR
17+
from keras.src.activations import softmax
18+
19+
from fast_plate_ocr.config import (
20+
DEFAULT_IMG_HEIGHT,
21+
DEFAULT_IMG_WIDTH,
22+
MAX_PLATE_SLOTS,
23+
MODEL_ALPHABET,
24+
PAD_CHAR,
25+
VOCABULARY_SIZE,
26+
)
27+
from fast_plate_ocr.custom import cat_acc_metric, cce_loss, plate_acc_metric, top_3_k_metric
1228
from fast_plate_ocr.custom_types import Framework
1329

1430

@@ -49,7 +65,9 @@ def set_keras_backend(framework: Framework) -> None:
4965
os.environ["KERAS_BACKEND"] = framework
5066

5167

52-
def read_plate_image(image_path: str, img_height: int, img_width: int) -> npt.NDArray:
68+
def read_plate_image(
69+
image_path: str, img_height: int = DEFAULT_IMG_HEIGHT, img_width: int = DEFAULT_IMG_WIDTH
70+
) -> npt.NDArray:
5371
"""
5472
Read and resize a license plate image.
5573
@@ -62,3 +80,59 @@ def read_plate_image(image_path: str, img_height: int, img_width: int) -> npt.ND
6280
img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_LINEAR)
6381
img = np.expand_dims(img, -1)
6482
return img
83+
84+
85+
def load_keras_model(
86+
model_path: pathlib.Path,
87+
vocab_size: int = VOCABULARY_SIZE,
88+
max_plate_slots: int = MAX_PLATE_SLOTS,
89+
) -> keras.Model:
90+
"""
91+
Utility helper function to load the keras OCR model.
92+
"""
93+
custom_objects = {
94+
"cce": cce_loss(vocabulary_size=vocab_size),
95+
"cat_acc": cat_acc_metric(max_plate_slots=max_plate_slots, vocabulary_size=vocab_size),
96+
"plate_acc": plate_acc_metric(max_plate_slots=max_plate_slots, vocabulary_size=vocab_size),
97+
"top_3_k": top_3_k_metric(vocabulary_size=vocab_size),
98+
"softmax": softmax,
99+
}
100+
model = keras.models.load_model(model_path, custom_objects=custom_objects)
101+
return model
102+
103+
104+
IMG_EXTENSIONS: set[str] = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"}
105+
"""Valid image extensions for the scope of this script."""
106+
107+
108+
def load_images_from_folder(
109+
img_dir: pathlib.Path,
110+
width: int = DEFAULT_IMG_WIDTH,
111+
height: int = DEFAULT_IMG_HEIGHT,
112+
shuffle: bool = False,
113+
limit: int | None = None,
114+
) -> list[npt.NDArray]:
115+
"""
116+
Return all images read from a directory. This uses the same read function used during training.
117+
"""
118+
image_paths = sorted(
119+
str(f.resolve()) for f in img_dir.iterdir() if f.is_file() and f.suffix in IMG_EXTENSIONS
120+
)
121+
if limit:
122+
image_paths = image_paths[:limit]
123+
if shuffle:
124+
random.shuffle(image_paths)
125+
images = [read_plate_image(i, img_height=height, img_width=width) for i in image_paths]
126+
return images
127+
128+
129+
@contextmanager
130+
def log_time_taken(process_name: str) -> Iterator[None]:
131+
"""A concise context manager to time code snippets and log the result."""
132+
time_start: float = time.perf_counter()
133+
try:
134+
yield
135+
finally:
136+
time_end: float = time.perf_counter()
137+
time_elapsed: float = time_end - time_start
138+
logging.info("Computation time of '%s' = %.3fms", process_name, 1000 * time_elapsed)

valid.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55
import pathlib
66

77
import click
8-
import keras
9-
from keras.src.activations import softmax
108
from torch.utils.data import DataLoader
119

10+
from fast_plate_ocr import utils
1211
from fast_plate_ocr.config import MAX_PLATE_SLOTS, MODEL_ALPHABET, PAD_CHAR, VOCABULARY_SIZE
1312

1413
# Custom metris / losses
15-
from fast_plate_ocr.custom import cat_acc_metric, cce_loss, plate_acc_metric, top_3_k_metric
1614
from fast_plate_ocr.dataset import LicensePlateDataset
1715

1816

@@ -79,14 +77,7 @@ def valid(
7977
pad_char: str,
8078
) -> None:
8179
"""Validate a model for a given annotated data."""
82-
custom_objects = {
83-
"cce": cce_loss(vocabulary_size=vocab_size),
84-
"cat_acc": cat_acc_metric(max_plate_slots=plate_slots, vocabulary_size=vocab_size),
85-
"plate_acc": plate_acc_metric(max_plate_slots=plate_slots, vocabulary_size=vocab_size),
86-
"top_3_k": top_3_k_metric(vocabulary_size=vocab_size),
87-
"softmax": softmax,
88-
}
89-
model = keras.models.load_model(model_path, custom_objects=custom_objects)
80+
model = utils.load_keras_model(model_path, vocab_size=vocab_size, max_plate_slots=plate_slots)
9081
val_torch_dataset = LicensePlateDataset(
9182
annotations_file=annotations,
9283
max_plate_slots=plate_slots,

visualize_augmentation.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111
import numpy as np
1212
import numpy.typing as npt
1313

14+
from fast_plate_ocr import utils
1415
from fast_plate_ocr.augmentation import TRAIN_AUGMENTATION
1516
from fast_plate_ocr.config import DEFAULT_IMG_HEIGHT, DEFAULT_IMG_WIDTH
16-
from fast_plate_ocr.utils import read_plate_image
17-
18-
IMG_EXTENSIONS: set[str] = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"}
19-
"""Valid image extensions for the scope of this script."""
2017

2118

2219
def _set_seed(seed: int | None) -> None:
@@ -29,18 +26,13 @@ def _set_seed(seed: int | None) -> None:
2926
def load_images(
3027
img_dir: pathlib.Path,
3128
num_images: int,
32-
shuffle_img: bool,
29+
shuffle: bool,
3330
height: int,
3431
width: int,
3532
) -> tuple[list[npt.NDArray[np.uint8]], list[npt.NDArray[np.uint8]]]:
36-
img_paths = sorted(f for f in img_dir.iterdir() if f.is_file() and f.suffix in IMG_EXTENSIONS)
37-
img_paths = img_paths[:num_images]
38-
if shuffle_img:
39-
random.shuffle(img_paths)
40-
images = [
41-
read_plate_image(image_path=str(img), img_height=height, img_width=width)
42-
for img in img_paths
43-
]
33+
images = utils.load_images_from_folder(
34+
img_dir, height=height, width=width, shuffle=shuffle, limit=num_images
35+
)
4436
augmented_images = [TRAIN_AUGMENTATION(image=i)["image"] for i in images]
4537
return images, augmented_images
4638

@@ -94,7 +86,7 @@ def display_images(
9486
help="Maximum number of images to visualize.",
9587
)
9688
@click.option(
97-
"--shuffle_img",
89+
"--shuffle",
9890
"-s",
9991
is_flag=True,
10092
default=False,
@@ -146,7 +138,7 @@ def display_images(
146138
def visualize_augmentation(
147139
img_dir: pathlib.Path,
148140
num_images: int,
149-
shuffle_img: bool,
141+
shuffle: bool,
150142
columns: int,
151143
rows: int,
152144
height: int,
@@ -155,7 +147,7 @@ def visualize_augmentation(
155147
show_original: bool,
156148
) -> None:
157149
_set_seed(seed)
158-
images, augmented_images = load_images(img_dir, num_images, shuffle_img, height, width)
150+
images, augmented_images = load_images(img_dir, num_images, shuffle, height, width)
159151
display_images(images, augmented_images, columns, rows, show_original)
160152

161153

0 commit comments

Comments
 (0)