Skip to content

Commit

Permalink
Increases shrinking with padding test array sizes to ensure we take t…
Browse files Browse the repository at this point in the history
…he right pathway through the code. Adjusts how the shrinking works when padding is also needed to ensure we always use as much of the original data as we need to.
  • Loading branch information
MoseleyS committed Sep 13, 2023
1 parent 83e4451 commit 7548887
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 19 deletions.
35 changes: 29 additions & 6 deletions improver/nbhood/nbhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _do_nbhood_sum(
ystop, xstop = data.shape
size = data.size
when_all_extremes = 0
half_nb_size = (self.nb_size // 2) + 2 # rounded up
half_nb_size = self.nb_size // 2
for _extreme, _when_all_extremes in ((0, 0), (1, max_extreme)):
if _when_all_extremes is None:
# We can't take this shortcut
Expand All @@ -369,6 +369,14 @@ def _do_nbhood_sum(
nonextreme_indices.min(0),
nonextreme_indices.max(0) + 1,
)
if (
(_ystart - self.nb_size < 0)
or (_ystop + self.nb_size > data_shape[-2])
or (_xstart - self.nb_size < 0)
or (_xstop + self.nb_size > data_shape[-1])
):
# Cannot safely crop this domain with enough buffer to preserve the result.
continue
_ystart = max(0, _ystart - half_nb_size)
_ystop = min(data_shape[0], _ystop + half_nb_size)
_xstart = max(0, _xstart - half_nb_size)
Expand All @@ -383,19 +391,34 @@ def _do_nbhood_sum(
_xstart,
_xstop,
)
square_buffer_kwargs = {"mode": "constant"}
if size != data.size:
# Determine default array for the extremes around the edges, or everywhere
if isinstance(when_all_extremes, np.ndarray):
untrimmed = when_all_extremes.astype(data.dtype)
else:
untrimmed = np.full(data_shape, when_all_extremes, dtype=data.dtype)
if size:
# Trim to the calculated box
data = data[ystart:ystop, xstart:xstop]
if (
self.neighbourhood_method == "square"
and min(ystart, xstart, data_shape[-2] - ystop, data_shape[-1] - xstop)
>= half_nb_size + 1
):
# Trim to the calculated box plus padding
# and rely on boxsum to return the right sized array
data = data[
ystart - half_nb_size - 1 : ystop + half_nb_size,
xstart - half_nb_size - 1 : xstop + half_nb_size,
]
square_buffer_kwargs = {}
else:
# Trim to the calculated box without padding and let boxsum do its padding.
# Circular neighbourhoods will ignore the square_buffer_kwargs.
data = data[ystart:ystop, xstart:xstop]

# Calculate neighbourhood totals for input data.
if size:
# Calculate neighbourhood totals for input data if it isn't all constant.
if self.neighbourhood_method == "square":
data = boxsum(data, self.nb_size, mode="constant")
data = boxsum(data, self.nb_size, **square_buffer_kwargs)
elif self.neighbourhood_method == "circular":
data = correlate(data, self.kernel, mode="nearest")
else:
Expand Down
26 changes: 13 additions & 13 deletions improver_tests/nbhood/nbhood/test_NeighbourhoodProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@ def test_annulus_square(self):
"""Test the _calculate_neighbourhood method with a square neighbourhood where the data
are ones with a central block of zeros, which will trigger the array-shrinking optimisation
AND the padding method."""
data = np.ones((8, 8), dtype=self.data.dtype)
data[3:5, 3:5] = 0
expected_array = np.ones((8, 8), dtype=self.data.dtype)
expected_array[3:5, 3:5] = 5 / 9 # centre
expected_array[2::3, 3:5] = 7 / 9 # edges (y)
expected_array[3:5, 2::3] = 7 / 9 # edges (x)
expected_array[2::3, 2::3] = 8 / 9 # corners
data = np.ones((10, 10), dtype=self.data.dtype)
data[4:6, 4:6] = 0
expected_array = np.ones_like(data, dtype=self.data.dtype)
expected_array[4:6, 4:6] = 5 / 9 # centre
expected_array[3:7:3, 4:6] = 7 / 9 # edges (y)
expected_array[4:6, 3:7:3] = 7 / 9 # edges (x)
expected_array[3:7:3, 3:7:3] = 8 / 9 # corners
plugin = NeighbourhoodProcessing("square", self.RADIUS)
plugin.nb_size = self.nbhood_size
result = plugin._calculate_neighbourhood(data)
Expand All @@ -246,12 +246,12 @@ def test_annulus_circular(self):
"""Test the _calculate_neighbourhood method with a circular neighbourhood where the data
are ones with a central block of zeros, which will trigger the array-shrinking optimisation
AND the padding method."""
data = np.ones((8, 8), dtype=self.data.dtype)
data[3:5, 3:5] = 0
expected_array = np.ones((8, 8), dtype=self.data.dtype)
expected_array[3:5, 3:5] = 0.4 # centre
expected_array[2::3, 3:5] = 0.8 # edges (y)
expected_array[3:5, 2::3] = 0.8 # edges (x)
data = np.ones((10, 10), dtype=self.data.dtype)
data[4:6, 4:6] = 0
expected_array = np.ones_like(data, dtype=self.data.dtype)
expected_array[4:6, 4:6] = 0.4 # centre
expected_array[3:7:3, 4:6] = 0.8 # edges (y)
expected_array[4:6, 3:7:3] = 0.8 # edges (x)
plugin = NeighbourhoodProcessing("circular", self.RADIUS)
plugin.kernel = self.circular_kernel
plugin.nb_size = self.nbhood_size
Expand Down

0 comments on commit 7548887

Please sign in to comment.