Skip to content

Commit

Permalink
Added support for correct label retrieval for ilastik version>1.3.2 (#3)
Browse files Browse the repository at this point in the history
* fixed bug for label retrieval to support ilastik version >1.3.2

* updated version to 0.0.7 and added support for empty label matrices

* deleted some comments

* formatted test file
  • Loading branch information
cmohl2013 authored Apr 16, 2020
1 parent 187caaa commit a8f3bfa
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 42 deletions.
98 changes: 86 additions & 12 deletions pyilastik/ilastik_storage_version_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,27 @@ def _get_slices_for(p_pos, q_pos, p_shape, q_shape):
return [slice(start, stop) for start, stop in q_slice]


def normalize_dim_order(dim_order, data=None, reverse=False):
'''
transpose tile data to dimension order zyxc or yxc
'''
n_dims = len(dim_order)
assert n_dims in [3, 4]

ref_order = 'zyxc'
if n_dims == 3:
ref_order = 'yxc'

mapping = tuple([dim_order.find(k) for k in ref_order])
if reverse:
mapping = tuple([ref_order.find(k) for k in dim_order])
if data is None:
return mapping

assert n_dims == len(data.shape)
return np.transpose(data, mapping)


class IlastikStorageVersion01(object):

def __init__(self, h5_handle, image_path=None, prediction=False,
Expand All @@ -86,9 +107,8 @@ def __init__(self, h5_handle, image_path=None, prediction=False,
assert version == '0.2'

def ilastik_version(self):

version_str = self.f.get('ilastikVersion')[()].decode()
return int(version_str.replace('.','')[:3])
version_str = self.f.get('ilastikVersion')[()].decode()
return int(version_str.replace('.', '')[:3])

def __iter__(self):
'''
Expand Down Expand Up @@ -153,12 +173,16 @@ def __getitem__(self, i):
prediction = None # TODO

# 1st get the (approximate) labeled image size
shape = self.shape_of_labelmatrix(i)
shape = self.shape_of_original_labelmatrix(i)
n_dims = len(shape)

tile_slice = np.array([[0, s] for s in shape])
labels = self.tile_inner(i, tile_slice)

labels = self.tile(i, tile_slice)
if len(labels) > 0:
labels = normalize_dim_order(
self.original_dimension_order(),
data=labels)

msg = 'dimensions of labelmatrix should be 4 (zyxc) or 3 (yxc)'
assert n_dims in [3, 4], msg
Expand All @@ -167,16 +191,12 @@ def __getitem__(self, i):
# add z dimension if missing
labels = np.expand_dims(labels, axis=0)

version = self.ilastik_version()
if self.skip_image:
if version >= 133:
# if version>=1.3.3
labels = np.transpose(labels, (0, 2, 3, 1))
return original_path, (None, labels, prediction)

msg = ('ilastik versions > 1.3.2 are not supported. '
'Set skip_image=True or use older ilastik version.')
assert version < 133, msg
assert self.ilastik_version() < 133, msg

fname = utils.basename(path)
if self.image_path is not None:
Expand All @@ -202,7 +222,6 @@ def __getitem__(self, i):
labels = np.pad(labels, padding, mode='constant',
constant_values=0)


return original_path, (img, labels, prediction)

def n_dims(self, item_index):
Expand All @@ -219,7 +238,51 @@ def n_dims(self, item_index):
else:
return 0

def original_dimension_order(self):
'''
Dimension orders of label matrices depend on dimensionality
of the pixel dataset (zstack vs 2d, monochannel vs multichannel) and
the ilastik version. Dimension order handling was changed in
ilastik version 1.3.3 (both staorage version 01).
'''
s = self.shape_of_original_labelmatrix(0)

assert len(s) in [3, 4]

if len(s) == 4:

assert s[3] == 1 or s[1] == 1

if self.ilastik_version() < 133:
order = 'zyxc'
else:
if s[1] == 1:
order = 'zcyx'
elif s[3] == 1:
order = 'zyxc'
if len(s) == 3:
if self.ilastik_version() < 133:

assert s[2] == 1
order = 'yxc'
else:
assert s[0] == 1 or s[2] == 1
if s[0] == 1:
order = 'cyx'
else:
order = 'yxc'

return order

def shape_of_labelmatrix(self, item_index):
original_shape = self.shape_of_original_labelmatrix(item_index)
dim_mapping = list(normalize_dim_order(
self.original_dimension_order()))

return original_shape[dim_mapping]

def shape_of_original_labelmatrix(self, item_index):
'''
Label matrix shape is retrieved from label data.
Expand Down Expand Up @@ -295,6 +358,18 @@ def load_block_data(self, item_index, block_index):
return block[()]

def tile(self, item_index, tile_slice):

ordering = list(normalize_dim_order(self.original_dimension_order(),
reverse=True))
slices_corr = tile_slice[ordering]

t = self.tile_inner(item_index, slices_corr)
t_corr = normalize_dim_order(self.original_dimension_order(),
reverse=False,
data=t)
return t_corr

def tile_inner(self, item_index, tile_slice):
'''
Order is (Z, Y, X, C) or (Y, X, C) where C size of C dimension
is always 1 (only one label channel implemented, i.e.
Expand All @@ -320,7 +395,6 @@ def tile(self, item_index, tile_slice):
p_slice = _get_slices_for(pos_q, pos_p, shape_q, shape_p)

labels_q[q_slice] = labels_p[p_slice]

return labels_q

@lru_cache(maxsize=None)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit a8f3bfa

Please sign in to comment.