From ea08915db364ffec747e0a80d6c9296e56b09469 Mon Sep 17 00:00:00 2001 From: Stephen Po-Chedley Date: Fri, 28 Jun 2024 12:54:47 -0700 Subject: [PATCH 1/8] initial attempt at #531 (for spatial averaging) --- tests/test_spatial.py | 34 +++++++++++++++++++++++++++++ xcdat/spatial.py | 51 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/tests/test_spatial.py b/tests/test_spatial.py index fe0361cd..4f27b226 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -140,6 +140,17 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims( with pytest.raises(ValueError): self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights) + def test_raises_error_if_required_weight_not_between_zero_and_one( + self, + ): + # ensure error if required_weight less than zero + with pytest.raises(ValueError): + self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=-0.01) + + # ensure error if required_weight greater than 1 + with pytest.raises(ValueError): + self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=1.01) + def test_spatial_average_for_lat_region_and_keep_weights(self): ds = self.ds.copy() @@ -254,6 +265,29 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self): xr.testing.assert_allclose(result, expected) + def test_spatial_average_with_required_weight(self): + ds = self.ds.copy() + + # insert a nan + ds["ts"][0, :, 2] = np.nan + + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + lat_bounds=(-5.0, 5), + lon_bounds=(-170, -120.1), + required_weight=1.0, + ) + + expected = self.ds.copy() + expected["ts"] = xr.DataArray( + data=np.array([np.nan, 1.0, 1.0]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self): ds = self.ds.copy() diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 15bec956..1beab80a 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -76,6 +76,7 @@ def average( keep_weights: bool = False, lat_bounds: Optional[RegionAxisBounds] = None, lon_bounds: Optional[RegionAxisBounds] = None, + required_weight: Optional[float] = 0.0, ) -> xr.Dataset: """ Calculates the spatial average for a rectilinear grid over an optionally @@ -125,6 +126,9 @@ def average( ignored if ``weights`` are supplied. The lower bound can be larger than the upper bound (e.g., across the prime meridian, dateline), by default None. + required_weight : optional, float + Fraction of data coverage (i..e, weight) needed to return a + spatial average value. Value must range from 0 to 1. Returns ------- @@ -196,7 +200,7 @@ def average( self._weights = weights self._validate_weights(dv, axis) - ds[dv.name] = self._averager(dv, axis) + ds[dv.name] = self._averager(dv, axis, required_weight=required_weight) if keep_weights: ds[self._weights.name] = self._weights @@ -702,7 +706,10 @@ def _validate_weights( ) def _averager( - self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] + self, + data_var: xr.DataArray, + axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], + required_weight: Optional[float] = 0.0, ): """Perform a weighted average of a data variable. @@ -721,6 +728,9 @@ def _averager( Data variable inside a Dataset. axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. + required_weight : optional, float + Fraction of data coverage (i..e, weight) needed to return a + spatial average value. Value must range from 0 to 1. Returns ------- @@ -734,11 +744,48 @@ def _averager( """ weights = self._weights.fillna(0) + # ensure required weight is between 0 and 1 + if required_weight is None: + required_weight = 0.0 + + if required_weight < 0.0: + raise ValueError( + "required_weight argment is less than zero. " + "required_weight must be between 0 and 1." + ) + + if required_weight > 1.0: + raise ValueError( + "required_weight argment is greater than zero. " + "required_weight must be between 0 and 1." + ) + + # need weights to match data_var dimensionality + if required_weight > 0.0: + weights, data_var = xr.broadcast(weights, data_var) + + # get averaging dimensions dim = [] for key in axis: dim.append(get_dim_keys(data_var, key)) + # compute weighed mean with xr.set_options(keep_attrs=True): weighted_mean = data_var.cf.weighted(weights).mean(dim=dim) + # if weight thresholds applied, calculate fraction of data availability + # replace values that do not meet minimum weight with nan + if required_weight > 0.0: + # sum all weights (assuming no missing values exist) + print(dim) + weight_sum_all = weights.sum(dim=dim) # type: ignore + # zero out cells with missing values in data_var + weights = xr.where(~np.isnan(data_var), weights, 0) + # sum all weights (including zero for missing values) + weight_sum_masked = weights.sum(dim=dim) # type: ignore + # get fraction of weight available + frac = weight_sum_masked / weight_sum_all + # nan out values that don't meet specified weight threshold + weighted_mean = xr.where(frac >= required_weight, weighted_mean, np.nan) + return weighted_mean From 127e7117e2c4a8d354524c2e949305268896af96 Mon Sep 17 00:00:00 2001 From: Stephen Po-Chedley Date: Fri, 28 Jun 2024 13:08:26 -0700 Subject: [PATCH 2/8] cleanup print statement and complete code coverage --- tests/test_spatial.py | 20 ++++++++++++++++++++ xcdat/spatial.py | 1 - 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/test_spatial.py b/tests/test_spatial.py index 4f27b226..ea9f44a6 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -288,6 +288,26 @@ def test_spatial_average_with_required_weight(self): xr.testing.assert_allclose(result, expected) + def test_spatial_average_with_required_weight_as_None(self): + ds = self.ds.copy() + + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + lat_bounds=(-5.0, 5), + lon_bounds=(-170, -120.1), + required_weight=None, + ) + + expected = self.ds.copy() + expected["ts"] = xr.DataArray( + data=np.array([2.25, 1.0, 1.0]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self): ds = self.ds.copy() diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 1beab80a..07b8eab6 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -777,7 +777,6 @@ def _averager( # replace values that do not meet minimum weight with nan if required_weight > 0.0: # sum all weights (assuming no missing values exist) - print(dim) weight_sum_all = weights.sum(dim=dim) # type: ignore # zero out cells with missing values in data_var weights = xr.where(~np.isnan(data_var), weights, 0) From a2fee04dfc6470405a84b9026efdf7b702bd4a9b Mon Sep 17 00:00:00 2001 From: Stephen Po-Chedley Date: Fri, 23 Aug 2024 11:21:24 -0700 Subject: [PATCH 3/8] Apply review suggestion. Co-authored-by: Tom Vo --- xcdat/spatial.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 07b8eab6..d34d34de 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -747,19 +747,17 @@ def _averager( # ensure required weight is between 0 and 1 if required_weight is None: required_weight = 0.0 - - if required_weight < 0.0: + elif required_weight < 0.0: raise ValueError( - "required_weight argment is less than zero. " + "required_weight argument is less than 0. " "required_weight must be between 0 and 1." ) - - if required_weight > 1.0: + elif required_weight > 1.0: raise ValueError( - "required_weight argment is greater than zero. " + "required_weight argument is greater than 1. " "required_weight must be between 0 and 1." ) - + # need weights to match data_var dimensionality if required_weight > 0.0: weights, data_var = xr.broadcast(weights, data_var) From 0602ea6dab38c23a00ea5a68db9701416c1f554e Mon Sep 17 00:00:00 2001 From: Stephen Po-Chedley Date: Fri, 23 Aug 2024 11:34:37 -0700 Subject: [PATCH 4/8] update required_weight argument (to minimum_weight) --- xcdat/spatial.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/xcdat/spatial.py b/xcdat/spatial.py index d34d34de..7dc4075f 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -76,7 +76,7 @@ def average( keep_weights: bool = False, lat_bounds: Optional[RegionAxisBounds] = None, lon_bounds: Optional[RegionAxisBounds] = None, - required_weight: Optional[float] = 0.0, + minimum_weight: Optional[float] = None, ) -> xr.Dataset: """ Calculates the spatial average for a rectilinear grid over an optionally @@ -126,9 +126,10 @@ def average( ignored if ``weights`` are supplied. The lower bound can be larger than the upper bound (e.g., across the prime meridian, dateline), by default None. - required_weight : optional, float + minimum_weight : optional, float Fraction of data coverage (i..e, weight) needed to return a - spatial average value. Value must range from 0 to 1. + spatial average value. Value must range from 0 to 1, by default None + (equivalent to minimum_weight=0.0). Returns ------- @@ -200,7 +201,7 @@ def average( self._weights = weights self._validate_weights(dv, axis) - ds[dv.name] = self._averager(dv, axis, required_weight=required_weight) + ds[dv.name] = self._averager(dv, axis, minimum_weight=minimum_weight) if keep_weights: ds[self._weights.name] = self._weights @@ -709,7 +710,7 @@ def _averager( self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], - required_weight: Optional[float] = 0.0, + minimum_weight: Optional[float] = None, ): """Perform a weighted average of a data variable. @@ -728,9 +729,10 @@ def _averager( Data variable inside a Dataset. axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. - required_weight : optional, float + minimum_weight : optional, float Fraction of data coverage (i..e, weight) needed to return a - spatial average value. Value must range from 0 to 1. + spatial average value. Value must range from 0 to 1, by default None + (equivalent to minimum_weight=0.0). Returns ------- @@ -745,21 +747,21 @@ def _averager( weights = self._weights.fillna(0) # ensure required weight is between 0 and 1 - if required_weight is None: - required_weight = 0.0 - elif required_weight < 0.0: + if minimum_weight is None: + minimum_weight = 0.0 + elif minimum_weight < 0.0: raise ValueError( - "required_weight argument is less than 0. " - "required_weight must be between 0 and 1." + "minimum_weight argument is less than 0. " + "minimum_weight must be between 0 and 1." ) - elif required_weight > 1.0: + elif minimum_weight > 1.0: raise ValueError( - "required_weight argument is greater than 1. " - "required_weight must be between 0 and 1." + "minimum_weight argument is greater than 1. " + "minimum_weight must be between 0 and 1." ) - + # need weights to match data_var dimensionality - if required_weight > 0.0: + if minimum_weight > 0.0: weights, data_var = xr.broadcast(weights, data_var) # get averaging dimensions @@ -773,7 +775,7 @@ def _averager( # if weight thresholds applied, calculate fraction of data availability # replace values that do not meet minimum weight with nan - if required_weight > 0.0: + if minimum_weight > 0.0: # sum all weights (assuming no missing values exist) weight_sum_all = weights.sum(dim=dim) # type: ignore # zero out cells with missing values in data_var @@ -783,6 +785,6 @@ def _averager( # get fraction of weight available frac = weight_sum_masked / weight_sum_all # nan out values that don't meet specified weight threshold - weighted_mean = xr.where(frac >= required_weight, weighted_mean, np.nan) + weighted_mean = xr.where(frac >= minimum_weight, weighted_mean, np.nan) return weighted_mean From 1c5c3c1af136fda7f86a767323ae2c9367427334 Mon Sep 17 00:00:00 2001 From: Stephen Po-Chedley Date: Fri, 23 Aug 2024 11:51:34 -0700 Subject: [PATCH 5/8] update tests for minimum_weight parameter --- tests/test_spatial.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_spatial.py b/tests/test_spatial.py index ea9f44a6..ea2ad27f 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -140,16 +140,16 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims( with pytest.raises(ValueError): self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights) - def test_raises_error_if_required_weight_not_between_zero_and_one( + def test_raises_error_if_minimum_weight_not_between_zero_and_one( self, ): - # ensure error if required_weight less than zero + # ensure error if minimum_weight less than zero with pytest.raises(ValueError): - self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=-0.01) + self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=-0.01) - # ensure error if required_weight greater than 1 + # ensure error if minimum_weight greater than 1 with pytest.raises(ValueError): - self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=1.01) + self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=1.01) def test_spatial_average_for_lat_region_and_keep_weights(self): ds = self.ds.copy() @@ -265,7 +265,7 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self): xr.testing.assert_allclose(result, expected) - def test_spatial_average_with_required_weight(self): + def test_spatial_average_with_minimum_weight(self): ds = self.ds.copy() # insert a nan @@ -276,7 +276,7 @@ def test_spatial_average_with_required_weight(self): axis=["X", "Y"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1), - required_weight=1.0, + minimum_weight=1.0, ) expected = self.ds.copy() @@ -288,7 +288,7 @@ def test_spatial_average_with_required_weight(self): xr.testing.assert_allclose(result, expected) - def test_spatial_average_with_required_weight_as_None(self): + def test_spatial_average_with_minimum_weight_as_None(self): ds = self.ds.copy() result = ds.spatial.average( @@ -296,7 +296,7 @@ def test_spatial_average_with_required_weight_as_None(self): axis=["X", "Y"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1), - required_weight=None, + minimum_weight=None, ) expected = self.ds.copy() From 636366b2ae3bd19c0cbe2846b8dc8e93dcf2140c Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Thu, 5 Sep 2024 10:46:47 -0700 Subject: [PATCH 6/8] Updates from code review - Rename arg `minimum_weight` to `min_weight` - Add `_get_masked_weights()` and `_validate_min_weight()` to `utils.py` - Update `SpatialAccessor` to use `_get_masked_weights()` and `_validate_min_weight()` - Replace type annotation `Optional` with `|` - Extract `_mask_var_with_with_threshold()` from `_averager()` for readability --- tests/test_spatial.py | 18 ++--- tests/test_utils.py | 22 +++++- xcdat/spatial.py | 160 +++++++++++++++++++++++++----------------- xcdat/utils.py | 60 ++++++++++++++++ 4 files changed, 185 insertions(+), 75 deletions(-) diff --git a/tests/test_spatial.py b/tests/test_spatial.py index ea2ad27f..244bcf31 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -140,16 +140,16 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims( with pytest.raises(ValueError): self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights) - def test_raises_error_if_minimum_weight_not_between_zero_and_one( + def test_raises_error_if_min_weight_not_between_zero_and_one( self, ): - # ensure error if minimum_weight less than zero + # ensure error if min_weight less than zero with pytest.raises(ValueError): - self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=-0.01) + self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=-0.01) - # ensure error if minimum_weight greater than 1 + # ensure error if min_weight greater than 1 with pytest.raises(ValueError): - self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=1.01) + self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=1.01) def test_spatial_average_for_lat_region_and_keep_weights(self): ds = self.ds.copy() @@ -265,7 +265,7 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self): xr.testing.assert_allclose(result, expected) - def test_spatial_average_with_minimum_weight(self): + def test_spatial_average_with_min_weight(self): ds = self.ds.copy() # insert a nan @@ -276,7 +276,7 @@ def test_spatial_average_with_minimum_weight(self): axis=["X", "Y"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1), - minimum_weight=1.0, + min_weight=1.0, ) expected = self.ds.copy() @@ -288,7 +288,7 @@ def test_spatial_average_with_minimum_weight(self): xr.testing.assert_allclose(result, expected) - def test_spatial_average_with_minimum_weight_as_None(self): + def test_spatial_average_with_min_weight_as_None(self): ds = self.ds.copy() result = ds.spatial.average( @@ -296,7 +296,7 @@ def test_spatial_average_with_minimum_weight_as_None(self): axis=["X", "Y"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1), - minimum_weight=None, + min_weight=None, ) expected = self.ds.copy() diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d4dcbe8..30d3cbfb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import pytest import xarray as xr -from xcdat.utils import compare_datasets, str_to_bool +from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool class TestCompareDatasets: @@ -103,3 +103,23 @@ def test_raises_error_if_str_is_not_a_python_bool(self): with pytest.raises(ValueError): str_to_bool("1") + + +class TestValidateMinWeight: + def test_pass_None_returns_0(self): + result = _validate_min_weight(None) + + assert result == 0 + + def test_returns_error_if_less_than_0(self): + with pytest.raises(ValueError): + _validate_min_weight(-1) + + def test_returns_error_if_greater_than_1(self): + with pytest.raises(ValueError): + _validate_min_weight(1.1) + + def test_returns_valid_min_weight(self): + result = _validate_min_weight(1) + + assert result == 1 diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 7dc4075f..20d8cc0c 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -1,5 +1,8 @@ """Module containing geospatial averaging functions.""" +<<<<<<< HEAD +======= +>>>>>>> 34b570d6 (Updates from code review) from __future__ import annotations from functools import reduce @@ -9,7 +12,6 @@ Hashable, List, Literal, - Optional, Tuple, TypedDict, Union, @@ -27,7 +29,11 @@ get_dim_keys, ) from xcdat.dataset import _get_data_var -from xcdat.utils import _if_multidim_dask_array_then_load +from xcdat.utils import ( + _get_masked_weights, + _if_multidim_dask_array_then_load, + _validate_min_weight, +) #: Type alias for a dictionary of axis keys mapped to their bounds. AxisWeights = Dict[Hashable, xr.DataArray] @@ -74,9 +80,9 @@ def average( axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"), weights: Union[Literal["generate"], xr.DataArray] = "generate", keep_weights: bool = False, - lat_bounds: Optional[RegionAxisBounds] = None, - lon_bounds: Optional[RegionAxisBounds] = None, - minimum_weight: Optional[float] = None, + lat_bounds: RegionAxisBounds | None = None, + lon_bounds: RegionAxisBounds | None = None, + min_weight: float | None = None, ) -> xr.Dataset: """ Calculates the spatial average for a rectilinear grid over an optionally @@ -115,21 +121,21 @@ def average( keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. - lat_bounds : Optional[RegionAxisBounds], optional + lat_bounds : RegionAxisBounds | None, optional A tuple of floats/ints for the regional latitude lower and upper boundaries. This arg is used when calculating axis weights, but is ignored if ``weights`` are supplied. The lower bound cannot be larger than the upper bound, by default None. - lon_bounds : Optional[RegionAxisBounds], optional + lon_bounds : RegionAxisBounds | None, optional A tuple of floats/ints for the regional longitude lower and upper boundaries. This arg is used when calculating axis weights, but is ignored if ``weights`` are supplied. The lower bound can be larger than the upper bound (e.g., across the prime meridian, dateline), by default None. - minimum_weight : optional, float - Fraction of data coverage (i..e, weight) needed to return a + min_weight : optional, float + Fraction of data coverage (i.e, weight) needed to return a spatial average value. Value must range from 0 to 1, by default None - (equivalent to minimum_weight=0.0). + (equivalent to ``min_weight=0.0``). Returns ------- @@ -189,7 +195,9 @@ def average( """ ds = self._dataset.copy() dv = _get_data_var(ds, data_var) + self._validate_axis_arg(axis) + min_weight = _validate_min_weight(min_weight) if isinstance(weights, str) and weights == "generate": if lat_bounds is not None: @@ -201,7 +209,7 @@ def average( self._weights = weights self._validate_weights(dv, axis) - ds[dv.name] = self._averager(dv, axis, minimum_weight=minimum_weight) + ds[dv.name] = self._averager(dv, axis, min_weight=min_weight) if keep_weights: ds[self._weights.name] = self._weights @@ -211,9 +219,9 @@ def average( def get_weights( self, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], - lat_bounds: Optional[RegionAxisBounds] = None, - lon_bounds: Optional[RegionAxisBounds] = None, - data_var: Optional[str] = None, + lat_bounds: RegionAxisBounds | None = None, + lon_bounds: RegionAxisBounds | None = None, + data_var: str | None = None, ) -> xr.DataArray: """ Get area weights for specified axis keys and an optional target domain. @@ -232,13 +240,13 @@ def get_weights( ---------- axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. - lat_bounds : Optional[RegionAxisBounds] + lat_bounds : RegionAxisBounds | None Tuple of latitude boundaries for regional selection, by default None. - lon_bounds : Optional[RegionAxisBounds] + lon_bounds : RegionAxisBounds | None Tuple of longitude boundaries for regional selection, by default None. - data_var: Optional[str] + data_var: str | None The key of the data variable, by default None. Pass this argument when the dataset has more than one bounds per axis (e.g., "lon" and "zlon_bnds" for the "X" axis), or you want weights for a @@ -259,7 +267,7 @@ def get_weights( and pressure). """ Bounds = TypedDict( - "Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]} + "Bounds", {"weights_method": Callable, "region": np.ndarray | None} ) axis_bounds: Dict[SpatialAxis, Bounds] = { @@ -382,7 +390,7 @@ def _validate_region_bounds(self, axis: SpatialAxis, bounds: RegionAxisBounds): ) def _get_longitude_weights( - self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None ) -> xr.DataArray: """Gets weights for the longitude axis. @@ -409,7 +417,7 @@ def _get_longitude_weights( ---------- domain_bounds : xr.DataArray The array of bounds for the longitude domain. - region_bounds : Optional[np.ndarray] + region_bounds : np.ndarray | None The array of bounds for longitude regional selection. Returns @@ -423,7 +431,7 @@ def _get_longitude_weights( If the there are multiple instances in which the domain_bounds[:, 0] > domain_bounds[:, 1] """ - p_meridian_index: Optional[np.ndarray] = None + p_meridian_index: np.ndarray | None = None d_bounds = domain_bounds.copy() pm_cells = np.where(domain_bounds[:, 1] - domain_bounds[:, 0] < 0)[0] @@ -455,7 +463,7 @@ def _get_longitude_weights( return weights def _get_latitude_weights( - self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None ) -> xr.DataArray: """Gets weights for the latitude axis. @@ -467,7 +475,7 @@ def _get_latitude_weights( ---------- domain_bounds : xr.DataArray The array of bounds for the latitude domain. - region_bounds : Optional[np.ndarray] + region_bounds : np.ndarray | None The array of bounds for latitude regional selection. Returns @@ -710,7 +718,7 @@ def _averager( self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], - minimum_weight: Optional[float] = None, + min_weight: float, ): """Perform a weighted average of a data variable. @@ -729,10 +737,9 @@ def _averager( Data variable inside a Dataset. axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. - minimum_weight : optional, float - Fraction of data coverage (i..e, weight) needed to return a - spatial average value. Value must range from 0 to 1, by default None - (equivalent to minimum_weight=0.0). + min_weight : float + Fraction of data coverage (i.e, weight) needed to return a + spatial average value. Value must range from 0 to 1. Returns ------- @@ -746,45 +753,68 @@ def _averager( """ weights = self._weights.fillna(0) - # ensure required weight is between 0 and 1 - if minimum_weight is None: - minimum_weight = 0.0 - elif minimum_weight < 0.0: - raise ValueError( - "minimum_weight argument is less than 0. " - "minimum_weight must be between 0 and 1." - ) - elif minimum_weight > 1.0: - raise ValueError( - "minimum_weight argument is greater than 1. " - "minimum_weight must be between 0 and 1." - ) - - # need weights to match data_var dimensionality - if minimum_weight > 0.0: + # TODO: This conditional might not be needed because Xarray will + # automatically broadcast the weights to the data variable for + # operations such as .mean() and .where(). + if min_weight > 0.0: weights, data_var = xr.broadcast(weights, data_var) - # get averaging dimensions - dim = [] + dim: List[str] = [] for key in axis: - dim.append(get_dim_keys(data_var, key)) + dim.append(get_dim_keys(data_var, key)) # type: ignore - # compute weighed mean with xr.set_options(keep_attrs=True): - weighted_mean = data_var.cf.weighted(weights).mean(dim=dim) - - # if weight thresholds applied, calculate fraction of data availability - # replace values that do not meet minimum weight with nan - if minimum_weight > 0.0: - # sum all weights (assuming no missing values exist) - weight_sum_all = weights.sum(dim=dim) # type: ignore - # zero out cells with missing values in data_var - weights = xr.where(~np.isnan(data_var), weights, 0) - # sum all weights (including zero for missing values) - weight_sum_masked = weights.sum(dim=dim) # type: ignore - # get fraction of weight available - frac = weight_sum_masked / weight_sum_all - # nan out values that don't meet specified weight threshold - weighted_mean = xr.where(frac >= minimum_weight, weighted_mean, np.nan) - - return weighted_mean + dv_mean = data_var.cf.weighted(weights).mean(dim=dim) + + if min_weight > 0.0: + dv_mean = self._mask_var_with_weight_threshold( + dv_mean, dim, weights, min_weight + ) + + return dv_mean + + def _mask_var_with_weight_threshold( + self, dv: xr.DataArray, dim: List[str], weights: xr.DataArray, min_weight: float + ) -> xr.DataArray: + """Mask values that do not meet the minimum weight threshold with np.nan. + + This function is useful for cases where the weighting of data might be + skewed based on the availability of data. For example, if a portion of + cells in a region has significantly more missing data than other other + regions, it can result in inaccurate calculations of spatial averaging. + Masking values that do not meet the minimum weight threshold ensures + more accurate calculations. + + Parameters + ---------- + dv : xr.DataArray + The weighted variable. + dim: List[str]: + List of axis dimensions to average over. + weights : xr.DataArray + A DataArray containing either the regional weights used for weighted + averaging. ``weights`` must include the same axis dimensions and + dimensional sizes as the data variable. + min_weight : float + Fraction of data coverage (i.e, weight) needed to return a + spatial average value. Value must range from 0 to 1. + + Returns + ------- + xr.DataArray + The variable with the minimum weight threshold applied. + """ + # Sum all weights, including zero for missing values. + weight_sum_all = weights.sum(dim=dim) + + masked_weights = _get_masked_weights(dv, weights) + weight_sum_masked = masked_weights.sum(dim=dim) + + # Get fraction of the available weight. + frac = weight_sum_masked / weight_sum_all + + # Nan out values that don't meet specified weight threshold. + dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True) + dv_new.name = dv.name + + return dv_new diff --git a/xcdat/utils.py b/xcdat/utils.py index 83596561..a2f674fa 100644 --- a/xcdat/utils.py +++ b/xcdat/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import json from typing import Dict, List, Optional, Union @@ -132,3 +134,61 @@ def _if_multidim_dask_array_then_load( return obj.load() return None + + +def _get_masked_weights(dv: xr.DataArray, weights: xr.DataArray) -> xr.DataArray: + """Get weights with missing data (`np.nan`) receiving no weight (zero). + + Parameters + ---------- + dv : xr.DataArray + The variable. + weights : xr.DataArray + A DataArray containing either the regional or temporal weights used for + weighted averaging. ``weights`` must include the same axis dimensions + and dimensional sizes as the data variable. + + Returns + ------- + xr.DataArray + The masked weights. + """ + masked_weights = xr.where(dv.copy().isnull(), 0.0, weights) + + return masked_weights + + +def _validate_min_weight(min_weight: float | None) -> float: + """Validate the ``min_weight`` value. + + Parameters + ---------- + min_weight : float | None + Fraction of data coverage (i..e, weight) needed to return a + spatial average value. Value must range from 0 to 1. + + Returns + ------- + float + The required weight percentage. + + Raises + ------ + ValueError + If the `min_weight` argument is less than 0. + ValueError + If the `min_weight` argument is greater than 1. + """ + if min_weight is None: + return 0.0 + elif min_weight < 0.0: + raise ValueError( + "min_weight argument is less than 0. " "min_weight must be between 0 and 1." + ) + elif min_weight > 1.0: + raise ValueError( + "min_weight argument is greater than 1. " + "min_weight must be between 0 and 1." + ) + + return min_weight From 82cbfe272d570fb3aaf4d09f635956cbb97d02ad Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Thu, 5 Sep 2024 11:29:33 -0700 Subject: [PATCH 7/8] Fix `TypeError` for optional `region` arg in `TypedDict` --- xcdat/spatial.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 20d8cc0c..f3508116 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -12,6 +12,7 @@ Hashable, List, Literal, + Optional, Tuple, TypedDict, Union, @@ -267,7 +268,7 @@ def get_weights( and pressure). """ Bounds = TypedDict( - "Bounds", {"weights_method": Callable, "region": np.ndarray | None} + "Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]} ) axis_bounds: Dict[SpatialAxis, Bounds] = { From 9331a02db967f95aa9f4be8a76d93465619bbc56 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Thu, 21 Nov 2024 12:35:58 -0800 Subject: [PATCH 8/8] Remove rebase comment --- xcdat/spatial.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xcdat/spatial.py b/xcdat/spatial.py index f3508116..d105666e 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -1,8 +1,5 @@ """Module containing geospatial averaging functions.""" -<<<<<<< HEAD -======= ->>>>>>> 34b570d6 (Updates from code review) from __future__ import annotations from functools import reduce