Skip to content

Commit

Permalink
add support for the get_transform analog to set_transform (#264)
Browse files Browse the repository at this point in the history
* add support for getting the aggregated scale/transform for individual arrays

* Revert "formatting changes"

This reverts the formatting changes that were accidentally included with commit  0c21a73.

* address ziwen's comments

* black formatting changes

* update test_get_transform_image test name

* factor out retrieving the list of all applicable transforms and improve docs

* use get_effective_scale to simplify the scale property

* use the first array's path in `scale` property to get the highest resolution image metadata

* Separate test into 2, and cover cases where both the fov and image have scale and translation transforms

* Return effective scale and transform as simple lists of floats instead of TransformationMeta

* Return plain floats instead of numpy floats

* fix test docstring

* Handle case where `self.metadata.multiscales[0].coordinate_transformations` is `None`
  • Loading branch information
pattonw authored Dec 20, 2024
1 parent 05428df commit c8822bc
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 12 deletions.
111 changes: 99 additions & 12 deletions iohub/ngff/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,19 +969,9 @@ def scale(self) -> list[float]:
Helper function for scale transform metadata of
highest resolution scale.
"""
scale = [1] * self.data.ndim
transforms = (
self.metadata.multiscales[0].datasets[0].coordinate_transformations
return self.get_effective_scale(
self.metadata.multiscales[0].datasets[0].path
)
for trans in transforms:
if trans.type == "scale":
if len(trans.scale) != len(scale):
raise RuntimeError(
f"Length of scale transformation {len(trans.scale)} "
f"does not match data dimension {len(scale)}."
)
scale = [s1 * s2 for s1, s2 in zip(scale, trans.scale)]
return scale

@property
def axis_names(self) -> list[str]:
Expand Down Expand Up @@ -1010,6 +1000,103 @@ def get_axis_index(self, axis_name: str) -> int:
"""
return self.axis_names.index(axis_name.lower())

def _get_all_transforms(
self, image: str | Literal["*"]
) -> list[TransformationMeta]:
"""Get all transforms metadata
for one image array or the whole FOV.
Parameters
----------
image : str | Literal["*"]
Name of one image array (e.g. "0") to query,
or "*" for the whole FOV
Returns
-------
list[TransformationMeta]
All transforms applicable to this image or FOV.
"""
transforms: list[TransformationMeta] = (
[
t
for t in self.metadata.multiscales[
0
].coordinate_transformations
]
if self.metadata.multiscales[0].coordinate_transformations
is not None
else []
)
if image != "*" and image in self:
for i, dataset_meta in enumerate(
self.metadata.multiscales[0].datasets
):
if dataset_meta.path == image:
transforms.extend(
self.metadata.multiscales[0]
.datasets[i]
.coordinate_transformations
)
elif image != "*":
raise ValueError(f"Key {image} not recognized.")
return transforms

def get_effective_scale(
self,
image: str | Literal["*"],
) -> list[float]:
"""Get the effective coordinate scale metadata
for one image array or the whole FOV.
Parameters
----------
image : str | Literal["*"]
Name of one image array (e.g. "0") to query,
or "*" for the whole FOV
Returns
-------
list[float]
A list of floats representing the total scale
for the image or FOV for each axis.
"""
transforms = self._get_all_transforms(image)

full_scale = np.ones(len(self.axes), dtype=float)
for transform in transforms:
if transform.type == "scale":
full_scale *= np.array(transform.scale)

return [float(x) for x in full_scale]

def get_effective_translation(
self,
image: str | Literal["*"],
) -> TransformationMeta:
"""Get the effective coordinate translation metadata
for one image array or the whole FOV.
Parameters
----------
image : str | Literal["*"]
Name of one image array (e.g. "0") to query,
or "*" for the whole FOV
Returns
-------
list[float]
A list of floats representing the total translation
for the image or FOV for each axis.
"""
transforms = self._get_all_transforms(image)
full_translation = np.zeros(len(self.axes), dtype=float)
for transform in transforms:
if transform.type == "translation":
full_translation += np.array(transform.translation)

return [float(x) for x in full_translation]

def set_transform(
self,
image: str | Literal["*"],
Expand Down
94 changes: 94 additions & 0 deletions tests/ngff/test_ngff.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,100 @@ def test_set_transform_image(ch_shape_dtype, arr_name):
]


input_transformations = [
([TransformationMeta(type="identity")], []),
([TransformationMeta(type="scale", scale=(1.0, 2.0, 3.0, 4.0, 5.0))], []),
(
[
TransformationMeta(
type="translation", translation=(1.0, 2.0, 3.0, 4.0, 5.0)
)
],
[],
),
(
[
TransformationMeta(type="scale", scale=(2.0, 2.0, 2.0, 2.0, 2.0)),
TransformationMeta(
type="translation", translation=(1.0, 1.0, 1.0, 1.0, 1.0)
),
],
[
TransformationMeta(type="scale", scale=(2.0, 2.0, 2.0, 2.0, 2.0)),
TransformationMeta(
type="translation", translation=(1.0, 1.0, 1.0, 1.0, 1.0)
),
],
),
]
target_scales = [
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[4.0, 4.0, 4.0, 4.0, 4.0],
]
target_translations = [
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
]


@pytest.mark.parametrize(
"transforms",
[
(saved, target)
for saved, target in zip(input_transformations, target_scales)
],
)
@given(
ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(),
arr_name=short_alpha_numeric,
)
def test_get_effective_scale_image(transforms, ch_shape_dtype, arr_name):
"""Test `iohub.ngff.Position.get_effective_scale()`"""
(fov_transform, img_transform), expected_scale = transforms
channel_names, shape, dtype = ch_shape_dtype
with TemporaryDirectory() as temp_dir:
store_path = os.path.join(temp_dir, "ome.zarr")
with open_ome_zarr(
store_path, layout="fov", mode="w-", channel_names=channel_names
) as dataset:
dataset.create_zeros(name=arr_name, shape=shape, dtype=dtype)
dataset.set_transform(image="*", transform=fov_transform)
dataset.set_transform(image=arr_name, transform=img_transform)
scale = dataset.get_effective_scale(image=arr_name)
assert scale == expected_scale


@pytest.mark.parametrize(
"transforms",
[
(saved, target)
for saved, target in zip(input_transformations, target_translations)
],
)
@given(
ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(),
arr_name=short_alpha_numeric,
)
def test_get_effective_translation_image(transforms, ch_shape_dtype, arr_name):
"""Test `iohub.ngff.Position.get_effective_translation()`"""
(fov_transform, img_transform), expected_translation = transforms
channel_names, shape, dtype = ch_shape_dtype
with TemporaryDirectory() as temp_dir:
store_path = os.path.join(temp_dir, "ome.zarr")
with open_ome_zarr(
store_path, layout="fov", mode="w-", channel_names=channel_names
) as dataset:
dataset.create_zeros(name=arr_name, shape=shape, dtype=dtype)
dataset.set_transform(image="*", transform=fov_transform)
dataset.set_transform(image=arr_name, transform=img_transform)
translation = dataset.get_effective_translation(image=arr_name)
assert translation == expected_translation


@given(
ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(),
arr_name=short_alpha_numeric,
Expand Down

0 comments on commit c8822bc

Please sign in to comment.