Skip to content

Commit

Permalink
[PR]: Update Regrid2 missing and fill value behaviors to align with C…
Browse files Browse the repository at this point in the history
…DAT and add `unmapped_to_nan` arg for output data (#613)

Co-authored-by: tomvothecoder <tomvothecoder@gmail.com>
Co-authored-by: Jiwoo Lee <lee1043@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 10, 2024
1 parent 472c04b commit 1f4d22a
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 40 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ jobs:
pytest
- name: Upload Coverage Report
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
file: "tests_coverage_reports/coverage.xml"
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}

# `build-result` is a workaround to skipped matrix jobs in `build` not being considered "successful",
# which can block PR merges if matrix jobs are required status checks.
Expand Down
36 changes: 31 additions & 5 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,36 @@ def test_regrid_input_mask(self):

output_data = regridder.horizontal("ts", self.coarse_2d_ds)

# np.nan != np.nan, replace with 1e20
output_data = output_data.fillna(1e20)

expected_output = np.array(
[
[1e20] * 4,
[1.0] * 4,
[1.0] * 4,
[1e20] * 4,
],
dtype=np.float32,
)

assert np.all(output_data.ts.values == expected_output)

@pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning")
def test_regrid_input_mask_unmapped_to_nan(self):
regridder = regrid2.Regrid2Regridder(
self.coarse_2d_ds, self.fine_2d_ds, unmapped_to_nan=False
)

self.coarse_2d_ds["mask"] = (("lat", "lon"), [[0, 0], [1, 1], [0, 0]])

output_data = regridder.horizontal("ts", self.coarse_2d_ds)

expected_output = np.array(
[
[0.0] * 4,
[0.70710677] * 4,
[0.70710677] * 4,
[1.0] * 4,
[1.0] * 4,
[0.0] * 4,
],
dtype=np.float32,
Expand Down Expand Up @@ -690,7 +715,7 @@ def test_regrid(self):
assert "time_bnds" in output

@pytest.mark.parametrize(
"name,value,attr_name",
"name,value,_",
[
("periodic", True, "_periodic"),
("extrap_method", "inverse_dist", "_extrap_method"),
Expand All @@ -700,14 +725,15 @@ def test_regrid(self):
("ignore_degenerate", False, "_ignore_degenerate"),
],
)
def test_flags(self, name, value, attr_name):
def test_flags(self, name, value, _):
ds = self.ds.copy()

options = {name: value}

regridder = xesmf.XESMFRegridder(ds, self.new_grid, "bilinear", **options)

assert getattr(regridder, attr_name) == value
assert name in regridder._extra_options
assert regridder._extra_options[name] == value

def test_no_variable(self):
ds = self.ds.copy()
Expand Down
108 changes: 84 additions & 24 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple

import numpy as np
import xarray as xr
Expand All @@ -8,7 +8,13 @@


class Regrid2Regridder(BaseRegridder):
def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: Any):
def __init__(
self,
input_grid: xr.Dataset,
output_grid: xr.Dataset,
unmapped_to_nan=True,
**options: Any,
):
"""
Pure python implementation of the regrid2 horizontal regridder from
CDMS2's regrid2 module.
Expand Down Expand Up @@ -47,6 +53,8 @@ def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: A
"""
super().__init__(input_grid, output_grid, **options)

self._unmapped_to_nan = unmapped_to_nan

def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
"""Placeholder for base class."""
raise NotImplementedError()
Expand All @@ -66,20 +74,31 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
dst_lat_bnds = _get_bounds_ensure_dtype(self._output_grid, "Y")
dst_lon_bnds = _get_bounds_ensure_dtype(self._output_grid, "X")

src_mask = self._input_grid.get("mask", None)
src_mask_da = self._input_grid.get("mask", None)

# DataArray to np.ndarray, handle error when None
try:
src_mask = src_mask_da.values # type: ignore
except AttributeError:
src_mask = None

# apply source mask to input data
if src_mask is not None:
masked_value = self._input_grid.attrs.get("_FillValue", None)
nan_replace = input_data_var.encoding.get("_FillValue", None)

if masked_value is None:
masked_value = self._input_grid.attrs.get("missing_value", 0.0)
if nan_replace is None:
nan_replace = input_data_var.encoding.get("missing_value", 1e20)

# Xarray defaults to masking with np.nan, CDAT masked with _FillValue or missing_value which defaults to 1e20
input_data_var = input_data_var.where(src_mask != 0.0, masked_value)
# exclude alternative of NaN values if there are any
input_data_var = input_data_var.where(input_data_var != nan_replace)

# horizontal regrid
output_data = _regrid(
input_data_var, src_lat_bnds, src_lon_bnds, dst_lat_bnds, dst_lon_bnds
input_data_var,
src_lat_bnds,
src_lon_bnds,
dst_lat_bnds,
dst_lon_bnds,
src_mask,
unmapped_to_nan=self._unmapped_to_nan,
)

output_ds = _build_dataset(
Expand All @@ -101,7 +120,13 @@ def _regrid(
src_lon_bnds: np.ndarray,
dst_lat_bnds: np.ndarray,
dst_lon_bnds: np.ndarray,
src_mask: Optional[np.ndarray],
omitted=None,
unmapped_to_nan=True,
) -> np.ndarray:
if omitted is None:
omitted = np.nan

lat_mapping, lat_weights = _map_latitude(src_lat_bnds, dst_lat_bnds)
lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds)

Expand All @@ -114,6 +139,11 @@ def _regrid(
y_length = len(lat_mapping)
x_length = len(lon_mapping)

if src_mask is None:
input_data_shape = input_data.shape

src_mask = np.ones((input_data_shape[y_index], input_data_shape[x_index]))

other_dims = {
x: y for x, y in input_data_var.sizes.items() if x not in (y_name, x_name)
}
Expand All @@ -122,45 +152,62 @@ def _regrid(
data_shape = [y_length * x_length] + other_sizes
# output data is always float32 in original code
output_data = np.zeros(data_shape, dtype=np.float32)
output_mask = np.ones(data_shape, dtype=np.float32)

is_2d = input_data_var.ndim <= 2

# TODO: need to optimize further, investigate using ufuncs and dask arrays
# TODO: how common is lon by lat data? may need to reshape
for y in range(y_length):
y_seg = np.take(input_data, lat_mapping[y], axis=y_index)
y_mask_seg = np.take(src_mask, lat_mapping[y], axis=0)

for x in range(x_length):
x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap")
x_mask_seg = np.take(y_mask_seg, lon_mapping[x], axis=1, mode="wrap")

cell_weight = np.dot(lat_weights[y], lon_weights[x])
cell_weights = np.multiply(
np.dot(lat_weights[y], lon_weights[x]), x_mask_seg
)

cell_weight = np.sum(cell_weights)

output_seg_index = y * x_length + x

if cell_weight == 0.0:
output_mask[output_seg_index] = 0.0

# using the `out` argument is more performant, places data directly into
# array memory rather than allocating a new variable. wasn't working for
# single element output, needs further investigation as we may not need
# branch
if is_2d:
output_data[output_seg_index] = np.divide(
np.sum(
np.multiply(x_seg, cell_weight),
np.multiply(x_seg, cell_weights),
axis=(y_index, x_index),
),
np.sum(cell_weight),
cell_weight,
)
else:
output_seg = output_data[output_seg_index]

np.divide(
np.sum(
np.multiply(x_seg, cell_weight),
np.multiply(x_seg, cell_weights),
axis=(y_index, x_index),
),
np.sum(cell_weight),
cell_weight,
out=output_seg,
)

if cell_weight <= 0.0:
output_data[output_seg_index] = omitted

# default for unmapped is nan due to division by zero, use output mask to repalce
if not unmapped_to_nan:
output_data[output_mask == 0.0] = 0.0

output_data_shape = [y_length, x_length] + other_sizes

output_data = output_data.reshape(output_data_shape)
Expand Down Expand Up @@ -208,7 +255,9 @@ def _build_dataset(
return output_ds


def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
def _map_latitude(
src: np.ndarray, dst: np.ndarray
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Map source to destination latitude.
Expand All @@ -230,7 +279,7 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
Returns
-------
Tuple[List, List]
Tuple[List[np.ndarray], List[np.ndarray]]
A tuple of cell mappings and cell weights.
"""
src_south, src_north = _extract_bounds(src)
Expand All @@ -255,14 +304,25 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
]

# convert latitude to cell weight (difference of height above/below equator)
weights = [
(np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1))
for x, y in bounds
]
weights = _get_latitude_weights(bounds)

return mapping, weights


def _get_latitude_weights(
bounds: List[Tuple[np.ndarray, np.ndarray]]
) -> List[np.ndarray]:
weights = []

for x, y in bounds:
cell_weight = np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))
cell_weight = cell_weight.reshape((-1, 1))

weights.append(cell_weight)

return weights


def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
"""
Map source to destination longitude.
Expand Down Expand Up @@ -347,12 +407,12 @@ def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Parameters
----------
bounds : np.ndarray
Dataset containing axis with bounds.
A numpy array of bounds values.
Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing the lower and upper bounds for the axis.
A tuple containing the lower and upper bounds for the axis.
"""
if bounds[0, 0] < bounds[0, 1]:
lower = bounds[:, 0]
Expand Down
24 changes: 14 additions & 10 deletions xcdat/regridder/xesmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
extrap_dist_exponent: Optional[float] = None,
extrap_num_src_pnts: Optional[int] = None,
ignore_degenerate: bool = True,
unmapped_to_nan: bool = True,
**options: Any,
):
"""Extension of ``xESMF`` regridder.
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(
This only applies to "conservative" and "conservative_normed"
regridding methods.
unmapped_to_nan : bool
Sets values of unmapped points to `np.nan` instead of 0 (ESMF default).
**options : Any
Additional arguments passed to the underlying ``xesmf.XESMFRegridder``
constructor.
Expand Down Expand Up @@ -126,11 +129,17 @@ def __init__(
)

self._method = method
self._periodic = periodic
self._extrap_method = extrap_method
self._extrap_dist_exponent = extrap_dist_exponent
self._extrap_num_src_pnts = extrap_num_src_pnts
self._ignore_degenerate = ignore_degenerate

# Re-pack xesmf arguments, broken out for validation/documentation
options.update(
periodic=periodic,
extrap_method=extrap_method,
extrap_dist_exponent=extrap_dist_exponent,
extrap_num_src_pnts=extrap_num_src_pnts,
ignore_degenerate=ignore_degenerate,
unmapped_to_nan=unmapped_to_nan,
)

self._extra_options = options

def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
Expand All @@ -150,11 +159,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
self._input_grid,
self._output_grid,
method=self._method,
periodic=self._periodic,
extrap_method=self._extrap_method,
extrap_dist_exponent=self._extrap_dist_exponent,
extrap_num_src_pnts=self._extrap_num_src_pnts,
ignore_degenerate=self._ignore_degenerate,
**self._extra_options,
)

Expand Down

0 comments on commit 1f4d22a

Please sign in to comment.