Skip to content

Commit

Permalink
Reduces size of data array handled in nbhood to exclude rows and colu…
Browse files Browse the repository at this point in the history
…mns of zeros.
  • Loading branch information
MoseleyS committed Sep 6, 2023
1 parent cd2c901 commit b1519e5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
30 changes: 28 additions & 2 deletions improver/nbhood/nbhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,22 @@ def _calculate_neighbourhood(
Array containing the smoothed field after the
neighbourhood method has been applied.
"""
# Determine the smallest box containing all non-zero values with a neighbourhood-sized
# buffer and quit if there are none.
data_shape = data.shape
nonzero_indices = np.argwhere(data)
if nonzero_indices.size == 0:
# No non-zero values, so just return data
return data
(ystart, xstart), (ystop, xstop) = (
nonzero_indices.min(0),
nonzero_indices.max(0) + 1,
)
ystart = max(0, ystart - self.nb_size)
ystop = min(data_shape[0], ystop + self.nb_size)
xstart = max(0, xstart - self.nb_size)
xstop = min(data_shape[1], xstop + self.nb_size)

if not self.sum_only:
min_val = np.nanmin(data)
max_val = np.nanmax(data)
Expand All @@ -307,6 +323,11 @@ def _calculate_neighbourhood(
valid_data_mask = np.ones(data.shape, dtype=np.int64)
valid_data_mask[data_mask] = 0
data[data_mask] = 0

# Trim to the calculated box
data = data[ystart:ystop, xstart:xstop]
valid_data_mask = valid_data_mask[ystart:ystop, xstart:xstop]

# Calculate neighbourhood totals for input data.
if self.neighbourhood_method == "square":
data = boxsum(data, self.nb_size, mode="constant")
Expand Down Expand Up @@ -336,6 +357,12 @@ def _calculate_neighbourhood(
data_dtype = np.float32
data = data.astype(data_dtype)

# Expand data to the full size again
if data.shape != data_shape:
untrimmed = np.zeros(data_shape, dtype=data_dtype)
untrimmed[ystart:ystop, xstart:xstop] = data
data = untrimmed

if self.re_mask:
data = np.ma.masked_array(data, data_mask, copy=False)

Expand Down Expand Up @@ -378,8 +405,7 @@ def process(self, cube: Cube, mask_cube: Optional[Cube] = None) -> Cube:
grid_cells = distance_to_number_of_grid_cells(cube, self.radius)
if self.neighbourhood_method == "circular":
self.kernel = circular_kernel(grid_cells, self.weighted_mode)
elif self.neighbourhood_method == "square":
self.nb_size = 2 * grid_cells + 1
self.nb_size = 2 * grid_cells + 1

try:
mask_cube_data = mask_cube.data
Expand Down
2 changes: 1 addition & 1 deletion improver/nbhood/use_nbhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def process(self, cube: Cube, mask_cube: Cube) -> Cube:
prev_x_y_slice = x_y_slice

cube_slices = iris.cube.CubeList([])
# Apply each mask in in mask_cube to the 2D input slice.
# Apply each mask in mask_cube to the 2D input slice.
for mask_slice in mask_cube.slices_over(self.coord_for_masking):
output_cube = plugin(x_y_slice, mask_cube=mask_slice)
coord_object = mask_slice.coord(self.coord_for_masking).copy()
Expand Down
4 changes: 4 additions & 0 deletions improver_tests/nbhood/nbhood/test_NeighbourhoodProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_basic_circular(self):
)
plugin = NeighbourhoodProcessing("circular", self.RADIUS)
plugin.kernel = self.circular_kernel
plugin.nb_size = self.nbhood_size
result = plugin._calculate_neighbourhood(self.data)
self.assertArrayAlmostEqual(result.data, expected_array)

Expand All @@ -165,6 +166,7 @@ def test_basic_weighted_circular(self):
)
plugin = NeighbourhoodProcessing("circular", self.RADIUS)
plugin.kernel = self.weighted_circular_kernel
plugin.nb_size = self.nbhood_size
result = plugin._calculate_neighbourhood(self.data)
self.assertArrayAlmostEqual(result.data, expected_array)

Expand Down Expand Up @@ -199,6 +201,7 @@ def test_basic_circular_sum(self):
)
plugin = NeighbourhoodProcessing("circular", self.RADIUS, sum_only=True)
plugin.kernel = self.circular_kernel
plugin.nb_size = self.nbhood_size
result = plugin._calculate_neighbourhood(self.data)
self.assertArrayAlmostEqual(result.data, expected_array)

Expand Down Expand Up @@ -229,6 +232,7 @@ def test_masked_array_re_mask_true_circular(self):
input_data = np.ma.masked_where(self.mask == 0, self.data_for_masked_tests)
plugin = NeighbourhoodProcessing("circular", self.RADIUS)
plugin.kernel = self.circular_kernel
plugin.nb_size = self.nbhood_size
result = plugin._calculate_neighbourhood(input_data)

self.assertArrayAlmostEqual(result.data, expected_array)
Expand Down

0 comments on commit b1519e5

Please sign in to comment.