diff --git a/src/main.py b/src/main.py index a1db3545..ea03a98e 100644 --- a/src/main.py +++ b/src/main.py @@ -205,6 +205,7 @@ def main(): print("creating new model") model_generator = VGSLModelGenerator( model=args.model, + name=args.model_name, channels=model_channels, output_classes=len(char_list) + 2 if args.use_mask else len(char_list) + 1 diff --git a/src/vgsl_model_generator.py b/src/vgsl_model_generator.py index 1f7548f8..8efab367 100644 --- a/src/vgsl_model_generator.py +++ b/src/vgsl_model_generator.py @@ -102,7 +102,6 @@ def __init__(self, self._initializer = initializers.GlorotNormal(seed=42) self._channel_axis = -1 self.model_library = VGSLModelGenerator.get_model_libary() - self.model_name = name if name else model if model is None: raise ValueError("No model provided. Please provide a model name " @@ -115,6 +114,8 @@ def __init__(self, self.init_model_from_string(model_string, channels, output_classes) + self.model_name = name if name else model + except KeyError: raise KeyError("Model not found in model library") else: @@ -123,9 +124,7 @@ def __init__(self, self.init_model_from_string(model, channels, output_classes) - - # TODO: Add model_name argument to arg_parser.py - self.model_name = "custom_model" + self.model_name = name if name else "custom_model" except (TypeError, AttributeError) as e: raise ("Something is wrong with the input string, "