|
2 | 2 | Script for training the License Plate OCR models.
|
3 | 3 | """
|
4 | 4 |
|
5 |
| -import os |
6 | 5 | import pathlib
|
| 6 | +from datetime import datetime |
7 | 7 |
|
8 | 8 | import click
|
9 | 9 | from keras.callbacks import EarlyStopping, ReduceLROnPlateau, TensorBoard
|
10 | 10 | from keras.optimizers import Adam
|
| 11 | +from keras.src.callbacks import ModelCheckpoint |
11 | 12 | from torch.utils.data import DataLoader
|
12 | 13 |
|
13 | 14 | from fast_plate_ocr.augmentation import TRAIN_AUGMENTATION
|
|
67 | 68 | )
|
68 | 69 | @click.option(
|
69 | 70 | "--output-dir",
|
70 |
| - default=None, |
71 |
| - type=str, |
| 71 | + default="./trained-models", |
| 72 | + type=click.Path(dir_okay=True, path_type=pathlib.Path), |
72 | 73 | help="Output directory where model will be saved.",
|
73 | 74 | )
|
74 | 75 | @click.option(
|
@@ -114,7 +115,7 @@ def train(
|
114 | 115 | val_annotations: pathlib.Path,
|
115 | 116 | lr: float,
|
116 | 117 | batch_size: int,
|
117 |
| - output_dir: str, |
| 118 | + output_dir: pathlib.Path, |
118 | 119 | epochs: int,
|
119 | 120 | tensorboard: bool,
|
120 | 121 | tensorboard_dir: str,
|
@@ -169,43 +170,42 @@ def train(
|
169 | 170 | ],
|
170 | 171 | )
|
171 | 172 |
|
| 173 | + output_dir /= datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| 174 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 175 | + model_file_path = output_dir / (f"{model_type}-" + "{epoch:02d}-{val_plate_acc:.3f}.keras") |
| 176 | + |
172 | 177 | callbacks = [
|
173 | 178 | # Reduce the learning rate by 0.5x if 'val_plate_acc' doesn't improve within X epochs
|
174 | 179 | ReduceLROnPlateau(
|
175 | 180 | "val_plate_acc",
|
176 |
| - verbose=1, |
177 | 181 | patience=reduce_lr_patience,
|
178 | 182 | factor=0.5,
|
179 | 183 | min_lr=1e-5,
|
| 184 | + verbose=1, |
180 | 185 | ),
|
181 | 186 | # Stop training when 'val_plate_acc' doesn't improve for X epochs
|
182 | 187 | EarlyStopping(
|
183 | 188 | monitor="val_plate_acc",
|
184 | 189 | patience=early_stopping_patience,
|
185 | 190 | mode="max",
|
186 |
| - restore_best_weights=True, |
| 191 | + restore_best_weights=False, |
| 192 | + verbose=1, |
| 193 | + ), |
| 194 | + # We don't use EarlyStopping restore_best_weights=True because it won't restore the best |
| 195 | + # weights when it didn't manage to EarlyStop but finished all epochs |
| 196 | + ModelCheckpoint( |
| 197 | + model_file_path, |
| 198 | + monitor="val_plate_acc", |
| 199 | + mode="max", |
| 200 | + save_best_only=True, |
| 201 | + verbose=1, |
187 | 202 | ),
|
188 | 203 | ]
|
189 | 204 |
|
190 | 205 | if tensorboard:
|
191 | 206 | callbacks.append(TensorBoard(log_dir=tensorboard_dir))
|
192 | 207 |
|
193 |
| - history = model.fit( |
194 |
| - train_dataloader, epochs=epochs, validation_data=val_dataloader, callbacks=callbacks |
195 |
| - ) |
196 |
| - |
197 |
| - best_vpa = max(history.history["val_plate_acc"]) |
198 |
| - epochs = len(history.epoch) |
199 |
| - model_name = f"cnn-ocr_{best_vpa:.4}-vpa_epochs-{epochs}" |
200 |
| - # Make dir for trained model |
201 |
| - if output_dir is None: |
202 |
| - model_folder = f"./trained/{model_name}" |
203 |
| - if not os.path.exists(model_folder): |
204 |
| - os.makedirs(model_folder) |
205 |
| - output_path = model_folder |
206 |
| - else: |
207 |
| - output_path = output_dir |
208 |
| - model.save(os.path.join(output_path, f"{model_name}.keras")) |
| 208 | + model.fit(train_dataloader, epochs=epochs, validation_data=val_dataloader, callbacks=callbacks) |
209 | 209 |
|
210 | 210 |
|
211 | 211 | if __name__ == "__main__":
|
|
0 commit comments