Skip to content

Commit

Permalink
--replace_final_layer with new tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 28, 2024
1 parent 8169d03 commit 142acc2
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/data/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def initialize_data_manager(config: Config,
config : Config
A Config containing various arguments to configure the data manager
(e.g., batch size, image size, lists for training, validation, etc.).
charlist : List[str]
A list of characters to be used by the data manager.
tokenizer : Tokenizer
The tokenizer to be used for tokenizing text.
model : tf.keras.Model
The Keras model, used to derive input dimensions for the data manager.
augment_model : tf.keras.Sequential
Expand Down
4 changes: 2 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main():
)

# Load the tokenizer if a valid path was found
if json_path:
if json_path and not config["replace_final_layer"]:
tokenizer = Tokenizer.load_from_file(json_path)
else:
tokenizer = None # Indicate that a new tokenizer will be created later
Expand Down Expand Up @@ -91,7 +91,7 @@ def main():
# layers, or adjusting for float32
model = customize_model(model, config, tokenizer)

# Save the charlist
# Save the tokenizer
tokenizer.save_to_json(os.path.join(config["output"],
"tokenizer.json"))

Expand Down
5 changes: 1 addition & 4 deletions src/model/replacing.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,7 @@ def replace_final_layer(model: tf.keras.models.Model,
inputs=model.inputs, outputs=model.get_layer(last_layer).output
)

# Account for the mask and OOV tokens
units = number_characters + 2

x = layers.Dense(units, activation="softmax", name="dense_out",
x = layers.Dense(number_characters, activation="softmax", name="dense_out",
kernel_initializer=initializer)(prediction_model.output)

# Add a linear activation layer with float32 data type
Expand Down
8 changes: 4 additions & 4 deletions src/setup/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def get_arg_parser():
general_args.add_argument('--seed', metavar='seed', type=int, default=42,
help="Seed for random number generators to "
"ensure reproducibility. Default: 42.")
general_args.add_argument('--charlist', metavar='charlist ', type=str,
default=None, help="Path to a file containing "
"the list of characters to be recognized. "
"Required for inference and validation.")
general_args.add_argument('--tokenizer', metavar='tokenizer', type=str,
default=None, help="Path to a tokenizer file for "
"the model. Required for inference and "
"validation.")
general_args.add_argument('--test_list', metavar='test_list',
type=str, default=None, help="File(s) with "
"textline locations and transcriptions for "
Expand Down
2 changes: 1 addition & 1 deletion src/setup/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def organize_args(self, args: argparse.Namespace) -> dict:
"config_file": args.config_file,
"batch_size": args.batch_size,
"seed": args.seed,
"charlist": args.charlist
"tokenizer": args.tokenizer
},
"training": {
"epochs": args.epochs,
Expand Down

0 comments on commit 142acc2

Please sign in to comment.