Skip to content

Commit 2e736ca

Browse files
committed
Initial prototype using Xarray SeasonGrouper`
1 parent 8d156c2 commit 2e736ca

File tree

1 file changed

+54
-17
lines changed

1 file changed

+54
-17
lines changed

xcdat/temporal.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from xarray.coding.cftime_offsets import get_date_type
1515
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
1616
from xarray.core.groupby import DataArrayGroupBy
17+
from xarray.groupers import SeasonGrouper, UniqueGrouper
1718

1819
from xcdat import bounds # noqa: F401
1920
from xcdat._logger import _setup_custom_logger
@@ -1091,7 +1092,10 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]]
10911092
f"Supported months include: {predefined_months}."
10921093
)
10931094

1094-
c_seasons = {"".join(months): months for months in custom_seasons}
1095+
c_seasons = {}
1096+
for season in custom_seasons:
1097+
key = "".join([month[0] for month in season])
1098+
c_seasons[key] = season
10951099

10961100
return c_seasons
10971101

@@ -1130,18 +1134,19 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
11301134
self._freq == "season"
11311135
and self._season_config.get("custom_seasons") is not None
11321136
):
1133-
# Get a flat list of all of the months included in the custom
1134-
# seasons to determine if the dataset needs to be subsetted
1135-
# on just those months. For example, if we define a custom season
1136-
# "NDJFM", we should subset the dataset for time coordinates
1137-
# belonging to those months.
11381137
months = self._season_config["custom_seasons"].values() # type: ignore
11391138
months = list(chain.from_iterable(months))
11401139

11411140
if len(months) != 12:
11421141
ds = self._subset_coords_for_custom_seasons(ds, months)
11431142

1144-
ds = self._shift_custom_season_years(ds)
1143+
# FIXME: This causes a bug when accessing `.groups` with
1144+
# SeasonGrouper(). Also shifting custom seasons is done for
1145+
# drop_incomplete_seasons and grouping for months that span the
1146+
# calendar year. The Xarray PR will handle both of these cases
1147+
# and this method will be removed.
1148+
# ds = self._shift_custom_season_years(ds)
1149+
pass
11451150

11461151
if self._freq == "season" and self._season_config.get("dec_mode") == "DJF":
11471152
ds = self._shift_djf_decembers(ds)
@@ -1153,11 +1158,11 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
11531158
):
11541159
ds = self._drop_incomplete_djf(ds)
11551160

1156-
if (
1157-
self._freq == "season"
1158-
and self._season_config["drop_incomplete_seasons"] is True
1159-
):
1160-
ds = self._drop_incomplete_seasons(ds)
1161+
# if (
1162+
# self._freq == "season"
1163+
# and self._season_config["drop_incomplete_seasons"] is True
1164+
# ):
1165+
# ds = self._drop_incomplete_seasons(ds)
11611166

11621167
return ds
11631168

@@ -1494,8 +1499,7 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
14941499

14951500
# Label the time coordinates for grouping weights and the data variable
14961501
# values.
1497-
self._labeled_time = self._label_time_coords(dv[self.dim])
1498-
dv = dv.assign_coords({self.dim: self._labeled_time})
1502+
dv_grouped = self._label_time_coords_for_grouping(dv)
14991503

15001504
if self._weighted:
15011505
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
@@ -1514,13 +1518,14 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
15141518
# Perform weighted average using the formula
15151519
# WA = sum(data*weights) / sum(weights). The denominator must be
15161520
# included to take into account zero weight for missing data.
1521+
weights_gb = self._label_time_coords_for_grouping(weights)
15171522
with xr.set_options(keep_attrs=True):
1518-
dv = self._group_data(dv).sum() / self._group_data(weights).sum()
1523+
dv = dv_grouped.sum() / weights_gb.sum()
15191524

15201525
# Restore the data variable's name.
15211526
dv.name = data_var
15221527
else:
1523-
dv = self._group_data(dv).mean()
1528+
dv = dv_grouped.mean()
15241529

15251530
# After grouping and aggregating, the grouped time dimension's
15261531
# attributes are removed. Xarray's `keep_attrs=True` option only keeps
@@ -1578,7 +1583,10 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:
15781583

15791584
time_lengths = time_lengths.astype(np.float64)
15801585

1581-
grouped_time_lengths = self._group_data(time_lengths)
1586+
grouped_time_lengths = self._label_time_coords_for_grouping(time_lengths)
1587+
# FIXME: File "/opt/miniconda3/envs/xcdat_dev_416_xr/lib/python3.12/site-packages/xarray/core/groupby.py", line 639, in _raise_if_not_single_group
1588+
# raise NotImplementedError(
1589+
# NotImplementedError: This method is not supported for grouping by multiple variables yet.
15821590
weights: xr.DataArray = grouped_time_lengths / grouped_time_lengths.sum()
15831591
weights.name = f"{self.dim}_wts"
15841592

@@ -1670,6 +1678,35 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray:
16701678

16711679
return time_grouped
16721680

1681+
def _label_time_coords_for_grouping(self, dv: xr.DataArray) -> DataArrayGroupBy:
1682+
# Use the TIME_GROUPS dictionary to determine which components
1683+
# are needed to form the labeled time coordinates.
1684+
dt_comps = TIME_GROUPS[self._mode][self._freq]
1685+
dt_comps_map: Dict[str, UniqueGrouper | SeasonGrouper] = {
1686+
comp: UniqueGrouper() for comp in dt_comps if comp != "season"
1687+
}
1688+
1689+
dv_new = dv.copy()
1690+
for comp in dt_comps_map.keys():
1691+
dv_new.coords[comp] = dv_new[self.dim][f"{self.dim}.{comp}"]
1692+
1693+
if self._freq == "season":
1694+
custom_seasons = self._season_config.get("custom_seasons")
1695+
# NOTE: SeasonGrouper() does not drop incomplete seasons yet.
1696+
# TODO: Add `drop_incomplete` arg once available.
1697+
1698+
if custom_seasons is not None:
1699+
season_keys = list(custom_seasons.keys())
1700+
season_grouper = SeasonGrouper(season_keys)
1701+
else:
1702+
season_keys = list(SEASON_TO_MONTH.keys())
1703+
season_grouper = SeasonGrouper(season_keys)
1704+
1705+
dt_comps_map[self.dim] = season_grouper
1706+
dv_gb = dv_new.groupby(**dt_comps_map) # type: ignore
1707+
1708+
return dv_gb
1709+
16731710
def _get_df_dt_components(
16741711
self, time_coords: xr.DataArray, drop_obsolete_cols: bool
16751712
) -> pd.DataFrame:

0 commit comments

Comments
 (0)