diff --git a/yapic_io/dataset.py b/yapic_io/dataset.py index 398ade8..072ab40 100644 --- a/yapic_io/dataset.py +++ b/yapic_io/dataset.py @@ -78,11 +78,27 @@ def image_dimensions(self, image_nr): ''' return self.pixel_connector.image_dimensions(image_nr) + + def _smallest_image_size_xy(self): + Z = 10 + X = 500 + Y = 500 + for i in range(self.n_images): + size_z, size_x, size_y = self.image_dimensions(i)[-3:] + if size_z < Z: + Z = size_z + if size_x < X: + X = size_x + if size_y < Y: + Y = size_y + return (Z, X, Y) + + def pixel_statistics(self, channels, upper=99, lower=1, - tile_size_zxy=(1, 50, 50), + tile_size_zxy=None, n_tiles=1000): ''' Performs random sampling of n tiles and calculates upper and lower @@ -107,10 +123,11 @@ def pixel_statistics(self, [(lower_01, upper_01), (lower_02, upper_02), (lower_03, upper_03), ...] ''' - tile_size_zxy = (1, 30, 30) + if tile_size_zxy is None: + tile_size_zxy = self._smallest_image_size_xy() percentiles = np.zeros((n_tiles, len(channels), 2)) - msg = '\n\nCalculate global pixel statistics ({} tiles)...\n'.format( - n_tiles) + msg = ('\n\nCalculate global pixel statistics' + '({} tiles of size {})...\n').format(n_tiles, tile_size_zxy) sys.stdout.write(msg) for i in range(n_tiles): image_nr, pos_zxy = self._random_pos_izxy(None, tile_size_zxy) diff --git a/yapic_io/minibatch.py b/yapic_io/minibatch.py index 4900679..d731159 100644 --- a/yapic_io/minibatch.py +++ b/yapic_io/minibatch.py @@ -112,9 +112,10 @@ def set_normalize_mode(self, mode_str, minmax=None): self.normalize_mode = mode_str if mode_str == 'global': - assert minmax is not None, \ - 'normalization range (min, max) required' + if minmax is None: + # calculate upper and lower percentile automatically + minmax = self.dataset.pixel_statistics(self.channels) n_channels = self.dataset.pixel_connector.image_dimensions(0)[0] if len(minmax) == 2: diff --git a/yapic_io/tests/test_training_batch.py b/yapic_io/tests/test_training_batch.py index 33debad..0305992 100644 --- a/yapic_io/tests/test_training_batch.py +++ b/yapic_io/tests/test_training_batch.py @@ -322,6 +322,22 @@ def test_normalize_multichannel(self): assert_array_equal(np.unique(pixels_normalized_zscore), [0]) assert_array_equal(np.unique(pixels_normalized_local), [0]) + def test_normalize_global_auto(self): + + img_path = os.path.join(base_path, '../test_data/tiffconnector_1/im/') + label_path = os.path.join(base_path, + '../test_data/tiffconnector_1/labels/') + c = TiffConnector(img_path, label_path) + d = Dataset(c) + + size = (1, 5, 4) + pad = (0, 0, 0) + + m = TrainingBatch(d, size, padding_zxy=pad) + assert m.global_norm_minmax is None + m.set_normalize_mode('global') + assert len(m.global_norm_minmax) == 3 + def test_normalize_global_multichannel(self): img_path = os.path.join(base_path,