Skip to content

Commit

Permalink
replace custom padding code with np.pad
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Jan 2, 2024
1 parent c6928bd commit 0fb29c8
Showing 1 changed file with 6 additions and 38 deletions.
44 changes: 6 additions & 38 deletions gunpowder/nodes/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,43 +127,11 @@ def __expand(self, a, from_roi, to_roi, value):
)

num_channels = len(a.shape) - from_roi.dims
channel_shapes = a.shape[:num_channels]
b = np.zeros(channel_shapes + to_roi.shape, dtype=a.dtype)
shift = -to_roi.offset
lower_pad = from_roi.begin - to_roi.begin
upper_pad = to_roi.end - from_roi.end
pad_width = [(0, 0)] * num_channels + list(zip(lower_pad, upper_pad))
if self.mode == "constant":
if value != 0:
b[:] = value
padded = np.pad(a, pad_width, "constant", constant_values=value)
elif self.mode == "reflect":
if a.shape == b.shape:
pass # handled later
else:
diff = Coordinate(b.shape) - Coordinate(a.shape)
slices = [
(
(slice(None),) * num_channels
+ tuple(
slice(diff[i], None) if d == 1 else slice(None, diff[i])
for i, d in enumerate(selected_dims)
),
(slice(None),) * num_channels
+ tuple(
slice(None, diff[i]) if d == 1 else slice(None)
for i, d in enumerate(selected_dims)
),
(slice(None),) * num_channels
+ tuple(
slice(None, None, -1) if d == 1 else slice(None)
for i, d in enumerate(selected_dims)
),
)
for selected_dims in product((0, 1), repeat=from_roi.dims)
]
for output_slices, input_slices, rev_slices in slices:
b[output_slices] = a[input_slices][rev_slices]

logger.debug("shifting 'from' by " + str(shift))
a_in_b = from_roi.shift(shift).to_slices()
logger.debug("target shape is " + str(b.shape))
logger.debug("target slice is " + str(a_in_b))
b[(slice(None),) * num_channels + a_in_b] = a
return b
padded = np.pad(a, pad_width, "reflect")
return padded

0 comments on commit 0fb29c8

Please sign in to comment.