diff --git a/src/data/data_handling.py b/src/data/data_handling.py index 0f835ec..d7f9797 100644 --- a/src/data/data_handling.py +++ b/src/data/data_handling.py @@ -22,8 +22,8 @@ def initialize_data_manager(config: Config, config : Config A Config containing various arguments to configure the data manager (e.g., batch size, image size, lists for training, validation, etc.). - charlist : List[str] - A list of characters to be used by the data manager. + tokenizer : Tokenizer + The tokenizer to be used for tokenizing text. model : tf.keras.Model The Keras model, used to derive input dimensions for the data manager. augment_model : tf.keras.Sequential diff --git a/src/main.py b/src/main.py index be01444..aa25171 100644 --- a/src/main.py +++ b/src/main.py @@ -60,7 +60,7 @@ def main(): ) # Load the tokenizer if a valid path was found - if json_path: + if json_path and not config["replace_final_layer"]: tokenizer = Tokenizer.load_from_file(json_path) else: tokenizer = None # Indicate that a new tokenizer will be created later @@ -91,7 +91,7 @@ def main(): # layers, or adjusting for float32 model = customize_model(model, config, tokenizer) - # Save the charlist + # Save the tokenizer tokenizer.save_to_json(os.path.join(config["output"], "tokenizer.json")) diff --git a/src/model/replacing.py b/src/model/replacing.py index 23a5f36..06dd1a6 100644 --- a/src/model/replacing.py +++ b/src/model/replacing.py @@ -128,10 +128,7 @@ def replace_final_layer(model: tf.keras.models.Model, inputs=model.inputs, outputs=model.get_layer(last_layer).output ) - # Account for the mask and OOV tokens - units = number_characters + 2 - - x = layers.Dense(units, activation="softmax", name="dense_out", + x = layers.Dense(number_characters, activation="softmax", name="dense_out", kernel_initializer=initializer)(prediction_model.output) # Add a linear activation layer with float32 data type diff --git a/src/setup/arg_parser.py b/src/setup/arg_parser.py index 72c5739..76f2062 100644 --- a/src/setup/arg_parser.py +++ b/src/setup/arg_parser.py @@ -33,10 +33,10 @@ def get_arg_parser(): general_args.add_argument('--seed', metavar='seed', type=int, default=42, help="Seed for random number generators to " "ensure reproducibility. Default: 42.") - general_args.add_argument('--charlist', metavar='charlist ', type=str, - default=None, help="Path to a file containing " - "the list of characters to be recognized. " - "Required for inference and validation.") + general_args.add_argument('--tokenizer', metavar='tokenizer', type=str, + default=None, help="Path to a tokenizer file for " + "the model. Required for inference and " + "validation.") general_args.add_argument('--test_list', metavar='test_list', type=str, default=None, help="File(s) with " "textline locations and transcriptions for " diff --git a/src/setup/config.py b/src/setup/config.py index 00941cd..1f39d88 100644 --- a/src/setup/config.py +++ b/src/setup/config.py @@ -159,7 +159,7 @@ def organize_args(self, args: argparse.Namespace) -> dict: "config_file": args.config_file, "batch_size": args.batch_size, "seed": args.seed, - "charlist": args.charlist + "tokenizer": args.tokenizer }, "training": { "epochs": args.epochs,