Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight threshold option for spatial averaging #672

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
54 changes: 54 additions & 0 deletions tests/test_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims(
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights)

def test_raises_error_if_min_weight_not_between_zero_and_one(
self,
):
# ensure error if min_weight less than zero
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=-0.01)

# ensure error if min_weight greater than 1
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=1.01)

def test_spatial_average_for_lat_region_and_keep_weights(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -254,6 +265,49 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self):

xr.testing.assert_allclose(result, expected)

def test_spatial_average_with_min_weight(self):
ds = self.ds.copy()

# insert a nan
ds["ts"][0, :, 2] = np.nan

result = ds.spatial.average(
"ts",
axis=["X", "Y"],
lat_bounds=(-5.0, 5),
lon_bounds=(-170, -120.1),
min_weight=1.0,
)

expected = self.ds.copy()
expected["ts"] = xr.DataArray(
data=np.array([np.nan, 1.0, 1.0]),
coords={"time": expected.time},
dims="time",
)

xr.testing.assert_allclose(result, expected)

def test_spatial_average_with_min_weight_as_None(self):
ds = self.ds.copy()

result = ds.spatial.average(
"ts",
axis=["X", "Y"],
lat_bounds=(-5.0, 5),
lon_bounds=(-170, -120.1),
min_weight=None,
)

expected = self.ds.copy()
expected["ts"] = xr.DataArray(
data=np.array([2.25, 1.0, 1.0]),
coords={"time": expected.time},
dims="time",
)

xr.testing.assert_allclose(result, expected)

def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self):
ds = self.ds.copy()

Expand Down
22 changes: 21 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import xarray as xr

from xcdat.utils import compare_datasets, str_to_bool
from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool


class TestCompareDatasets:
Expand Down Expand Up @@ -103,3 +103,23 @@ def test_raises_error_if_str_is_not_a_python_bool(self):

with pytest.raises(ValueError):
str_to_bool("1")


class TestValidateMinWeight:
def test_pass_None_returns_0(self):
result = _validate_min_weight(None)

assert result == 0

def test_returns_error_if_less_than_0(self):
with pytest.raises(ValueError):
_validate_min_weight(-1)

def test_returns_error_if_greater_than_1(self):
with pytest.raises(ValueError):
_validate_min_weight(1.1)

def test_returns_valid_min_weight(self):
result = _validate_min_weight(1)

assert result == 1
121 changes: 99 additions & 22 deletions xcdat/spatial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Module containing geospatial averaging functions."""
<<<<<<< HEAD

=======
>>>>>>> 34b570d6 (Updates from code review)
from __future__ import annotations

from functools import reduce
Expand Down Expand Up @@ -27,7 +30,11 @@
get_dim_keys,
)
from xcdat.dataset import _get_data_var
from xcdat.utils import _if_multidim_dask_array_then_load
from xcdat.utils import (
_get_masked_weights,
_if_multidim_dask_array_then_load,
_validate_min_weight,
)

#: Type alias for a dictionary of axis keys mapped to their bounds.
AxisWeights = Dict[Hashable, xr.DataArray]
Expand Down Expand Up @@ -74,8 +81,9 @@ def average(
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"),
weights: Union[Literal["generate"], xr.DataArray] = "generate",
keep_weights: bool = False,
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
lat_bounds: RegionAxisBounds | None = None,
lon_bounds: RegionAxisBounds | None = None,
min_weight: float | None = None,
) -> xr.Dataset:
"""
Calculates the spatial average for a rectilinear grid over an optionally
Expand Down Expand Up @@ -114,17 +122,21 @@ def average(
keep_weights : bool, optional
If calculating averages using weights, keep the weights in the
final dataset output, by default False.
lat_bounds : Optional[RegionAxisBounds], optional
lat_bounds : RegionAxisBounds | None, optional
A tuple of floats/ints for the regional latitude lower and upper
boundaries. This arg is used when calculating axis weights, but is
ignored if ``weights`` are supplied. The lower bound cannot be
larger than the upper bound, by default None.
lon_bounds : Optional[RegionAxisBounds], optional
lon_bounds : RegionAxisBounds | None, optional
A tuple of floats/ints for the regional longitude lower and upper
boundaries. This arg is used when calculating axis weights, but is
ignored if ``weights`` are supplied. The lower bound can be larger
than the upper bound (e.g., across the prime meridian, dateline), by
default None.
min_weight : optional, float
Fraction of data coverage (i.e, weight) needed to return a
spatial average value. Value must range from 0 to 1, by default None
(equivalent to ``min_weight=0.0``).

Returns
-------
Expand Down Expand Up @@ -184,7 +196,9 @@ def average(
"""
ds = self._dataset.copy()
dv = _get_data_var(ds, data_var)

self._validate_axis_arg(axis)
min_weight = _validate_min_weight(min_weight)

if isinstance(weights, str) and weights == "generate":
if lat_bounds is not None:
Expand All @@ -196,7 +210,7 @@ def average(
self._weights = weights

self._validate_weights(dv, axis)
ds[dv.name] = self._averager(dv, axis)
ds[dv.name] = self._averager(dv, axis, min_weight=min_weight)

if keep_weights:
ds[self._weights.name] = self._weights
Expand All @@ -206,9 +220,9 @@ def average(
def get_weights(
self,
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
data_var: Optional[str] = None,
lat_bounds: RegionAxisBounds | None = None,
lon_bounds: RegionAxisBounds | None = None,
data_var: str | None = None,
) -> xr.DataArray:
"""
Get area weights for specified axis keys and an optional target domain.
Expand All @@ -227,13 +241,13 @@ def get_weights(
----------
axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
List of axis dimensions to average over.
lat_bounds : Optional[RegionAxisBounds]
lat_bounds : RegionAxisBounds | None
Tuple of latitude boundaries for regional selection, by default
None.
lon_bounds : Optional[RegionAxisBounds]
lon_bounds : RegionAxisBounds | None
Tuple of longitude boundaries for regional selection, by default
None.
data_var: Optional[str]
data_var: str | None
The key of the data variable, by default None. Pass this argument
when the dataset has more than one bounds per axis (e.g., "lon"
and "zlon_bnds" for the "X" axis), or you want weights for a
Expand Down Expand Up @@ -377,7 +391,7 @@ def _validate_region_bounds(self, axis: SpatialAxis, bounds: RegionAxisBounds):
)

def _get_longitude_weights(
self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray]
self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None
) -> xr.DataArray:
"""Gets weights for the longitude axis.

Expand All @@ -404,7 +418,7 @@ def _get_longitude_weights(
----------
domain_bounds : xr.DataArray
The array of bounds for the longitude domain.
region_bounds : Optional[np.ndarray]
region_bounds : np.ndarray | None
The array of bounds for longitude regional selection.

Returns
Expand All @@ -418,7 +432,7 @@ def _get_longitude_weights(
If the there are multiple instances in which the
domain_bounds[:, 0] > domain_bounds[:, 1]
"""
p_meridian_index: Optional[np.ndarray] = None
p_meridian_index: np.ndarray | None = None
d_bounds = domain_bounds.copy()

pm_cells = np.where(domain_bounds[:, 1] - domain_bounds[:, 0] < 0)[0]
Expand Down Expand Up @@ -450,7 +464,7 @@ def _get_longitude_weights(
return weights

def _get_latitude_weights(
self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray]
self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None
) -> xr.DataArray:
"""Gets weights for the latitude axis.

Expand All @@ -462,7 +476,7 @@ def _get_latitude_weights(
----------
domain_bounds : xr.DataArray
The array of bounds for the latitude domain.
region_bounds : Optional[np.ndarray]
region_bounds : np.ndarray | None
The array of bounds for latitude regional selection.

Returns
Expand Down Expand Up @@ -702,7 +716,10 @@ def _validate_weights(
)

def _averager(
self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...]
self,
data_var: xr.DataArray,
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
min_weight: float,
):
"""Perform a weighted average of a data variable.

Expand All @@ -721,6 +738,9 @@ def _averager(
Data variable inside a Dataset.
axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
List of axis dimensions to average over.
min_weight : float
Fraction of data coverage (i.e, weight) needed to return a
spatial average value. Value must range from 0 to 1.

Returns
-------
Expand All @@ -734,11 +754,68 @@ def _averager(
"""
weights = self._weights.fillna(0)
tomvothecoder marked this conversation as resolved.
Show resolved Hide resolved

dim = []
# TODO: This conditional might not be needed because Xarray will
# automatically broadcast the weights to the data variable for
# operations such as .mean() and .where().
if min_weight > 0.0:
weights, data_var = xr.broadcast(weights, data_var)
Comment on lines +754 to +758
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO: This conditional might not be needed because Xarray will
# automatically broadcast the weights to the data variable for
# operations such as .mean() and .where().
if min_weight > 0.0:
weights, data_var = xr.broadcast(weights, data_var)

Xarray will automatically broadcast to align the shapes of weights and data_var.


dim: List[str] = []
for key in axis:
dim.append(get_dim_keys(data_var, key))
dim.append(get_dim_keys(data_var, key)) # type: ignore

with xr.set_options(keep_attrs=True):
weighted_mean = data_var.cf.weighted(weights).mean(dim=dim)
dv_mean = data_var.cf.weighted(weights).mean(dim=dim)

if min_weight > 0.0:
dv_mean = self._mask_var_with_weight_threshold(
dv_mean, dim, weights, min_weight
)

return dv_mean

def _mask_var_with_weight_threshold(
self, dv: xr.DataArray, dim: List[str], weights: xr.DataArray, min_weight: float
) -> xr.DataArray:
"""Mask values that do not meet the minimum weight threshold with np.nan.

This function is useful for cases where the weighting of data might be
skewed based on the availability of data. For example, if a portion of
cells in a region has significantly more missing data than other other
regions, it can result in inaccurate calculations of spatial averaging.
Masking values that do not meet the minimum weight threshold ensures
more accurate calculations.

Parameters
----------
dv : xr.DataArray
The weighted variable.
dim: List[str]:
List of axis dimensions to average over.
weights : xr.DataArray
A DataArray containing either the regional weights used for weighted
averaging. ``weights`` must include the same axis dimensions and
dimensional sizes as the data variable.
min_weight : float
Fraction of data coverage (i.e, weight) needed to return a
spatial average value. Value must range from 0 to 1.

Returns
-------
xr.DataArray
The variable with the minimum weight threshold applied.
"""
# Sum all weights, including zero for missing values.
weight_sum_all = weights.sum(dim=dim)

masked_weights = _get_masked_weights(dv, weights)
weight_sum_masked = masked_weights.sum(dim=dim)

# Get fraction of the available weight.
frac = weight_sum_masked / weight_sum_all

# Nan out values that don't meet specified weight threshold.
dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True)
dv_new.name = dv.name

return weighted_mean
return dv_new
Loading
Loading