Skip to content

Commit

Permalink
Fix ElasticTransform bug, manager repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 8, 2024
1 parent 9f6152d commit 60f2190
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
5 changes: 4 additions & 1 deletion src/data/augment_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def call(self, inputs, training=None):
)

# Apply the shear transformation
# Fill value is set to 0 to ensure that binarization is not affected
sheared_image = tf.raw_ops.ImageProjectiveTransformV3(
images=inputs,
transforms=shear_matrix_tf,
Expand Down Expand Up @@ -136,6 +135,10 @@ def call(self, inputs, training=None):
x_deformed = tf.clip_by_value(x_deformed,
clip_value_min=0.0,
clip_value_max=1.0)

# HACK: Reshape to original shape to force a shape
x_deformed = tf.reshape(x_deformed, tf.shape(inputs))

return x_deformed


Expand Down
15 changes: 7 additions & 8 deletions src/data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _fill_datasets_dict(self,
partition_list = self.config[f"{partition}_list"]
if partition_list:
# Create dataset for the current partition
datasets[partition] = self.create_dataset(
datasets[partition] = self._create_dataset(
files=partitions[partition],
labels=labels[partition],
sample_weights=sample_weights[partition],
Expand Down Expand Up @@ -410,11 +410,11 @@ def get_train_batches(self):
return int(np.ceil(len(self.raw_data['train'])
/ self.config['batch_size']))

def create_dataset(self,
files: List[str],
labels: List[str],
sample_weights: List[str],
partition_name: str) -> tf.data.Dataset:
def _create_dataset(self,
files: List[str],
labels: List[str],
sample_weights: List[str],
partition_name: str) -> tf.data.Dataset:
"""
Create a dataset for a specific partition.
Expand Down Expand Up @@ -453,8 +453,7 @@ def create_dataset(self,
dataset = tf.data.Dataset.from_tensor_slices(data)
if is_training:
# Add additional repeat and shuffle for training
dataset = dataset.repeat(self.config["aug_multiply"])\
.shuffle(len(files))
dataset = dataset.repeat().shuffle(len(files))

dataset = (dataset
.map(data_loader.load_images,
Expand Down

0 comments on commit 60f2190

Please sign in to comment.