Skip to content

Commit

Permalink
Simplified type casting and removes unreachable lines.
Browse files Browse the repository at this point in the history
  • Loading branch information
MoseleyS committed Sep 7, 2023
1 parent 046616a commit 68f4a3d
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions improver/nbhood/nbhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,9 @@ def _calculate_neighbourhood(
if self.sum_only:
area_sum = None
else:
area_sum = self._do_nbhood_sum(valid_data_mask, np.int64)
area_sum = self._do_nbhood_sum(valid_data_mask)
# Where data are all ones in nbhood, result will be same as area_sum
data = self._do_nbhood_sum(data, out_data_dtype, max_extreme=area_sum)
data = self._do_nbhood_sum(data, max_extreme=area_sum)

if not self.sum_only:
with np.errstate(divide="ignore", invalid="ignore"):
Expand All @@ -325,18 +325,15 @@ def _calculate_neighbourhood(
# For points where all data in the neighbourhood is masked,
# set result to nan
data[area_sum == 0] = np.nan
data = data.clip(min_val, max_val).astype(out_data_dtype)
data = data.clip(min_val, max_val)

if self.re_mask:
data = np.ma.masked_array(data, data_mask, copy=False)

return data
return data.astype(out_data_dtype)

def _do_nbhood_sum(
self,
data: np.ndarray,
out_data_dtype: Type,
max_extreme: Optional[Union[int, np.ndarray]] = None,
self, data: np.ndarray, max_extreme: Optional[Union[int, np.ndarray]] = None,
) -> np.ndarray:
"""Calculate the sum-in-area from an array.
As this can be expensive, the method first checks for the extreme cases where the data are:
Expand Down Expand Up @@ -386,12 +383,8 @@ def _do_nbhood_sum(
# Determine default array for the extremes around the edges, or everywhere
if isinstance(when_all_extremes, np.ndarray):
untrimmed = when_all_extremes
elif when_all_extremes is None:
raise NotImplementedError(
"Don't know what to do when default is None. Shouldn't get here."
)
else:
untrimmed = np.full(data_shape, when_all_extremes, dtype=out_data_dtype)
untrimmed = np.full(data_shape, when_all_extremes)
if size:
# Trim to the calculated box
data = data[ystart:ystop, xstart:xstop]
Expand All @@ -408,7 +401,7 @@ def _do_nbhood_sum(
if data.shape != data_shape:
untrimmed[ystart:ystop, xstart:xstop] = data
data = untrimmed
return data.astype(out_data_dtype)
return data

def process(self, cube: Cube, mask_cube: Optional[Cube] = None) -> Cube:
"""
Expand Down

0 comments on commit 68f4a3d

Please sign in to comment.