Skip to content

Commit

Permalink
Fix test and inference modes
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 7, 2024
1 parent abc6e74 commit f7bea35
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
21 changes: 13 additions & 8 deletions src/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _create_data(self,
# Extract the filename and ground truth from the fields
file_name = fields[0]

# Skip missing files unless explicitly included
# Skip missing files
if not os.path.exists(file_name):
logging.warning(f"Missing: {file_name} in {file_path}."
f" Skipping for {partition_name}...")
Expand All @@ -162,10 +162,14 @@ def _create_data(self,
# ground truth unless it's an inference partition
ground_truth = fields[-1] if len(fields) > 1 else ""

if partition_name != "inference" and not ground_truth:
logging.warning(f"Empty ground truth in {line}. "
f"Skipping for {partition_name}...")
continue
if not ground_truth:
if partition_name == "inference":
ground_truth = "INFERENCE"
else:
logging.warning(f"Empty ground truth in {line}. "
f"Skipping for {partition_name}..."
)
continue

# Normalize the ground truth if a normalization file is
# provided and the partition is either 'train' or
Expand All @@ -180,9 +184,10 @@ def _create_data(self,
# Check for unsupported characters in the ground truth
# Evaluation partition is allowed to have unsupported
# characters for a more realistic evaluation
if any(char not in characters for char in ground_truth) \
and partition_name != "validation":
if partition_name == 'train' and not self.charlist:
if any(char not in characters for char in ground_truth):
if partition_name in ['evaluation', 'inference']:
pass
elif partition_name == 'train' and not self.charlist:
characters.update(set(ground_truth))
else:
logging.warning("Unsupported character in %s. "
Expand Down
6 changes: 4 additions & 2 deletions src/modes/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ def perform_test(config: Config,
n_items = 0

for batch_no, batch in enumerate(test_dataset):
logging.info("Batch %s/%s", batch_no + 1, len(test_dataset))
# Unpack the batch and ignore the third element (sample weights)
X, y_true, _ = batch

batch_counter = process_batch(batch, prediction_model, tokenizer,
logging.info("Batch %s/%s", batch_no + 1, len(test_dataset))
batch_counter = process_batch((X, y_true), prediction_model, tokenizer,
config, wbs, dataloader, charlist)

# Update the total counter
Expand Down

0 comments on commit f7bea35

Please sign in to comment.