Skip to content

Commit

Permalink
-attempt to fix apply_transform_test. @ieivanov revert if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Sep 23, 2024
1 parent 9ee6acb commit 961995c
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 57 deletions.
16 changes: 8 additions & 8 deletions iohub/ngff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@ def apply_transform_to_czyx_and_save(
"""

# TODO: temporary fix to slumkit issue
if _is_nested(input_channel_indices):
input_channel_indices = [
int(x) for x in input_channel_indices if x.isdigit()
]
if _is_nested(output_channel_indices):
output_channel_indices = [
int(x) for x in output_channel_indices if x.isdigit()
]
# if _is_nested(input_channel_indices):
# input_channel_indices = [
# int(x) for x in input_channel_indices if x.isdigit()
# ]
# if _is_nested(output_channel_indices):
# output_channel_indices = [
# int(x) for x in output_channel_indices if x.isdigit()
# ]

# Check if input_time_indices should be added to the func kwargs
# This is needed when a different processing is needed for each time point,
Expand Down
152 changes: 103 additions & 49 deletions tests/ngff/test_ngff_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,68 @@ def plate_setup(draw):
return position_keys, channel_names, shape, chunks, scale, dtype


@st.composite
def apply_transform_czyx_setup(draw):
"""
Composite strategy to generate plate setup
along with valid channel and time indices
Returns
-------
Tuple containing:
- position_keys
- channel_names
- shape
- chunks
- scale
- dtype
- channel_indices
- time_indices
"""
# Generate plate setup parameters
position_keys, channel_names, shape, chunks, scale, dtype = draw(
plate_setup()
)
T, C = shape[:2]

# Define a helper strategy to generate channel indices based on C
channel_indices_strategy = st.one_of(
st.builds(
slice,
st.integers(min_value=0, max_value=0),
st.integers(min_value=1, max_value=C),
st.just(1),
),
st.lists(
st.integers(min_value=0, max_value=C - 1),
min_size=1,
max_size=min(3, C),
),
)

time_indices_strategy = st.one_of(
st.lists(
st.integers(min_value=0, max_value=T - 1),
min_size=1,
max_size=min(3, T),
),
)

# Generate input and output channel indices based on C
channel_indices = draw(channel_indices_strategy)
time_indices = draw(time_indices_strategy)

return (
position_keys,
channel_names,
shape,
chunks,
scale,
dtype,
channel_indices,
time_indices,
)


@st.composite
def process_single_position_setup(draw):
"""
Expand All @@ -233,21 +295,20 @@ def process_single_position_setup(draw):
- chunks
- scale
- dtype
- input_channel_indices
- output_channel_indices
- input_time_indices
- output_time_indices
- channel_indices
- time_indices
"""
# Generate plate setup parameters
position_keys, channel_names, shape, chunks, scale, dtype = draw(
plate_setup()
)

T, C = shape[:2]

# Define a helper strategy to generate channel indices based on C
channel_indices_strategy = st.one_of(
st.none(),
st.lists(st.slices(size=C), min_size=1, max_size=min(3, C)),
st.lists(st.slices(size=C + 1), min_size=1, max_size=min(3, C)),
st.lists(
st.lists(st.integers(min_value=0, max_value=C - 1)),
min_size=1,
Expand All @@ -265,10 +326,8 @@ def process_single_position_setup(draw):
)

# Generate input and output channel indices based on C
input_channel_indices = draw(channel_indices_strategy)
output_channel_indices = draw(channel_indices_strategy)
input_time_indices = draw(time_indices_strategy)
output_time_indices = draw(time_indices_strategy)
channel_indices = draw(channel_indices_strategy)
time_indices = draw(time_indices_strategy)

return (
position_keys,
Expand All @@ -277,10 +336,8 @@ def process_single_position_setup(draw):
chunks,
scale,
dtype,
input_channel_indices,
output_channel_indices,
input_time_indices,
output_time_indices,
channel_indices,
time_indices,
)


Expand Down Expand Up @@ -319,7 +376,10 @@ def verify_transformation(
output_store_path: Path,
position_key_tuple: Tuple[str, str, str],
shape: Tuple[int, ...],
time_indices: list[int],
channel_indices: list[int],
transform_func,
**kwargs,
):
with open_ome_zarr(input_store_path) as input_dataset, open_ome_zarr(
output_store_path
Expand All @@ -329,18 +389,17 @@ def verify_transformation(
output_position = output_dataset[position_key_tuple]

T, C, Z, Y, X = shape
for t in range(T):
for c in range(C):
input_data = input_position.data.oindex[t, c][:]
output_data = output_position.data.oindex[t, c][:]

expected_data = transform_func(input_data)
np.testing.assert_array_almost_equal(
output_data,
expected_data,
err_msg=f"Mismatch in position \
{position_key_tuple}, time {t}, channel {c}.",
)

for t_in in time_indices:
input_data = input_position.data.oindex[t_in, channel_indices][:]
output_data = output_position.data.oindex[t_in, channel_indices][:]
expected_data = transform_func(input_data, **kwargs)
np.testing.assert_array_almost_equal(
output_data,
expected_data,
err_msg=f"Mismatch in position \
{position_key_tuple}, time {t_in}",
)


@given(
Expand Down Expand Up @@ -414,10 +473,10 @@ def test_create_empty_plate(plate_setup, extra_channels):


@given(
setup=process_single_position_setup(),
setup=apply_transform_czyx_setup(),
constant=st.integers(min_value=1, max_value=5),
)
@settings(max_examples=3)
@settings(max_examples=1, deadline=1200)
def test_apply_transform_to_zyx_and_save(setup, constant):
(
position_keys,
Expand All @@ -426,17 +485,10 @@ def test_apply_transform_to_zyx_and_save(setup, constant):
chunks,
scale,
dtype,
input_channel_indices,
output_channel_indices,
input_time_indices,
output_time_indices,
channel_indices,
time_indices,
) = setup

input_channel_indices = input_channel_indices[0]
output_channel_indices = output_channel_indices[0]
input_time_indices = input_time_indices[0]
output_time_indices = output_time_indices[0]

# Use the enhanced context manager to get both input and output store paths
with _temp_ome_zarr_stores(
position_keys=position_keys,
Expand All @@ -458,15 +510,15 @@ def test_apply_transform_to_zyx_and_save(setup, constant):
*position_key_tuple
)

for t_in, t_out in zip(input_time_indices, output_time_indices):
for t_in in time_indices:
apply_transform_to_czyx_and_save(
func=dummy_transform,
input_position_path=Path(input_position_path),
output_position_path=Path(output_position_path),
input_channel_indices=input_channel_indices,
output_channel_indices=output_channel_indices,
input_channel_indices=channel_indices,
output_channel_indices=channel_indices,
input_time_index=t_in,
output_time_index=t_out,
output_time_index=t_in,
**kwargs,
)

Expand All @@ -476,16 +528,19 @@ def test_apply_transform_to_zyx_and_save(setup, constant):
output_store_path,
position_key_tuple,
shape,
time_indices,
channel_indices,
dummy_transform,
**kwargs,
)


@settings(max_examples=3)
@given(
setup=process_single_position_setup(),
constant=st.integers(min_value=1, max_value=5),
num_processes=st.integers(min_value=1, max_value=4),
)
@settings(max_examples=3, deadline=1200)
def test_process_single_position(setup, constant, num_processes):
(
position_keys,
Expand All @@ -494,10 +549,8 @@ def test_process_single_position(setup, constant, num_processes):
chunks,
scale,
dtype,
input_channel_indices,
output_channel_indices,
input_time_indices,
output_time_indices,
channel_indices,
time_indices,
) = setup

# Use the enhanced context manager to get both input and output store paths
Expand Down Expand Up @@ -528,10 +581,10 @@ def test_process_single_position(setup, constant, num_processes):
func=dummy_transform,
input_position_path=input_position_path,
output_position_path=output_position_path,
input_channel_indices=input_channel_indices,
output_channel_indices=output_channel_indices,
input_time_indices=input_time_indices,
output_time_indices=output_time_indices,
input_channel_indices=channel_indices,
output_channel_indices=channel_indices,
input_time_indices=time_indices,
output_time_indices=time_indices,
num_processes=num_processes,
**kwargs,
)
Expand All @@ -543,4 +596,5 @@ def test_process_single_position(setup, constant, num_processes):
position_key_tuple,
shape,
dummy_transform,
**kwargs,
)

0 comments on commit 961995c

Please sign in to comment.