Skip to content

Commit

Permalink
Fix incorrect length train batches bug + some extra logging
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 13, 2024
1 parent 0079214 commit e2baf48
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,19 @@ def __init__(self,

# Process the raw data and create file names, labels, sample weights,
# and tokenizer
logging.info("Processing raw data...")
file_names, labels, sample_weights, self.tokenizer \
= self._process_raw_data()

self.raw_data = {split: (file_names[split], labels[split],
sample_weights[split])
for split in ['train', 'evaluation', 'validation',
'test', 'inference']}

# Fill the datasets dictionary with datasets for different partitions
self.datasets = self._fill_datasets_dict(
file_names, labels, sample_weights)
logging.info("Creating datasets...")
self.datasets = self._fill_datasets_dict(file_names, labels,
sample_weights)

def _process_raw_data(self) -> Tuple[Dict[str, List[str]],
Dict[str, List[str]],
Expand Down Expand Up @@ -362,7 +365,7 @@ def _is_valid_ground_truth(self,
if unsupported_characters:
# Unsupported characters are allowed in the validation, inference,
# and test partitions, but not in the evaluation partition
if partition_name in ['validation', 'inference', 'test']:
if partition_name in ('validation', 'inference', 'test'):
return True

if partition_name == 'train' and not self.charlist:
Expand Down Expand Up @@ -406,7 +409,7 @@ def get_ground_truth(self, partition: str, item_id: int):

def get_train_batches(self):
""" Get the number of batches for training """
return int(np.ceil(len(self.raw_data['train'])
return int(np.ceil(len(self.raw_data['train'][0])
/ self.config['batch_size']))

def _create_dataset(self,
Expand Down Expand Up @@ -452,7 +455,7 @@ def _create_dataset(self,
dataset = tf.data.Dataset.from_tensor_slices(data)
if is_training:
# Add additional repeat and shuffle for training
dataset = dataset.repeat().shuffle(len(files))
dataset = dataset.repeat().shuffle(len(files), seed=42)

dataset = (dataset
.map(data_loader.load_images,
Expand Down

0 comments on commit e2baf48

Please sign in to comment.