Skip to content

Commit

Permalink
Add docstrings and type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 7, 2024
1 parent 5802bc9 commit eef1b0c
Showing 1 changed file with 175 additions and 19 deletions.
194 changes: 175 additions & 19 deletions src/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
import logging
import os
from typing import Dict, Tuple, Optional, List, Set

# > Third party dependencies
import tensorflow as tf
Expand All @@ -12,18 +13,33 @@

# > Local dependencies
from data.generator import DataGenerator
from setup.config import Config
from utils.text import Tokenizer, normalize_text


class DataLoader:
"""loader for dataset at given location, preprocess images and text
according to parameters"""
"""
Loader for dataset at given location, preprocess images and text
according to parameters
Parameters
----------
img_size : Tuple[int, int, int]
The size of the input images (height, width, channels).
augment_model : tf.keras.Sequential
The model used for data augmentation.
config : Config
The configuration dictionary containing various settings.
charlist : Optional[List[str]], optional
The list of characters to use for tokenization, by default None.
"""

def __init__(self,
img_size,
augment_model,
config,
charlist=None,
img_size: Tuple[int, int, int],
augment_model: tf.keras.Sequential,
config: Config,
charlist: Optional[List[str]] = None,
):

self.augment_model = augment_model
Expand Down Expand Up @@ -55,7 +71,24 @@ def __init__(self,
self.datasets = self._fill_datasets_dict(
file_names, labels, sample_weights)

def _process_raw_data(self):
def _process_raw_data(self) -> Tuple[Dict[str, List[str]],
Dict[str, List[str]],
Dict[str, List[str]],
Tokenizer]:
"""
Process the raw data and create file names, labels, sample weights,
and tokenizer.
Returns
-------
Tuple[Dict[str, List[str]],
Dict[str, List[str]],
Dict[str, List[str]],
Tokenizer]
A tuple containing dictionaries of file names, labels, sample
weights, and the tokenizer.
"""

# Initialize character set and data partitions with corresponding
# labels
file_names_dict = defaultdict(list)
Expand Down Expand Up @@ -84,17 +117,38 @@ def _process_raw_data(self):

return file_names_dict, labels_dict, sample_weights_dict, tokenizer

def _fill_datasets_dict(self, partitions, labels, sample_weights):
def _fill_datasets_dict(self,
partitions: Dict[str, List[str]],
labels: Dict[str, List[str]],
sample_weights: Dict[str, List[str]]) \
-> Dict[str, tf.data.Dataset]:
"""
Initializes data generators for different dataset partitions and
updates character set and tokenizer based on the dataset.
Initialize data generators for different dataset partitions and
update character set and tokenizer based on the dataset.
Parameters
----------
partitions : Dict[str, List[str]]
A dictionary containing lists of file names for each partition.
labels : Dict[str, List[str]]
A dictionary containing lists of labels for each partition.
sample_weights : Dict[str, List[str]]
A dictionary containing lists of sample weights for each partition.
Returns
-------
Dict[str, tf.data.Dataset]
A dictionary containing datasets for each partition.
"""

# Create datasets for different partitions
datasets = defaultdict(lambda: None)

for partition in ['train', 'evaluation', 'validation',
'test', 'inference']:
# Special case for evaluation partition, since there is no
# evaluation_list in the config, but the evaluation_list is
# inferred from the validation_list and train_list in the init
if partition == "evaluation":
partition_list = self.evaluation_list
else:
Expand All @@ -118,7 +172,25 @@ def _fill_datasets_dict(self, partitions, labels, sample_weights):

return datasets

def _create_data(self, partition_name, text_file):
def _create_data(self,
partition_name: str,
text_file: str) -> Tuple[List[str], List[str], List[str]]:
"""
Create data for a specific partition from a text file.
Parameters
----------
partition_name : str
The name of the partition.
text_file : str
The path to the text file containing the data.
Returns
-------
Tuple[List[str], List[str], List[str]]
A tuple containing lists of file names, labels, and sample weights.
"""

# Define the lists for the current partition
labels, partitions, sample_weights = [], [], []

Expand Down Expand Up @@ -168,7 +240,31 @@ def _create_data(self, partition_name, text_file):

return partitions, labels, sample_weights

def _process_line(self, line, partition_name, characters):
def _process_line(self,
line: str,
partition_name: str,
characters: Set[str]) \
-> Tuple[Optional[Tuple[str, str, float]],
Optional[str]]:
"""
Process a single line from the data file.
Parameters
----------
line : str
The line to process.
partition_name : str
The name of the partition.
characters : Set[str]
The set of characters to use for validation.
Returns
-------
Tuple[Optional[Tuple[str, str, float]], Optional[str]]
A tuple containing the processed data (file name, ground truth,
sample weight) and the flaw (if any).
"""

# Strip the line of leading and trailing whitespace
line = line.strip()

Expand Down Expand Up @@ -210,7 +306,26 @@ def _process_line(self, line, partition_name, characters):

return (file_name, ground_truth, sample_weight), None

def _get_ground_truth(self, fields, partition_name):
def _get_ground_truth(self,
fields: List[str],
partition_name: str) \
-> Tuple[Optional[str], Optional[str]]:
"""
Extract the ground truth from the fields.
Parameters
----------
fields : List[str]
The fields from the line.
partition_name : str
The name of the partition.
Returns
-------
Tuple[Optional[str], Optional[str]]
A tuple containing the ground truth and the flaw (if any).
"""

# Collect the ground truth and skip lines with empty ground truth
# unless it's an inference partition
ground_truth = fields[-1] if len(fields) > 1 else ""
Expand All @@ -223,12 +338,35 @@ def _get_ground_truth(self, fields, partition_name):

return ground_truth, None

def _is_valid_ground_truth(self, ground_truth, partition_name, characters):
def _is_valid_ground_truth(self,
ground_truth: str,
partition_name: str,
characters: Set[str]) -> bool:
"""
Check if the ground truth is valid.
Parameters
----------
ground_truth : str
The ground truth to check.
partition_name : str
The name of the partition.
characters : Set[str]
The set of characters.
Returns
-------
bool
True if the ground truth is valid, False otherwise.
"""

# Check for unsupported characters in the ground truth
# Evaluation partition is allowed to have unsupported characters for a
# more realistic evaluation
# and update the character set if the partition is 'train'
unsupported_characters = set(ground_truth) - characters
if unsupported_characters:

# Unsupported characters are allowed in the validation, inference,
# and test partitions, but not in the evaluation partition
if partition_name in ['validation', 'inference', 'test']:
return True
elif partition_name == 'train' and not self.charlist:
Expand All @@ -238,7 +376,22 @@ def _is_valid_ground_truth(self, ground_truth, partition_name, characters):
return False
return True

def _get_sample_weight(self, fields):
def _get_sample_weight(self,
fields: List[str]) -> float:
"""
Extract the sample weight from the fields.
Parameters
----------
fields : List[str]
The fields from the line.
Returns
-------
float
The sample weight.
"""

# Extract the sample weight from the fields
sample_weight = 1.0
if len(fields) > 2:
Expand All @@ -248,13 +401,16 @@ def _get_sample_weight(self, fields):
pass
return sample_weight

def get_filename(self, partition, item_id):
def get_filename(self, partition: str, item_id: int):
""" Get the filename for the given partition and item id """
return self.raw_data[partition][0][item_id]

def get_ground_truth(self, partition, item_id):
def get_ground_truth(self, partition: str, item_id: int):
""" Get the ground truth for the given partition and item id """
return self.raw_data[partition][1][item_id]

def get_train_batches(self):
""" Get the number of batches for training """
return int(np.ceil(len(self.raw_data['train'])
/ self.config['batch_size']))

Expand Down

0 comments on commit eef1b0c

Please sign in to comment.