Skip to content

Commit

Permalink
[MASK] -> [PAD]
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 28, 2024
1 parent 8def198 commit 64f35a9
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
3 changes: 0 additions & 3 deletions src/modes/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def perform_inference(config: Config,

# Print the predictions and process the CER
for index, (confidence, prediction) in enumerate(y_pred):
# Remove the special characters from the prediction
prediction = prediction.strip().replace('[MASK]', '')

# Format the filename
filename = data_manager.get_filename('inference',
(batch_no *
Expand Down
2 changes: 1 addition & 1 deletion src/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def decode_batch_predictions(pred: np.ndarray, tokenizer: Tokenizer,
decoded_array += 1 # Shift the index by 1 to account for the blank character

# Normalize the confidence score based on the number of timesteps
text = tokenizer.decode(decoded_array).strip().replace("[MASK]", "")
text = tokenizer.decode(decoded_array).strip().replace("[PAD]", "")

# Calculate the effective steps for each sample in the batch
# That is before the first blank character
Expand Down
8 changes: 4 additions & 4 deletions src/utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def _initialize_string_lookup_layers(self):
num_oov_indices=1,
oov_token='[UNK]',
encoding="UTF-8",
mask_token='[MASK]'
mask_token='[PAD]'
)
self.num_to_token = tf.keras.layers.StringLookup(
vocabulary=self.token_to_num.get_vocabulary(),
num_oov_indices=0,
oov_token='',
encoding="UTF-8",
invert=True,
mask_token='[MASK]'
mask_token='[PAD]'
)

self.token_list = self.token_to_num.get_vocabulary()
Expand Down Expand Up @@ -257,11 +257,11 @@ def preprocess_text(text: str) -> str:
Notes
-----
This function performs operations like stripping whitespace, replacing
specific characters (e.g., '[MASK]'), and removing certain control tags using
specific characters (e.g., '[PAD]'), and removing certain control tags using
the `remove_tags` function.
"""

text = text.strip().replace('[MASK]', '')
text = text.strip().replace('[PAD]', '')
text = remove_tags(text)
return text

Expand Down

0 comments on commit 64f35a9

Please sign in to comment.