diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index 108e40bd..3dcbf886 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -107,10 +107,11 @@ jobs: pytest - name: Upload Coverage Report - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: file: "tests_coverage_reports/coverage.xml" fail_ci_if_error: true + token: ${{ secrets.CODECOV_TOKEN }} # `build-result` is a workaround to skipped matrix jobs in `build` not being considered "successful", # which can block PR merges if matrix jobs are required status checks. diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 193a22a6..c7253f4f 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -496,11 +496,36 @@ def test_regrid_input_mask(self): output_data = regridder.horizontal("ts", self.coarse_2d_ds) + # np.nan != np.nan, replace with 1e20 + output_data = output_data.fillna(1e20) + + expected_output = np.array( + [ + [1e20] * 4, + [1.0] * 4, + [1.0] * 4, + [1e20] * 4, + ], + dtype=np.float32, + ) + + assert np.all(output_data.ts.values == expected_output) + + @pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning") + def test_regrid_input_mask_unmapped_to_nan(self): + regridder = regrid2.Regrid2Regridder( + self.coarse_2d_ds, self.fine_2d_ds, unmapped_to_nan=False + ) + + self.coarse_2d_ds["mask"] = (("lat", "lon"), [[0, 0], [1, 1], [0, 0]]) + + output_data = regridder.horizontal("ts", self.coarse_2d_ds) + expected_output = np.array( [ [0.0] * 4, - [0.70710677] * 4, - [0.70710677] * 4, + [1.0] * 4, + [1.0] * 4, [0.0] * 4, ], dtype=np.float32, @@ -690,7 +715,7 @@ def test_regrid(self): assert "time_bnds" in output @pytest.mark.parametrize( - "name,value,attr_name", + "name,value,_", [ ("periodic", True, "_periodic"), ("extrap_method", "inverse_dist", "_extrap_method"), @@ -700,14 +725,15 @@ def test_regrid(self): ("ignore_degenerate", False, "_ignore_degenerate"), ], ) - def test_flags(self, name, value, attr_name): + def test_flags(self, name, value, _): ds = self.ds.copy() options = {name: value} regridder = xesmf.XESMFRegridder(ds, self.new_grid, "bilinear", **options) - assert getattr(regridder, attr_name) == value + assert name in regridder._extra_options + assert regridder._extra_options[name] == value def test_no_variable(self): ds = self.ds.copy() diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index 3c139c03..823215d7 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple import numpy as np import xarray as xr @@ -8,7 +8,13 @@ class Regrid2Regridder(BaseRegridder): - def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: Any): + def __init__( + self, + input_grid: xr.Dataset, + output_grid: xr.Dataset, + unmapped_to_nan=True, + **options: Any, + ): """ Pure python implementation of the regrid2 horizontal regridder from CDMS2's regrid2 module. @@ -47,6 +53,8 @@ def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: A """ super().__init__(input_grid, output_grid, **options) + self._unmapped_to_nan = unmapped_to_nan + def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """Placeholder for base class.""" raise NotImplementedError() @@ -66,20 +74,31 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: dst_lat_bnds = _get_bounds_ensure_dtype(self._output_grid, "Y") dst_lon_bnds = _get_bounds_ensure_dtype(self._output_grid, "X") - src_mask = self._input_grid.get("mask", None) + src_mask_da = self._input_grid.get("mask", None) + + # DataArray to np.ndarray, handle error when None + try: + src_mask = src_mask_da.values # type: ignore + except AttributeError: + src_mask = None - # apply source mask to input data - if src_mask is not None: - masked_value = self._input_grid.attrs.get("_FillValue", None) + nan_replace = input_data_var.encoding.get("_FillValue", None) - if masked_value is None: - masked_value = self._input_grid.attrs.get("missing_value", 0.0) + if nan_replace is None: + nan_replace = input_data_var.encoding.get("missing_value", 1e20) - # Xarray defaults to masking with np.nan, CDAT masked with _FillValue or missing_value which defaults to 1e20 - input_data_var = input_data_var.where(src_mask != 0.0, masked_value) + # exclude alternative of NaN values if there are any + input_data_var = input_data_var.where(input_data_var != nan_replace) + # horizontal regrid output_data = _regrid( - input_data_var, src_lat_bnds, src_lon_bnds, dst_lat_bnds, dst_lon_bnds + input_data_var, + src_lat_bnds, + src_lon_bnds, + dst_lat_bnds, + dst_lon_bnds, + src_mask, + unmapped_to_nan=self._unmapped_to_nan, ) output_ds = _build_dataset( @@ -101,7 +120,13 @@ def _regrid( src_lon_bnds: np.ndarray, dst_lat_bnds: np.ndarray, dst_lon_bnds: np.ndarray, + src_mask: Optional[np.ndarray], + omitted=None, + unmapped_to_nan=True, ) -> np.ndarray: + if omitted is None: + omitted = np.nan + lat_mapping, lat_weights = _map_latitude(src_lat_bnds, dst_lat_bnds) lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds) @@ -114,6 +139,11 @@ def _regrid( y_length = len(lat_mapping) x_length = len(lon_mapping) + if src_mask is None: + input_data_shape = input_data.shape + + src_mask = np.ones((input_data_shape[y_index], input_data_shape[x_index])) + other_dims = { x: y for x, y in input_data_var.sizes.items() if x not in (y_name, x_name) } @@ -122,6 +152,7 @@ def _regrid( data_shape = [y_length * x_length] + other_sizes # output data is always float32 in original code output_data = np.zeros(data_shape, dtype=np.float32) + output_mask = np.ones(data_shape, dtype=np.float32) is_2d = input_data_var.ndim <= 2 @@ -129,14 +160,23 @@ def _regrid( # TODO: how common is lon by lat data? may need to reshape for y in range(y_length): y_seg = np.take(input_data, lat_mapping[y], axis=y_index) + y_mask_seg = np.take(src_mask, lat_mapping[y], axis=0) for x in range(x_length): x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap") + x_mask_seg = np.take(y_mask_seg, lon_mapping[x], axis=1, mode="wrap") - cell_weight = np.dot(lat_weights[y], lon_weights[x]) + cell_weights = np.multiply( + np.dot(lat_weights[y], lon_weights[x]), x_mask_seg + ) + + cell_weight = np.sum(cell_weights) output_seg_index = y * x_length + x + if cell_weight == 0.0: + output_mask[output_seg_index] = 0.0 + # using the `out` argument is more performant, places data directly into # array memory rather than allocating a new variable. wasn't working for # single element output, needs further investigation as we may not need @@ -144,23 +184,30 @@ def _regrid( if is_2d: output_data[output_seg_index] = np.divide( np.sum( - np.multiply(x_seg, cell_weight), + np.multiply(x_seg, cell_weights), axis=(y_index, x_index), ), - np.sum(cell_weight), + cell_weight, ) else: output_seg = output_data[output_seg_index] np.divide( np.sum( - np.multiply(x_seg, cell_weight), + np.multiply(x_seg, cell_weights), axis=(y_index, x_index), ), - np.sum(cell_weight), + cell_weight, out=output_seg, ) + if cell_weight <= 0.0: + output_data[output_seg_index] = omitted + + # default for unmapped is nan due to division by zero, use output mask to repalce + if not unmapped_to_nan: + output_data[output_mask == 0.0] = 0.0 + output_data_shape = [y_length, x_length] + other_sizes output_data = output_data.reshape(output_data_shape) @@ -208,7 +255,9 @@ def _build_dataset( return output_ds -def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: +def _map_latitude( + src: np.ndarray, dst: np.ndarray +) -> Tuple[List[np.ndarray], List[np.ndarray]]: """ Map source to destination latitude. @@ -230,7 +279,7 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: Returns ------- - Tuple[List, List] + Tuple[List[np.ndarray], List[np.ndarray]] A tuple of cell mappings and cell weights. """ src_south, src_north = _extract_bounds(src) @@ -255,14 +304,25 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: ] # convert latitude to cell weight (difference of height above/below equator) - weights = [ - (np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1)) - for x, y in bounds - ] + weights = _get_latitude_weights(bounds) return mapping, weights +def _get_latitude_weights( + bounds: List[Tuple[np.ndarray, np.ndarray]] +) -> List[np.ndarray]: + weights = [] + + for x, y in bounds: + cell_weight = np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y)) + cell_weight = cell_weight.reshape((-1, 1)) + + weights.append(cell_weight) + + return weights + + def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: """ Map source to destination longitude. @@ -347,12 +407,12 @@ def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: Parameters ---------- bounds : np.ndarray - Dataset containing axis with bounds. + A numpy array of bounds values. Returns ------- Tuple[np.ndarray, np.ndarray] - A tuple containing the lower and upper bounds for the axis. + A tuple containing the lower and upper bounds for the axis. """ if bounds[0, 0] < bounds[0, 1]: lower = bounds[:, 0] diff --git a/xcdat/regridder/xesmf.py b/xcdat/regridder/xesmf.py index 90239469..7ad9dba0 100644 --- a/xcdat/regridder/xesmf.py +++ b/xcdat/regridder/xesmf.py @@ -28,6 +28,7 @@ def __init__( extrap_dist_exponent: Optional[float] = None, extrap_num_src_pnts: Optional[int] = None, ignore_degenerate: bool = True, + unmapped_to_nan: bool = True, **options: Any, ): """Extension of ``xESMF`` regridder. @@ -74,6 +75,8 @@ def __init__( This only applies to "conservative" and "conservative_normed" regridding methods. + unmapped_to_nan : bool + Sets values of unmapped points to `np.nan` instead of 0 (ESMF default). **options : Any Additional arguments passed to the underlying ``xesmf.XESMFRegridder`` constructor. @@ -126,11 +129,17 @@ def __init__( ) self._method = method - self._periodic = periodic - self._extrap_method = extrap_method - self._extrap_dist_exponent = extrap_dist_exponent - self._extrap_num_src_pnts = extrap_num_src_pnts - self._ignore_degenerate = ignore_degenerate + + # Re-pack xesmf arguments, broken out for validation/documentation + options.update( + periodic=periodic, + extrap_method=extrap_method, + extrap_dist_exponent=extrap_dist_exponent, + extrap_num_src_pnts=extrap_num_src_pnts, + ignore_degenerate=ignore_degenerate, + unmapped_to_nan=unmapped_to_nan, + ) + self._extra_options = options def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: @@ -150,11 +159,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: self._input_grid, self._output_grid, method=self._method, - periodic=self._periodic, - extrap_method=self._extrap_method, - extrap_dist_exponent=self._extrap_dist_exponent, - extrap_num_src_pnts=self._extrap_num_src_pnts, - ignore_degenerate=self._ignore_degenerate, **self._extra_options, )