Skip to content

Commit

Permalink
Move normalize, refactor DataGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 6, 2024
1 parent 749fe75 commit 28fe06e
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 151 deletions.
22 changes: 12 additions & 10 deletions src/data/augment_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ def call(self, inputs, training=None):


class ElasticTransformLayer(tf.keras.layers.Layer):
def __init__(self, binary=False, **kwargs):
def __init__(self, **kwargs):
super(ElasticTransformLayer, self).__init__(**kwargs)
self.fill_value = 1 if binary else 0

def call(self, inputs, training=None):
"""
Expand Down Expand Up @@ -277,7 +276,7 @@ def __init__(self, target_height, target_width=None,
raise ValueError("Either target_width or additional_width must be "
"specified")

self.fill_value = 1.0 if binary else 0.0
self.binary = binary

def estimate_background_color(self, image):
"""
Expand Down Expand Up @@ -372,15 +371,18 @@ def call(self, inputs, training=None):
padding = [[0, 0], [top_pad, bottom_pad],
[left_pad, right_pad], [0, 0]]

# Estimate background color
background_color = self.estimate_background_color(inputs)
if self.binary:
# Estimate background color
background_color = self.estimate_background_color(inputs)

# Reduce it to a scalar of the same dtype
background_color_scalar = tf.reduce_mean(background_color)
# Reduce it to a scalar of the same dtype
background_color_scalar = tf.reduce_mean(background_color)

# Ensure the scalar is the correct type, matching resized_img
background_color_scalar = tf.cast(background_color_scalar,
dtype=tf.float32)
# Ensure the scalar is the correct type, matching resized_img
background_color_scalar = tf.cast(background_color_scalar,
dtype=tf.float32)
else:
background_color_scalar = 0.0

# Pad the image
padded_img = tf.pad(resized_img, paddings=padding, mode="CONSTANT",
Expand Down
130 changes: 73 additions & 57 deletions src/data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,99 +11,115 @@
import numpy as np


class DataGenerator(tf.keras.utils.Sequence):

class DataGenerator:
def __init__(self,
tokenizer,
batch_size,
augment_model,
height=64,
channels=1,
augment_model=None,
is_training=True,
is_training=False,
):
self.batch_size = batch_size
self.tokenizer = tokenizer
self.augment_model = augment_model
self.height = height
self.channels = channels
self.augment_model = augment_model
self.is_training = is_training

def load_images(self, image_info_tuple: Tuple[str, str]) -> (
Tuple)[np.ndarray, np.ndarray]:
"""
Load and preprocess images.
Loads, preprocesses a single image, and encodes its label.
Unpacks the tuple for readability.
"""

# Load and preprocess the image
image = self._load_and_preprocess_image(image_info_tuple[0])

# Encode the label
encoded_label = self.tokenizer(image_info_tuple[1])

# Ensure the image width is sufficient for CTC decoding
image = self._ensure_width_for_ctc(image, encoded_label)

# Center the image values around 0.5
image = 0.5 - image

# Transpose the image
image = tf.transpose(image, perm=[1, 0, 2])

return image, encoded_label

def _load_and_preprocess_image(self, image_path: str) -> tf.Tensor:
"""
Loads and preprocesses a single image.
Parameters
----------
- image_info_tuple (tuple): Tuple containing the file path (string)
and label(string) of the image.
image_path: str
The path to the image file.
Returns
-------
- Tuple: A tuple containing the preprocessed image (numpy.ndarray) and
encoded label (numpy.ndarray).
Raises
------
- ValueError: If the number of channels is not 1, 3, or 4.
Notes
-----
- This function uses TensorFlow operations to read, decode, and
preprocess images.
- Preprocessing steps include resizing, channel manipulation,
distortion (if specified), elastic transform, cropping, shearing, and
label encoding.
Example:
>>> loader = ImageLoader()
>>> image_info_tuple = ("/path/to/image.png", "label")
>>> preprocessed_image, encoded_label =
... loader.load_images(image_info_tuple)
tf.Tensor
A preprocessed image tensor ready for training.
"""

image = tf.io.read_file(image_info_tuple[0])
# 1. Load the Image
image = tf.io.read_file(image_path)

try:
image = tf.image.decode_png(image, channels=self.channels)
image = tf.image.decode_image(image, channels=self.channels,
expand_animations=False)
except ValueError:
logging.error("Invalid number of channels. "
"Supported values are 1, 3, or 4.")

# 2. Resize the Image and Normalize Pixel Values to [0, 1]
image = tf.image.resize(image, (self.height, 99999),
preserve_aspect_ratio=True) / 255.0

# Add batch dimension
# 3. Apply Data Augmentations
# Add batch dimension (required for augmentation model)
image = tf.expand_dims(image, 0)

# Apply augmentations
if self.augment_model is not None:
for layer in self.augment_model.layers:
# Mandatory resize_with_pad layer
if layer.name == "extra_resize_with_pad":
image = layer(image, training=True)
continue
for layer in self.augment_model.layers:
# Custom layer handling (assuming 'extra_resize_with_pad'
# remains)
if layer.name == "extra_resize_with_pad":
image = layer(image, training=True)
else:
image = layer(image, training=self.is_training)

# Remove batch dimension
image = image[0]
image = tf.cast(image, dtype=tf.float32)
return tf.cast(image[0], tf.float32)

image_width = tf.shape(image)[1]
def _ensure_width_for_ctc(self, image, encoded_label):
"""Resizes the image if necessary to accommodate the encoded label
during CTC decoding.
"""

label = image_info_tuple[1]
encoded_label = self.tokenizer(label)
# Calculate the required width for the image
required_width = len(encoded_label)

label_counter = 0
num_repetitions = 0
last_char = None

for char in encoded_label:
label_counter += 1
if char == last_char:
label_counter += 1
num_repetitions += 1
last_char = char
label_width = label_counter
if image_width < label_width*16:
image_width = label_width * 16
image = tf.image.resize_with_pad(image, self.height, image_width)
image = 0.5 - image
image = tf.transpose(image, perm=[1, 0, 2])
return image, encoded_label

# Add repetitions
required_width += num_repetitions

# Convert to pixels
pixels_per_column = 16
required_width *= pixels_per_column

# Mandatory cast to float32
image = tf.cast(image, tf.float32)

if tf.shape(image)[1] < required_width:
image = tf.image.resize_with_pad(
image, self.height, required_width)

return image
54 changes: 6 additions & 48 deletions src/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,25 @@
# > Standard library
import logging
import os
import re
import json


# > Third party dependencies
import tensorflow as tf
from tensorflow.data import AUTOTUNE
import numpy as np


# > Local dependencies
from data.generator import DataGenerator
from utils.text import Tokenizer
from utils.text import Tokenizer, normalize_text


class DataLoader:
"""loader for dataset at given location, preprocess images and text
according to parameters"""
DTYPE = 'float32'
currIdx = 0
charList = []
samples = []
validation_dataset = []

def __init__(self,
batch_size,
img_size,
augment_model,
char_list=[],
train_list='',
validation_list='',
Expand All @@ -39,15 +31,11 @@ def __init__(self,
multiply=1,
check_missing_files=True,
replace_final_layer=False,
use_mask=False,
augment_model=None
use_mask=False
):
self.currIdx = 0
self.batch_size = batch_size
self.imgSize = img_size
self.samples = []
self.height = img_size[0]
self.width = img_size[1]
self.augment_model = augment_model
self.channels = img_size[2]
self.partition = []
self.injected_charlist = char_list
Expand All @@ -60,36 +48,7 @@ def __init__(self,
self.check_missing_files = check_missing_files
self.replace_final_layer = replace_final_layer
self.use_mask = use_mask
self.augment_model = augment_model

@staticmethod
def normalize(input: str, replacements: str) -> str:
"""
Normalize text using a json file with replacements
Parameters
----------
input : str
Input string to normalize
replacements : str
Path to json file with replacements, where key is the string to
replace and value is the replacement. Example: {"a": "b"} will
replace all "a" with "b" in the input string.
Returns
-------
str
Normalized string
"""

with open(replacements, 'r') as f:
replacements = json.load(f)
for key, value in replacements.items():
input = input.replace(key, value)

input = re.sub(r"\s+", " ", input)

return input.strip()
self.charList = char_list

def init_data_generator(self,
files,
Expand Down Expand Up @@ -220,7 +179,6 @@ def get_generators(self):
train_params = {
'tokenizer': self.tokenizer,
'height': self.height,
'batch_size': self.batch_size,
'channels': self.channels,
'augment_model': self.augment_model
}
Expand Down Expand Up @@ -331,7 +289,7 @@ def create_data(self, characters, labels, partitions, partition_name,
elif self.normalization_file and \
(partition_name == 'train'
or partition_name == 'evaluation'):
ground_truth = self.normalize(line_parts[1],
ground_truth = normalize_text(line_parts[1],
self.normalization_file)
else:
ground_truth = line_parts[1]
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main():
augmentation_model)
training_dataset, evaluation_dataset, validation_dataset, \
test_dataset, inference_dataset, _, train_batches, \
validation_labels = loader.generators()
validation_labels = loader.get_generators()

# Replace the charlist with the one from the data loader
charlist = loader.charList
Expand Down
6 changes: 3 additions & 3 deletions src/modes/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from utils.calculate import calc_95_confidence_interval, calculate_cers, \
increment_counters, calculate_edit_distances
from utils.decoding import decode_batch_predictions
from utils.text import preprocess_text, Tokenizer
from utils.text import preprocess_text, Tokenizer, normalize_text
from utils.wbs import setup_word_beam_search, handle_wbs_results


Expand Down Expand Up @@ -85,7 +85,7 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor],
original_text = preprocess_text(orig_texts[index])\
.replace("[UNK]", "�")
normalized_original = None if not config["normalization_file"] else \
loader.normalize(original_text, config["normalization_file"])
normalize_text(original_text, config["normalization_file"])

# Calculate edit distances
distances = calculate_edit_distances(prediction, original_text)
Expand Down Expand Up @@ -149,7 +149,7 @@ def perform_test(config: Config,
prediction_model = get_prediction_model(model)

# Setup WordBeamSearch if needed
wbs = setup_word_beam_search(config, charlist, dataloader) \
wbs = setup_word_beam_search(config, charlist) \
if config["corpus_file"] else None

# Initialize the counters
Expand Down
6 changes: 3 additions & 3 deletions src/modes/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
calculate_edit_distances, update_statistics, increment_counters
from utils.decoding import decode_batch_predictions
from utils.print import print_predictions, display_statistics
from utils.text import preprocess_text, Tokenizer
from utils.text import preprocess_text, Tokenizer, normalize_text
from utils.wbs import setup_word_beam_search, handle_wbs_results


Expand Down Expand Up @@ -86,7 +86,7 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor],
prediction = preprocess_text(prediction)
original_text = preprocess_text(y_true[index])
normalized_original = None if not config["normalization_file"] else \
loader.normalize(original_text, config["normalization_file"])
normalize_text(original_text, config["normalization_file"])

# Calculate edit distances here so we can use them for printing the
# predictions
Expand Down Expand Up @@ -166,7 +166,7 @@ def perform_validation(config: Config,
prediction_model = get_prediction_model(model)

# Setup WordBeamSearch if needed
wbs = setup_word_beam_search(config, charlist, dataloader) \
wbs = setup_word_beam_search(config, charlist) \
if config["corpus_file"] else None

# Initialize variables for CER calculation
Expand Down
Loading

0 comments on commit 28fe06e

Please sign in to comment.