From e067a6c69a98d7d81506d88822ddb3450f2ce12c Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 20 Nov 2024 09:22:18 -0800 Subject: [PATCH 1/4] Fixes preserving coordinates in regrid2 output --- xcdat/regridder/regrid2.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index 8a1fdc85..9d1f456e 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -238,11 +238,24 @@ def _build_dataset( output_coords: dict[str, xr.DataArray] = {} output_data_vars: dict[str, xr.DataArray] = {} - dims = list(input_data_var.dims) + for dim in input_data_var.dims: + dim = str(dim) + + try: + axis_name = [x for x, y in ds.cf.axes.items() if dim in y][0] + except Exception: + raise ValueError( + f"Could not determine axis name for dimension {dim}" + ) from None + + if axis_name in ["X", "Y"]: + output_coords[dim] = output_grid.cf[axis_name] + else: + output_coords[dim] = input_data_var.cf[axis_name] output_da = xr.DataArray( output_data, - dims=dims, + dims=input_data_var.dims, coords=output_coords, attrs=ds[data_var].attrs.copy(), name=data_var, From e2d259eb419fc5793333cd66437919b8dc6c6ba5 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Wed, 4 Dec 2024 15:22:31 -0800 Subject: [PATCH 2/4] Add test and code cleanup - Fix `_get_bounds_ensure_dtype` to determine `bounds` with axis that has `standard_name` attr (in addition to `axis` attr check) - Remove unused `dst_lat_bnds` and `dst_lon_bnds` args for `_build_dataset()` - Add unit test to cover `ValueError` in `regrid2.py` `_build_dataset()` --- tests/test_regrid.py | 12 ++++++++++++ xcdat/regridder/regrid2.py | 34 ++++++++++++++++++++-------------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 861b0124..9cfd1b6f 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -517,6 +517,18 @@ def test_unknown_variable(self): with pytest.raises(KeyError): regridder.horizontal("unknown", self.coarse_2d_ds) + def test_raises_error_if_axis_name_for_dim_cannot_be_determined(self): + ds = self.coarse_2d_ds.copy() + ds["lat"].attrs["standard_name"] = "latitude" + ds["lat"].attrs.pop("axis") + + regridder = regrid2.Regrid2Regridder(ds, self.fine_2d_ds) + + with pytest.raises( + ValueError, match="Could not determine axis name for dimension" + ): + regridder.horizontal("ts", ds) + @pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning") def test_regrid_input_mask(self): regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds) diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index 9d1f456e..1c49e934 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -3,7 +3,7 @@ import numpy as np import xarray as xr -from xcdat.axis import get_dim_keys +from xcdat.axis import CF_ATTR_MAP, get_dim_keys from xcdat.regridder.base import BaseRegridder, _preserve_bounds @@ -105,8 +105,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: ds, data_var, output_data, - dst_lat_bnds, - dst_lon_bnds, self._input_grid, self._output_grid, ) @@ -228,8 +226,6 @@ def _build_dataset( ds: xr.Dataset, data_var: str, output_data: np.ndarray, - dst_lat_bnds, - dst_lon_bnds, input_grid: xr.Dataset, output_grid: xr.Dataset, ) -> xr.Dataset: @@ -242,11 +238,13 @@ def _build_dataset( dim = str(dim) try: - axis_name = [x for x, y in ds.cf.axes.items() if dim in y][0] - except Exception: + axis_name = [ + cf_axis for cf_axis, dims in ds.cf.axes.items() if dim in dims + ][0] + except IndexError as e: raise ValueError( f"Could not determine axis name for dimension {dim}" - ) from None + ) from e if axis_name in ["X", "Y"]: output_coords[dim] = output_grid.cf[axis_name] @@ -566,12 +564,20 @@ def _get_dimension(input_data_var, cf_axis_name): def _get_bounds_ensure_dtype(ds, axis): - try: - name = ds.cf.bounds[axis][0] - except (KeyError, IndexError) as e: - raise RuntimeError(f"Could not determine {axis!r} bounds") from e - else: - bounds = ds[name] + cf_keys = CF_ATTR_MAP[axis].values() + + bounds = None + + for key in cf_keys: + try: + name = ds.cf.bounds[key][0] + except (KeyError, IndexError): + pass + else: + bounds = ds[name] + + if bounds is None: + raise RuntimeError(f"Could not determine {axis!r} bounds") if bounds.dtype != np.float32: bounds = bounds.astype(np.float32) From a8732a82dcfdf6c12624efbcdaadecf312282c99 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Fri, 6 Dec 2024 11:35:12 -0800 Subject: [PATCH 3/4] Refactor logic for preserving coordinates with regrid2 --- tests/test_regrid.py | 12 ---- xcdat/regridder/regrid2.py | 113 +++++++++++++++++++++++++------------ 2 files changed, 76 insertions(+), 49 deletions(-) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 9cfd1b6f..861b0124 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -517,18 +517,6 @@ def test_unknown_variable(self): with pytest.raises(KeyError): regridder.horizontal("unknown", self.coarse_2d_ds) - def test_raises_error_if_axis_name_for_dim_cannot_be_determined(self): - ds = self.coarse_2d_ds.copy() - ds["lat"].attrs["standard_name"] = "latitude" - ds["lat"].attrs.pop("axis") - - regridder = regrid2.Regrid2Regridder(ds, self.fine_2d_ds) - - with pytest.raises( - ValueError, match="Could not determine axis name for dimension" - ): - regridder.horizontal("ts", ds) - @pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning") def test_regrid_input_mask(self): regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds) diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index 1c49e934..dab1f260 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -1,11 +1,16 @@ -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np import xarray as xr -from xcdat.axis import CF_ATTR_MAP, get_dim_keys +import xcdat as xc +from xcdat.axis import VAR_NAME_MAP, get_dim_keys from xcdat.regridder.base import BaseRegridder, _preserve_bounds +# Spatial axes keys used to map to the axes in an input data variable to build +# the output variable. +VALID_SPATIAL_AXES_KEYS = ["X", "Y"] + VAR_NAME_MAP["X"] + VAR_NAME_MAP["Y"] + class Regrid2Regridder(BaseRegridder): def __init__( @@ -229,48 +234,87 @@ def _build_dataset( input_grid: xr.Dataset, output_grid: xr.Dataset, ) -> xr.Dataset: - input_data_var = ds[data_var] + """Build a new xarray Dataset with the given output data and coordinates. - output_coords: dict[str, xr.DataArray] = {} - output_data_vars: dict[str, xr.DataArray] = {} + Parameters + ---------- + ds : xr.Dataset + The input dataset containing the data variable to be regridded. + data_var : str + The name of the data variable in the input dataset to be regridded. + output_data : np.ndarray + The regridded data to be included in the output dataset. + input_grid : xr.Dataset + The input grid dataset containing the original grid information. + output_grid : xr.Dataset + The output grid dataset containing the new grid information. - for dim in input_data_var.dims: - dim = str(dim) + Returns + ------- + xr.Dataset + A new dataset containing the regridded data variable with updated + coordinates and attributes. + """ + dv_input = ds[data_var] - try: - axis_name = [ - cf_axis for cf_axis, dims in ds.cf.axes.items() if dim in dims - ][0] - except IndexError as e: - raise ValueError( - f"Could not determine axis name for dimension {dim}" - ) from e - - if axis_name in ["X", "Y"]: - output_coords[dim] = output_grid.cf[axis_name] - else: - output_coords[dim] = input_data_var.cf[axis_name] + output_coords = _get_output_coords(dv_input, output_grid) output_da = xr.DataArray( output_data, - dims=input_data_var.dims, + dims=dv_input.dims, coords=output_coords, attrs=ds[data_var].attrs.copy(), name=data_var, ) - output_data_vars[data_var] = output_da - - output_ds = xr.Dataset( - output_data_vars, - attrs=input_grid.attrs.copy(), - ) - + output_ds = output_da.to_dataset() + output_ds.attrs = input_grid.attrs.copy() output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"]) return output_ds +def _get_output_coords( + dv_input: xr.DataArray, output_grid: xr.Dataset +) -> Dict[str, xr.DataArray]: + """ + Generate the output coordinates for regridding based on the input data + variable and output grid. + + Parameters + ---------- + dv_input : xr.DataArray + The input data variable containing the original coordinates. + output_grid : xr.Dataset + The dataset containing the target grid coordinates. + + Returns + ------- + Dict[str, xr.DataArray] + A dictionary where keys are coordinate names and values are the + corresponding coordinates from the output grid or input data variable, + aligned with the dimensions of the input data variable. + """ + output_coords: Dict[str, xr.DataArray] = {} + + # First get the X and Y axes from the output grid. + for key in ["X", "Y"]: + input_coord = xc.get_dim_coords(dv_input, key) # type: ignore + output_coord = xc.get_dim_coords(output_grid, key) # type: ignore + + output_coords[str(input_coord.name)] = output_coord # type: ignore + + # Get the remaining axes the input data variable (e.g., "time"). + for dim in dv_input.dims: + if dim not in output_coords: + output_coords[str(dim)] = dv_input[dim] + + # Sort the coords to align with the input data variable dims. + output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims} + + return output_coords + + def _map_latitude( src: np.ndarray, dst: np.ndarray ) -> Tuple[List[np.ndarray], List[np.ndarray]]: @@ -564,17 +608,12 @@ def _get_dimension(input_data_var, cf_axis_name): def _get_bounds_ensure_dtype(ds, axis): - cf_keys = CF_ATTR_MAP[axis].values() - bounds = None - for key in cf_keys: - try: - name = ds.cf.bounds[key][0] - except (KeyError, IndexError): - pass - else: - bounds = ds[name] + try: + bounds = ds.bounds.get_bounds(axis) + except KeyError: + pass if bounds is None: raise RuntimeError(f"Could not determine {axis!r} bounds") From e149b9d6428faf1471dfb687949dc29748aa78ae Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Fri, 6 Dec 2024 11:41:38 -0800 Subject: [PATCH 4/4] Remove unused constant var --- xcdat/regridder/regrid2.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index dab1f260..710313d1 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -4,13 +4,9 @@ import xarray as xr import xcdat as xc -from xcdat.axis import VAR_NAME_MAP, get_dim_keys +from xcdat.axis import get_dim_keys from xcdat.regridder.base import BaseRegridder, _preserve_bounds -# Spatial axes keys used to map to the axes in an input data variable to build -# the output variable. -VALID_SPATIAL_AXES_KEYS = ["X", "Y"] + VAR_NAME_MAP["X"] + VAR_NAME_MAP["Y"] - class Regrid2Regridder(BaseRegridder): def __init__(