From 8def198f8fbbb09a3d1f3d24c2520b137f4247a1 Mon Sep 17 00:00:00 2001 From: Tim Koornstra Date: Wed, 28 Aug 2024 12:01:07 +0200 Subject: [PATCH] Fix inference and test with new tokenizer --- src/modes/inference.py | 7 ++----- src/modes/test.py | 21 +++++---------------- 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/src/modes/inference.py b/src/modes/inference.py index a182574..69dffd7 100644 --- a/src/modes/inference.py +++ b/src/modes/inference.py @@ -17,7 +17,6 @@ def perform_inference(config: Config, model: tf.keras.Model, inference_dataset: tf.data.Dataset, - charlist: List[str], data_manager: DataManager) -> None: """ Performs inference on a given dataset using a specified model and writes @@ -33,8 +32,6 @@ def perform_inference(config: Config, The Keras model to be used for inference. inference_dataset : tf.data.Dataset The dataset on which inference is to be performed. - charlist : List[str] - A list of characters used in the model, for decoding predictions. data_manager : DataManager A data manager object used for retrieving additional information needed during inference (e.g., filenames). @@ -48,7 +45,7 @@ def perform_inference(config: Config, results. """ - tokenizer = Tokenizer(charlist) + tokenizer = data_manager.tokenizer with open(config["results_file"], "w", encoding="utf-8") as results_file: for batch_no, batch in enumerate(inference_dataset): @@ -62,7 +59,7 @@ def perform_inference(config: Config, # Print the predictions and process the CER for index, (confidence, prediction) in enumerate(y_pred): # Remove the special characters from the prediction - prediction = prediction.strip().replace('', '') + prediction = prediction.strip().replace('[MASK]', '') # Format the filename filename = data_manager.get_filename('inference', diff --git a/src/modes/test.py b/src/modes/test.py index 0104f0d..23533e9 100644 --- a/src/modes/test.py +++ b/src/modes/test.py @@ -21,9 +21,7 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor], prediction_model: tf.keras.Model, tokenizer: Tokenizer, config: Config, - wbs: Optional[Any], - data_manager: DataManager, - chars: List[str]) -> Dict[str, int]: + wbs: Optional[Any]) -> Dict[str, int]: """ Processes a batch of data by predicting, calculating Character Error Rate (CER), and handling Word Beam Search (WBS) if enabled. @@ -43,11 +41,6 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor], wbs : Optional[Any] An optional Word Beam Search object for advanced decoding, if applicable. - data_manager : DataManager - A data data_manager object for additional operations like - normalization. - chars : List[str] - A list of characters used in the model. Returns ------- @@ -68,7 +61,7 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor], # Transpose the predictions for WordBeamSearch if wbs: predsbeam = tf.transpose(predictions, perm=[1, 0, 2]) - char_str = handle_wbs_results(predsbeam, wbs, chars) + char_str = handle_wbs_results(predsbeam, wbs, tokenizer.token_list) else: char_str = None @@ -113,7 +106,6 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor], def perform_test(config: Config, model: tf.keras.Model, test_dataset: tf.data.Dataset, - charlist: List[str], data_manager: DataManager) -> None: """ Performs test run on a dataset using a given model and calculates various @@ -128,8 +120,6 @@ def perform_test(config: Config, The Keras model to be validated. test_dataset : tf.data.Dataset The dataset to be used for testing. - charlist : List[str] - A list of characters used in the model. data_manager : DataManager A data data_manager object for additional operations like normalization and Word Beam Search setup. @@ -143,11 +133,10 @@ def perform_test(config: Config, """ logging.info("Performing test...") - - tokenizer = Tokenizer(charlist) + tokenizer = data_manager.tokenizer # Setup WordBeamSearch if needed - wbs = setup_word_beam_search(config, charlist) \ + wbs = setup_word_beam_search(config, tokenizer.token_list) \ if config["corpus_file"] else None # Initialize the counters @@ -160,7 +149,7 @@ def perform_test(config: Config, logging.info("Batch %s/%s", batch_no + 1, len(test_dataset)) batch_counter = process_batch((X, y_true), model, tokenizer, - config, wbs, data_manager, charlist) + config, wbs) # Update the total counter for key, value in batch_counter.items():