From 60f21905b24f7bd64a19e1386664e90e94058f02 Mon Sep 17 00:00:00 2001 From: Tim Koornstra Date: Fri, 8 Mar 2024 12:22:32 +0100 Subject: [PATCH] Fix ElasticTransform bug, manager repeat --- src/data/augment_layers.py | 5 ++++- src/data/manager.py | 15 +++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/data/augment_layers.py b/src/data/augment_layers.py index 757e1afb..961bc94c 100644 --- a/src/data/augment_layers.py +++ b/src/data/augment_layers.py @@ -67,7 +67,6 @@ def call(self, inputs, training=None): ) # Apply the shear transformation - # Fill value is set to 0 to ensure that binarization is not affected sheared_image = tf.raw_ops.ImageProjectiveTransformV3( images=inputs, transforms=shear_matrix_tf, @@ -136,6 +135,10 @@ def call(self, inputs, training=None): x_deformed = tf.clip_by_value(x_deformed, clip_value_min=0.0, clip_value_max=1.0) + + # HACK: Reshape to original shape to force a shape + x_deformed = tf.reshape(x_deformed, tf.shape(inputs)) + return x_deformed diff --git a/src/data/manager.py b/src/data/manager.py index 3070866d..302549c2 100644 --- a/src/data/manager.py +++ b/src/data/manager.py @@ -159,7 +159,7 @@ def _fill_datasets_dict(self, partition_list = self.config[f"{partition}_list"] if partition_list: # Create dataset for the current partition - datasets[partition] = self.create_dataset( + datasets[partition] = self._create_dataset( files=partitions[partition], labels=labels[partition], sample_weights=sample_weights[partition], @@ -410,11 +410,11 @@ def get_train_batches(self): return int(np.ceil(len(self.raw_data['train']) / self.config['batch_size'])) - def create_dataset(self, - files: List[str], - labels: List[str], - sample_weights: List[str], - partition_name: str) -> tf.data.Dataset: + def _create_dataset(self, + files: List[str], + labels: List[str], + sample_weights: List[str], + partition_name: str) -> tf.data.Dataset: """ Create a dataset for a specific partition. @@ -453,8 +453,7 @@ def create_dataset(self, dataset = tf.data.Dataset.from_tensor_slices(data) if is_training: # Add additional repeat and shuffle for training - dataset = dataset.repeat(self.config["aug_multiply"])\ - .shuffle(len(files)) + dataset = dataset.repeat().shuffle(len(files)) dataset = (dataset .map(data_loader.load_images,