diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index 8a1fdc85..710313d1 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -1,8 +1,9 @@ -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np import xarray as xr +import xcdat as xc from xcdat.axis import get_dim_keys from xcdat.regridder.base import BaseRegridder, _preserve_bounds @@ -105,8 +106,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: ds, data_var, output_data, - dst_lat_bnds, - dst_lon_bnds, self._input_grid, self._output_grid, ) @@ -228,38 +227,90 @@ def _build_dataset( ds: xr.Dataset, data_var: str, output_data: np.ndarray, - dst_lat_bnds, - dst_lon_bnds, input_grid: xr.Dataset, output_grid: xr.Dataset, ) -> xr.Dataset: - input_data_var = ds[data_var] + """Build a new xarray Dataset with the given output data and coordinates. + + Parameters + ---------- + ds : xr.Dataset + The input dataset containing the data variable to be regridded. + data_var : str + The name of the data variable in the input dataset to be regridded. + output_data : np.ndarray + The regridded data to be included in the output dataset. + input_grid : xr.Dataset + The input grid dataset containing the original grid information. + output_grid : xr.Dataset + The output grid dataset containing the new grid information. - output_coords: dict[str, xr.DataArray] = {} - output_data_vars: dict[str, xr.DataArray] = {} + Returns + ------- + xr.Dataset + A new dataset containing the regridded data variable with updated + coordinates and attributes. + """ + dv_input = ds[data_var] - dims = list(input_data_var.dims) + output_coords = _get_output_coords(dv_input, output_grid) output_da = xr.DataArray( output_data, - dims=dims, + dims=dv_input.dims, coords=output_coords, attrs=ds[data_var].attrs.copy(), name=data_var, ) - output_data_vars[data_var] = output_da - - output_ds = xr.Dataset( - output_data_vars, - attrs=input_grid.attrs.copy(), - ) - + output_ds = output_da.to_dataset() + output_ds.attrs = input_grid.attrs.copy() output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"]) return output_ds +def _get_output_coords( + dv_input: xr.DataArray, output_grid: xr.Dataset +) -> Dict[str, xr.DataArray]: + """ + Generate the output coordinates for regridding based on the input data + variable and output grid. + + Parameters + ---------- + dv_input : xr.DataArray + The input data variable containing the original coordinates. + output_grid : xr.Dataset + The dataset containing the target grid coordinates. + + Returns + ------- + Dict[str, xr.DataArray] + A dictionary where keys are coordinate names and values are the + corresponding coordinates from the output grid or input data variable, + aligned with the dimensions of the input data variable. + """ + output_coords: Dict[str, xr.DataArray] = {} + + # First get the X and Y axes from the output grid. + for key in ["X", "Y"]: + input_coord = xc.get_dim_coords(dv_input, key) # type: ignore + output_coord = xc.get_dim_coords(output_grid, key) # type: ignore + + output_coords[str(input_coord.name)] = output_coord # type: ignore + + # Get the remaining axes the input data variable (e.g., "time"). + for dim in dv_input.dims: + if dim not in output_coords: + output_coords[str(dim)] = dv_input[dim] + + # Sort the coords to align with the input data variable dims. + output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims} + + return output_coords + + def _map_latitude( src: np.ndarray, dst: np.ndarray ) -> Tuple[List[np.ndarray], List[np.ndarray]]: @@ -553,12 +604,15 @@ def _get_dimension(input_data_var, cf_axis_name): def _get_bounds_ensure_dtype(ds, axis): + bounds = None + try: - name = ds.cf.bounds[axis][0] - except (KeyError, IndexError) as e: - raise RuntimeError(f"Could not determine {axis!r} bounds") from e - else: - bounds = ds[name] + bounds = ds.bounds.get_bounds(axis) + except KeyError: + pass + + if bounds is None: + raise RuntimeError(f"Could not determine {axis!r} bounds") if bounds.dtype != np.float32: bounds = bounds.astype(np.float32)