Skip to content

Commit

Permalink
Remove most args from deprecation zone
Browse files Browse the repository at this point in the history
use_mask to be replaced with a new tokenizer implementation
  • Loading branch information
TimKoornstra committed Aug 27, 2024
1 parent acaed98 commit febcb76
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 112 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ cd src/visualize

```bash
python3 main.py
--existing_model /path/to/existing/model
--model /path/to/existing/model
--sample_image /path/to/sample/img
```

Expand Down
11 changes: 0 additions & 11 deletions configs/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,6 @@
"greedy": false,
"wbs_smoothing": 0.1
},
"depr": {
"channels": 3,
"config_file_output": null,
"do_inference": false,
"do_train": true,
"height": 64,
"no_auto": false,
"output_charlist": null,
"thaw": false,
"use_mask": true
},
"general": {
"batch_size": 4,
"charlist": null,
Expand Down
3 changes: 0 additions & 3 deletions configs/training.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
"beam_width": 10,
"greedy": false
},
"depr": {
"channels": 1
},
"general": {
"batch_size": 64,
"gpu": "0",
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def main():
decay_rate=config["decay_rate"],
decay_steps=config["decay_steps"],
train_batches=data_manager.get_train_batches(),
do_train=config["do_train"],
do_train=config["train_list"],
warmup_ratio=config["warmup_ratio"],
epochs=config["epochs"],
decay_per_epoch=config["decay_per_epoch"],
Expand Down
12 changes: 5 additions & 7 deletions src/model/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,12 @@ def customize_model(model: tf.keras.Model,
model = replace_final_layer(model, len(charlist), model.name,
use_mask=config["use_mask"])

# Freeze or thaw layers if specified
if any([config["thaw"], config["freeze_conv_layers"],
config["freeze_recurrent_layers"], config["freeze_dense_layers"]]):
# Freeze layers if specified
if any([config["freeze_conv_layers"],
config["freeze_recurrent_layers"],
config["freeze_dense_layers"]]):
for layer in model.layers:
if config["thaw"]:
layer.trainable = True
logging.info("Thawing layer: %s", layer.name)
elif config["freeze_conv_layers"] and \
if config["freeze_conv_layers"] and \
(layer.name.lower().startswith("conv") or
layer.name.lower().startswith("residual")):
logging.info("Freezing layer: %s", layer.name)
Expand Down
77 changes: 0 additions & 77 deletions src/setup/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,96 +238,20 @@ def get_arg_parser():
# Deprecation zone
depr_args = parser.add_argument_group(
'Deprecation zone', 'These arguments will be removed in the future')
depr_args.add_argument('--do_train', help='enable the training. '
'Use this flag if you want to train.',
action='store_true')
depr_args.add_argument('--do_inference', help='inference',
action='store_true')
depr_args.add_argument('--use_mask', help='whether or not to mask certain '
'parts of the data. Defaults to true when '
'batch_size > 1', action='store_true')
depr_args.add_argument('--no_auto', action='store_true',
help='No Auto disabled automatic "fixing" of '
'certain parameters')
depr_args.add_argument('--height', metavar='height', type=int, default=64,
help='rescale everything to this height before '
'training, default 64')
depr_args.add_argument('--channels', metavar='channels', type=int,
default=3, help='number of channels to use. 1 for '
'grey-scale/binary images, three for color images, '
'4 for png\'s with transparency')
depr_args.add_argument('--output_charlist', metavar='output_charlist',
type=str, default=None, help="Path to save the "
"character list used during training/inference. "
"If not specified, the charlist is saved to"
"'output/charlist.txt'.")
depr_args.add_argument('--config_file_output',
metavar='config_file_output', type=str,
default=None, help="Path to save the "
"configuration file. If not specified, the "
"configuration is set to 'output/config.json'.")
depr_args.add_argument('--thaw', action='store_true',
help="Unfreeze convolutional layers in an "
"existing model for further training.")
depr_args.add_argument('--existing_model', metavar='existing_model',
type=str, default=None, help="Path to an existing "
"model to continue training, validation, testing, "
"or inferencing. Used as a starting point.")

return parser


def fix_args(args):
if not args.no_auto and args.train_list:
logging.warning('--do_train implied by providing a train_list')
args.__dict__['do_train'] = True
if not args.no_auto and args.batch_size > 1:
logging.warning('--batch_size > 1, setting use_mask=True')
args.__dict__['use_mask'] = True
if not args.no_auto and args.inference_list:
logging.warning('--do_inference implied by providing a inference_list')
args.__dict__['do_inference'] = True
if not args.no_auto and args.existing_model:
args.__dict__['model'] = args.existing_model


def arg_future_warning(args):
logger = logging.getLogger(__name__)

# May 2024
if args.do_train:
logger.warning("Argument will lose support in May 2024: --do_train. "
"Training will be enabled by providing a train_list. ")
if args.do_inference:
logger.warning("Argument will lose support in May 2024: "
"--do_inference. Inference will be enabled by "
"providing an inference_list. ")
if args.use_mask:
logger.warning("Argument will lose support in May 2024: --use_mask. "
"Masking will be enabled by default.")
if args.no_auto:
logger.warning("Argument will lose support in May 2024: --no_auto.")
if args.height:
logger.warning("Argument will lose support in May 2024: --height. "
"Height will be inferred from the VGSL spec.")
if args.channels:
logger.warning("Argument will lose support in May 2024: --channels. "
"Channels will be inferred from the VGSL spec.")
if args.output_charlist:
logger.warning("Argument will lose support in May 2024: "
"--output_charlist. The charlist will be saved to "
"output/charlist.txt by default.")
if args.config_file_output:
logger.warning("Argument will lose support in May 2024: "
"--config_file_output. The configuration will be saved "
"to output/config.json by default.")
if args.thaw:
logging.warning("Argument will lose support in May 2024: --thaw. "
"Models are saved with all layers thawed by default.")
if args.existing_model:
logger.warning("Argument will lose support in May 2024: "
"--existing_model. The --model argument can be used "
"to load or create a model instead.")


def check_required_args(args, explicit):
Expand Down Expand Up @@ -377,7 +301,6 @@ def get_args():

# TODO: remove after deprecation period
arg_future_warning(args)
fix_args(args)
check_required_args(args, explicit)

if args.steps_per_epoch:
Expand Down
12 changes: 1 addition & 11 deletions src/setup/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def save(self, output_file: str = None) -> None:
"""

if not output_file:
output_file = self.args.config_file_output or \
f"{self.args.output}/config.json"
output_file = f"{self.args.output}/config.json"
try:
with open(output_file, "w", encoding="utf-8") as file:
json.dump(self.config, file, indent=4, sort_keys=True)
Expand Down Expand Up @@ -222,16 +221,7 @@ def organize_args(self, args: argparse.Namespace) -> dict:
"deterministic": args.deterministic
},
"depr": {
"do_train": args.do_train,
"do_inference": args.do_inference,
"use_mask": args.use_mask,
"no_auto": args.no_auto,
"height": args.height,
"channels": args.channels,
"output_charlist": args.output_charlist,
"config_file_output": args.config_file_output,
"thaw": args.thaw,
"existing_model": args.existing_model,
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ class Tokenizer:
Decodes the tokenized sequences back into text.
"""

def __init__(self, chars: list, use_mask: bool = False,
def __init__(self,
chars: list,
use_mask: bool = False,
num_oov_indices: int = 1):
"""
Initializes the Tokenizer with a given character list and mask option.
Expand Down

0 comments on commit febcb76

Please sign in to comment.