From cdab73db0f3d03ecc4d6f984e456cd882d0e6c27 Mon Sep 17 00:00:00 2001 From: Tim Koornstra Date: Thu, 7 Mar 2024 12:13:33 +0100 Subject: [PATCH] Fix generator test --- src/data/generator.py | 2 +- tests/test_datagenerator.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/data/generator.py b/src/data/generator.py index 84fa3442..d7962341 100644 --- a/src/data/generator.py +++ b/src/data/generator.py @@ -25,7 +25,7 @@ def __init__(self, self.channels = channels self.is_training = is_training - def load_images(self, image_info_tuple: Tuple[str, str, float]) -> ( + def load_images(self, image_info_tuple: Tuple[str, str, str]) -> ( Tuple)[np.ndarray, np.ndarray]: """ Loads, preprocesses a single image, and encodes its label. diff --git a/tests/test_datagenerator.py b/tests/test_datagenerator.py index 1b7000f7..4b806d77 100644 --- a/tests/test_datagenerator.py +++ b/tests/test_datagenerator.py @@ -53,7 +53,8 @@ def test_load_images(self): # Set up a mock image file and label image_path = "path/to/mock_image.png" label = "mock_label" - image_info_tuple = (image_path, label) + sample_weight = "1.0" + image_info_tuple = (image_path, label, sample_weight) dummy_augment_model = tf.keras.Sequential([]) tokenizer = self.Tokenizer(chars=["ABC"], use_mask=False) @@ -66,13 +67,14 @@ def test_load_images(self): with unittest.mock.patch.object(tf.image, 'decode_image', return_value=tf.ones([100, 100, 3]) ): - preprocessed_image, encoded_label = dg.load_images( - image_info_tuple) + preprocessed_image, encoded_label, sample_weights \ + = dg.load_images(image_info_tuple) # Assert the shape of the preprocessed image self.assertEqual(preprocessed_image.shape, (304, 64, 3)) self.assertIsInstance(preprocessed_image, tf.Tensor) self.assertIsInstance(encoded_label, tf.Tensor) + self.assertIsInstance(sample_weights, tf.Tensor) if __name__ == '__main__':