Skip to content

Commit

Permalink
compatibility with minimal deskew w/ @edyoshikun
Browse files Browse the repository at this point in the history
  • Loading branch information
talonchandler committed Sep 20, 2024
1 parent 671bf3d commit 3a0887e
Showing 1 changed file with 37 additions and 41 deletions.
78 changes: 37 additions & 41 deletions iohub/ngff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ def create_empty_plate(

def apply_transform_to_czyx_and_save(
func: Callable,
position_key: str,
input_store_path: Path,
output_store_path: Path,
input_position_path: Path,
output_position_path: Path,
input_channel_indices: Union[list[int], slice],
output_channel_indices: Union[list[int], slice],
input_time_index: int,
Expand All @@ -137,12 +136,12 @@ def apply_transform_to_czyx_and_save(
func : Callable
The function to be applied to the data.
Should take a CZYX array and return a transformed CZYX array.
position_key : str
The label of the position to process, e.g. "A/1/0".
input_store_path : Path
The path to input OME-Zarr Store.
output_store_path : Path
The path to output OME-Zarr Store.
input_position_path : Path
The path to input OME-Zarr position store
(e.g., input_store_path.zarr/A/1/0).
output_position_path : Path
The path to output OME-Zarr position store
(e.g., output_store_path.zarr/A/1/0).
input_channel_indices : Union[list[int], slice]
The channel indices to process. Acceptable values:
- Slices: slice(0, 2).
Expand All @@ -163,8 +162,8 @@ def apply_transform_to_czyx_and_save(
Using slices for input_channel_indices:
apply_transform_to_zyx_and_save(
func=some_function,
position=some_position,
output_store_path=Path("/path/to/output"),
input_position_path=Path("/path/to/input.zarr/A/1/0"),
output_position_path=Path("/path/to/output.zarr/A/1/0"),
input_channel_indices=slice(0, 2),
output_channel_indices=[0],
input_time_index=0,
Expand All @@ -174,8 +173,8 @@ def apply_transform_to_czyx_and_save(
Using list for input_channel_indices:
apply_transform_to_zyx_and_save(
func=some_function,
position=some_position,
output_store_path=Path("/path/to/output"),
input_position_path=Path("/path/to/input.zarr/A/1/0"),
output_store_path=Path("/path/to/output.zarr/A/1/0"),
input_channel_indices=[0, 1, 2, 3, 4],
output_channel_indices=[0, 1, 2],
input_time_index=0,
Expand Down Expand Up @@ -203,20 +202,15 @@ def apply_transform_to_czyx_and_save(

# Process CZYX given with the given indices
# if input_channel_indices is not None and len(input_channel_indices) > 0:
click.echo(
f"""Processing t={input_time_index}
and channels {input_channel_indices}"""
)
input_dataset = open_ome_zarr(input_store_path / position_key)
click.echo(f"Processing t={input_time_index}, c={input_channel_indices}")
input_dataset = open_ome_zarr(input_position_path)
czyx_data = input_dataset.data.oindex[
input_time_index, input_channel_indices
]
if not _check_nan_n_zeros(czyx_data):
transformed_czyx = func(czyx_data, **kwargs)
# Write to file
with open_ome_zarr(
output_store_path / position_key, mode="r+"
) as output_dataset:
with open_ome_zarr(output_position_path, mode="r+") as output_dataset:
output_dataset[0].oindex[
output_time_index, output_channel_indices
] = transformed_czyx
Expand All @@ -230,9 +224,8 @@ def apply_transform_to_czyx_and_save(

def process_single_position(
func: Callable,
position_key: str,
input_store_path: Path,
output_store_path: Path,
input_position_path: Path,
output_position_path: Path,
input_channel_indices: Union[list[slice], list[list[int]]] = None,
output_channel_indices: Union[list[slice], list[list[int]]] = None,
input_time_indices: list[int] = None,
Expand All @@ -250,12 +243,12 @@ def process_single_position(
func : CZYX -> CZYX Callable
The function to be applied to the data.
Should take a CZYX array and return a transformed CZYX array.
position_key : str
The label of the position to process, e.g. "A/1/0".
input_position_path : Path
The path to the input OME-Zarr store (e.g., input_store_path.zarr).
output_store_path : Path
The path to the output OME-Zarr store (e.g., output_store_path.zarr).
The path to the input OME-Zarr position store
(e.g., input_store_path.zarr/A/1/0).
output_position_path : Path
The path to the output OME-Zarr position store
(e.g., output_store_path.zarr/A/1/0).
input_time_indices : list[int], optional
If not provided, all timepoints will be processed.
output_time_indices : list[int], optional
Expand Down Expand Up @@ -316,18 +309,20 @@ def process_single_position(
"""
# Function to be applied
click.echo(f"Function to be applied: \t{func}")
click.echo(f"Input data path:\t{input_store_path}")
click.echo(f"Output data path:\t{output_store_path}")
click.echo(f"Input data path:\t{input_position_path}")
click.echo(f"Output data path:\t{output_position_path}")

# Get the reader
with open_ome_zarr(input_store_path / position_key) as input_dataset:
with open_ome_zarr(input_position_path) as input_dataset:
input_data_shape = input_dataset.data.shape

# Process time indices
if input_time_indices is None:
input_time_indices = list(range(input_data_shape[0]))
output_time_indices = input_time_indices
assert input_time_indices is list, "input_time_indices must be a list"
assert (
type(input_time_indices) is list
), "input_time_indices must be a list"
if output_time_indices is None:
output_time_indices = input_time_indices

Expand All @@ -336,7 +331,7 @@ def process_single_position(
input_channel_indices = [[c] for c in range(input_data_shape[1])]
output_channel_indices = input_channel_indices
assert (
input_channel_indices is list
type(input_channel_indices) is list
), "input_channel_indices must be a list"
if output_channel_indices is None:
output_channel_indices = input_channel_indices
Expand All @@ -352,28 +347,29 @@ def process_single_position(

# Write extra metadata to the output store
extra_metadata = kwargs.pop("extra_metadata", None)
with open_ome_zarr(output_store_path, mode="r+") as output_dataset:
with open_ome_zarr(output_position_path, mode="r+") as output_dataset:
output_dataset.zattrs["extra_metadata"] = extra_metadata

# Loop through (T, C), applying transform and writing as we go
iterable = itertools.product(
zip(input_channel_indices, output_channel_indices),
zip(input_time_indices, output_time_indices),
)
partial_apply_transform_to_zyx_and_save = partial(
flat_iterable = ((a[0], a[1], b[0], b[1]) for a, b in iterable)

partial_apply_transform_to_czyx_and_save = partial(
apply_transform_to_czyx_and_save,
func,
position_key,
input_store_path,
output_store_path,
input_position_path,
output_position_path,
**kwargs,
)

click.echo(f"\nStarting multiprocess pool with {num_processes} processes")
with mp.Pool(num_processes) as p:
p.starmap(
partial_apply_transform_to_zyx_and_save,
iterable,
partial_apply_transform_to_czyx_and_save,
flat_iterable,
)


Expand Down

0 comments on commit 3a0887e

Please sign in to comment.