Skip to content

Commit

Permalink
Save all model layers as trainable by default
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 15, 2024
1 parent 85a12d3 commit 63043ee
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/model/custom_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,16 @@ def _save_model(self, subdir: str):
os.makedirs(outputdir, exist_ok=True)
model_path = os.path.join(outputdir, "model.keras")

# Save model
self.model.save(model_path)
# Create a copy of the model
unfrozen_model = tf.keras.models.clone_model(self.model)
unfrozen_model.set_weights(self.model.get_weights())

# Unfreeze all layers in the copied model
for layer in unfrozen_model.layers:
layer.trainable = True

# Save the unfrozen model
unfrozen_model.save(model_path)

# Save additional files
if self.charlist:
Expand Down
3 changes: 3 additions & 0 deletions src/setup/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ def arg_future_warning(args):
logger.warning("Argument will lose support in May 2024: "
"--config_file_output. The configuration will be saved "
"to output/config.json by default.")
if args.thaw:
logging.warning("Argument will lose support in May 2024: --thaw. "
"Models are saved with all layers thawed by default.")


def get_args():
Expand Down

0 comments on commit 63043ee

Please sign in to comment.