From 2d9b51e861c0dff2962f5dd21086afdaeecce538 Mon Sep 17 00:00:00 2001 From: ankandrew <61120139+ankandrew@users.noreply.github.com> Date: Sat, 30 Mar 2024 23:44:15 -0300 Subject: [PATCH] Revert "Use AdamW and add weight decay CLI arg" This reverts commit 98a60b55359733ac325d99c9275f321380da8bc2. --- train.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index d53729b..b4c4748 100644 --- a/train.py +++ b/train.py @@ -7,7 +7,7 @@ import click from keras.callbacks import EarlyStopping, ReduceLROnPlateau, TensorBoard -from keras.optimizers import AdamW +from keras.optimizers import Adam from keras.src.callbacks import ModelCheckpoint from torch.utils.data import DataLoader @@ -107,13 +107,6 @@ type=int, help="Reduce the learning rate by 0.5x if 'val_plate_acc' doesn't improve within X epochs.", ) -@click.option( - "--weight-decay", - default=0.01, - show_default=True, - type=float, - help="Weight decay value.", -) def train( model_type: str, dense: bool, @@ -128,7 +121,6 @@ def train( tensorboard_dir: str, early_stopping_patience: int, reduce_lr_patience: int, - weight_decay: float, ) -> None: config = load_config_from_yaml(config_file) train_torch_dataset = LicensePlateDataset( @@ -166,7 +158,7 @@ def train( ) model.compile( loss=cce_loss(vocabulary_size=config.vocabulary_size), - optimizer=AdamW(lr, weight_decay=weight_decay), + optimizer=Adam(lr), metrics=[ cat_acc_metric( max_plate_slots=config.max_plate_slots, vocabulary_size=config.vocabulary_size