Skip to content

Commit

Permalink
Rolls back some of my complicated changes and goes with the simpler a…
Browse files Browse the repository at this point in the history
…pproach instead.
  • Loading branch information
MoseleyS committed Sep 13, 2023
1 parent fb58424 commit bf2fce1
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions improver/nbhood/nbhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,12 @@ def _do_nbhood_sum(
size = data.size
extreme = 0
when_all_extremes = 0
half_nb_size = self.nb_size // 2
half_nb_size = (self.nb_size // 2)
for _extreme, _when_all_extremes in ((0, 0), (1, max_extreme)):
if _when_all_extremes is None:
# We can't take this shortcut
if _when_all_extremes is None or issubclass(data.dtype.type, np.complexfloating):
# We can't take this shortcut if we don't have either a default value/array,
# or the data values are complex, as comparisons with non-complex values are
# tricky.
continue
nonextreme_indices = np.argwhere(data != _extreme)
if nonextreme_indices.size == 0:
Expand All @@ -370,14 +372,6 @@ def _do_nbhood_sum(
nonextreme_indices.min(0),
nonextreme_indices.max(0) + 1,
)
if (
(_ystart - self.nb_size < 0)
or (_ystop + self.nb_size > data_shape[-2])
or (_xstart - self.nb_size < 0)
or (_xstop + self.nb_size > data_shape[-1])
):
# Cannot safely crop this domain with enough buffer to preserve the result.
continue
_ystart = max(0, _ystart - half_nb_size)
_ystop = min(data_shape[0], _ystop + half_nb_size)
_xstart = max(0, _xstart - half_nb_size)
Expand All @@ -393,34 +387,21 @@ def _do_nbhood_sum(
_xstart,
_xstop,
)
square_buffer_kwargs = {"mode": "constant", "constant_value": extreme}
if size != data.size:
# Determine default array for the extremes around the edges, or everywhere
if isinstance(when_all_extremes, np.ndarray):
untrimmed = when_all_extremes.astype(data.dtype)
else:
untrimmed = np.full(data_shape, when_all_extremes, dtype=data.dtype)
if (
self.neighbourhood_method == "square"
and min(ystart, xstart, data_shape[-2] - ystop, data_shape[-1] - xstop)
>= half_nb_size + 1
):
# Trim to the calculated box plus padding
# and rely on boxsum to return the right sized array
data = data[
ystart - half_nb_size - 1 : ystop + half_nb_size,
xstart - half_nb_size - 1 : xstop + half_nb_size,
]
square_buffer_kwargs = {}
else:
# Trim to the calculated box without padding and let boxsum do its padding.
# Circular neighbourhoods will ignore the square_buffer_kwargs.
data = data[ystart:ystop, xstart:xstop]

if size:
# Calculate neighbourhood totals for input data if it isn't all constant.
# Trim to the calculated box
data = data[ystart:ystop, xstart:xstop]

# Calculate neighbourhood totals for input data.
if self.neighbourhood_method == "square":
data = boxsum(data, self.nb_size, **square_buffer_kwargs)
data = boxsum(
data, self.nb_size, mode="constant", constant_values=extreme
)
elif self.neighbourhood_method == "circular":
data = correlate(data, self.kernel, mode="nearest")
else:
Expand Down

0 comments on commit bf2fce1

Please sign in to comment.