From 64f35a98208e3716fedae3d246a0a34e5ddb5393 Mon Sep 17 00:00:00 2001 From: Tim Koornstra Date: Wed, 28 Aug 2024 12:15:50 +0200 Subject: [PATCH] [MASK] -> [PAD] --- src/modes/inference.py | 3 --- src/utils/decoding.py | 2 +- src/utils/text.py | 8 ++++---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/modes/inference.py b/src/modes/inference.py index 69dffd7..5c7f2fc 100644 --- a/src/modes/inference.py +++ b/src/modes/inference.py @@ -58,9 +58,6 @@ def perform_inference(config: Config, # Print the predictions and process the CER for index, (confidence, prediction) in enumerate(y_pred): - # Remove the special characters from the prediction - prediction = prediction.strip().replace('[MASK]', '') - # Format the filename filename = data_manager.get_filename('inference', (batch_no * diff --git a/src/utils/decoding.py b/src/utils/decoding.py index 1d88f07..ab3fa3a 100644 --- a/src/utils/decoding.py +++ b/src/utils/decoding.py @@ -98,7 +98,7 @@ def decode_batch_predictions(pred: np.ndarray, tokenizer: Tokenizer, decoded_array += 1 # Shift the index by 1 to account for the blank character # Normalize the confidence score based on the number of timesteps - text = tokenizer.decode(decoded_array).strip().replace("[MASK]", "") + text = tokenizer.decode(decoded_array).strip().replace("[PAD]", "") # Calculate the effective steps for each sample in the batch # That is before the first blank character diff --git a/src/utils/text.py b/src/utils/text.py index d0d8845..27a9772 100644 --- a/src/utils/text.py +++ b/src/utils/text.py @@ -43,7 +43,7 @@ def _initialize_string_lookup_layers(self): num_oov_indices=1, oov_token='[UNK]', encoding="UTF-8", - mask_token='[MASK]' + mask_token='[PAD]' ) self.num_to_token = tf.keras.layers.StringLookup( vocabulary=self.token_to_num.get_vocabulary(), @@ -51,7 +51,7 @@ def _initialize_string_lookup_layers(self): oov_token='', encoding="UTF-8", invert=True, - mask_token='[MASK]' + mask_token='[PAD]' ) self.token_list = self.token_to_num.get_vocabulary() @@ -257,11 +257,11 @@ def preprocess_text(text: str) -> str: Notes ----- This function performs operations like stripping whitespace, replacing - specific characters (e.g., '[MASK]'), and removing certain control tags using + specific characters (e.g., '[PAD]'), and removing certain control tags using the `remove_tags` function. """ - text = text.strip().replace('[MASK]', '') + text = text.strip().replace('[PAD]', '') text = remove_tags(text) return text