diff --git a/iohub/ngff/nodes.py b/iohub/ngff/nodes.py index 32ec1e7..e96a127 100644 --- a/iohub/ngff/nodes.py +++ b/iohub/ngff/nodes.py @@ -969,19 +969,9 @@ def scale(self) -> list[float]: Helper function for scale transform metadata of highest resolution scale. """ - scale = [1] * self.data.ndim - transforms = ( - self.metadata.multiscales[0].datasets[0].coordinate_transformations + return self.get_effective_scale( + self.metadata.multiscales[0].datasets[0].path ) - for trans in transforms: - if trans.type == "scale": - if len(trans.scale) != len(scale): - raise RuntimeError( - f"Length of scale transformation {len(trans.scale)} " - f"does not match data dimension {len(scale)}." - ) - scale = [s1 * s2 for s1, s2 in zip(scale, trans.scale)] - return scale @property def axis_names(self) -> list[str]: @@ -1010,6 +1000,103 @@ def get_axis_index(self, axis_name: str) -> int: """ return self.axis_names.index(axis_name.lower()) + def _get_all_transforms( + self, image: str | Literal["*"] + ) -> list[TransformationMeta]: + """Get all transforms metadata + for one image array or the whole FOV. + + Parameters + ---------- + image : str | Literal["*"] + Name of one image array (e.g. "0") to query, + or "*" for the whole FOV + + Returns + ------- + list[TransformationMeta] + All transforms applicable to this image or FOV. + """ + transforms: list[TransformationMeta] = ( + [ + t + for t in self.metadata.multiscales[ + 0 + ].coordinate_transformations + ] + if self.metadata.multiscales[0].coordinate_transformations + is not None + else [] + ) + if image != "*" and image in self: + for i, dataset_meta in enumerate( + self.metadata.multiscales[0].datasets + ): + if dataset_meta.path == image: + transforms.extend( + self.metadata.multiscales[0] + .datasets[i] + .coordinate_transformations + ) + elif image != "*": + raise ValueError(f"Key {image} not recognized.") + return transforms + + def get_effective_scale( + self, + image: str | Literal["*"], + ) -> list[float]: + """Get the effective coordinate scale metadata + for one image array or the whole FOV. + + Parameters + ---------- + image : str | Literal["*"] + Name of one image array (e.g. "0") to query, + or "*" for the whole FOV + + Returns + ------- + list[float] + A list of floats representing the total scale + for the image or FOV for each axis. + """ + transforms = self._get_all_transforms(image) + + full_scale = np.ones(len(self.axes), dtype=float) + for transform in transforms: + if transform.type == "scale": + full_scale *= np.array(transform.scale) + + return [float(x) for x in full_scale] + + def get_effective_translation( + self, + image: str | Literal["*"], + ) -> TransformationMeta: + """Get the effective coordinate translation metadata + for one image array or the whole FOV. + + Parameters + ---------- + image : str | Literal["*"] + Name of one image array (e.g. "0") to query, + or "*" for the whole FOV + + Returns + ------- + list[float] + A list of floats representing the total translation + for the image or FOV for each axis. + """ + transforms = self._get_all_transforms(image) + full_translation = np.zeros(len(self.axes), dtype=float) + for transform in transforms: + if transform.type == "translation": + full_translation += np.array(transform.translation) + + return [float(x) for x in full_translation] + def set_transform( self, image: str | Literal["*"], diff --git a/tests/ngff/test_ngff.py b/tests/ngff/test_ngff.py index cf0ce4f..1e3ff9e 100644 --- a/tests/ngff/test_ngff.py +++ b/tests/ngff/test_ngff.py @@ -485,6 +485,100 @@ def test_set_transform_image(ch_shape_dtype, arr_name): ] +input_transformations = [ + ([TransformationMeta(type="identity")], []), + ([TransformationMeta(type="scale", scale=(1.0, 2.0, 3.0, 4.0, 5.0))], []), + ( + [ + TransformationMeta( + type="translation", translation=(1.0, 2.0, 3.0, 4.0, 5.0) + ) + ], + [], + ), + ( + [ + TransformationMeta(type="scale", scale=(2.0, 2.0, 2.0, 2.0, 2.0)), + TransformationMeta( + type="translation", translation=(1.0, 1.0, 1.0, 1.0, 1.0) + ), + ], + [ + TransformationMeta(type="scale", scale=(2.0, 2.0, 2.0, 2.0, 2.0)), + TransformationMeta( + type="translation", translation=(1.0, 1.0, 1.0, 1.0, 1.0) + ), + ], + ), +] +target_scales = [ + [1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 1.0, 1.0, 1.0, 1.0], + [4.0, 4.0, 4.0, 4.0, 4.0], +] +target_translations = [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [2.0, 2.0, 2.0, 2.0, 2.0], +] + + +@pytest.mark.parametrize( + "transforms", + [ + (saved, target) + for saved, target in zip(input_transformations, target_scales) + ], +) +@given( + ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(), + arr_name=short_alpha_numeric, +) +def test_get_effective_scale_image(transforms, ch_shape_dtype, arr_name): + """Test `iohub.ngff.Position.get_effective_scale()`""" + (fov_transform, img_transform), expected_scale = transforms + channel_names, shape, dtype = ch_shape_dtype + with TemporaryDirectory() as temp_dir: + store_path = os.path.join(temp_dir, "ome.zarr") + with open_ome_zarr( + store_path, layout="fov", mode="w-", channel_names=channel_names + ) as dataset: + dataset.create_zeros(name=arr_name, shape=shape, dtype=dtype) + dataset.set_transform(image="*", transform=fov_transform) + dataset.set_transform(image=arr_name, transform=img_transform) + scale = dataset.get_effective_scale(image=arr_name) + assert scale == expected_scale + + +@pytest.mark.parametrize( + "transforms", + [ + (saved, target) + for saved, target in zip(input_transformations, target_translations) + ], +) +@given( + ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(), + arr_name=short_alpha_numeric, +) +def test_get_effective_translation_image(transforms, ch_shape_dtype, arr_name): + """Test `iohub.ngff.Position.get_effective_translation()`""" + (fov_transform, img_transform), expected_translation = transforms + channel_names, shape, dtype = ch_shape_dtype + with TemporaryDirectory() as temp_dir: + store_path = os.path.join(temp_dir, "ome.zarr") + with open_ome_zarr( + store_path, layout="fov", mode="w-", channel_names=channel_names + ) as dataset: + dataset.create_zeros(name=arr_name, shape=shape, dtype=dtype) + dataset.set_transform(image="*", transform=fov_transform) + dataset.set_transform(image=arr_name, transform=img_transform) + translation = dataset.get_effective_translation(image=arr_name) + assert translation == expected_translation + + @given( ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(), arr_name=short_alpha_numeric,