Skip to content

Commit

Permalink
Revert "Use AdamW and add weight decay CLI arg"
Browse files Browse the repository at this point in the history
This reverts commit 98a60b5.
  • Loading branch information
ankandrew committed Mar 31, 2024
1 parent f8452b3 commit 2d9b51e
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import click
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, TensorBoard
from keras.optimizers import AdamW
from keras.optimizers import Adam
from keras.src.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -107,13 +107,6 @@
type=int,
help="Reduce the learning rate by 0.5x if 'val_plate_acc' doesn't improve within X epochs.",
)
@click.option(
"--weight-decay",
default=0.01,
show_default=True,
type=float,
help="Weight decay value.",
)
def train(
model_type: str,
dense: bool,
Expand All @@ -128,7 +121,6 @@ def train(
tensorboard_dir: str,
early_stopping_patience: int,
reduce_lr_patience: int,
weight_decay: float,
) -> None:
config = load_config_from_yaml(config_file)
train_torch_dataset = LicensePlateDataset(
Expand Down Expand Up @@ -166,7 +158,7 @@ def train(
)
model.compile(
loss=cce_loss(vocabulary_size=config.vocabulary_size),
optimizer=AdamW(lr, weight_decay=weight_decay),
optimizer=Adam(lr),
metrics=[
cat_acc_metric(
max_plate_slots=config.max_plate_slots, vocabulary_size=config.vocabulary_size
Expand Down

0 comments on commit 2d9b51e

Please sign in to comment.