Skip to content

Commit

Permalink
Deprecate --existing_model, make model library more lenient
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 15, 2024
1 parent f19e107 commit 151861a
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 20 deletions.
1 change: 0 additions & 1 deletion configs/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
"normalization_file": null
},
"model": {
"existing_model": null,
"freeze_conv_layers": false,
"freeze_dense_layers": false,
"freeze_recurrent_layers": false,
Expand Down
2 changes: 1 addition & 1 deletion configs/finetuning.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
"normalization_file": "/path/to/normalization.json"
},
"model": {
"existing_model": "/path/to/model/",
"freeze_conv_layers": true,
"freeze_dense_layers": false,
"freeze_recurrent_layers": false,
"model": "/path/to/model/",
"model_name": "My-Finetuned-Loghi-Model",
"replace_final_layer": true,
"replace_recurrent_layer": null,
Expand Down
2 changes: 1 addition & 1 deletion configs/inference.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"deterministic": false
},
"model": {
"existing_model": "/path/to/model/",
"model": "/path/to/model/",
"use_float32": false
}
}
Expand Down
2 changes: 1 addition & 1 deletion configs/testing.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"normalization_file": "/path/to/normalization.json"
},
"model": {
"existing_model": "/path/to/model/",
"model": "/path/to/model/",
"use_float32": false
},
"training": {
Expand Down
2 changes: 1 addition & 1 deletion configs/validation.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"normalization_file": "/path/to/normalization.json"
},
"model": {
"existing_model": "/path/to/model/",
"model": "/path/to/model/",
"use_float32": false
},
"training": {
Expand Down
4 changes: 2 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def main():
os.makedirs(config["output"], exist_ok=True)

# Get the initial character list
if config["existing_model"] or config["charlist"]:
if os.path.isdir(config["model"]) or config["charlist"]:
charlist, removed_padding = load_initial_charlist(
config["charlist"], config["existing_model"],
config["charlist"], config["model"],
config["output"], config["replace_final_layer"])
else:
charlist = []
Expand Down
9 changes: 5 additions & 4 deletions src/model/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def customize_model(model: tf.keras.Model,
use_mask=config["use_mask"])

# Replace the final layer if specified
if config["replace_final_layer"] or not config["existing_model"]:
if config["replace_final_layer"] or not os.path.isdir(config["model"]):
new_classes = len(charlist) + \
2 if config["use_mask"] else len(charlist) + 1
logging.info("Replacing final layer with %s classes", new_classes)
Expand Down Expand Up @@ -191,9 +191,10 @@ def load_or_create_model(config: Config,
The loaded or newly created Keras model.
"""

if config["existing_model"]:
model = load_model_from_directory(
config["existing_model"], custom_objects=custom_objects)
# Check if config["model"] is a directory
if os.path.isdir(config["model"]):
model = load_model_from_directory(config["model"],
custom_objects=custom_objects)
if config["model_name"]:
model._name = config["model_name"]
else:
Expand Down
2 changes: 1 addition & 1 deletion src/model/vgsl_model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self,
raise ValueError("No model provided. Please provide a model name "
"from the model library or a VGSL-spec string.")

if model_spec.startswith("model"):
if model_spec in self.model_library.keys():
try:
logging.info("Pulling model from model library")
model_string = self.model_library[model_spec]
Expand Down
20 changes: 13 additions & 7 deletions src/setup/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ def get_arg_parser():
model_args.add_argument('--use_float32', action='store_true',
help="Use 32-bit float precision in the model. "
"Can improve performance at the cost of memory.")
model_args.add_argument('--existing_model', metavar='existing_model',
type=str, default=None, help="Path to an existing "
"model to continue training, validation, testing, "
"or inferencing. Used as a starting point.")
model_args.add_argument('--model_name', metavar='model_name', type=str,
default=None, help="Custom name for the model, "
"used in outputs. Default: None (uses the model "
Expand Down Expand Up @@ -273,20 +269,26 @@ def get_arg_parser():
depr_args.add_argument('--thaw', action='store_true',
help="Unfreeze convolutional layers in an "
"existing model for further training.")
depr_args.add_argument('--existing_model', metavar='existing_model',
type=str, default=None, help="Path to an existing "
"model to continue training, validation, testing, "
"or inferencing. Used as a starting point.")

return parser


def fix_args(args):
if not args.no_auto and args.train_list:
logging.warning('do_train implied by providing a train_list')
logging.warning('--do_train implied by providing a train_list')
args.__dict__['do_train'] = True
if not args.no_auto and args.batch_size > 1:
logging.warning('batch_size > 1, setting use_mask=True')
logging.warning('--batch_size > 1, setting use_mask=True')
args.__dict__['use_mask'] = True
if not args.no_auto and args.inference_list:
logging.warning('do_inference implied by providing a inference_list')
logging.warning('--do_inference implied by providing a inference_list')
args.__dict__['do_inference'] = True
if not args.no_auto and args.existing_model:
args.__dict__['model'] = args.existing_model


def arg_future_warning(args):
Expand Down Expand Up @@ -322,6 +324,10 @@ def arg_future_warning(args):
if args.thaw:
logging.warning("Argument will lose support in May 2024: --thaw. "
"Models are saved with all layers thawed by default.")
if args.existing_model:
logger.warning("Argument will lose support in May 2024: "
"--existing_model. The --model argument can be used "
"to load or create a model instead.")


def get_args():
Expand Down
2 changes: 1 addition & 1 deletion src/setup/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def organize_args(self, args: argparse.Namespace) -> dict:
"model": {
"model": args.model,
"use_float32": args.use_float32,
"existing_model": args.existing_model,
"model_name": args.model_name,
"replace_final_layer": args.replace_final_layer,
"replace_recurrent_layer": args.replace_recurrent_layer,
Expand Down Expand Up @@ -232,6 +231,7 @@ def organize_args(self, args: argparse.Namespace) -> dict:
"output_charlist": args.output_charlist,
"config_file_output": args.config_file_output,
"thaw": args.thaw,
"existing_model": args.existing_model,
}
}

Expand Down

0 comments on commit 151861a

Please sign in to comment.