diff --git a/gunpowder/nodes/pad.py b/gunpowder/nodes/pad.py index 827825d8..81150835 100644 --- a/gunpowder/nodes/pad.py +++ b/gunpowder/nodes/pad.py @@ -7,6 +7,8 @@ from gunpowder.coordinate import Coordinate from gunpowder.batch_request import BatchRequest +from itertools import product + logger = logging.getLogger(__name__) @@ -136,25 +138,29 @@ def __expand(self, a, from_roi, to_roi, value): pass # handled later else: diff = Coordinate(b.shape) - Coordinate(a.shape) - if len(diff) == 3: # (C Y X) - b[:, : diff[1], diff[2] :] = a[:, : diff[1], :][:, ::-1, :] # Y - b[:, diff[1] :, : diff[2]] = a[:, :, : diff[2]][:, :, ::-1] # X - b[:, : diff[1], : diff[2]] = a[:, : diff[1], : diff[2]][ - :, ::-1, ::-1 - ] - elif len(diff) == 4: # (C Z Y X) - b[:, : diff[1], diff[2] :, diff[3] :] = a[:, : diff[1], :, :][ - :, ::-1, :, : - ] # Z - b[:, diff[1] :, : diff[2], diff[3] :] = a[:, :, : diff[2], :][ - :, :, ::-1, : - ] # Y - b[:, diff[1] :, diff[2] :, : diff[3]] = a[:, :, :, : diff[3]][ - :, :, :, ::-1 - ] # X - b[:, : diff[1], : diff[2], : diff[3]] = a[ - :, : diff[1], : diff[2], : diff[3] - ][:, ::-1, ::-1, ::-1] + slices = [ + ( + (slice(None),) * num_channels + + tuple( + slice(diff[i], None) if d == 1 else slice(None, diff[i]) + for i, d in enumerate(selected_dims) + ), + (slice(None),) * num_channels + + tuple( + slice(None, diff[i]) if d == 1 else slice(None) + for i, d in enumerate(selected_dims) + ), + (slice(None),) * num_channels + + tuple( + slice(None, None, -1) if d == 1 else slice(None) + for i, d in enumerate(selected_dims) + ), + ) + for selected_dims in product((0, 1), repeat=from_roi.dims) + ] + for output_slices, input_slices, rev_slices in slices: + b[output_slices] = a[input_slices][rev_slices] + logger.debug("shifting 'from' by " + str(shift)) a_in_b = from_roi.shift(shift).to_slices() logger.debug("target shape is " + str(b.shape)) diff --git a/tests/cases/pad.py b/tests/cases/pad.py index 2ee8968d..c66332fe 100644 --- a/tests/cases/pad.py +++ b/tests/cases/pad.py @@ -68,4 +68,8 @@ def test_output(mode): ) elif mode == "reflect": octants = [100 * 1 * 5 * 10 for _ in range(8)] - assert np.sum(data) == np.sum(octants), data.shape + assert np.sum(data) == np.sum(octants), ( + np.sum(data), + np.sum(octants), + data, + )