Skip to content

Commit

Permalink
Adds a circular kernel overlapping the edge to check what is going on…
Browse files Browse the repository at this point in the history
… there. Ensures that the `when_all_extremes` array is copied if it is used to prevent it being modified in place.
  • Loading branch information
MoseleyS committed Sep 7, 2023
1 parent 68f4a3d commit 30e5a76
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion improver/nbhood/nbhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def _do_nbhood_sum(
if size != data.size:
# Determine default array for the extremes around the edges, or everywhere
if isinstance(when_all_extremes, np.ndarray):
untrimmed = when_all_extremes
untrimmed = when_all_extremes.copy()
else:
untrimmed = np.full(data_shape, when_all_extremes)
if size:
Expand Down
21 changes: 21 additions & 0 deletions improver_tests/nbhood/nbhood/test_NeighbourhoodProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,27 @@ def test_basic_circular(self):
result = plugin._calculate_neighbourhood(self.data)
self.assertArrayAlmostEqual(result.data, expected_array)

def test_edge_circular(self):
"""Test the _calculate_neighbourhood method with a circular neighbourhood that crosses the
edge. The zero is now in the left column and the "nearest" method means that this zero
is repeated in the sum, so the final calculation is 3 / 5 instead of 4 / 5."""
data = np.ones_like(self.data)
data[:, :3] = self.data[:, 2:]
expected_array = np.array(
[
[1.0, 1.0, 1.0, 1.0, 1.0],
[0.8, 1.0, 1.0, 1.0, 1.0],
[0.6, 0.8, 1.0, 1.0, 1.0],
[0.8, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
]
)
plugin = NeighbourhoodProcessing("circular", self.RADIUS)
plugin.kernel = self.circular_kernel
plugin.nb_size = max(plugin.kernel.shape)
result = plugin._calculate_neighbourhood(data)
self.assertArrayAlmostEqual(result.data, expected_array)

def test_basic_weighted_circular(self):
"""Test the _calculate_neighbourhood method with a
weighted circular neighbourhood."""
Expand Down

0 comments on commit 30e5a76

Please sign in to comment.