14
14
from xarray .coding .cftime_offsets import get_date_type
15
15
from xarray .core .common import contains_cftime_datetimes , is_np_datetime_like
16
16
from xarray .core .groupby import DataArrayGroupBy
17
+ from xarray .groupers import SeasonGrouper , UniqueGrouper
17
18
18
19
from xcdat import bounds # noqa: F401
19
20
from xcdat ._logger import _setup_custom_logger
@@ -1091,7 +1092,10 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]]
1091
1092
f"Supported months include: { predefined_months } ."
1092
1093
)
1093
1094
1094
- c_seasons = {"" .join (months ): months for months in custom_seasons }
1095
+ c_seasons = {}
1096
+ for season in custom_seasons :
1097
+ key = "" .join ([month [0 ] for month in season ])
1098
+ c_seasons [key ] = season
1095
1099
1096
1100
return c_seasons
1097
1101
@@ -1130,18 +1134,19 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
1130
1134
self ._freq == "season"
1131
1135
and self ._season_config .get ("custom_seasons" ) is not None
1132
1136
):
1133
- # Get a flat list of all of the months included in the custom
1134
- # seasons to determine if the dataset needs to be subsetted
1135
- # on just those months. For example, if we define a custom season
1136
- # "NDJFM", we should subset the dataset for time coordinates
1137
- # belonging to those months.
1138
1137
months = self ._season_config ["custom_seasons" ].values () # type: ignore
1139
1138
months = list (chain .from_iterable (months ))
1140
1139
1141
1140
if len (months ) != 12 :
1142
1141
ds = self ._subset_coords_for_custom_seasons (ds , months )
1143
1142
1144
- ds = self ._shift_custom_season_years (ds )
1143
+ # FIXME: This causes a bug when accessing `.groups` with
1144
+ # SeasonGrouper(). Also shifting custom seasons is done for
1145
+ # drop_incomplete_seasons and grouping for months that span the
1146
+ # calendar year. The Xarray PR will handle both of these cases
1147
+ # and this method will be removed.
1148
+ # ds = self._shift_custom_season_years(ds)
1149
+ pass
1145
1150
1146
1151
if self ._freq == "season" and self ._season_config .get ("dec_mode" ) == "DJF" :
1147
1152
ds = self ._shift_djf_decembers (ds )
@@ -1153,11 +1158,11 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
1153
1158
):
1154
1159
ds = self ._drop_incomplete_djf (ds )
1155
1160
1156
- if (
1157
- self ._freq == "season"
1158
- and self ._season_config ["drop_incomplete_seasons" ] is True
1159
- ):
1160
- ds = self ._drop_incomplete_seasons (ds )
1161
+ # if (
1162
+ # self._freq == "season"
1163
+ # and self._season_config["drop_incomplete_seasons"] is True
1164
+ # ):
1165
+ # ds = self._drop_incomplete_seasons(ds)
1161
1166
1162
1167
return ds
1163
1168
@@ -1494,8 +1499,7 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
1494
1499
1495
1500
# Label the time coordinates for grouping weights and the data variable
1496
1501
# values.
1497
- self ._labeled_time = self ._label_time_coords (dv [self .dim ])
1498
- dv = dv .assign_coords ({self .dim : self ._labeled_time })
1502
+ dv_grouped = self ._label_time_coords_for_grouping (dv )
1499
1503
1500
1504
if self ._weighted :
1501
1505
time_bounds = ds .bounds .get_bounds ("T" , var_key = data_var )
@@ -1514,13 +1518,14 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
1514
1518
# Perform weighted average using the formula
1515
1519
# WA = sum(data*weights) / sum(weights). The denominator must be
1516
1520
# included to take into account zero weight for missing data.
1521
+ weights_gb = self ._label_time_coords_for_grouping (weights )
1517
1522
with xr .set_options (keep_attrs = True ):
1518
- dv = self . _group_data ( dv ). sum () / self . _group_data ( weights ) .sum ()
1523
+ dv = dv_grouped . sum () / weights_gb .sum ()
1519
1524
1520
1525
# Restore the data variable's name.
1521
1526
dv .name = data_var
1522
1527
else :
1523
- dv = self . _group_data ( dv ) .mean ()
1528
+ dv = dv_grouped .mean ()
1524
1529
1525
1530
# After grouping and aggregating, the grouped time dimension's
1526
1531
# attributes are removed. Xarray's `keep_attrs=True` option only keeps
@@ -1578,7 +1583,10 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:
1578
1583
1579
1584
time_lengths = time_lengths .astype (np .float64 )
1580
1585
1581
- grouped_time_lengths = self ._group_data (time_lengths )
1586
+ grouped_time_lengths = self ._label_time_coords_for_grouping (time_lengths )
1587
+ # FIXME: File "/opt/miniconda3/envs/xcdat_dev_416_xr/lib/python3.12/site-packages/xarray/core/groupby.py", line 639, in _raise_if_not_single_group
1588
+ # raise NotImplementedError(
1589
+ # NotImplementedError: This method is not supported for grouping by multiple variables yet.
1582
1590
weights : xr .DataArray = grouped_time_lengths / grouped_time_lengths .sum ()
1583
1591
weights .name = f"{ self .dim } _wts"
1584
1592
@@ -1670,6 +1678,35 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray:
1670
1678
1671
1679
return time_grouped
1672
1680
1681
+ def _label_time_coords_for_grouping (self , dv : xr .DataArray ) -> DataArrayGroupBy :
1682
+ # Use the TIME_GROUPS dictionary to determine which components
1683
+ # are needed to form the labeled time coordinates.
1684
+ dt_comps = TIME_GROUPS [self ._mode ][self ._freq ]
1685
+ dt_comps_map : Dict [str , UniqueGrouper | SeasonGrouper ] = {
1686
+ comp : UniqueGrouper () for comp in dt_comps if comp != "season"
1687
+ }
1688
+
1689
+ dv_new = dv .copy ()
1690
+ for comp in dt_comps_map .keys ():
1691
+ dv_new .coords [comp ] = dv_new [self .dim ][f"{ self .dim } .{ comp } " ]
1692
+
1693
+ if self ._freq == "season" :
1694
+ custom_seasons = self ._season_config .get ("custom_seasons" )
1695
+ # NOTE: SeasonGrouper() does not drop incomplete seasons yet.
1696
+ # TODO: Add `drop_incomplete` arg once available.
1697
+
1698
+ if custom_seasons is not None :
1699
+ season_keys = list (custom_seasons .keys ())
1700
+ season_grouper = SeasonGrouper (season_keys )
1701
+ else :
1702
+ season_keys = list (SEASON_TO_MONTH .keys ())
1703
+ season_grouper = SeasonGrouper (season_keys )
1704
+
1705
+ dt_comps_map [self .dim ] = season_grouper
1706
+ dv_gb = dv_new .groupby (** dt_comps_map ) # type: ignore
1707
+
1708
+ return dv_gb
1709
+
1673
1710
def _get_df_dt_components (
1674
1711
self , time_coords : xr .DataArray , drop_obsolete_cols : bool
1675
1712
) -> pd .DataFrame :
0 commit comments