Skip to content

Commit 1ba3af6

Browse files
committed
Add LR scheduler and Early stopping CLI args
1 parent 6f73cae commit 1ba3af6

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

train.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,20 @@
134134
type=str,
135135
help="Padding char for plates with length less than '--plate-slots'.",
136136
)
137+
@click.option(
138+
"--early-stopping-patience",
139+
default=120,
140+
show_default=True,
141+
type=int,
142+
help="Stop training when 'val_plate_acc' doesn't improve for X epochs.",
143+
)
144+
@click.option(
145+
"--reduce-lr-patience",
146+
default=100,
147+
show_default=True,
148+
type=int,
149+
help="Reduce the learning rate by 0.5x if 'val_plate_acc' doesn't improve within X epochs.",
150+
)
137151
def train(
138152
model_type: str,
139153
dense: bool,
@@ -151,6 +165,8 @@ def train(
151165
alphabet: str,
152166
vocab_size: int,
153167
pad_char: str,
168+
early_stopping_patience: int,
169+
reduce_lr_patience: int,
154170
) -> None:
155171
train_torch_dataset = LicensePlateDataset(
156172
annotations_file=annotations,
@@ -201,9 +217,20 @@ def train(
201217

202218
callbacks = [
203219
# Reduce the learning rate by 0.5x if 'val_plate_acc' doesn't improve within X epochs
204-
ReduceLROnPlateau("val_plate_acc", verbose=1, patience=35, factor=0.5, min_lr=1e-5),
220+
ReduceLROnPlateau(
221+
"val_plate_acc",
222+
verbose=1,
223+
patience=reduce_lr_patience,
224+
factor=0.5,
225+
min_lr=1e-5,
226+
),
205227
# Stop training when 'val_plate_acc' doesn't improve for X epochs
206-
EarlyStopping(monitor="val_plate_acc", patience=50, mode="max", restore_best_weights=True),
228+
EarlyStopping(
229+
monitor="val_plate_acc",
230+
patience=early_stopping_patience,
231+
mode="max",
232+
restore_best_weights=True,
233+
),
207234
]
208235

209236
if tensorboard:

0 commit comments

Comments
 (0)