Skip to content

Commit

Permalink
Updates from code review
Browse files Browse the repository at this point in the history
- Rename arg `minimum_weight` to `min_weight`
- Add `_get_masked_weights()` and `_validate_min_weight()` to `utils.py`
- Update `SpatialAccessor` to use `_get_masked_weights()` and `_validate_min_weight()`
- Replace type annotation `Optional` with `|`
- Extract `_mask_var_with_with_threshold()` from `_averager()` for readability
  • Loading branch information
tomvothecoder committed Sep 5, 2024
1 parent 5b2afff commit b56befe
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 75 deletions.
18 changes: 9 additions & 9 deletions tests/test_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,16 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims(
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights)

def test_raises_error_if_minimum_weight_not_between_zero_and_one(
def test_raises_error_if_min_weight_not_between_zero_and_one(
self,
):
# ensure error if minimum_weight less than zero
# ensure error if min_weight less than zero
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=-0.01)
self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=-0.01)

# ensure error if minimum_weight greater than 1
# ensure error if min_weight greater than 1
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=1.01)
self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=1.01)

def test_spatial_average_for_lat_region_and_keep_weights(self):
ds = self.ds.copy()
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self):

xr.testing.assert_allclose(result, expected)

def test_spatial_average_with_minimum_weight(self):
def test_spatial_average_with_min_weight(self):
ds = self.ds.copy()

# insert a nan
Expand All @@ -276,7 +276,7 @@ def test_spatial_average_with_minimum_weight(self):
axis=["X", "Y"],
lat_bounds=(-5.0, 5),
lon_bounds=(-170, -120.1),
minimum_weight=1.0,
min_weight=1.0,
)

expected = self.ds.copy()
Expand All @@ -288,15 +288,15 @@ def test_spatial_average_with_minimum_weight(self):

xr.testing.assert_allclose(result, expected)

def test_spatial_average_with_minimum_weight_as_None(self):
def test_spatial_average_with_min_weight_as_None(self):
ds = self.ds.copy()

result = ds.spatial.average(
"ts",
axis=["X", "Y"],
lat_bounds=(-5.0, 5),
lon_bounds=(-170, -120.1),
minimum_weight=None,
min_weight=None,
)

expected = self.ds.copy()
Expand Down
22 changes: 21 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import xarray as xr

from xcdat.utils import compare_datasets, str_to_bool
from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool


class TestCompareDatasets:
Expand Down Expand Up @@ -103,3 +103,23 @@ def test_raises_error_if_str_is_not_a_python_bool(self):

with pytest.raises(ValueError):
str_to_bool("1")


class TestValidateMinWeight:
def test_pass_None_returns_0(self):
result = _validate_min_weight(None)

assert result == 0

def test_returns_error_if_less_than_0(self):
with pytest.raises(ValueError):
_validate_min_weight(-1)

def test_returns_error_if_greater_than_1(self):
with pytest.raises(ValueError):
_validate_min_weight(1.1)

def test_returns_valid_min_weight(self):
result = _validate_min_weight(1)

assert result == 1
159 changes: 94 additions & 65 deletions xcdat/spatial.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Module containing geospatial averaging functions."""
from __future__ import annotations

from functools import reduce
from typing import (
Callable,
Dict,
Hashable,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
Expand All @@ -24,7 +25,11 @@
get_dim_keys,
)
from xcdat.dataset import _get_data_var
from xcdat.utils import _if_multidim_dask_array_then_load
from xcdat.utils import (
_get_masked_weights,
_if_multidim_dask_array_then_load,
_validate_min_weight,
)

#: Type alias for a dictionary of axis keys mapped to their bounds.
AxisWeights = Dict[Hashable, xr.DataArray]
Expand Down Expand Up @@ -71,9 +76,9 @@ def average(
axis: List[SpatialAxis] = ["X", "Y"],
weights: Union[Literal["generate"], xr.DataArray] = "generate",
keep_weights: bool = False,
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
minimum_weight: Optional[float] = None,
lat_bounds: RegionAxisBounds | None = None,
lon_bounds: RegionAxisBounds | None = None,
min_weight: float | None = None,
) -> xr.Dataset:
"""
Calculates the spatial average for a rectilinear grid over an optionally
Expand Down Expand Up @@ -112,21 +117,21 @@ def average(
keep_weights : bool, optional
If calculating averages using weights, keep the weights in the
final dataset output, by default False.
lat_bounds : Optional[RegionAxisBounds], optional
lat_bounds : RegionAxisBounds | None, optional
A tuple of floats/ints for the regional latitude lower and upper
boundaries. This arg is used when calculating axis weights, but is
ignored if ``weights`` are supplied. The lower bound cannot be
larger than the upper bound, by default None.
lon_bounds : Optional[RegionAxisBounds], optional
lon_bounds : RegionAxisBounds | None, optional
A tuple of floats/ints for the regional longitude lower and upper
boundaries. This arg is used when calculating axis weights, but is
ignored if ``weights`` are supplied. The lower bound can be larger
than the upper bound (e.g., across the prime meridian, dateline), by
default None.
minimum_weight : optional, float
Fraction of data coverage (i..e, weight) needed to return a
min_weight : optional, float
Fraction of data coverage (i.e, weight) needed to return a
spatial average value. Value must range from 0 to 1, by default None
(equivalent to minimum_weight=0.0).
(equivalent to ``min_weight=0.0``).
Returns
-------
Expand Down Expand Up @@ -186,7 +191,9 @@ def average(
"""
ds = self._dataset.copy()
dv = _get_data_var(ds, data_var)

self._validate_axis_arg(axis)
min_weight = _validate_min_weight(min_weight)

if isinstance(weights, str) and weights == "generate":
if lat_bounds is not None:
Expand All @@ -198,7 +205,7 @@ def average(
self._weights = weights

self._validate_weights(dv, axis)
ds[dv.name] = self._averager(dv, axis, minimum_weight=minimum_weight)
ds[dv.name] = self._averager(dv, axis, min_weight=min_weight)

if keep_weights:
ds[self._weights.name] = self._weights
Expand All @@ -208,9 +215,9 @@ def average(
def get_weights(
self,
axis: List[SpatialAxis],
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
data_var: Optional[str] = None,
lat_bounds: RegionAxisBounds | None = None,
lon_bounds: RegionAxisBounds | None = None,
data_var: str | None = None,
) -> xr.DataArray:
"""
Get area weights for specified axis keys and an optional target domain.
Expand All @@ -229,13 +236,13 @@ def get_weights(
----------
axis : List[SpatialAxis]
List of axis dimensions to average over.
lat_bounds : Optional[RegionAxisBounds]
lat_bounds : RegionAxisBounds | None
Tuple of latitude boundaries for regional selection, by default
None.
lon_bounds : Optional[RegionAxisBounds]
lon_bounds : RegionAxisBounds | None
Tuple of longitude boundaries for regional selection, by default
None.
data_var: Optional[str]
data_var: str | None
The key of the data variable, by default None. Pass this argument
when the dataset has more than one bounds per axis (e.g., "lon"
and "zlon_bnds" for the "X" axis), or you want weights for a
Expand All @@ -256,7 +263,7 @@ def get_weights(
and pressure).
"""
Bounds = TypedDict(
"Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]}
"Bounds", {"weights_method": Callable, "region": np.ndarray | None}
)

axis_bounds: Dict[SpatialAxis, Bounds] = {
Expand Down Expand Up @@ -379,7 +386,7 @@ def _validate_region_bounds(self, axis: SpatialAxis, bounds: RegionAxisBounds):
)

def _get_longitude_weights(
self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray]
self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None
) -> xr.DataArray:
"""Gets weights for the longitude axis.
Expand All @@ -406,7 +413,7 @@ def _get_longitude_weights(
----------
domain_bounds : xr.DataArray
The array of bounds for the longitude domain.
region_bounds : Optional[np.ndarray]
region_bounds : np.ndarray | None
The array of bounds for longitude regional selection.
Returns
Expand All @@ -420,7 +427,7 @@ def _get_longitude_weights(
If the there are multiple instances in which the
domain_bounds[:, 0] > domain_bounds[:, 1]
"""
p_meridian_index: Optional[np.ndarray] = None
p_meridian_index: np.ndarray | None = None
d_bounds = domain_bounds.copy()

pm_cells = np.where(domain_bounds[:, 1] - domain_bounds[:, 0] < 0)[0]
Expand Down Expand Up @@ -454,7 +461,7 @@ def _get_longitude_weights(
return weights

def _get_latitude_weights(
self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray]
self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None
) -> xr.DataArray:
"""Gets weights for the latitude axis.
Expand All @@ -466,7 +473,7 @@ def _get_latitude_weights(
----------
domain_bounds : xr.DataArray
The array of bounds for the latitude domain.
region_bounds : Optional[np.ndarray]
region_bounds : np.ndarray | None
The array of bounds for latitude regional selection.
Returns
Expand Down Expand Up @@ -707,7 +714,7 @@ def _averager(
self,
data_var: xr.DataArray,
axis: List[SpatialAxis],
minimum_weight: Optional[float] = None,
min_weight: float,
):
"""Perform a weighted average of a data variable.
Expand All @@ -726,10 +733,9 @@ def _averager(
Data variable inside a Dataset.
axis : List[SpatialAxis]
List of axis dimensions to average over.
minimum_weight : optional, float
Fraction of data coverage (i..e, weight) needed to return a
spatial average value. Value must range from 0 to 1, by default None
(equivalent to minimum_weight=0.0).
min_weight : float
Fraction of data coverage (i.e, weight) needed to return a
spatial average value. Value must range from 0 to 1.
Returns
-------
Expand All @@ -743,45 +749,68 @@ def _averager(
"""
weights = self._weights.fillna(0)

# ensure required weight is between 0 and 1
if minimum_weight is None:
minimum_weight = 0.0
elif minimum_weight < 0.0:
raise ValueError(
"minimum_weight argument is less than 0. "
"minimum_weight must be between 0 and 1."
)
elif minimum_weight > 1.0:
raise ValueError(
"minimum_weight argument is greater than 1. "
"minimum_weight must be between 0 and 1."
)

# need weights to match data_var dimensionality
if minimum_weight > 0.0:
# TODO: This conditional might not be needed because Xarray will
# automatically broadcast the weights to the data variable for
# operations such as .mean() and .where().
if min_weight > 0.0:
weights, data_var = xr.broadcast(weights, data_var)

# get averaging dimensions
dim = []
dim: List[str] = []
for key in axis:
dim.append(get_dim_keys(data_var, key))
dim.append(get_dim_keys(data_var, key)) # type: ignore

# compute weighed mean
with xr.set_options(keep_attrs=True):
weighted_mean = data_var.cf.weighted(weights).mean(dim=dim)

# if weight thresholds applied, calculate fraction of data availability
# replace values that do not meet minimum weight with nan
if minimum_weight > 0.0:
# sum all weights (assuming no missing values exist)
weight_sum_all = weights.sum(dim=dim) # type: ignore
# zero out cells with missing values in data_var
weights = xr.where(~np.isnan(data_var), weights, 0)
# sum all weights (including zero for missing values)
weight_sum_masked = weights.sum(dim=dim) # type: ignore
# get fraction of weight available
frac = weight_sum_masked / weight_sum_all
# nan out values that don't meet specified weight threshold
weighted_mean = xr.where(frac >= minimum_weight, weighted_mean, np.nan)

return weighted_mean
dv_mean = data_var.cf.weighted(weights).mean(dim=dim)

if min_weight > 0.0:
dv_mean = self._mask_var_with_weight_threshold(
dv_mean, dim, weights, min_weight
)

return dv_mean

def _mask_var_with_weight_threshold(
self, dv: xr.DataArray, dim: List[str], weights: xr.DataArray, min_weight: float
) -> xr.DataArray:
"""Mask values that do not meet the minimum weight threshold with np.nan.
This function is useful for cases where the weighting of data might be
skewed based on the availability of data. For example, if a portion of
cells in a region has significantly more missing data than other other
regions, it can result in inaccurate calculations of spatial averaging.
Masking values that do not meet the minimum weight threshold ensures
more accurate calculations.
Parameters
----------
dv : xr.DataArray
The weighted variable.
dim: List[str]:
List of axis dimensions to average over.
weights : xr.DataArray
A DataArray containing either the regional weights used for weighted
averaging. ``weights`` must include the same axis dimensions and
dimensional sizes as the data variable.
min_weight : float
Fraction of data coverage (i.e, weight) needed to return a
spatial average value. Value must range from 0 to 1.
Returns
-------
xr.DataArray
The variable with the minimum weight threshold applied.
"""
# Sum all weights, including zero for missing values.
weight_sum_all = weights.sum(dim=dim)

masked_weights = _get_masked_weights(dv, weights)
weight_sum_masked = masked_weights.sum(dim=dim)

# Get fraction of the available weight.
frac = weight_sum_masked / weight_sum_all

# Nan out values that don't meet specified weight threshold.
dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True)
dv_new.name = dv.name

return dv_new
Loading

0 comments on commit b56befe

Please sign in to comment.