From eef1b0c4f47cad8c209f9587aac5f34a464a3531 Mon Sep 17 00:00:00 2001 From: Tim Koornstra Date: Thu, 7 Mar 2024 15:19:15 +0100 Subject: [PATCH] Add docstrings and type hints --- src/data/loader.py | 194 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 175 insertions(+), 19 deletions(-) diff --git a/src/data/loader.py b/src/data/loader.py index 36a87aea..2a75aa66 100644 --- a/src/data/loader.py +++ b/src/data/loader.py @@ -4,6 +4,7 @@ from collections import defaultdict import logging import os +from typing import Dict, Tuple, Optional, List, Set # > Third party dependencies import tensorflow as tf @@ -12,18 +13,33 @@ # > Local dependencies from data.generator import DataGenerator +from setup.config import Config from utils.text import Tokenizer, normalize_text class DataLoader: - """loader for dataset at given location, preprocess images and text - according to parameters""" + """ + Loader for dataset at given location, preprocess images and text + according to parameters + + + Parameters + ---------- + img_size : Tuple[int, int, int] + The size of the input images (height, width, channels). + augment_model : tf.keras.Sequential + The model used for data augmentation. + config : Config + The configuration dictionary containing various settings. + charlist : Optional[List[str]], optional + The list of characters to use for tokenization, by default None. + """ def __init__(self, - img_size, - augment_model, - config, - charlist=None, + img_size: Tuple[int, int, int], + augment_model: tf.keras.Sequential, + config: Config, + charlist: Optional[List[str]] = None, ): self.augment_model = augment_model @@ -55,7 +71,24 @@ def __init__(self, self.datasets = self._fill_datasets_dict( file_names, labels, sample_weights) - def _process_raw_data(self): + def _process_raw_data(self) -> Tuple[Dict[str, List[str]], + Dict[str, List[str]], + Dict[str, List[str]], + Tokenizer]: + """ + Process the raw data and create file names, labels, sample weights, + and tokenizer. + + Returns + ------- + Tuple[Dict[str, List[str]], + Dict[str, List[str]], + Dict[str, List[str]], + Tokenizer] + A tuple containing dictionaries of file names, labels, sample + weights, and the tokenizer. + """ + # Initialize character set and data partitions with corresponding # labels file_names_dict = defaultdict(list) @@ -84,10 +117,28 @@ def _process_raw_data(self): return file_names_dict, labels_dict, sample_weights_dict, tokenizer - def _fill_datasets_dict(self, partitions, labels, sample_weights): + def _fill_datasets_dict(self, + partitions: Dict[str, List[str]], + labels: Dict[str, List[str]], + sample_weights: Dict[str, List[str]]) \ + -> Dict[str, tf.data.Dataset]: """ - Initializes data generators for different dataset partitions and - updates character set and tokenizer based on the dataset. + Initialize data generators for different dataset partitions and + update character set and tokenizer based on the dataset. + + Parameters + ---------- + partitions : Dict[str, List[str]] + A dictionary containing lists of file names for each partition. + labels : Dict[str, List[str]] + A dictionary containing lists of labels for each partition. + sample_weights : Dict[str, List[str]] + A dictionary containing lists of sample weights for each partition. + + Returns + ------- + Dict[str, tf.data.Dataset] + A dictionary containing datasets for each partition. """ # Create datasets for different partitions @@ -95,6 +146,9 @@ def _fill_datasets_dict(self, partitions, labels, sample_weights): for partition in ['train', 'evaluation', 'validation', 'test', 'inference']: + # Special case for evaluation partition, since there is no + # evaluation_list in the config, but the evaluation_list is + # inferred from the validation_list and train_list in the init if partition == "evaluation": partition_list = self.evaluation_list else: @@ -118,7 +172,25 @@ def _fill_datasets_dict(self, partitions, labels, sample_weights): return datasets - def _create_data(self, partition_name, text_file): + def _create_data(self, + partition_name: str, + text_file: str) -> Tuple[List[str], List[str], List[str]]: + """ + Create data for a specific partition from a text file. + + Parameters + ---------- + partition_name : str + The name of the partition. + text_file : str + The path to the text file containing the data. + + Returns + ------- + Tuple[List[str], List[str], List[str]] + A tuple containing lists of file names, labels, and sample weights. + """ + # Define the lists for the current partition labels, partitions, sample_weights = [], [], [] @@ -168,7 +240,31 @@ def _create_data(self, partition_name, text_file): return partitions, labels, sample_weights - def _process_line(self, line, partition_name, characters): + def _process_line(self, + line: str, + partition_name: str, + characters: Set[str]) \ + -> Tuple[Optional[Tuple[str, str, float]], + Optional[str]]: + """ + Process a single line from the data file. + + Parameters + ---------- + line : str + The line to process. + partition_name : str + The name of the partition. + characters : Set[str] + The set of characters to use for validation. + + Returns + ------- + Tuple[Optional[Tuple[str, str, float]], Optional[str]] + A tuple containing the processed data (file name, ground truth, + sample weight) and the flaw (if any). + """ + # Strip the line of leading and trailing whitespace line = line.strip() @@ -210,7 +306,26 @@ def _process_line(self, line, partition_name, characters): return (file_name, ground_truth, sample_weight), None - def _get_ground_truth(self, fields, partition_name): + def _get_ground_truth(self, + fields: List[str], + partition_name: str) \ + -> Tuple[Optional[str], Optional[str]]: + """ + Extract the ground truth from the fields. + + Parameters + ---------- + fields : List[str] + The fields from the line. + partition_name : str + The name of the partition. + + Returns + ------- + Tuple[Optional[str], Optional[str]] + A tuple containing the ground truth and the flaw (if any). + """ + # Collect the ground truth and skip lines with empty ground truth # unless it's an inference partition ground_truth = fields[-1] if len(fields) > 1 else "" @@ -223,12 +338,35 @@ def _get_ground_truth(self, fields, partition_name): return ground_truth, None - def _is_valid_ground_truth(self, ground_truth, partition_name, characters): + def _is_valid_ground_truth(self, + ground_truth: str, + partition_name: str, + characters: Set[str]) -> bool: + """ + Check if the ground truth is valid. + + Parameters + ---------- + ground_truth : str + The ground truth to check. + partition_name : str + The name of the partition. + characters : Set[str] + The set of characters. + + Returns + ------- + bool + True if the ground truth is valid, False otherwise. + """ + # Check for unsupported characters in the ground truth - # Evaluation partition is allowed to have unsupported characters for a - # more realistic evaluation + # and update the character set if the partition is 'train' unsupported_characters = set(ground_truth) - characters if unsupported_characters: + + # Unsupported characters are allowed in the validation, inference, + # and test partitions, but not in the evaluation partition if partition_name in ['validation', 'inference', 'test']: return True elif partition_name == 'train' and not self.charlist: @@ -238,7 +376,22 @@ def _is_valid_ground_truth(self, ground_truth, partition_name, characters): return False return True - def _get_sample_weight(self, fields): + def _get_sample_weight(self, + fields: List[str]) -> float: + """ + Extract the sample weight from the fields. + + Parameters + ---------- + fields : List[str] + The fields from the line. + + Returns + ------- + float + The sample weight. + """ + # Extract the sample weight from the fields sample_weight = 1.0 if len(fields) > 2: @@ -248,13 +401,16 @@ def _get_sample_weight(self, fields): pass return sample_weight - def get_filename(self, partition, item_id): + def get_filename(self, partition: str, item_id: int): + """ Get the filename for the given partition and item id """ return self.raw_data[partition][0][item_id] - def get_ground_truth(self, partition, item_id): + def get_ground_truth(self, partition: str, item_id: int): + """ Get the ground truth for the given partition and item id """ return self.raw_data[partition][1][item_id] def get_train_batches(self): + """ Get the number of batches for training """ return int(np.ceil(len(self.raw_data['train']) / self.config['batch_size']))