Skip to content

Commit 2a17c7e

Browse files
committed
Add ModelCheckpoint callback
1 parent b1f9e90 commit 2a17c7e

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

train.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
Script for training the License Plate OCR models.
33
"""
44

5-
import os
65
import pathlib
6+
from datetime import datetime
77

88
import click
99
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, TensorBoard
1010
from keras.optimizers import Adam
11+
from keras.src.callbacks import ModelCheckpoint
1112
from torch.utils.data import DataLoader
1213

1314
from fast_plate_ocr.augmentation import TRAIN_AUGMENTATION
@@ -67,8 +68,8 @@
6768
)
6869
@click.option(
6970
"--output-dir",
70-
default=None,
71-
type=str,
71+
default="./trained-models",
72+
type=click.Path(dir_okay=True, path_type=pathlib.Path),
7273
help="Output directory where model will be saved.",
7374
)
7475
@click.option(
@@ -114,7 +115,7 @@ def train(
114115
val_annotations: pathlib.Path,
115116
lr: float,
116117
batch_size: int,
117-
output_dir: str,
118+
output_dir: pathlib.Path,
118119
epochs: int,
119120
tensorboard: bool,
120121
tensorboard_dir: str,
@@ -169,43 +170,42 @@ def train(
169170
],
170171
)
171172

173+
output_dir /= datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
174+
output_dir.mkdir(parents=True, exist_ok=True)
175+
model_file_path = output_dir / (f"{model_type}-" + "{epoch:02d}-{val_plate_acc:.3f}.keras")
176+
172177
callbacks = [
173178
# Reduce the learning rate by 0.5x if 'val_plate_acc' doesn't improve within X epochs
174179
ReduceLROnPlateau(
175180
"val_plate_acc",
176-
verbose=1,
177181
patience=reduce_lr_patience,
178182
factor=0.5,
179183
min_lr=1e-5,
184+
verbose=1,
180185
),
181186
# Stop training when 'val_plate_acc' doesn't improve for X epochs
182187
EarlyStopping(
183188
monitor="val_plate_acc",
184189
patience=early_stopping_patience,
185190
mode="max",
186-
restore_best_weights=True,
191+
restore_best_weights=False,
192+
verbose=1,
193+
),
194+
# We don't use EarlyStopping restore_best_weights=True because it won't restore the best
195+
# weights when it didn't manage to EarlyStop but finished all epochs
196+
ModelCheckpoint(
197+
model_file_path,
198+
monitor="val_plate_acc",
199+
mode="max",
200+
save_best_only=True,
201+
verbose=1,
187202
),
188203
]
189204

190205
if tensorboard:
191206
callbacks.append(TensorBoard(log_dir=tensorboard_dir))
192207

193-
history = model.fit(
194-
train_dataloader, epochs=epochs, validation_data=val_dataloader, callbacks=callbacks
195-
)
196-
197-
best_vpa = max(history.history["val_plate_acc"])
198-
epochs = len(history.epoch)
199-
model_name = f"cnn-ocr_{best_vpa:.4}-vpa_epochs-{epochs}"
200-
# Make dir for trained model
201-
if output_dir is None:
202-
model_folder = f"./trained/{model_name}"
203-
if not os.path.exists(model_folder):
204-
os.makedirs(model_folder)
205-
output_path = model_folder
206-
else:
207-
output_path = output_dir
208-
model.save(os.path.join(output_path, f"{model_name}.keras"))
208+
model.fit(train_dataloader, epochs=epochs, validation_data=val_dataloader, callbacks=callbacks)
209209

210210

211211
if __name__ == "__main__":

0 commit comments

Comments
 (0)