Skip to content

Commit

Permalink
added auto global normalization option to minibatch
Browse files Browse the repository at this point in the history
  • Loading branch information
cmohl2013 committed Oct 1, 2020
1 parent 1fce33e commit 8d23860
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
25 changes: 21 additions & 4 deletions yapic_io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,27 @@ def image_dimensions(self, image_nr):
'''
return self.pixel_connector.image_dimensions(image_nr)


def _smallest_image_size_xy(self):
Z = 10
X = 500
Y = 500
for i in range(self.n_images):
size_z, size_x, size_y = self.image_dimensions(i)[-3:]
if size_z < Z:
Z = size_z
if size_x < X:
X = size_x
if size_y < Y:
Y = size_y
return (Z, X, Y)


def pixel_statistics(self,
channels,
upper=99,
lower=1,
tile_size_zxy=(1, 50, 50),
tile_size_zxy=None,
n_tiles=1000):
'''
Performs random sampling of n tiles and calculates upper and lower
Expand All @@ -107,10 +123,11 @@ def pixel_statistics(self,
[(lower_01, upper_01), (lower_02, upper_02), (lower_03, upper_03), ...]
'''

tile_size_zxy = (1, 30, 30)
if tile_size_zxy is None:
tile_size_zxy = self._smallest_image_size_xy()
percentiles = np.zeros((n_tiles, len(channels), 2))
msg = '\n\nCalculate global pixel statistics ({} tiles)...\n'.format(
n_tiles)
msg = ('\n\nCalculate global pixel statistics'
'({} tiles of size {})...\n').format(n_tiles, tile_size_zxy)
sys.stdout.write(msg)
for i in range(n_tiles):
image_nr, pos_zxy = self._random_pos_izxy(None, tile_size_zxy)
Expand Down
5 changes: 3 additions & 2 deletions yapic_io/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ def set_normalize_mode(self, mode_str, minmax=None):
self.normalize_mode = mode_str

if mode_str == 'global':
assert minmax is not None, \
'normalization range (min, max) required'

if minmax is None:
# calculate upper and lower percentile automatically
minmax = self.dataset.pixel_statistics(self.channels)
n_channels = self.dataset.pixel_connector.image_dimensions(0)[0]

if len(minmax) == 2:
Expand Down
16 changes: 16 additions & 0 deletions yapic_io/tests/test_training_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,22 @@ def test_normalize_multichannel(self):
assert_array_equal(np.unique(pixels_normalized_zscore), [0])
assert_array_equal(np.unique(pixels_normalized_local), [0])

def test_normalize_global_auto(self):

img_path = os.path.join(base_path, '../test_data/tiffconnector_1/im/')
label_path = os.path.join(base_path,
'../test_data/tiffconnector_1/labels/')
c = TiffConnector(img_path, label_path)
d = Dataset(c)

size = (1, 5, 4)
pad = (0, 0, 0)

m = TrainingBatch(d, size, padding_zxy=pad)
assert m.global_norm_minmax is None
m.set_normalize_mode('global')
assert len(m.global_norm_minmax) == 3

def test_normalize_global_multichannel(self):

img_path = os.path.join(base_path,
Expand Down

0 comments on commit 8d23860

Please sign in to comment.