Skip to content

Commit

Permalink
ndim instead of len(tf.shape())
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 26, 2024
1 parent 0bbd09d commit 807277f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/data/augment_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tensorflow as tf
from data.gaussian_filter2d import gaussian_filter2d


class ShearXLayer(tf.keras.layers.Layer):
def __init__(self, binary=False, **kwargs):
super(ShearXLayer, self).__init__(**kwargs)
Expand Down Expand Up @@ -587,7 +588,7 @@ def call(self, inputs, training=None):
return inputs

# Get the width and height of the input image
if len(tf.shape(inputs)) < 4:
if inputs.ndim < 4:
# When input does have a batch size dim
original_width = tf.shape(inputs)[1]
original_height = tf.shape(inputs)[0]
Expand Down

0 comments on commit 807277f

Please sign in to comment.