Skip to content

Commit

Permalink
Fix inference and test with new tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 28, 2024
1 parent 142acc2 commit 8def198
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 21 deletions.
7 changes: 2 additions & 5 deletions src/modes/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
def perform_inference(config: Config,
model: tf.keras.Model,
inference_dataset: tf.data.Dataset,
charlist: List[str],
data_manager: DataManager) -> None:
"""
Performs inference on a given dataset using a specified model and writes
Expand All @@ -33,8 +32,6 @@ def perform_inference(config: Config,
The Keras model to be used for inference.
inference_dataset : tf.data.Dataset
The dataset on which inference is to be performed.
charlist : List[str]
A list of characters used in the model, for decoding predictions.
data_manager : DataManager
A data manager object used for retrieving additional information needed
during inference (e.g., filenames).
Expand All @@ -48,7 +45,7 @@ def perform_inference(config: Config,
results.
"""

tokenizer = Tokenizer(charlist)
tokenizer = data_manager.tokenizer

with open(config["results_file"], "w", encoding="utf-8") as results_file:
for batch_no, batch in enumerate(inference_dataset):
Expand All @@ -62,7 +59,7 @@ def perform_inference(config: Config,
# Print the predictions and process the CER
for index, (confidence, prediction) in enumerate(y_pred):
# Remove the special characters from the prediction
prediction = prediction.strip().replace('', '')
prediction = prediction.strip().replace('[MASK]', '')

# Format the filename
filename = data_manager.get_filename('inference',
Expand Down
21 changes: 5 additions & 16 deletions src/modes/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor],
prediction_model: tf.keras.Model,
tokenizer: Tokenizer,
config: Config,
wbs: Optional[Any],
data_manager: DataManager,
chars: List[str]) -> Dict[str, int]:
wbs: Optional[Any]) -> Dict[str, int]:
"""
Processes a batch of data by predicting, calculating Character Error Rate
(CER), and handling Word Beam Search (WBS) if enabled.
Expand All @@ -43,11 +41,6 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor],
wbs : Optional[Any]
An optional Word Beam Search object for advanced decoding, if
applicable.
data_manager : DataManager
A data data_manager object for additional operations like
normalization.
chars : List[str]
A list of characters used in the model.
Returns
-------
Expand All @@ -68,7 +61,7 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor],
# Transpose the predictions for WordBeamSearch
if wbs:
predsbeam = tf.transpose(predictions, perm=[1, 0, 2])
char_str = handle_wbs_results(predsbeam, wbs, chars)
char_str = handle_wbs_results(predsbeam, wbs, tokenizer.token_list)
else:
char_str = None

Expand Down Expand Up @@ -113,7 +106,6 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor],
def perform_test(config: Config,
model: tf.keras.Model,
test_dataset: tf.data.Dataset,
charlist: List[str],
data_manager: DataManager) -> None:
"""
Performs test run on a dataset using a given model and calculates various
Expand All @@ -128,8 +120,6 @@ def perform_test(config: Config,
The Keras model to be validated.
test_dataset : tf.data.Dataset
The dataset to be used for testing.
charlist : List[str]
A list of characters used in the model.
data_manager : DataManager
A data data_manager object for additional operations like normalization
and Word Beam Search setup.
Expand All @@ -143,11 +133,10 @@ def perform_test(config: Config,
"""

logging.info("Performing test...")

tokenizer = Tokenizer(charlist)
tokenizer = data_manager.tokenizer

# Setup WordBeamSearch if needed
wbs = setup_word_beam_search(config, charlist) \
wbs = setup_word_beam_search(config, tokenizer.token_list) \
if config["corpus_file"] else None

# Initialize the counters
Expand All @@ -160,7 +149,7 @@ def perform_test(config: Config,

logging.info("Batch %s/%s", batch_no + 1, len(test_dataset))
batch_counter = process_batch((X, y_true), model, tokenizer,
config, wbs, data_manager, charlist)
config, wbs)

# Update the total counter
for key, value in batch_counter.items():
Expand Down

0 comments on commit 8def198

Please sign in to comment.