Skip to content

Commit

Permalink
pass the fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Dec 19, 2023
1 parent a7027c6 commit 3782525
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
44 changes: 25 additions & 19 deletions gunpowder/nodes/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from gunpowder.coordinate import Coordinate
from gunpowder.batch_request import BatchRequest

from itertools import product

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -136,25 +138,29 @@ def __expand(self, a, from_roi, to_roi, value):
pass # handled later
else:
diff = Coordinate(b.shape) - Coordinate(a.shape)
if len(diff) == 3: # (C Y X)
b[:, : diff[1], diff[2] :] = a[:, : diff[1], :][:, ::-1, :] # Y
b[:, diff[1] :, : diff[2]] = a[:, :, : diff[2]][:, :, ::-1] # X
b[:, : diff[1], : diff[2]] = a[:, : diff[1], : diff[2]][
:, ::-1, ::-1
]
elif len(diff) == 4: # (C Z Y X)
b[:, : diff[1], diff[2] :, diff[3] :] = a[:, : diff[1], :, :][
:, ::-1, :, :
] # Z
b[:, diff[1] :, : diff[2], diff[3] :] = a[:, :, : diff[2], :][
:, :, ::-1, :
] # Y
b[:, diff[1] :, diff[2] :, : diff[3]] = a[:, :, :, : diff[3]][
:, :, :, ::-1
] # X
b[:, : diff[1], : diff[2], : diff[3]] = a[
:, : diff[1], : diff[2], : diff[3]
][:, ::-1, ::-1, ::-1]
slices = [
(
(slice(None),) * num_channels
+ tuple(
slice(diff[i], None) if d == 1 else slice(None, diff[i])
for i, d in enumerate(selected_dims)
),
(slice(None),) * num_channels
+ tuple(
slice(None, diff[i]) if d == 1 else slice(None)
for i, d in enumerate(selected_dims)
),
(slice(None),) * num_channels
+ tuple(
slice(None, None, -1) if d == 1 else slice(None)
for i, d in enumerate(selected_dims)
),
)
for selected_dims in product((0, 1), repeat=from_roi.dims)
]
for output_slices, input_slices, rev_slices in slices:
b[output_slices] = a[input_slices][rev_slices]

logger.debug("shifting 'from' by " + str(shift))
a_in_b = from_roi.shift(shift).to_slices()
logger.debug("target shape is " + str(b.shape))
Expand Down
6 changes: 5 additions & 1 deletion tests/cases/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,8 @@ def test_output(mode):
)
elif mode == "reflect":
octants = [100 * 1 * 5 * 10 for _ in range(8)]
assert np.sum(data) == np.sum(octants), data.shape
assert np.sum(data) == np.sum(octants), (
np.sum(data),
np.sum(octants),
data,
)

0 comments on commit 3782525

Please sign in to comment.