Skip to content

Commit

Permalink
Replace DataLoader params by config
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 7, 2024
1 parent b77817c commit 9c4d49b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 58 deletions.
15 changes: 4 additions & 11 deletions src/data/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def initialize_data_loader(config: Config,
charlist: List[str],
model: tf.keras.Model,
augment_model) -> DataLoader:
augment_model: tf.keras.Sequential) -> DataLoader:
"""
Initializes a data loader with specified parameters and based on the input
shape of a given model.
Expand All @@ -30,6 +30,8 @@ def initialize_data_loader(config: Config,
A list of characters to be used by the data loader.
model : tf.keras.Model
The Keras model, used to derive input dimensions for the data loader.
augment_model : tf.keras.Sequential
The Keras model used for data augmentation.
Returns
-------
Expand All @@ -49,19 +51,10 @@ def initialize_data_loader(config: Config,
img_size = (model_height, config["width"], model_channels)

return DataLoader(
batch_size=config["batch_size"],
img_size=img_size,
train_list=config["train_list"],
test_list=config["test_list"],
validation_list=config["validation_list"],
inference_list=config["inference_list"],
charlist=charlist,
multiply=config["aug_multiply"],
check_missing_files=config["check_missing_files"],
replace_final_layer=config["replace_final_layer"],
normalization_file=config["normalization_file"],
use_mask=config["use_mask"],
augment_model=augment_model,
config=config
)


Expand Down
66 changes: 29 additions & 37 deletions src/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,24 @@ class DataLoader:
according to parameters"""

def __init__(self,
batch_size,
img_size,
augment_model,
config,
charlist=None,
train_list='',
validation_list='',
test_list='',
inference_list='',
normalization_file=None,
multiply=1,
check_missing_files=True,
replace_final_layer=False,
use_mask=False
):

# TODO: Change most of these to use config
self.batch_size = batch_size
self.height = img_size[0]
self.augment_model = augment_model
self.height = img_size[0]
self.channels = img_size[2]
self.config = config

# TODO: Make this more clear
self.injected_charlist = charlist
self.train_list = train_list
self.validation_list = validation_list
self.test_list = test_list
self.inference_list = inference_list
self.normalization_file = normalization_file
self.multiply = multiply
self.check_missing_files = check_missing_files
self.replace_final_layer = replace_final_layer
self.use_mask = use_mask
self.charlist = charlist

self.evaluation_list = None
if train_list and validation_list:
self.evaluation_list = validation_list
if self.config['do_validate']:
self.evaluation_list = self.config['validation_list']

partitions, labels, self.tokenizer = self._process_raw_data()
self.raw_data = {split: (partitions[split], labels[split])
Expand All @@ -72,7 +55,10 @@ def _process_raw_data(self):

for partition in ['train', 'evaluation', 'validation',
'test', 'inference']:
partition_list = getattr(self, f"{partition}_list", None)
if partition == "evaluation":
partition_list = self.evaluation_list
else:
partition_list = self.config[f"{partition}_list"]
if partition_list:
include_unsupported_chars = partition in ['validation', 'test']
use_multiply = partition == 'train'
Expand All @@ -87,14 +73,14 @@ def _process_raw_data(self):
)

# Determine the character list for the tokenizer
if self.injected_charlist and not self.replace_final_layer:
if self.injected_charlist and not self.config['replace_final_layer']:
logging.info('Using injected charlist')
self.charlist = self.injected_charlist
else:
self.charlist = sorted(list(characters))

# Initialize the tokenizer
tokenizer = Tokenizer(self.charlist, self.use_mask)
tokenizer = Tokenizer(self.charlist, self.config['use_mask'])

return partitions, labels, tokenizer

Expand All @@ -109,7 +95,10 @@ def _fill_datasets_dict(self, partitions, labels):

for partition in ['train', 'evaluation', 'validation',
'test', 'inference']:
partition_list = getattr(self, f"{partition}_list", None)
if partition == "evaluation":
partition_list = self.evaluation_list
else:
partition_list = self.config[f"{partition}_list"]
if partition_list:
# Create dataset for the current partition
files = list(zip(partitions[partition], labels[partition]))
Expand All @@ -119,15 +108,16 @@ def _fill_datasets_dict(self, partitions, labels):
augment_model=self.augment_model,
height=self.height,
channels=self.channels,
batch_size=self.batch_size,
batch_size=self.config['batch_size'],
is_training=partition == 'train',
deterministic=partition != 'train'
)

return datasets

def get_train_batches(self):
return int(np.ceil(len(self.datasets['train']) / self.batch_size))
return int(np.ceil(len(self.raw_data['train'])
/ self.config['batch_size']))

def create_data(self, characters, labels, partitions, partition_name,
data_files,
Expand Down Expand Up @@ -175,20 +165,22 @@ def create_data(self, characters, labels, partitions, partition_name,
# filename
file_name = line_parts[0]
# Skip missing files unless explicitly included.
if not include_missing_files and self.check_missing_files \
and not os.path.exists(file_name):
if not include_missing_files and \
self.config['check_missing_files'] and \
not os.path.exists(file_name):
logging.warning(f"Missing: {file_name} in {file_path}. "
f"Skipping for {partition_name}...")
continue

# Determine the ground truth text.
if is_inference:
ground_truth = 'to be determined'
elif self.normalization_file and \
elif self.config['normalization_file'] and \
(partition_name == 'train'
or partition_name == 'evaluation'):
ground_truth = normalize_text(line_parts[1],
self.normalization_file)
ground_truth = normalize_text(
line_parts[1],
self.config['normalization_file'])
else:
ground_truth = line_parts[1]

Expand Down Expand Up @@ -216,7 +208,7 @@ def create_data(self, characters, labels, partitions, partition_name,
valid_lines += 1
if use_multiply:
# Multiply the data if requested
for _ in range(self.multiply):
for _ in range(self.config['aug_multiply']):
partitions[partition_name].append(file_name)
labels[partition_name].append(ground_truth)
processed_files.append([file_name, ground_truth])
Expand All @@ -226,7 +218,7 @@ def create_data(self, characters, labels, partitions, partition_name,
labels[partition_name].append(ground_truth)
processed_files.append([file_name, ground_truth])
if (not self.injected_charlist or
self.replace_final_layer) \
self.config['replace_final_layer']) \
and partition_name == 'train':
characters = characters.union(
set(char for label in ground_truth for char in label))
Expand Down
28 changes: 18 additions & 10 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ def setUpClass(cls):
level=logging.ERROR,
)

cls.dummy_config = {
"do_validate": False,
"train_list": "",
"validation_list": "",
"test_list": "",
"inference_list": "",
"replace_final_layer": False,
"use_mask": False,
}

# Determine the directory of this file
current_file_dir = Path(__file__).resolve().parent

Expand Down Expand Up @@ -102,21 +112,19 @@ def _remove_temp_file(self, filename):

def test_initialization(self):
# Only provide the required arguments for initialization and check them
batch_size = 32
img_size = (256, 256, 3)

data_loader = self.DataLoader(batch_size=batch_size,
img_size=img_size,
test_config = self.dummy_config.copy()
test_config.update({
"batch_size": 32,
"img_size": (256, 256, 3),
})

data_loader = self.DataLoader(img_size=test_config["img_size"],
config=test_config,
augment_model=None,
charlist=["a", "b", "c"])
self.assertIsInstance(data_loader, self.DataLoader,
"DataLoader not instantiated correctly")

# Check the values
self.assertEqual(data_loader.batch_size, batch_size,
f"batch_size not set correctly. Expected: "
f"{batch_size}, got: {data_loader.batch_size}")

# def test_create_data_simple(self):
# # Sample data
# chars = set()
Expand Down

0 comments on commit 9c4d49b

Please sign in to comment.