From 9f6152d09a7d6acbe703c74afbac01004db50422 Mon Sep 17 00:00:00 2001 From: Tim Koornstra Date: Fri, 8 Mar 2024 09:43:21 +0100 Subject: [PATCH] Fix augmentations --- src/data/augmentation.py | 49 ++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/src/data/augmentation.py b/src/data/augmentation.py index 0f66e172..4f3099a2 100644 --- a/src/data/augmentation.py +++ b/src/data/augmentation.py @@ -68,7 +68,8 @@ def save_augment_steps_plot(aug_model: tf.keras.Sequential, # Calculate histogram for the original image histograms = [] - original_hist = tf.histogram_fixed_width(sample_image[0], [0.0, 1.0], nbins=256) + original_hist = tf.histogram_fixed_width( + sample_image[0], [0.0, 1.0], nbins=256) histograms.append(original_hist) # Apply each augmentation layer to the image @@ -81,7 +82,8 @@ def save_augment_steps_plot(aug_model: tf.keras.Sequential, dtype=tf.float32) # Calculate histogram for the augmented image - hist = tf.histogram_fixed_width(sample_image_float32, [0.0, 1.0], nbins=256) + hist = tf.histogram_fixed_width( + sample_image_float32, [0.0, 1.0], nbins=256) histograms.append(hist) # Plot the original and each augmentation step @@ -186,24 +188,23 @@ def get_augment_selection(config: Config, channels: int) -> list: """ augment_selection = [] - binarize_present = True if config["do_binarize_sauvola"] or \ - config["do_binarize_otsu"] else False + binarize_present = True if config["aug_binarize_sauvola"] or \ + config["aug_binarize_otsu"] else False - if config["distort_jpeg"]: - logging.info("Selected data augment: distort_jpeg") + if config["aug_distort_jpeg"]: + logging.info("Selected data augment: JPEG distortion") augment_selection.append(DistortImageLayer()) - if config["elastic_transform"]: - logging.info("Selected data augment: elastic_transform") - augment_selection.append( - ElasticTransformLayer(binary=binarize_present)) + if config["aug_elastic_transform"]: + logging.info("Selected data augment: elastic transform") + augment_selection.append(ElasticTransformLayer()) - if config["random_crop"]: - logging.info("Selected data augment: random_crop") + if config["aug_random_crop"]: + logging.info("Selected data augment: random vertical crop") augment_selection.append(RandomVerticalCropLayer()) - if config["random_width"]: - logging.info("Selected data augment: random_width") + if config["aug_random_width"]: + logging.info("Selected data augment: random width") augment_selection.append(RandomWidthLayer(binary=binarize_present)) # For some reason, the original adds a 50px pad to the width here @@ -212,32 +213,32 @@ def get_augment_selection(config: Config, channels: int) -> list: binary=binarize_present, name="extra_resize_with_pad")) - if config["do_random_shear"]: + if config["aug_random_shear"]: # Apply padding to make sure that shear does not cut off img augment_selection.append(ResizeWithPadLayer(target_height=64, additional_width=64, binary=binarize_present)) - logging.info("Selected data augment: shear_x") + logging.info("Selected data augment: random shear along x-axis") augment_selection.append(ShearXLayer(binary=binarize_present)) - if config["do_binarize_sauvola"]: - logging.info("Selected data augment: binarize_sauvola") + if config["aug_binarize_sauvola"]: + logging.info("Selected data augment: Sauvola binarization") augment_selection.append(BinarizeLayer(method='sauvola', channels=channels, window_size=51)) - if config["do_binarize_otsu"]: - logging.info("Selected data augment: binarize_otsu") + if config["aug_binarize_otsu"]: + logging.info("Selected data augment: Otsu binarization") augment_selection.append(BinarizeLayer(method="otsu", channels=channels)) - if config["do_blur"]: - logging.info("Selected data augment: blur_image") + if config["aug_blur"]: + logging.info("Selected data augment: image blur") augment_selection.append(BlurImageLayer()) - if config["do_invert"]: - logging.info("Selected data augment: aug_invert") + if config["aug_invert"]: + logging.info("Selected data augment: invert image") augment_selection.append(InvertImageLayer(channels=channels)) return augment_selection