Skip to content

Commit

Permalink
Add invert augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
MMaas3 committed Oct 20, 2023
1 parent fbff93a commit 707c923
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import elasticdeform.tf as etf
import tensorflow_addons as tfa
from skimage.filters import threshold_otsu, threshold_sauvola
from tensorflow.python.ops import bitwise_ops


class DataGenerator(tf.keras.utils.Sequence):

Expand All @@ -27,6 +29,7 @@ def __init__(self,
channels=1,
do_random_shear=False,
do_blur=False,
do_invert=False
):
print(height)

Expand All @@ -42,6 +45,7 @@ def __init__(self,
self.channels = channels
self.do_random_shear = do_random_shear
self.do_blur = do_blur
self.do_invert = do_invert

def elastic_transform(self, original):
displacement_val = tf.random.normal([2, 3, 3]) * 5
Expand All @@ -66,6 +70,12 @@ def binarize_otsu(self, tensor):

return tf.convert_to_tensor((np_array > otsu_threshold) * 1)

def invert(self, tensor):
if str(tensor.numpy().dtype).startswith("uint") or str(tensor.numpy().dtype).startswith("int"):
return tf.convert_to_tensor(255 - tensor.numpy())
else:
return tf.convert_to_tensor(1 - tensor.numpy())

def blur(self, tensor):
return tfa.image.gaussian_filter2d(tensor, sigma=[3.0, 20.0], filter_shape=(10, 10))

Expand Down Expand Up @@ -141,6 +151,9 @@ def load_images(self, image_path):
if self.do_blur:
image = self.blur(image)

if self.do_invert:
image = self.invert(image)

label = image_path[1]
encodedLabel = self.utils.char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))

Expand Down

0 comments on commit 707c923

Please sign in to comment.