diff --git a/tests/cases/intensity_augment.py b/tests/cases/intensity_augment.py index fd8d21cb..c5f54c57 100644 --- a/tests/cases/intensity_augment.py +++ b/tests/cases/intensity_augment.py @@ -9,21 +9,47 @@ Roi, build, ) +import pytest from .helper_sources import ArraySource -def test_shift(): +@pytest.mark.parametrize("slab", [None, (1, -1, -1)]) +@pytest.mark.parametrize("z_section_wise", [None, True]) +def test_shift(slab, z_section_wise): raw_key = ArrayKey("RAW") raw_spec = ArraySpec( roi=Roi((0, 0, 0), (10, 10, 10)), voxel_size=(1, 1, 1), dtype=np.float32 ) - raw_data = np.zeros(raw_spec.roi.shape / raw_spec.voxel_size, dtype=np.float32) + raw_data = np.random.randn(*(raw_spec.roi.shape / raw_spec.voxel_size)).astype( + np.float32 + ) raw_array = Array(raw_data, raw_spec) - pipeline = ArraySource(raw_key, raw_array) + IntensityAugment( - raw_key, scale_min=0, scale_max=0, shift_min=0.5, shift_max=0.5 - ) + if z_section_wise is not None and slab is not None: + with pytest.raises(AssertionError): + pipeline = ArraySource(raw_key, raw_array) + IntensityAugment( + raw_key, + scale_min=0, + scale_max=0, + shift_min=0.5, + shift_max=0.5, + clip=False, + z_section_wise=z_section_wise, + slab=slab, + ) + return + else: + pipeline = ArraySource(raw_key, raw_array) + IntensityAugment( + raw_key, + scale_min=0, + scale_max=0, + shift_min=0.5, + shift_max=0.5, + clip=False, + z_section_wise=z_section_wise, + slab=slab, + ) request = BatchRequest() request.add(raw_key, (10, 10, 10)) @@ -32,5 +58,12 @@ def test_shift(): batch = pipeline.request_batch(request) x = batch.arrays[raw_key].data + + # subtract mean of unshifted data since intensity augment + # scales intensity from the mean + if z_section_wise is not None or slab is not None: + x -= np.mean(raw_data, axis=(1, 2), keepdims=True) + else: + x -= np.mean(raw_data) assert np.isclose(x.min(), 0.5) assert np.isclose(x.max(), 0.5)