diff --git a/src/data/loader.py b/src/data/loader.py index 1490eaeb..28841746 100644 --- a/src/data/loader.py +++ b/src/data/loader.py @@ -152,7 +152,7 @@ def _create_data(self, # Extract the filename and ground truth from the fields file_name = fields[0] - # Skip missing files unless explicitly included + # Skip missing files if not os.path.exists(file_name): logging.warning(f"Missing: {file_name} in {file_path}." f" Skipping for {partition_name}...") @@ -162,10 +162,14 @@ def _create_data(self, # ground truth unless it's an inference partition ground_truth = fields[-1] if len(fields) > 1 else "" - if partition_name != "inference" and not ground_truth: - logging.warning(f"Empty ground truth in {line}. " - f"Skipping for {partition_name}...") - continue + if not ground_truth: + if partition_name == "inference": + ground_truth = "INFERENCE" + else: + logging.warning(f"Empty ground truth in {line}. " + f"Skipping for {partition_name}..." + ) + continue # Normalize the ground truth if a normalization file is # provided and the partition is either 'train' or @@ -180,9 +184,10 @@ def _create_data(self, # Check for unsupported characters in the ground truth # Evaluation partition is allowed to have unsupported # characters for a more realistic evaluation - if any(char not in characters for char in ground_truth) \ - and partition_name != "validation": - if partition_name == 'train' and not self.charlist: + if any(char not in characters for char in ground_truth): + if partition_name in ['evaluation', 'inference']: + pass + elif partition_name == 'train' and not self.charlist: characters.update(set(ground_truth)) else: logging.warning("Unsupported character in %s. " diff --git a/src/modes/test.py b/src/modes/test.py index 65bc7b1c..c7b46405 100644 --- a/src/modes/test.py +++ b/src/modes/test.py @@ -157,9 +157,11 @@ def perform_test(config: Config, n_items = 0 for batch_no, batch in enumerate(test_dataset): - logging.info("Batch %s/%s", batch_no + 1, len(test_dataset)) + # Unpack the batch and ignore the third element (sample weights) + X, y_true, _ = batch - batch_counter = process_batch(batch, prediction_model, tokenizer, + logging.info("Batch %s/%s", batch_no + 1, len(test_dataset)) + batch_counter = process_batch((X, y_true), prediction_model, tokenizer, config, wbs, dataloader, charlist) # Update the total counter