diff --git a/ml/synthesis/src/components/data_processing/dataset_creation.py b/ml/synthesis/src/components/data_processing/dataset_creation.py index 96f9bfd13..6e017fcaf 100644 --- a/ml/synthesis/src/components/data_processing/dataset_creation.py +++ b/ml/synthesis/src/components/data_processing/dataset_creation.py @@ -45,7 +45,7 @@ def create_datasets(main_df: pd.DataFrame, val_df: pd.DataFrame | None = None) - input_columns = train_df_inputs.columns.tolist() logger.info(f"Input columns: {input_columns}") - train_ds = _df_to_dataset(train_df_inputs, train_df_targets, batch_size=128, repeat=True) + train_ds = _df_to_dataset(train_df_inputs, train_df_targets, batch_size=64, repeat=True) val_ds = _df_to_dataset(val_df_inputs, val_df_targets) del train_df_inputs, val_df_inputs diff --git a/ml/synthesis/src/components/model_generation/models.py b/ml/synthesis/src/components/model_generation/models.py index 8a75cda6a..8b8aa76e0 100644 --- a/ml/synthesis/src/components/model_generation/models.py +++ b/ml/synthesis/src/components/model_generation/models.py @@ -5,16 +5,17 @@ def create_baseline_model(input_shape) -> Model: model = Sequential( [ layers.InputLayer(input_shape=input_shape), + layers.Dense(256, activation="relu", kernel_regularizer="l2"), + layers.Dropout(0.7), layers.Dense(64, activation="relu", kernel_regularizer="l2"), layers.Dropout(0.5), - layers.Dense(16, activation="relu"), - layers.Dense(16, activation="relu"), + layers.Dense(32, activation="relu"), layers.Dense(1), ], ) model.compile( - optimizer=optimizers.Adam(learning_rate=1e-4), + optimizer=optimizers.Adam(learning_rate=8e-5), loss="mse", metrics=["mae"], ) diff --git a/ml/synthesis/src/components/model_generation/training.py b/ml/synthesis/src/components/model_generation/training.py index 58e53d333..d19efadca 100644 --- a/ml/synthesis/src/components/model_generation/training.py +++ b/ml/synthesis/src/components/model_generation/training.py @@ -31,8 +31,8 @@ def train_and_save_baseline_model( model = create_baseline_model(input_shape=sample.shape) effective_fitting_kwargs = dict( - epochs=45, - steps_per_epoch=3000, + epochs=30, + steps_per_epoch=1500, ) if fitting_kwargs: effective_fitting_kwargs.update(fitting_kwargs)