Skip to content

Commit

Permalink
implement binarization methods
Browse files Browse the repository at this point in the history
  • Loading branch information
MMaas3 committed Oct 20, 2023
1 parent b1f6cb9 commit b1da4df
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ metrics==0.3.3
tf_keras_vis==0.8.5
elasticdeform==0.5.0
blinker==1.4
scikit-image==0.22.0
24 changes: 22 additions & 2 deletions src/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
# > Local dependencies

# > Third party libraries
import cv2
import tensorflow as tf
import elasticdeform.tf as etf
import tensorflow_addons as tfa
from skimage.filters import threshold_otsu, threshold_sauvola

class DataGenerator(tf.keras.utils.Sequence):

Expand Down Expand Up @@ -44,8 +46,26 @@ def elastic_transform(self, original):
X_deformed = etf.deform_grid(original, displacement_val, axis=(0, 1), order=3)
return X_deformed

def load_images(self, imagePath):
image = tf.io.read_file(imagePath[0])
def binarize_sauvola(self, tensor):
np_array = tensor.numpy()
window_size = 51

sauvola_thresh = threshold_sauvola(np_array, window_size=window_size)
binary_sauvola = (np_array > sauvola_thresh) * 1

return tf.convert_to_tensor(binary_sauvola)

def binarize_otsu(self, tensor):
np_array = tensor.numpy()

np_array = cv2.cvtColor(np_array, cv2.COLOR_RGB2GRAY)

otsu_threshold = threshold_otsu(np_array)

return tf.convert_to_tensor((np_array > otsu_threshold) * 1)

def load_images(self, image_path):
image = tf.io.read_file(image_path[0])
image = tf.image.decode_png(image, channels=self.channels)
image = tf.image.resize(image, (self.height, 99999), preserve_aspect_ratio=True) / 255.0
if self.distort_jpeg:
Expand Down

0 comments on commit b1da4df

Please sign in to comment.