From 961995ca5458a524955ce410dfff92607c38d8d1 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 23 Sep 2024 08:40:23 -0700 Subject: [PATCH] -attempt to fix apply_transform_test. @ieivanov revert if needed --- iohub/ngff/utils.py | 16 ++-- tests/ngff/test_ngff_utils.py | 152 +++++++++++++++++++++++----------- 2 files changed, 111 insertions(+), 57 deletions(-) diff --git a/iohub/ngff/utils.py b/iohub/ngff/utils.py index 84d6ec65..60b813e0 100644 --- a/iohub/ngff/utils.py +++ b/iohub/ngff/utils.py @@ -184,14 +184,14 @@ def apply_transform_to_czyx_and_save( """ # TODO: temporary fix to slumkit issue - if _is_nested(input_channel_indices): - input_channel_indices = [ - int(x) for x in input_channel_indices if x.isdigit() - ] - if _is_nested(output_channel_indices): - output_channel_indices = [ - int(x) for x in output_channel_indices if x.isdigit() - ] + # if _is_nested(input_channel_indices): + # input_channel_indices = [ + # int(x) for x in input_channel_indices if x.isdigit() + # ] + # if _is_nested(output_channel_indices): + # output_channel_indices = [ + # int(x) for x in output_channel_indices if x.isdigit() + # ] # Check if input_time_indices should be added to the func kwargs # This is needed when a different processing is needed for each time point, diff --git a/tests/ngff/test_ngff_utils.py b/tests/ngff/test_ngff_utils.py index 8f1512ad..f9c92148 100644 --- a/tests/ngff/test_ngff_utils.py +++ b/tests/ngff/test_ngff_utils.py @@ -218,6 +218,68 @@ def plate_setup(draw): return position_keys, channel_names, shape, chunks, scale, dtype +@st.composite +def apply_transform_czyx_setup(draw): + """ + Composite strategy to generate plate setup + along with valid channel and time indices + Returns + ------- + Tuple containing: + - position_keys + - channel_names + - shape + - chunks + - scale + - dtype + - channel_indices + - time_indices + """ + # Generate plate setup parameters + position_keys, channel_names, shape, chunks, scale, dtype = draw( + plate_setup() + ) + T, C = shape[:2] + + # Define a helper strategy to generate channel indices based on C + channel_indices_strategy = st.one_of( + st.builds( + slice, + st.integers(min_value=0, max_value=0), + st.integers(min_value=1, max_value=C), + st.just(1), + ), + st.lists( + st.integers(min_value=0, max_value=C - 1), + min_size=1, + max_size=min(3, C), + ), + ) + + time_indices_strategy = st.one_of( + st.lists( + st.integers(min_value=0, max_value=T - 1), + min_size=1, + max_size=min(3, T), + ), + ) + + # Generate input and output channel indices based on C + channel_indices = draw(channel_indices_strategy) + time_indices = draw(time_indices_strategy) + + return ( + position_keys, + channel_names, + shape, + chunks, + scale, + dtype, + channel_indices, + time_indices, + ) + + @st.composite def process_single_position_setup(draw): """ @@ -233,21 +295,20 @@ def process_single_position_setup(draw): - chunks - scale - dtype - - input_channel_indices - - output_channel_indices - - input_time_indices - - output_time_indices + - channel_indices + - time_indices """ # Generate plate setup parameters position_keys, channel_names, shape, chunks, scale, dtype = draw( plate_setup() ) + T, C = shape[:2] # Define a helper strategy to generate channel indices based on C channel_indices_strategy = st.one_of( st.none(), - st.lists(st.slices(size=C), min_size=1, max_size=min(3, C)), + st.lists(st.slices(size=C + 1), min_size=1, max_size=min(3, C)), st.lists( st.lists(st.integers(min_value=0, max_value=C - 1)), min_size=1, @@ -265,10 +326,8 @@ def process_single_position_setup(draw): ) # Generate input and output channel indices based on C - input_channel_indices = draw(channel_indices_strategy) - output_channel_indices = draw(channel_indices_strategy) - input_time_indices = draw(time_indices_strategy) - output_time_indices = draw(time_indices_strategy) + channel_indices = draw(channel_indices_strategy) + time_indices = draw(time_indices_strategy) return ( position_keys, @@ -277,10 +336,8 @@ def process_single_position_setup(draw): chunks, scale, dtype, - input_channel_indices, - output_channel_indices, - input_time_indices, - output_time_indices, + channel_indices, + time_indices, ) @@ -319,7 +376,10 @@ def verify_transformation( output_store_path: Path, position_key_tuple: Tuple[str, str, str], shape: Tuple[int, ...], + time_indices: list[int], + channel_indices: list[int], transform_func, + **kwargs, ): with open_ome_zarr(input_store_path) as input_dataset, open_ome_zarr( output_store_path @@ -329,18 +389,17 @@ def verify_transformation( output_position = output_dataset[position_key_tuple] T, C, Z, Y, X = shape - for t in range(T): - for c in range(C): - input_data = input_position.data.oindex[t, c][:] - output_data = output_position.data.oindex[t, c][:] - - expected_data = transform_func(input_data) - np.testing.assert_array_almost_equal( - output_data, - expected_data, - err_msg=f"Mismatch in position \ - {position_key_tuple}, time {t}, channel {c}.", - ) + + for t_in in time_indices: + input_data = input_position.data.oindex[t_in, channel_indices][:] + output_data = output_position.data.oindex[t_in, channel_indices][:] + expected_data = transform_func(input_data, **kwargs) + np.testing.assert_array_almost_equal( + output_data, + expected_data, + err_msg=f"Mismatch in position \ + {position_key_tuple}, time {t_in}", + ) @given( @@ -414,10 +473,10 @@ def test_create_empty_plate(plate_setup, extra_channels): @given( - setup=process_single_position_setup(), + setup=apply_transform_czyx_setup(), constant=st.integers(min_value=1, max_value=5), ) -@settings(max_examples=3) +@settings(max_examples=1, deadline=1200) def test_apply_transform_to_zyx_and_save(setup, constant): ( position_keys, @@ -426,17 +485,10 @@ def test_apply_transform_to_zyx_and_save(setup, constant): chunks, scale, dtype, - input_channel_indices, - output_channel_indices, - input_time_indices, - output_time_indices, + channel_indices, + time_indices, ) = setup - input_channel_indices = input_channel_indices[0] - output_channel_indices = output_channel_indices[0] - input_time_indices = input_time_indices[0] - output_time_indices = output_time_indices[0] - # Use the enhanced context manager to get both input and output store paths with _temp_ome_zarr_stores( position_keys=position_keys, @@ -458,15 +510,15 @@ def test_apply_transform_to_zyx_and_save(setup, constant): *position_key_tuple ) - for t_in, t_out in zip(input_time_indices, output_time_indices): + for t_in in time_indices: apply_transform_to_czyx_and_save( func=dummy_transform, input_position_path=Path(input_position_path), output_position_path=Path(output_position_path), - input_channel_indices=input_channel_indices, - output_channel_indices=output_channel_indices, + input_channel_indices=channel_indices, + output_channel_indices=channel_indices, input_time_index=t_in, - output_time_index=t_out, + output_time_index=t_in, **kwargs, ) @@ -476,16 +528,19 @@ def test_apply_transform_to_zyx_and_save(setup, constant): output_store_path, position_key_tuple, shape, + time_indices, + channel_indices, dummy_transform, + **kwargs, ) -@settings(max_examples=3) @given( setup=process_single_position_setup(), constant=st.integers(min_value=1, max_value=5), num_processes=st.integers(min_value=1, max_value=4), ) +@settings(max_examples=3, deadline=1200) def test_process_single_position(setup, constant, num_processes): ( position_keys, @@ -494,10 +549,8 @@ def test_process_single_position(setup, constant, num_processes): chunks, scale, dtype, - input_channel_indices, - output_channel_indices, - input_time_indices, - output_time_indices, + channel_indices, + time_indices, ) = setup # Use the enhanced context manager to get both input and output store paths @@ -528,10 +581,10 @@ def test_process_single_position(setup, constant, num_processes): func=dummy_transform, input_position_path=input_position_path, output_position_path=output_position_path, - input_channel_indices=input_channel_indices, - output_channel_indices=output_channel_indices, - input_time_indices=input_time_indices, - output_time_indices=output_time_indices, + input_channel_indices=channel_indices, + output_channel_indices=channel_indices, + input_time_indices=time_indices, + output_time_indices=time_indices, num_processes=num_processes, **kwargs, ) @@ -543,4 +596,5 @@ def test_process_single_position(setup, constant, num_processes): position_key_tuple, shape, dummy_transform, + **kwargs, )