Skip to content

Commit

Permalink
Update DataLoader docstrings and type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 7, 2024
1 parent 23fcebc commit 74d356b
Showing 1 changed file with 60 additions and 13 deletions.
73 changes: 60 additions & 13 deletions src/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,59 @@
from typing import Tuple

# > Local dependencies
from utils.text import Tokenizer

# > Third party libraries
import tensorflow as tf
import numpy as np


class DataLoader:
def __init__(self,
tokenizer,
augment_model,
height=64,
channels=1,
is_training=False,
):
tokenizer: Tokenizer,
augment_model: tf.keras.Sequential,
height: int = 64,
channels: int = 1,
is_training: bool = False):
"""
Initializes the DataLoader.
Parameters
----------
tokenizer: Tokenizer
The tokenizer used for encoding labels.
augment_model: tf.keras.Sequential
The model used for data augmentation.
height : int, optional
The height of the preprocessed image (default is 64).
channels : int, optional
The number of channels in the image (default is 1).
is_training : bool, optional
Indicates whether the DataLoader is used for training (default is
False).
"""

self.tokenizer = tokenizer
self.augment_model = augment_model
self.height = height
self.channels = channels
self.is_training = is_training

def load_images(self, image_info_tuple: Tuple[str, str, str]) -> (
Tuple)[np.ndarray, np.ndarray]:
def load_images(self,
image_info_tuple: Tuple[str, str, str]) \
-> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Loads, preprocesses a single image, and encodes its label.
Unpacks the tuple for readability.
Parameters
----------
image_info_tuple : Tuple[str, str, str]
A tuple containing image path, label, and sample weight.
Returns
-------
Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
A tuple containing the preprocessed image, encoded label, and
sample weight.
"""

# Load and preprocess the image
Expand Down Expand Up @@ -66,6 +93,11 @@ def _load_and_preprocess_image(self, image_path: str) -> tf.Tensor:
-------
tf.Tensor
A preprocessed image tensor ready for training.
Raises
------
ValueError
If the number of channels is not 1, 3, or 4.
"""

# 1. Load the Image
Expand Down Expand Up @@ -96,9 +128,24 @@ def _load_and_preprocess_image(self, image_path: str) -> tf.Tensor:

return tf.cast(image[0], tf.float32)

def _ensure_width_for_ctc(self, image, encoded_label):
"""Resizes the image if necessary to accommodate the encoded label
during CTC decoding.
def _ensure_width_for_ctc(self,
image: tf.Tensor,
encoded_label: tf.Tensor) -> tf.Tensor:
"""
Resizes the image if necessary to accommodate the encoded label during
CTC decoding.
Parameters
----------
image : tf.Tensor
The preprocessed image tensor.
encoded_label : tf.Tensor
The encoded label.
Returns
-------
tf.Tensor
The resized image tensor.
"""

# Calculate the required width for the image
Expand Down

0 comments on commit 74d356b

Please sign in to comment.