From 7716efcb990ef1916eea7670121f3e23c242a57b Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Wed, 1 Mar 2023 09:32:08 -0800 Subject: [PATCH 01/17] Add baseline unit tests - Remove logic for requiring all 12 months to be used --- tests/test_temporal.py | 150 +++++++++++++++++++++++++++++++++++++---- xcdat/temporal.py | 4 -- 2 files changed, 136 insertions(+), 18 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index e5489b1b..388a97eb 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -789,6 +789,71 @@ def test_weighted_custom_seasonal_averages(self): xr.testing.assert_identical(result, expected) + def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_and_drop_incomplete_seasons( + self, + ): + ds = self.ds.copy() + custom_seasons = [ + ["Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov"], + ["Dec", "Jan", "Feb", "Mar"], + ] + + result = ds.temporal.group_average( + "ts", + "season", + season_config={ + "custom_seasons": custom_seasons, + "drop_incomplete_seasons": True, + }, + ) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[1.5]], [[1.0]], [[1.0]], [[2.0]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(2000, 2, 1), + cftime.DatetimeGregorian(2000, 5, 1), + cftime.DatetimeGregorian(2000, 8, 1), + cftime.DatetimeGregorian(2001, 2, 1), + ], + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", + "custom_seasons": [ + "JanFebMar", + "AprMayJun", + "JulAugSep", + "OctNovDec", + ], + "weighted": "True", + }, + ) + + assert result.identical(expected) + + @pytest.mark.xfail + def test_weighted_custom_seasonal_averages_with_single_custom_season(self): + assert 0 + def test_raises_error_with_incorrect_custom_seasons_argument(self): # Test raises error with non-3 letter strings with pytest.raises(ValueError): @@ -804,20 +869,6 @@ def test_raises_error_with_incorrect_custom_seasons_argument(self): season_config={"custom_seasons": custom_seasons}, ) - # Test raises error with missing month(s) - with pytest.raises(ValueError): - custom_seasons = [ - ["Feb", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] - self.ds.temporal.group_average( - "ts", - "season", - season_config={"custom_seasons": custom_seasons}, - ) - # Test raises error if duplicate month(s) were found with pytest.raises(ValueError): custom_seasons = [ @@ -1330,6 +1381,77 @@ def test_weighted_custom_seasonal_climatology(self): xr.testing.assert_identical(result, expected) + @pytest.mark.xfail + def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_years_and_drop_incomplete_seasons( + self, + ): + ds = self.ds.copy() + + custom_seasons = [ + ["Jan", "Feb", "Mar"], + ["Apr", "May", "Jun"], + ["Jul", "Aug", "Sep"], + ["Oct", "Nov", "Dec"], + ] + result = ds.temporal.climatology( + "ts", + "season", + season_config={ + "custom_seasons": custom_seasons, + "drop_incomplete_seasons": True, + }, + ) + + expected = ds.copy() + expected = expected.drop_dims("time") + expected_time = xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(1, 2, 1), + cftime.DatetimeGregorian(1, 5, 1), + cftime.DatetimeGregorian(1, 8, 1), + cftime.DatetimeGregorian(1, 11, 1), + ], + ), + coords={ + "time": np.array( + [ + cftime.DatetimeGregorian(1, 2, 1), + cftime.DatetimeGregorian(1, 5, 1), + cftime.DatetimeGregorian(1, 8, 1), + cftime.DatetimeGregorian(1, 11, 1), + ], + ), + }, + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ) + + expected["ts"] = xr.DataArray( + name="ts", + data=np.ones((4, 4, 4)), + coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time}, + dims=["time", "lat", "lon"], + attrs={ + "operation": "temporal_avg", + "mode": "climatology", + "freq": "season", + "weighted": "True", + "custom_seasons": [ + "JanFebMar", + "AprMayJun", + "JulAugSep", + "OctNovDec", + ], + }, + ) + + assert result.identical(expected) + def test_weighted_monthly_climatology(self): result = self.ds.temporal.climatology("ts", "month") diff --git a/xcdat/temporal.py b/xcdat/temporal.py index ce44bbfd..7242d897 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -997,10 +997,6 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]] predefined_months = list(MONTH_INT_TO_STR.values()) input_months = list(chain.from_iterable(custom_seasons)) - if len(input_months) != len(predefined_months): - raise ValueError( - "Exactly 12 months were not passed in the list of custom seasons." - ) if len(input_months) != len(set(input_months)): raise ValueError( "Duplicate month(s) were found in the list of custom seasons." From ade02bac717fa6df2514eabbb4d1a1a707fa3c6a Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Thu, 2 Mar 2023 16:53:37 -0800 Subject: [PATCH 02/17] Add support for shifting months in seasons spanning calendar year --- tests/test_temporal.py | 154 +++++++++++++++++++++++++++-------------- xcdat/temporal.py | 61 +++++++++++++++- 2 files changed, 162 insertions(+), 53 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 388a97eb..7c85a346 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -731,6 +731,35 @@ def test_weighted_seasonal_averages_with_JFD(self): xr.testing.assert_identical(result, expected) + def test_raises_error_with_incorrect_custom_seasons_argument(self): + # Test raises error with non-3 letter strings + with pytest.raises(ValueError): + custom_seasons = [ + ["J", "Feb", "Mar"], + ["Apr", "May", "Jun"], + ["Jul", "Aug", "Sep"], + ["Oct", "Nov", "Dec"], + ] + self.ds.temporal.group_average( + "ts", + "season", + season_config={"custom_seasons": custom_seasons}, + ) + + # Test raises error if duplicate month(s) were found + with pytest.raises(ValueError): + custom_seasons = [ + ["Jan", "Jan", "Mar"], + ["Apr", "May", "Jun"], + ["Jul", "Aug", "Sep"], + ["Oct", "Nov", "Dec"], + ] + self.ds.temporal.group_average( + "ts", + "season", + season_config={"custom_seasons": custom_seasons}, + ) + def test_weighted_custom_seasonal_averages(self): ds = self.ds.copy() @@ -789,13 +818,21 @@ def test_weighted_custom_seasonal_averages(self): xr.testing.assert_identical(result, expected) - def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_and_drop_incomplete_seasons( - self, - ): + def test_weighted_seasonal_averages_drops_incomplete_seasons(self): ds = self.ds.copy() + ds["time"].values[:] = np.array( + [ + "2000-11-16T12:00:00.000000000", + "2000-12-16T12:00:00.000000000", + "2001-01-16T00:00:00.000000000", + "2001-02-16T00:00:00.000000000", + "2001-03-16T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ) + custom_seasons = [ - ["Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov"], - ["Dec", "Jan", "Feb", "Mar"], + ["Nov", "Dec", "Jan", "Feb", "Mar"], ] result = ds.temporal.group_average( @@ -803,26 +840,19 @@ def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_ "season", season_config={ "custom_seasons": custom_seasons, - "drop_incomplete_seasons": True, + # "drop_incomplete_seasons": True, }, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[1.5]], [[1.0]], [[1.0]], [[2.0]]]), + data=np.array([[[1.3933333]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( - data=np.array( - [ - cftime.DatetimeGregorian(2000, 2, 1), - cftime.DatetimeGregorian(2000, 5, 1), - cftime.DatetimeGregorian(2000, 8, 1), - cftime.DatetimeGregorian(2001, 2, 1), - ], - ), + data=np.array([cftime.datetime(2001, 1, 1)], dtype=object), dims=["time"], attrs={ "axis": "T", @@ -838,50 +868,70 @@ def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_ "operation": "temporal_avg", "mode": "group_average", "freq": "season", - "custom_seasons": [ - "JanFebMar", - "AprMayJun", - "JulAugSep", - "OctNovDec", - ], + "custom_seasons": ["NovDecJanFebMar"], "weighted": "True", }, ) - assert result.identical(expected) + xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs - @pytest.mark.xfail - def test_weighted_custom_seasonal_averages_with_single_custom_season(self): - assert 0 + def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years( + self, + ): + ds = self.ds.copy() + ds["time"].values[:] = np.array( + [ + "2000-11-16T12:00:00.000000000", + "2000-12-16T12:00:00.000000000", + "2001-01-16T00:00:00.000000000", + "2001-02-16T00:00:00.000000000", + "2001-03-16T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ) - def test_raises_error_with_incorrect_custom_seasons_argument(self): - # Test raises error with non-3 letter strings - with pytest.raises(ValueError): - custom_seasons = [ - ["J", "Feb", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] - self.ds.temporal.group_average( - "ts", - "season", - season_config={"custom_seasons": custom_seasons}, - ) + custom_seasons = [ + ["Nov", "Dec", "Jan", "Feb", "Mar"], + ] - # Test raises error if duplicate month(s) were found - with pytest.raises(ValueError): - custom_seasons = [ - ["Jan", "Jan", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] - self.ds.temporal.group_average( - "ts", - "season", - season_config={"custom_seasons": custom_seasons}, - ) + result = ds.temporal.group_average( + "ts", + "season", + season_config={"custom_seasons": custom_seasons}, + ) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[1.3933333]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array([cftime.datetime(2001, 1, 1)], dtype=object), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", + "custom_seasons": ["NovDecJanFebMar"], + "weighted": "True", + }, + ) + + xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs def test_weighted_monthly_averages(self): ds = self.ds.copy() diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 7242d897..890a3f7b 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -104,6 +104,7 @@ 11: "Nov", 12: "Dec", } +MONTH_STR_TO_INT = {v: k for k, v in MONTH_INT_TO_STR.items()} # A dictionary mapping pre-defined seasons to their middle month. This # dictionary is used during the creation of datetime objects, which don't @@ -1421,6 +1422,7 @@ def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame: if custom_seasons is not None: df_new = self._map_months_to_custom_seasons(df_new) + df_new = self._shift_spanning_months(df_new) else: if dec_mode == "DJF": df_new = self._shift_decembers(df_new) @@ -1463,6 +1465,64 @@ def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame: return df_new + def _shift_spanning_months(self, df: pd.DataFrame) -> pd.DataFrame: + """Shifts months in seasons spanning the previous year to the next year. + + A season spans the previous year if it includes the month of "Jan" and + "Jan" is not the first month of the season. For example, let's say we + define ``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]`` to + represent the southern hemisphere growing seasons, "NDJFM". + - ["Nov", "Dec"] are from the previous year since they are listed + before "Jan". + - ["Jan", "Feb", "Mar"] are from the current year. + + Therefore, we need to shift ["Nov", "Dec"] a year forward in order for + xarray to group seasons correctly. Refer to the examples section below + for a visual demonstration. + + Parameters + ---------- + df : pd.Dataframe + The DataFrame of xarray datetime components produced using the + "season" frequency". + + Returns + ------- + pd.DataFrame + The DataFrame of xarray dataetime copmonents with months spanning + previous year shifted over to the next year. + + Examples + -------- + + Before and after shifting months for "NDJFM" seasons: + + >>> # Before shifting months + >>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] + + >>> # After shifting months + >>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 1), (2001, "NDJFM", 2)] + """ + df_new = df.copy() + custom_seasons = self._season_config["custom_seasons"] + + span_months: List[int] = [] + for months in custom_seasons.values(): # type: ignore + month_nums = [MONTH_STR_TO_INT[month] for month in months] + try: + jan_index = month_nums.index(1) + span_months = span_months + month_nums[:jan_index] + break + except ValueError: + continue + + if len(span_months) > 0: + df_new.loc[df_new["month"].isin(span_months), "year"] = df_new["year"] + 1 + + return df_new + def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame: """Shifts Decembers over to the next year for "DJF" seasons in-place. @@ -1496,7 +1556,6 @@ def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame: >>> # "DJF" (shifted Decembers) >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12), >>> (2001, "DJF", 1), (2001, "DJF", 2)] - """ df_season.loc[df_season["month"] == 12, "year"] = df_season["year"] + 1 From 2a87cf0ec5495fec7a16b72f78e9bf124d4b61da Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Thu, 2 Mar 2023 17:00:10 -0800 Subject: [PATCH 03/17] Add explaination of code for spanning months --- xcdat/temporal.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 890a3f7b..c3b83344 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1509,11 +1509,17 @@ def _shift_spanning_months(self, df: pd.DataFrame) -> pd.DataFrame: custom_seasons = self._season_config["custom_seasons"] span_months: List[int] = [] + + # Loop over the custom seasons and get the list of months for the + # current season. Convert those months to their integer representations. + # If 1 ("Jan") is in the list of months and it is NOT the first element, + # then get all elements before it (aka the spanning months). for months in custom_seasons.values(): # type: ignore month_nums = [MONTH_STR_TO_INT[month] for month in months] try: jan_index = month_nums.index(1) - span_months = span_months + month_nums[:jan_index] + if jan_index != 0: + span_months = span_months + month_nums[:jan_index] break except ValueError: continue From 6b4ec6528c87fd8fecad61ad30f4705059c8c8a9 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Fri, 3 Mar 2023 15:51:42 -0800 Subject: [PATCH 04/17] Replace `drop_incomplete_djf` with `drop_incomplete_seasons` --- tests/test_temporal.py | 57 +++++------ xcdat/temporal.py | 225 ++++++++++++++++++++++++----------------- 2 files changed, 162 insertions(+), 120 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 7c85a346..da37f7b7 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -577,7 +577,7 @@ def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): result = ds.temporal.group_average( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) expected = ds.copy() # Drop the incomplete DJF seasons @@ -615,7 +615,7 @@ def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) @@ -629,7 +629,7 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons result = ds.temporal.group_average( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": False}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") @@ -666,7 +666,7 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "False", + "drop_incomplete_seasons": "False", }, ) @@ -818,7 +818,7 @@ def test_weighted_custom_seasonal_averages(self): xr.testing.assert_identical(result, expected) - def test_weighted_seasonal_averages_drops_incomplete_seasons(self): + def test_weighted_custom_seasonal_averages_drops_incomplete_seasons(self): ds = self.ds.copy() ds["time"].values[:] = np.array( [ @@ -831,28 +831,26 @@ def test_weighted_seasonal_averages_drops_incomplete_seasons(self): dtype="datetime64[ns]", ) - custom_seasons = [ - ["Nov", "Dec", "Jan", "Feb", "Mar"], - ] + custom_seasons = [["Nov", "Dec"], ["Feb", "Mar", "Apr"]] result = ds.temporal.group_average( "ts", "season", season_config={ + "drop_incomplete_seasons": True, "custom_seasons": custom_seasons, - # "drop_incomplete_seasons": True, }, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[1.3933333]]]), + data=np.array([[[1.5]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( - data=np.array([cftime.datetime(2001, 1, 1)], dtype=object), + data=np.array([cftime.datetime(2000, 12, 1)], dtype=object), dims=["time"], attrs={ "axis": "T", @@ -868,13 +866,12 @@ def test_weighted_seasonal_averages_drops_incomplete_seasons(self): "operation": "temporal_avg", "mode": "group_average", "freq": "season", - "custom_seasons": ["NovDecJanFebMar"], + "custom_seasons": ["NovDec", "FebMarApr"], "weighted": "True", }, ) - xr.testing.assert_allclose(result, expected) - assert result.ts.attrs == expected.ts.attrs + assert result.identical(expected) def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years( self, @@ -1214,7 +1211,7 @@ def test_weighted_seasonal_climatology_with_DJF(self): result = ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) expected = ds.copy() @@ -1256,7 +1253,7 @@ def test_weighted_seasonal_climatology_with_DJF(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) @@ -1269,7 +1266,7 @@ def test_chunked_weighted_seasonal_climatology_with_DJF(self): result = ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) expected = ds.copy() @@ -1311,7 +1308,7 @@ def test_chunked_weighted_seasonal_climatology_with_DJF(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) @@ -1432,7 +1429,7 @@ def test_weighted_custom_seasonal_climatology(self): xr.testing.assert_identical(result, expected) @pytest.mark.xfail - def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_years_and_drop_incomplete_seasons( + def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_years( self, ): ds = self.ds.copy() @@ -2076,7 +2073,7 @@ def test_weighted_seasonal_departures_with_DJF(self): "ts", "season", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) expected = ds.copy() @@ -2113,7 +2110,7 @@ def test_weighted_seasonal_departures_with_DJF(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) @@ -2127,7 +2124,7 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): "season", weighted=True, keep_weights=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) expected = ds.copy() @@ -2164,7 +2161,7 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) expected["time_wts"] = xr.DataArray( @@ -2202,7 +2199,7 @@ def test_unweighted_seasonal_departures_with_DJF(self): "ts", "season", weighted=False, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) expected = ds.copy() @@ -2239,7 +2236,7 @@ def test_unweighted_seasonal_departures_with_DJF(self): "freq": "season", "weighted": "False", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) @@ -3462,7 +3459,7 @@ def test_raises_error_with_incorrect_mode_arg(self): weighted=True, season_config={ "dec_mode": "DJF", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) @@ -3478,7 +3475,7 @@ def test_raises_error_if_freq_arg_is_not_supported_by_operation(self): weighted=True, season_config={ "dec_mode": "DJF", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) @@ -3490,7 +3487,7 @@ def test_raises_error_if_freq_arg_is_not_supported_by_operation(self): weighted=True, season_config={ "dec_mode": "DJF", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) @@ -3502,7 +3499,7 @@ def test_raises_error_if_freq_arg_is_not_supported_by_operation(self): weighted=True, season_config={ "dec_mode": "DJF", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) @@ -3528,7 +3525,7 @@ def test_raises_error_if_december_mode_is_not_supported(self): weighted=True, season_config={ "dec_mode": "unsupported", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index c3b83344..a200bc49 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -67,7 +67,7 @@ "SeasonConfigInput", { "dec_mode": Literal["DJF", "JFD"], - "drop_incomplete_djf": bool, + "drop_incomplete_seasons": bool, "custom_seasons": Optional[List[List[str]]], }, total=False, @@ -77,7 +77,7 @@ "SeasonConfigAttr", { "dec_mode": Literal["DJF", "JFD"], - "drop_incomplete_djf": bool, + "drop_incomplete_seasons": bool, "custom_seasons": Optional[Dict[str, List[str]]], }, total=False, @@ -85,7 +85,7 @@ DEFAULT_SEASON_CONFIG: SeasonConfigInput = { "dec_mode": "DJF", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, } @@ -296,11 +296,18 @@ def group_average( Xarray labels the season with December as "DJF", but it is actually "JFD". - * "drop_incomplete_djf" (bool, by default False) - If the "dec_mode" is "DJF", this flag drops (True) or keeps - (False) time coordinates that fall under incomplete DJF seasons - Incomplete DJF seasons include the start year Jan/Feb and the - end year Dec. + * "drop_incomplete_seasons" (bool, by default False) + Seasons are considered incomplete if they do not have all of + the required months to form the season. For example, if we have + the time coordinates ["2000-11-16", "2000-12-16", "2001-01-16", + "2001-02-16"] and we want to group seasons by "ND" ("Nov", + "Dec") and "JFM" ("Jan", "Feb", "Mar"). + + * ["2000-11-16", "2000-12-16"] is considered a complete "ND" + season since both "Nov" and "Dec" are present. + * ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. Configs for custom seasons: @@ -350,7 +357,7 @@ def group_average( >>> "season", >>> season_config={ >>> "dec_mode": "DJF", - >>> "drop_incomplete_season": True + >>> "drop_incomplete_seasons": True >>> } >>> ) >>> ds_season.ts @@ -386,7 +393,7 @@ def group_average( 'freq': 'season', 'weighted': 'True', 'dec_mode': 'DJF', - 'drop_incomplete_djf': 'False' + 'drop_incomplete_seasons': 'False' } """ self._set_data_var_attrs(data_var) @@ -461,6 +468,21 @@ def climatology( predefined seasons are passed, configs for custom seasons are ignored and vice versa. + General configs: + + * "drop_incomplete_seasons" (bool, by default False) + Seasons are considered incomplete if they do not have all of + the required months to form the season. For example, if we have + the time coordinates ["2000-11-16", "2000-12-16", "2001-01-16", + "2001-02-16"] and we want to group seasons by "ND" ("Nov", + "Dec") and "JFM" ("Jan", "Feb", "Mar"). + + * ["2000-11-16", "2000-12-16"] is considered a complete "ND" + season since both "Nov" and "Dec" are present. + * ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. + Configs for predefined seasons: * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") @@ -471,12 +493,6 @@ def climatology( Xarray labels the season with December as "DJF", but it is actually "JFD". - * "drop_incomplete_djf" (bool, by default False) - If the "dec_mode" is "DJF", this flag drops (True) or keeps - (False) time coordinates that fall under incomplete DJF seasons - Incomplete DJF seasons include the start year Jan/Feb and the - end year Dec. - Configs for custom seasons: * "custom_seasons" ([List[List[str]]], by default None) @@ -529,7 +545,7 @@ def climatology( >>> "season", >>> season_config={ >>> "dec_mode": "DJF", - >>> "drop_incomplete_season": True + >>> "drop_incomplete_seasons": True >>> } >>> ) >>> ds_season.ts @@ -565,7 +581,7 @@ def climatology( 'freq': 'season', 'weighted': 'True', 'dec_mode': 'DJF', - 'drop_incomplete_djf': 'False' + 'drop_incomplete_seasons': 'False' } """ self._set_data_var_attrs(data_var) @@ -648,6 +664,21 @@ def departures( predefined seasons are passed, configs for custom seasons are ignored and vice versa. + General configs: + + * "drop_incomplete_seasons" (bool, by default False) + Seasons are considered incomplete if they do not have all of + the required months to form the season. For example, if we have + the time coordinates ["2000-11-16", "2000-12-16", "2001-01-16", + "2001-02-16"] and we want to group seasons by "ND" ("Nov", + "Dec") and "JFM" ("Jan", "Feb", "Mar"). + + * ["2000-11-16", "2000-12-16"] is considered a complete "ND" + season since both "Nov" and "Dec" are present. + * ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. + Configs for predefined seasons: * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") @@ -658,12 +689,6 @@ def departures( Xarray labels the season with December as "DJF", but it is actually "JFD". - * "drop_incomplete_djf" (bool, by default False) - If the "dec_mode" is "DJF", this flag drops (True) or keeps - (False) time coordinates that fall under incomplete DJF seasons - Incomplete DJF seasons include the start year Jan/Feb and the - end year Dec. - Configs for custom seasons: * "custom_seasons" ([List[List[str]]], by default None) @@ -731,7 +756,7 @@ def departures( 'frequency': 'season', 'weighted': 'True', 'dec_mode': 'DJF', - 'drop_incomplete_djf': 'False' + 'drop_incomplete_seasons': 'False' } """ # 1. Set the attributes for this instance of `TemporalAccessor`. @@ -941,9 +966,12 @@ def _set_arg_attrs( ) custom_seasons = season_config.get("custom_seasons", None) dec_mode = season_config.get("dec_mode", "DJF") - drop_incomplete_djf = season_config.get("drop_incomplete_djf", False) self._season_config: SeasonConfigAttr = {} + self._season_config["drop_incomplete_seasons"] = season_config.get( + "drop_incomplete_seasons", False + ) + if custom_seasons is None: if dec_mode not in ("DJF", "JFD"): raise ValueError( @@ -952,8 +980,6 @@ def _set_arg_attrs( ) self._season_config["dec_mode"] = dec_mode - if dec_mode == "DJF": - self._season_config["drop_incomplete_djf"] = drop_incomplete_djf else: self._season_config["custom_seasons"] = self._form_seasons(custom_seasons) @@ -1032,10 +1058,9 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: """ if ( self._freq == "season" - and self._season_config.get("dec_mode") == "DJF" - and self._season_config.get("drop_incomplete_djf") is True + and self._season_config.get("drop_incomplete_seasons") is True ): - ds = self._drop_incomplete_djf(ds) + ds = self._drop_incomplete_seasons(ds) if ( self._freq == "day" @@ -1051,53 +1076,63 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: return ds - def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset: - """Drops incomplete DJF seasons within a continuous time series. + def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: + """Drops incomplete seasons within a continuous time series. - This method assumes that the time series is continuous and removes the - leading and trailing incomplete seasons (e.g., the first January and - February of a time series that are not complete, because the December of - the previous year is missing). This method does not account for or - remove missing time steps anywhere else. + Seasons are considered incomplete if they do not have all of the + required months to form the season. For example, if we have the time + coordinates ["2000-11-16", "2000-12-16", "2001-01-16", "2001-02-16"] + and we want to group seasons by "ND" ("Nov", "Dec") and "JFM" ("Jan", + "Feb", "Mar"). + - ["2000-11-16", "2000-12-16"] is considered a complete "ND" season + since both "Nov" and "Dec" are present. + - ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. Parameters ---------- - dataset : xr.Dataset - The dataset with some possibly incomplete DJF seasons. + df : pd.DataFrame + A DataFrame of seasonal datetime components with potentially + incomplete seasons. Returns ------- - xr.Dataset - The dataset with only complete DJF seasons. + pd.DataFrame + A DataFrame of seasonal datetime components with only complete + seasons. """ - # Separate the dataset into two datasets, one with and one without - # the time dimension. This is necessary because the xarray .where() - # method concatenates the time dimension to non-time dimension data - # vars, which is not a desired behavior. - ds = dataset.copy() - ds_time = ds.get([v for v in ds.data_vars if self.dim in ds[v].dims]) # type: ignore - ds_no_time = ds.get([v for v in ds.data_vars if self.dim not in ds[v].dims]) # type: ignore - - start_year, end_year = ( - ds[self.dim].dt.year.values[0], - ds[self.dim].dt.year.values[-1], - ) - incomplete_seasons = ( - f"{int(start_year):04d}-01", - f"{int(start_year):04d}-02", - f"{int(end_year):04d}-12", - ) - - for year_month in incomplete_seasons: - try: - coord_pt = ds.loc[dict(time=year_month)][self.dim][0] - ds_time = ds_time.where(ds_time[self.dim] != coord_pt, drop=True) - except (KeyError, IndexError): - continue - - ds_final = xr.merge((ds_time, ds_no_time)) - - return ds_final + # Algorithm + # Prereq - This needs to be done AFTER time coordinates are labeled + # and BEFORE obsoelete columns are dropped because custom seasons can be + # assigned to the time coordiantes first. + # 1. Get the count of months per season (pre-defined seasons by xarray + # all have 3), otherwise use custom seasons count + # 2. Label all time coordinates by groups + # 3. Group the time coordinates by group and the get count + # 4. Drop time coordinates where count != expected count for season + ds_new = ds.copy() + time_coords = ds[self.dim].copy() + + # Transform the time coords into a DataFrame of seasonal datetime + # components based on the grouping mode. + df = self._get_df_dt_components(time_coords, drop_obsolete_cols=False) + + # Add a column for the expected count of months for that season + # For example, "NovDec" is split into ["Nov", "Dec"] which equals an + # expected count of 2 months. + df["expected_months"] = df["season"].str.split(r"(?<=.)(?=[A-Z])").str.len() + # Add a column for the actual count of months for that season. + df["actual_months"] = df.groupby(["season"])["year"].transform("count") + + # Get the incomplete seasons and drop the time coordinates that are in + # those incomplete seasons. + indexes_to_drop = df[df["expected_months"] != df["actual_months"]].index + if len(indexes_to_drop) > 0: + coords_to_drop = time_coords.values[indexes_to_drop] + ds_new = ds_new.where(~time_coords.isin(coords_to_drop), drop=True) + + return ds_new def _drop_leap_days(self, ds: xr.Dataset): """Drop leap days from time coordinates. @@ -1291,9 +1326,9 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray: This methods labels time coordinates for grouping by first extracting specific xarray datetime components from time coordinates and storing them in a pandas DataFrame. After processing (if necessary) is performed - on the DataFrame, it is converted to a numpy array of datetime - objects. This numpy serves as the data source for the final - DataArray of labeled time coordinates. + on the DataFrame, it is converted to a numpy array of datetime objects. + This numpy array serves as the data source for the final DataArray of + labeled time coordinates. Parameters ---------- @@ -1329,7 +1364,9 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray: >>> Coordinates: >>> * time (time) datetime64[ns] 2000-01-01T00:00:00 ... 2000-04-01T00:00:00 """ - df_dt_components: pd.DataFrame = self._get_df_dt_components(time_coords) + df_dt_components: pd.DataFrame = self._get_df_dt_components( + time_coords, drop_obsolete_cols=True + ) dt_objects = self._convert_df_to_dt(df_dt_components) time_grouped = xr.DataArray( @@ -1343,7 +1380,9 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray: return time_grouped - def _get_df_dt_components(self, time_coords: xr.DataArray) -> pd.DataFrame: + def _get_df_dt_components( + self, time_coords: xr.DataArray, drop_obsolete_cols: bool + ) -> pd.DataFrame: """Returns a DataFrame of xarray datetime components. This method extracts the applicable xarray datetime components from each @@ -1364,6 +1403,12 @@ def _get_df_dt_components(self, time_coords: xr.DataArray) -> pd.DataFrame: ---------- time_coords : xr.DataArray The time coordinates. + drop_obsolete_cols : bool + Drop obsolete columns after processing seasonal DataFrame when + ``self._freq="season"``. Set to False to keep datetime columns + needed for preprocessing the dataset (e.g,. removing incomplete + seasons), and set to True to remove obsolete columns when needing + to group time coordinates. Returns ------- @@ -1394,12 +1439,15 @@ def _get_df_dt_components(self, time_coords: xr.DataArray) -> pd.DataFrame: if self._mode in ["climatology", "departures"]: df["year"] = time_coords[f"{self.dim}.year"].values df["month"] = time_coords[f"{self.dim}.month"].values - - if self._mode == "group_average": + elif self._mode == "group_average": df["month"] = time_coords[f"{self.dim}.month"].values df = self._process_season_df(df) + if drop_obsolete_cols: + df = self._drop_obsolete_columns(df) + df = self._map_seasons_to_mid_months(df) + return df def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame: @@ -1408,13 +1456,13 @@ def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame: Parameters ---------- - df : pd.DataFrame - A DataFrame of xarray datetime components. + df : xr.DataArray + A DataFrame of seasonal datetime components. Returns ------- pd.DataFrame - A DataFrame of processed xarray datetime components. + A DataFrame of seasonal datetime components. """ df_new = df.copy() custom_seasons = self._season_config.get("custom_seasons") @@ -1427,8 +1475,6 @@ def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame: if dec_mode == "DJF": df_new = self._shift_decembers(df_new) - df_new = self._drop_obsolete_columns(df_new) - df_new = self._map_seasons_to_mid_months(df_new) return df_new def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame: @@ -1743,17 +1789,16 @@ def _add_operation_attrs(self, data_var: xr.DataArray) -> xr.DataArray: ) if self._freq == "season": - custom_seasons = self._season_config.get("custom_seasons") + data_var.attrs["drop_incomplete_seasons"] = self._season_config.get( + "drop_incomplete_seasons" + ) - if custom_seasons is None: + custom_seasons = self._season_config.get("custom_seasons") + if custom_seasons is not None: + data_var.attrs["custom_seasons"] = list(custom_seasons.keys()) + else: dec_mode = self._season_config.get("dec_mode") - drop_incomplete_djf = self._season_config.get("drop_incomplete_djf") - data_var.attrs["dec_mode"] = dec_mode - if dec_mode == "DJF": - data_var.attrs["drop_incomplete_djf"] = str(drop_incomplete_djf) - else: - data_var.attrs["custom_seasons"] = list(custom_seasons.keys()) return data_var From fbad335480d49d749920433a1d1989cd2c7786a3 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Tue, 7 Mar 2023 11:55:03 -0800 Subject: [PATCH 05/17] Fix `_drop_incomplete_seasons()` adding dims to variables - Add conditional that determines whether subsetting time coordinates is necessary with custom seasons - Update docstrings for `season_config` - Add tests --- tests/test_temporal.py | 79 +++++++++++--------------- xcdat/temporal.py | 124 ++++++++++++++++++++++++++--------------- 2 files changed, 111 insertions(+), 92 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index da37f7b7..e9ead770 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -571,27 +571,28 @@ def test_weighted_annual_averages_with_chunking(self): assert result.ts.attrs == expected.ts.attrs assert result.time.attrs == expected.time.attrs - def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): + def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons( + self, + ): ds = self.ds.copy() result = ds.temporal.group_average( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() - # Drop the incomplete DJF seasons - expected = expected.isel(time=slice(2, -1)) expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[1]], [[1]], [[1]], [[2.0]]]), + data=np.array([[[2.0]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -614,35 +615,33 @@ def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): "mode": "group_average", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "dec_mode": "DJF", - "drop_incomplete_seasons": "True", }, ) - xr.testing.assert_identical(result, expected) + assert result.identical(expected) - def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons( - self, - ): - ds = self.ds.copy() + def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): + ds = generate_dataset(decode_times=True, cf_compliant=True, has_bounds=True) result = ds.temporal.group_average( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) + expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[2.0]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]]), + data=np.ones((4, 4, 4)), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ - cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -660,13 +659,12 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons }, dims=["time", "lat", "lon"], attrs={ - "test_attr": "test", "operation": "temporal_avg", "mode": "group_average", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "True", "dec_mode": "DJF", - "drop_incomplete_seasons": "False", }, ) @@ -725,6 +723,7 @@ def test_weighted_seasonal_averages_with_JFD(self): "mode": "group_average", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "dec_mode": "JFD", }, ) @@ -806,13 +805,14 @@ def test_weighted_custom_seasonal_averages(self): "operation": "temporal_avg", "mode": "group_average", "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "False", "custom_seasons": [ "JanFebMar", "AprMayJun", "JulAugSep", "OctNovDec", ], - "weighted": "True", }, ) @@ -866,8 +866,9 @@ def test_weighted_custom_seasonal_averages_drops_incomplete_seasons(self): "operation": "temporal_avg", "mode": "group_average", "freq": "season", - "custom_seasons": ["NovDec", "FebMarApr"], "weighted": "True", + "drop_incomplete_seasons": "True", + "custom_seasons": ["NovDec", "FebMarApr"], }, ) @@ -922,8 +923,9 @@ def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years( "operation": "temporal_avg", "mode": "group_average", "freq": "season", - "custom_seasons": ["NovDecJanFebMar"], "weighted": "True", + "drop_incomplete_seasons": "False", + "custom_seasons": ["NovDecJanFebMar"], }, ) @@ -1252,8 +1254,8 @@ def test_weighted_seasonal_climatology_with_DJF(self): "mode": "climatology", "freq": "season", "weighted": "True", - "dec_mode": "DJF", "drop_incomplete_seasons": "True", + "dec_mode": "DJF", }, ) @@ -1359,6 +1361,7 @@ def test_weighted_seasonal_climatology_with_JFD(self): "mode": "climatology", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "dec_mode": "JFD", }, ) @@ -1417,6 +1420,7 @@ def test_weighted_custom_seasonal_climatology(self): "mode": "climatology", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "custom_seasons": [ "JanFebMar", "AprMayJun", @@ -1428,24 +1432,18 @@ def test_weighted_custom_seasonal_climatology(self): xr.testing.assert_identical(result, expected) - @pytest.mark.xfail def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_years( self, ): ds = self.ds.copy() - custom_seasons = [ - ["Jan", "Feb", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] + custom_seasons = [["Nov", "Dec", "Jan", "Feb", "Mar"]] result = ds.temporal.climatology( "ts", "season", season_config={ + "drop_incomplete_seasons": False, "custom_seasons": custom_seasons, - "drop_incomplete_seasons": True, }, ) @@ -1453,21 +1451,11 @@ def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_yea expected = expected.drop_dims("time") expected_time = xr.DataArray( data=np.array( - [ - cftime.DatetimeGregorian(1, 2, 1), - cftime.DatetimeGregorian(1, 5, 1), - cftime.DatetimeGregorian(1, 8, 1), - cftime.DatetimeGregorian(1, 11, 1), - ], + [cftime.DatetimeGregorian(1, 1, 1)], ), coords={ "time": np.array( - [ - cftime.DatetimeGregorian(1, 2, 1), - cftime.DatetimeGregorian(1, 5, 1), - cftime.DatetimeGregorian(1, 8, 1), - cftime.DatetimeGregorian(1, 11, 1), - ], + [cftime.DatetimeGregorian(1, 1, 1)], ), }, attrs={ @@ -1488,12 +1476,8 @@ def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_yea "mode": "climatology", "freq": "season", "weighted": "True", - "custom_seasons": [ - "JanFebMar", - "AprMayJun", - "JulAugSep", - "OctNovDec", - ], + "drop_incomplete_seasons": "False", + "custom_seasons": ["NovDecJanFebMar"], }, ) @@ -2235,8 +2219,8 @@ def test_unweighted_seasonal_departures_with_DJF(self): "mode": "departures", "freq": "season", "weighted": "False", - "dec_mode": "DJF", "drop_incomplete_seasons": "True", + "dec_mode": "DJF", }, ) @@ -2286,6 +2270,7 @@ def test_unweighted_seasonal_departures_with_JFD(self): "mode": "departures", "freq": "season", "weighted": "False", + "drop_incomplete_seasons": "False", "dec_mode": "JFD", }, ) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index a200bc49..5bf87294 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -66,8 +66,8 @@ SeasonConfigInput = TypedDict( "SeasonConfigInput", { - "dec_mode": Literal["DJF", "JFD"], "drop_incomplete_seasons": bool, + "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[List[List[str]]], }, total=False, @@ -76,16 +76,16 @@ SeasonConfigAttr = TypedDict( "SeasonConfigAttr", { - "dec_mode": Literal["DJF", "JFD"], "drop_incomplete_seasons": bool, + "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[Dict[str, List[str]]], }, total=False, ) DEFAULT_SEASON_CONFIG: SeasonConfigInput = { - "dec_mode": "DJF", "drop_incomplete_seasons": False, + "dec_mode": "DJF", "custom_seasons": None, } @@ -286,16 +286,6 @@ def group_average( predefined seasons are passed, configs for custom seasons are ignored and vice versa. - Configs for predefined seasons: - - * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") - The mode for the season that includes December. - - * "DJF": season includes the previous year December. - * "JFD": season includes the same year December. - Xarray labels the season with December as "DJF", but it is - actually "JFD". - * "drop_incomplete_seasons" (bool, by default False) Seasons are considered incomplete if they do not have all of the required months to form the season. For example, if we have @@ -309,11 +299,19 @@ def group_average( season because it only has "Jan" and "Feb". Therefore, these time coordinates are dropped. - Configs for custom seasons: + * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") + The mode for the season that includes December in the list of + list of pre-defined seasons ("DJF"/"JFD", "MAM", "JJA", "SON"). + This config is ignored if the ``custom_seasons`` config is set. + + * "DJF": season includes the previous year December. + * "JFD": season includes the same year December. + Xarray labels the season with December as "DJF", but it is + actually "JFD". * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist - representing a custom season. + representing a custom season. This config overrides the `decod * Month strings must be in the three letter format (e.g., 'Jan') * Each month must be included once in a custom season @@ -468,8 +466,6 @@ def climatology( predefined seasons are passed, configs for custom seasons are ignored and vice versa. - General configs: - * "drop_incomplete_seasons" (bool, by default False) Seasons are considered incomplete if they do not have all of the required months to form the season. For example, if we have @@ -483,21 +479,19 @@ def climatology( season because it only has "Jan" and "Feb". Therefore, these time coordinates are dropped. - Configs for predefined seasons: - * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") - The mode for the season that includes December. + The mode for the season that includes December in the list of + list of pre-defined seasons ("DJF"/"JFD", "MAM", "JJA", "SON"). + This config is ignored if the ``custom_seasons`` config is set. * "DJF": season includes the previous year December. * "JFD": season includes the same year December. Xarray labels the season with December as "DJF", but it is actually "JFD". - Configs for custom seasons: - * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist - representing a custom season. + representing a custom season. This config overrides the `decod * Month strings must be in the three letter format (e.g., 'Jan') * Each month must be included once in a custom season @@ -1058,7 +1052,22 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: """ if ( self._freq == "season" - and self._season_config.get("drop_incomplete_seasons") is True + and self._season_config.get("custom_seasons") is not None + ): + # Get a flat list of all of the months included in the custom + # seasons to determine if the dataset needs to be subsetted + # on just those months. For example, if we define a custom season + # "NDJFM", we should subset the dataset for time coordinates + # belonging to those months. + months = self._season_config["custom_seasons"].values() # type: ignore + months = list(chain.from_iterable(months.values())) # type: ignore + + if len(months) != 12: + ds = self._subset_coords_for_custom_seasons(ds, months) + + if ( + self._freq == "season" + and self._season_config["drop_incomplete_seasons"] is True ): ds = self._drop_incomplete_seasons(ds) @@ -1076,6 +1085,34 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: return ds + def _subset_coords_for_custom_seasons( + self, ds: xr.Dataset, months: List[str] + ) -> xr.Dataset: + """Subsets time coordinates to the months included in custom seasons. + + Parameters + ---------- + ds : xr.Dataset + The dataset. + months : List[str] + A list of months included in custom seasons. + Example: ["Nov", "Dec", "Jan"] + + Returns + ------- + xr.Dataset + The dataset with time coordinate subsetted to months used in + custom seasons. + """ + months_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) + + coords_by_month = ds.time.groupby(f"{self.dim}.month").groups + months_idxs = {k: coords_by_month[k] for k in months_ints} + months_idxs = sorted(list(chain.from_iterable(months_idxs.values()))) # type: ignore + ds_new = ds.isel({f"{self.dim}": months_idxs}) + + return ds_new + def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: """Drops incomplete seasons within a continuous time series. @@ -1102,37 +1139,34 @@ def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: A DataFrame of seasonal datetime components with only complete seasons. """ - # Algorithm - # Prereq - This needs to be done AFTER time coordinates are labeled - # and BEFORE obsoelete columns are dropped because custom seasons can be - # assigned to the time coordiantes first. - # 1. Get the count of months per season (pre-defined seasons by xarray - # all have 3), otherwise use custom seasons count - # 2. Label all time coordinates by groups - # 3. Group the time coordinates by group and the get count - # 4. Drop time coordinates where count != expected count for season - ds_new = ds.copy() - time_coords = ds[self.dim].copy() - # Transform the time coords into a DataFrame of seasonal datetime # components based on the grouping mode. + time_coords = ds[self.dim].copy() df = self._get_df_dt_components(time_coords, drop_obsolete_cols=False) - # Add a column for the expected count of months for that season - # For example, "NovDec" is split into ["Nov", "Dec"] which equals an - # expected count of 2 months. + # Get the expected and actual number of months for each season group. df["expected_months"] = df["season"].str.split(r"(?<=.)(?=[A-Z])").str.len() - # Add a column for the actual count of months for that season. - df["actual_months"] = df.groupby(["season"])["year"].transform("count") + df["actual_months"] = df.groupby(["year", "season"])["year"].transform("count") # Get the incomplete seasons and drop the time coordinates that are in # those incomplete seasons. indexes_to_drop = df[df["expected_months"] != df["actual_months"]].index if len(indexes_to_drop) > 0: + # The dataset needs to be split into a dataset with and a dataset + # without the time dimension because the xarray `.where()` method + # concatenates the time dimension to non-time dimension data vars, + # which is an undesired behavior. + ds_no_time = ds.get([v for v in ds.data_vars if self.dim not in ds[v].dims]) # type: ignore + ds_time = ds.get([v for v in ds.data_vars if self.dim in ds[v].dims]) # type: ignore + coords_to_drop = time_coords.values[indexes_to_drop] - ds_new = ds_new.where(~time_coords.isin(coords_to_drop), drop=True) + ds_time = ds_time.where(~time_coords.isin(coords_to_drop), drop=True) - return ds_new + ds_new = xr.merge([ds_time, ds_no_time]) + + return ds_new + + return ds def _drop_leap_days(self, ds: xr.Dataset): """Drop leap days from time coordinates. @@ -1789,8 +1823,8 @@ def _add_operation_attrs(self, data_var: xr.DataArray) -> xr.DataArray: ) if self._freq == "season": - data_var.attrs["drop_incomplete_seasons"] = self._season_config.get( - "drop_incomplete_seasons" + data_var.attrs["drop_incomplete_seasons"] = str( + self._season_config["drop_incomplete_seasons"] ) custom_seasons = self._season_config.get("custom_seasons") From 4a896badd3bf66d6926cd770b52a664ccb65aaa1 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Thu, 4 Apr 2024 13:24:55 -0700 Subject: [PATCH 06/17] Add deprecation warning for "drop_incomplete_djf" --- xcdat/temporal.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 5bf87294..6275e4b1 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1,5 +1,6 @@ """Module containing temporal functions.""" +import warnings from datetime import datetime from itertools import chain from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union, get_args @@ -66,6 +67,7 @@ SeasonConfigInput = TypedDict( "SeasonConfigInput", { + "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[List[List[str]]], @@ -76,6 +78,7 @@ SeasonConfigAttr = TypedDict( "SeasonConfigAttr", { + "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[Dict[str, List[str]]], @@ -249,6 +252,11 @@ def group_average( Time bounds are used for generating weights to calculate weighted group averages (refer to the ``weighted`` parameter documentation below). + .. deprecated:: v0.7.0 + The ``season_config`` dictionary argument ``"drop_incomplete_djf"`` + is being deprecated. Please use ``"drop_incomplete_seasons"`` + instead. + Parameters ---------- data_var: str @@ -420,6 +428,11 @@ def climatology( Time bounds are used for generating weights to calculate weighted climatology (refer to the ``weighted`` parameter documentation below). + .. deprecated:: v0.7.0 + The ``season_config`` dictionary argument ``"drop_incomplete_djf"`` + is being deprecated. Please use ``"drop_incomplete_seasons"`` + instead. + Parameters ---------- data_var: str @@ -611,6 +624,11 @@ def departures( Time bounds are used for generating weights to calculate weighted climatology (refer to the ``weighted`` parameter documentation below). + .. deprecated:: v0.7.0 + The ``season_config`` dictionary argument ``"drop_incomplete_djf"`` + is being deprecated. Please use ``"drop_incomplete_seasons"`` + instead. + Parameters ---------- data_var: str @@ -962,9 +980,20 @@ def _set_arg_attrs( dec_mode = season_config.get("dec_mode", "DJF") self._season_config: SeasonConfigAttr = {} - self._season_config["drop_incomplete_seasons"] = season_config.get( - "drop_incomplete_seasons", False - ) + + # TODO: Deprecate `drop_incomplete_djf`. + drop_incomplete_djf = season_config.get("drop_incomplete_djf", None) + if drop_incomplete_djf is not None: + warnings.warn( + "The `season_config` argument 'drop_incomplete_djf' is being " + "deprecated. Please use 'drop_incomplete_seasons' instead.", + DeprecationWarning, + ) + self._season_config["drop_incomplete_seasons"] = drop_incomplete_djf + else: + self._season_config["drop_incomplete_seasons"] = season_config.get( + "drop_incomplete_seasons", False + ) if custom_seasons is None: if dec_mode not in ("DJF", "JFD"): From 9f496e5584bb4e1d635fcd28ac2f96a2fe29a85d Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Thu, 4 Apr 2024 14:44:25 -0700 Subject: [PATCH 07/17] Fix tests --- tests/test_temporal.py | 120 ++++++++++++++++++++++++++++++++--------- xcdat/temporal.py | 24 +++++---- 2 files changed, 108 insertions(+), 36 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index e9ead770..9ba0763e 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -1,4 +1,5 @@ import logging +import warnings import cftime import numpy as np @@ -620,7 +621,7 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): ds = generate_dataset(decode_times=True, cf_compliant=True, has_bounds=True) @@ -872,7 +873,7 @@ def test_weighted_custom_seasonal_averages_drops_incomplete_seasons(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years( self, @@ -1151,7 +1152,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000", "01-01-2000"), ) @@ -1159,7 +1160,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000"), ) @@ -1169,7 +1170,7 @@ def test_subsets_climatology_based_on_reference_period(self): result = ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("2000-01-01", "2000-06-01"), ) @@ -1201,7 +1202,7 @@ def test_subsets_climatology_based_on_reference_period(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) @@ -1261,6 +1262,70 @@ def test_weighted_seasonal_climatology_with_DJF(self): xr.testing.assert_identical(result, expected) + def test_raises_deprecation_warning_with_drop_incomplete_djf_season_config(self): + # NOTE: This will test will also cover the other public APIs that + # have drop_incomplete_djf as a season_config arg. + ds = self.ds.copy() + + with warnings.catch_warnings(record=True) as w: + result = ds.temporal.climatology( + "ts", + "season", + season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + ) + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert str(w[0].message) == ( + "The `season_config` argument 'drop_incomplete_djf' is being deprecated. " + "Please use 'drop_incomplete_seasons' instead." + ) + + expected = ds.copy() + expected = expected.drop_dims("time") + expected_time = xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(1, 1, 1), + cftime.DatetimeGregorian(1, 4, 1), + cftime.DatetimeGregorian(1, 7, 1), + cftime.DatetimeGregorian(1, 10, 1), + ], + ), + coords={ + "time": np.array( + [ + cftime.DatetimeGregorian(1, 1, 1), + cftime.DatetimeGregorian(1, 4, 1), + cftime.DatetimeGregorian(1, 7, 1), + cftime.DatetimeGregorian(1, 10, 1), + ], + ), + }, + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ) + expected["ts"] = xr.DataArray( + name="ts", + data=np.ones((4, 4, 4)), + coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time}, + dims=["time", "lat", "lon"], + attrs={ + "operation": "temporal_avg", + "mode": "climatology", + "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "True", + "dec_mode": "DJF", + }, + ) + + xr.testing.assert_identical(result, expected) + @requires_dask def test_chunked_weighted_seasonal_climatology_with_DJF(self): ds = self.ds.copy().chunk({"time": 2}) @@ -1468,7 +1533,7 @@ def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_yea expected["ts"] = xr.DataArray( name="ts", - data=np.ones((4, 4, 4)), + data=np.ones((1, 4, 4)), coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time}, dims=["time", "lat", "lon"], attrs={ @@ -1481,7 +1546,7 @@ def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_yea }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_monthly_climatology(self): result = self.ds.temporal.climatology("ts", "month") @@ -1902,7 +1967,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.departures( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000", "01-01-2000"), ) @@ -1910,7 +1975,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.departures( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000"), ) @@ -1921,7 +1986,7 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): "ts", "season", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, reference_period=("2000-01-01", "2000-06-01"), ) @@ -1929,13 +1994,14 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[np.nan]], [[np.nan]], [[np.nan]]]), + data=np.array([[[0.0]], [[0.0]], [[np.nan]], [[np.nan]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -1959,7 +2025,7 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "False", }, ) @@ -1974,7 +2040,7 @@ def test_monthly_departures_relative_to_climatology_reference_period_with_same_o "ts", "month", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("2000-01-01", "2000-06-01"), ) @@ -2057,20 +2123,21 @@ def test_weighted_seasonal_departures_with_DJF(self): "ts", "season", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]]]), + data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -2094,7 +2161,7 @@ def test_weighted_seasonal_departures_with_DJF(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_seasons": "True", + "drop_incomplete_seasons": "False", }, ) @@ -2108,20 +2175,21 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): "season", weighted=True, keep_weights=True, - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]]]), + data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -2145,16 +2213,17 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_seasons": "True", + "drop_incomplete_seasons": "False", }, ) expected["time_wts"] = xr.DataArray( name="ts", - data=np.array([1.0, 1.0, 1.0, 1.0]), + data=np.array([0.52542373, 1.0, 1.0, 1.0, 0.47457627]), coords={ "time_original": xr.DataArray( data=np.array( [ + "2000-01-16T12:00:00.000000000", "2000-03-16T12:00:00.000000000", "2000-06-16T00:00:00.000000000", "2000-09-16T00:00:00.000000000", @@ -2174,7 +2243,7 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): dims=["time_original"], ) - xr.testing.assert_identical(result, expected) + xr.testing.assert_allclose(result, expected) def test_unweighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() @@ -2183,20 +2252,21 @@ def test_unweighted_seasonal_departures_with_DJF(self): "ts", "season", weighted=False, - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]]]), + data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -2219,7 +2289,7 @@ def test_unweighted_seasonal_departures_with_DJF(self): "mode": "departures", "freq": "season", "weighted": "False", - "drop_incomplete_seasons": "True", + "drop_incomplete_seasons": "False", "dec_mode": "DJF", }, ) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 6275e4b1..f7efcaaa 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -67,7 +67,6 @@ SeasonConfigInput = TypedDict( "SeasonConfigInput", { - "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[List[List[str]]], @@ -78,7 +77,6 @@ SeasonConfigAttr = TypedDict( "SeasonConfigAttr", { - "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[Dict[str, List[str]]], @@ -319,7 +317,7 @@ def group_average( * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist - representing a custom season. This config overrides the `decod + representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') * Each month must be included once in a custom season @@ -504,7 +502,7 @@ def climatology( * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist - representing a custom season. This config overrides the `decod + representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') * Each month must be included once in a custom season @@ -971,7 +969,9 @@ def _set_arg_attrs( # "season" frequency specific configuration attributes. for key in season_config.keys(): - if key not in DEFAULT_SEASON_CONFIG.keys(): + # TODO: Deprecate `drop_incomplete_djf`. + valid_keys = list(DEFAULT_SEASON_CONFIG.keys()) + ["drop_incomplete_djf"] + if key not in valid_keys: raise KeyError( f"'{key}' is not a supported season config. Supported " f"configs include: {DEFAULT_SEASON_CONFIG.keys()}." @@ -989,7 +989,7 @@ def _set_arg_attrs( "deprecated. Please use 'drop_incomplete_seasons' instead.", DeprecationWarning, ) - self._season_config["drop_incomplete_seasons"] = drop_incomplete_djf + self._season_config["drop_incomplete_seasons"] = drop_incomplete_djf # type: ignore else: self._season_config["drop_incomplete_seasons"] = season_config.get( "drop_incomplete_seasons", False @@ -1089,7 +1089,7 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: # "NDJFM", we should subset the dataset for time coordinates # belonging to those months. months = self._season_config["custom_seasons"].values() # type: ignore - months = list(chain.from_iterable(months.values())) # type: ignore + months = list(chain.from_iterable(months)) if len(months) != 12: ds = self._subset_coords_for_custom_seasons(ds, months) @@ -1133,12 +1133,14 @@ def _subset_coords_for_custom_seasons( The dataset with time coordinate subsetted to months used in custom seasons. """ - months_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) + month_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) coords_by_month = ds.time.groupby(f"{self.dim}.month").groups - months_idxs = {k: coords_by_month[k] for k in months_ints} - months_idxs = sorted(list(chain.from_iterable(months_idxs.values()))) # type: ignore - ds_new = ds.isel({f"{self.dim}": months_idxs}) + month_to_time_idx = { + k: coords_by_month[k] for k in month_ints if k in coords_by_month + } + month_to_time_idx = sorted(list(chain.from_iterable(month_to_time_idx.values()))) # type: ignore + ds_new = ds.isel({f"{self.dim}": month_to_time_idx}) return ds_new From 7cb69103d491fccef0f97f5db0673b66190415c5 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Thu, 4 Apr 2024 15:22:33 -0700 Subject: [PATCH 08/17] Add test to cover missing line in `_drop_incomplete_season` --- tests/test_temporal.py | 65 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 9ba0763e..723925fb 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -210,6 +210,7 @@ def test_averages_for_monthly_time_series(self): ) xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -227,6 +228,7 @@ def test_averages_for_monthly_time_series(self): }, ) xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs def test_averages_for_daily_time_series(self): ds = xr.Dataset( @@ -819,6 +821,68 @@ def test_weighted_custom_seasonal_averages(self): xr.testing.assert_identical(result, expected) + def test_weighted_seasonal_averages_with_custom_seasons_and_all_complete_seasons( + self, + ): + ds = self.ds.copy() + ds["time"].values[:] = np.array( + [ + "2000-01-16T12:00:00.000000000", + "2000-02-15T12:00:00.000000000", + "2000-03-16T12:00:00.000000000", + "2000-06-16T00:00:00.000000000", + "2000-09-16T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ) + + result = ds.temporal.group_average( + "ts", + "season", + season_config={ + "custom_seasons": [["Jan", "Mar", "Jun"], ["Feb", "Sep"]], + "drop_incomplete_seasons": True, + }, + ) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[1.34065934]], [[1.47457627]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(2000, 3, 1), + cftime.DatetimeGregorian(2000, 9, 1), + ], + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "True", + "custom_seasons": ["JanMarJun", "FebSep"], + }, + ) + + xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs + def test_weighted_custom_seasonal_averages_drops_incomplete_seasons(self): ds = self.ds.copy() ds["time"].values[:] = np.array( @@ -2244,6 +2308,7 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): ) xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs def test_unweighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() From 1cf436447234dff03cbb5e2f9816077070b1dd8a Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Tue, 30 Jul 2024 12:30:32 -0700 Subject: [PATCH 09/17] Add clear error message when no complete seasons are found --- tests/test_temporal.py | 15 +++++++++++++++ xcdat/temporal.py | 11 +++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 723925fb..c748e7c7 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -762,6 +762,21 @@ def test_raises_error_with_incorrect_custom_seasons_argument(self): season_config={"custom_seasons": custom_seasons}, ) + def test_raises_error_with_dataset_that_has_no_complete_seasons(self): + ds = self.ds.copy() + ds = ds.isel(time=slice(0, 1)) + custom_seasons = [["Dec", "Jan"]] + + with pytest.raises(RuntimeError): + ds.temporal.group_average( + "ts", + "season", + season_config={ + "custom_seasons": custom_seasons, + "drop_incomplete_seasons": True, + }, + ) + def test_weighted_custom_seasonal_averages(self): ds = self.ds.copy() diff --git a/xcdat/temporal.py b/xcdat/temporal.py index f7efcaaa..57a524f9 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -250,7 +250,7 @@ def group_average( Time bounds are used for generating weights to calculate weighted group averages (refer to the ``weighted`` parameter documentation below). - .. deprecated:: v0.7.0 + .. deprecated:: v0.8.0 The ``season_config`` dictionary argument ``"drop_incomplete_djf"`` is being deprecated. Please use ``"drop_incomplete_seasons"`` instead. @@ -1182,7 +1182,14 @@ def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: # Get the incomplete seasons and drop the time coordinates that are in # those incomplete seasons. indexes_to_drop = df[df["expected_months"] != df["actual_months"]].index - if len(indexes_to_drop) > 0: + + if len(indexes_to_drop) == len(time_coords): + raise RuntimeError( + "No time coordinates remain with `drop_incomplete_seasons=True`. " + "Check the dataset has at least one complete season and/or " + "specify `drop_incomplete_seasons=False` instead." + ) + elif len(indexes_to_drop) > 0: # The dataset needs to be split into a dataset with and a dataset # without the time dimension because the xarray `.where()` method # concatenates the time dimension to non-time dimension data vars, From a05088423c509a02ab1c41ee7b021d59f3fc131d Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Tue, 30 Jul 2024 12:34:18 -0700 Subject: [PATCH 10/17] Update deprecation version for `drop_incomplete_djf` --- xcdat/temporal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 57a524f9..36d1f8b6 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -426,7 +426,7 @@ def climatology( Time bounds are used for generating weights to calculate weighted climatology (refer to the ``weighted`` parameter documentation below). - .. deprecated:: v0.7.0 + .. deprecated:: v0.8.0 The ``season_config`` dictionary argument ``"drop_incomplete_djf"`` is being deprecated. Please use ``"drop_incomplete_seasons"`` instead. @@ -622,7 +622,7 @@ def departures( Time bounds are used for generating weights to calculate weighted climatology (refer to the ``weighted`` parameter documentation below). - .. deprecated:: v0.7.0 + .. deprecated:: v0.8.0 The ``season_config`` dictionary argument ``"drop_incomplete_djf"`` is being deprecated. Please use ``"drop_incomplete_seasons"`` instead. From 5ace7a89ca79238714bf25ed49fd89ef369412b1 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Thu, 1 Aug 2024 10:50:49 -0700 Subject: [PATCH 11/17] Update static time dim ref to CF interpreted dim key --- xcdat/temporal.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 36d1f8b6..675acf42 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1135,11 +1135,13 @@ def _subset_coords_for_custom_seasons( """ month_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) - coords_by_month = ds.time.groupby(f"{self.dim}.month").groups + coords_by_month = ds[self.dim].groupby(f"{self.dim}.month").groups month_to_time_idx = { k: coords_by_month[k] for k in month_ints if k in coords_by_month } - month_to_time_idx = sorted(list(chain.from_iterable(month_to_time_idx.values()))) # type: ignore + month_to_time_idx = sorted( + list(chain.from_iterable(month_to_time_idx.values())) + ) # type: ignore ds_new = ds.isel({f"{self.dim}": month_to_time_idx}) return ds_new @@ -1224,7 +1226,9 @@ def _drop_leap_days(self, ds: xr.Dataset): ------- xr.Dataset """ - ds = ds.sel(**{self.dim: ~((ds.time.dt.month == 2) & (ds.time.dt.day == 29))}) + ds = ds.sel( # type: ignore + **{self.dim: ~((ds[self.dim].dt.month == 2) & (ds[self.dim].dt.day == 29))} + ) return ds def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: From 25e31f7e4903ddc325fd4431cd6a00813a0d4431 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 14 Oct 2024 12:51:02 -0700 Subject: [PATCH 12/17] Add `drop_incomplete_djf` back --- tests/test_temporal.py | 52 ++++++++++- xcdat/temporal.py | 195 +++++++++++++++++++++++++++++++---------- 2 files changed, 199 insertions(+), 48 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index c748e7c7..7951d3f3 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -673,6 +673,56 @@ def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): xr.testing.assert_identical(result, expected) + def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_djf(self): + ds = self.ds.copy() + + result = ds.temporal.group_average( + "ts", + "season", + season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + ) + expected = ds.copy() + # Drop the incomplete DJF seasons + expected = expected.isel(time=slice(2, -1)) + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[1]], [[1]], [[1]], [[2.0]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(2000, 4, 1), + cftime.DatetimeGregorian(2000, 7, 1), + cftime.DatetimeGregorian(2000, 10, 1), + cftime.DatetimeGregorian(2001, 1, 1), + ], + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", + "weighted": "True", + "dec_mode": "DJF", + "drop_incomplete_djf": "True", + }, + ) + + xr.testing.assert_identical(result, expected) + def test_weighted_seasonal_averages_with_JFD(self): ds = self.ds.copy() @@ -1398,7 +1448,7 @@ def test_raises_deprecation_warning_with_drop_incomplete_djf_season_config(self) "mode": "climatology", "freq": "season", "weighted": "True", - "drop_incomplete_seasons": "True", + "drop_incomplete_djf": "True", "dec_mode": "DJF", }, ) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 675acf42..dde7ff7d 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -67,6 +67,8 @@ SeasonConfigInput = TypedDict( "SeasonConfigInput", { + # TODO: Deprecate incomplete_djf. + "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[List[List[str]]], @@ -77,6 +79,8 @@ SeasonConfigAttr = TypedDict( "SeasonConfigAttr", { + # TODO: Deprecate incomplete_djf. + "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[Dict[str, List[str]]], @@ -85,6 +89,8 @@ ) DEFAULT_SEASON_CONFIG: SeasonConfigInput = { + # TODO: Deprecate incomplete_djf. + "drop_incomplete_djf": False, "drop_incomplete_seasons": False, "dec_mode": "DJF", "custom_seasons": None, @@ -287,23 +293,32 @@ def group_average( keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. - season_config: SeasonConfigInput, optional + season_config : SeasonConfigInput, optional A dictionary for "season" frequency configurations. If configs for predefined seasons are passed, configs for custom seasons are ignored and vice versa. * "drop_incomplete_seasons" (bool, by default False) Seasons are considered incomplete if they do not have all of - the required months to form the season. For example, if we have + the required months to form the season. This argument supersedes + "drop_incomplete_djf". For example, if we have the time coordinates ["2000-11-16", "2000-12-16", "2001-01-16", "2001-02-16"] and we want to group seasons by "ND" ("Nov", "Dec") and "JFM" ("Jan", "Feb", "Mar"). * ["2000-11-16", "2000-12-16"] is considered a complete "ND" - season since both "Nov" and "Dec" are present. + season since both "Nov" and "Dec" are present. * ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" - season because it only has "Jan" and "Feb". Therefore, these - time coordinates are dropped. + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. + + * "drop_incomplete_djf" (bool, by default False) + If the "dec_mode" is "DJF", this flag drops (True) or keeps + (False) time coordinates that fall under incomplete DJF seasons + Incomplete DJF seasons include the start year Jan/Feb and the + end year Dec. This argument is superceded by + "drop_incomplete_seasons" and will be deprecated in a future + release. * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") The mode for the season that includes December in the list of @@ -472,23 +487,32 @@ def climatology( 'yyyy-mm-dd'. For example, ``('1850-01-01', '1899-12-31')``. If no value is provided, the climatological reference period will be the full period covered by the dataset. - season_config: SeasonConfigInput, optional + season_config : SeasonConfigInput, optional A dictionary for "season" frequency configurations. If configs for predefined seasons are passed, configs for custom seasons are ignored and vice versa. * "drop_incomplete_seasons" (bool, by default False) Seasons are considered incomplete if they do not have all of - the required months to form the season. For example, if we have + the required months to form the season. This argument supersedes + "drop_incomplete_djf". For example, if we have the time coordinates ["2000-11-16", "2000-12-16", "2001-01-16", "2001-02-16"] and we want to group seasons by "ND" ("Nov", "Dec") and "JFM" ("Jan", "Feb", "Mar"). * ["2000-11-16", "2000-12-16"] is considered a complete "ND" - season since both "Nov" and "Dec" are present. + season since both "Nov" and "Dec" are present. * ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" - season because it only has "Jan" and "Feb". Therefore, these - time coordinates are dropped. + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. + + * "drop_incomplete_djf" (bool, by default False) + If the "dec_mode" is "DJF", this flag drops (True) or keeps + (False) time coordinates that fall under incomplete DJF seasons + Incomplete DJF seasons include the start year Jan/Feb and the + end year Dec. This argument is superceded by + "drop_incomplete_seasons" and will be deprecated in a future + release. * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") The mode for the season that includes December in the list of @@ -669,7 +693,7 @@ def departures( ``('1850-01-01', '1899-12-31')``. If no value is provided, the climatological reference period will be the full period covered by the dataset. - season_config: SeasonConfigInput, optional + season_config : SeasonConfigInput, optional A dictionary for "season" frequency configurations. If configs for predefined seasons are passed, configs for custom seasons are ignored and vice versa. @@ -678,16 +702,25 @@ def departures( * "drop_incomplete_seasons" (bool, by default False) Seasons are considered incomplete if they do not have all of - the required months to form the season. For example, if we have + the required months to form the season. This argument supersedes + "drop_incomplete_djf". For example, if we have the time coordinates ["2000-11-16", "2000-12-16", "2001-01-16", "2001-02-16"] and we want to group seasons by "ND" ("Nov", "Dec") and "JFM" ("Jan", "Feb", "Mar"). * ["2000-11-16", "2000-12-16"] is considered a complete "ND" - season since both "Nov" and "Dec" are present. + season since both "Nov" and "Dec" are present. * ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" - season because it only has "Jan" and "Feb". Therefore, these - time coordinates are dropped. + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. + + * "drop_incomplete_djf" (bool, by default False) + If the "dec_mode" is "DJF", this flag drops (True) or keeps + (False) time coordinates that fall under incomplete DJF seasons + Incomplete DJF seasons include the start year Jan/Feb and the + end year Dec. This argument is superceded by + "drop_incomplete_seasons" and will be deprecated in a future + release. Configs for predefined seasons: @@ -967,44 +1000,46 @@ def _set_arg_attrs( self._is_valid_reference_period(reference_period) self._reference_period = reference_period - # "season" frequency specific configuration attributes. + self._set_season_config_attr(season_config) + + def _set_season_config_attr(self, season_config: SeasonConfigInput): for key in season_config.keys(): - # TODO: Deprecate `drop_incomplete_djf`. - valid_keys = list(DEFAULT_SEASON_CONFIG.keys()) + ["drop_incomplete_djf"] - if key not in valid_keys: + if key not in DEFAULT_SEASON_CONFIG: raise KeyError( f"'{key}' is not a supported season config. Supported " f"configs include: {DEFAULT_SEASON_CONFIG.keys()}." ) - custom_seasons = season_config.get("custom_seasons", None) - dec_mode = season_config.get("dec_mode", "DJF") self._season_config: SeasonConfigAttr = {} + self._season_config["drop_incomplete_seasons"] = season_config.get( + "drop_incomplete_seasons", False + ) - # TODO: Deprecate `drop_incomplete_djf`. - drop_incomplete_djf = season_config.get("drop_incomplete_djf", None) - if drop_incomplete_djf is not None: - warnings.warn( - "The `season_config` argument 'drop_incomplete_djf' is being " - "deprecated. Please use 'drop_incomplete_seasons' instead.", - DeprecationWarning, - ) - self._season_config["drop_incomplete_seasons"] = drop_incomplete_djf # type: ignore + custom_seasons = season_config.get("custom_seasons", None) + if custom_seasons is not None: + self._season_config["custom_seasons"] = self._form_seasons(custom_seasons) else: - self._season_config["drop_incomplete_seasons"] = season_config.get( - "drop_incomplete_seasons", False - ) - - if custom_seasons is None: + dec_mode = season_config.get("dec_mode", "DJF") if dec_mode not in ("DJF", "JFD"): raise ValueError( "Incorrect 'dec_mode' key value for `season_config`. " "Supported modes include 'DJF' or 'JFD'." ) + self._season_config["dec_mode"] = dec_mode - else: - self._season_config["custom_seasons"] = self._form_seasons(custom_seasons) + # TODO: Deprecate incomplete_djf. + drop_incomplete_djf = season_config.get("drop_incomplete_djf", False) + if drop_incomplete_djf is not False: + warnings.warn( + "The `season_config` argument 'drop_incomplete_djf' is being " + "deprecated. Please use 'drop_incomplete_seasons' instead.", + DeprecationWarning, + stacklevel=2, + ) + + if dec_mode == "DJF": + self._season_config["drop_incomplete_djf"] = drop_incomplete_djf def _is_valid_reference_period(self, reference_period: Tuple[str, str]): try: @@ -1094,18 +1129,28 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: if len(months) != 12: ds = self._subset_coords_for_custom_seasons(ds, months) + if ( + self._freq == "day" + and self._mode in ["climatology", "departures"] + and self.calendar in ["gregorian", "proleptic_gregorian", "standard"] + ): + ds = self._drop_leap_days(ds) + if ( self._freq == "season" and self._season_config["drop_incomplete_seasons"] is True ): ds = self._drop_incomplete_seasons(ds) + # TODO: Deprecate incomplete_djf. Only run this is drop_incomplete_seasons + # is False and drop_incomplete_djf is True. if ( - self._freq == "day" - and self._mode in ["climatology", "departures"] - and self.calendar in ["gregorian", "proleptic_gregorian", "standard"] + self._freq == "season" + and self._season_config.get("dec_mode") == "DJF" + and self._season_config.get("drop_incomplete_djf") is True + and self._season_config.get("drop_incomplete_seasons") is False ): - ds = self._drop_leap_days(ds) + ds = self._drop_incomplete_djf(ds) if self._mode == "climatology" and self._reference_period is not None: ds = ds.sel( @@ -1140,12 +1185,59 @@ def _subset_coords_for_custom_seasons( k: coords_by_month[k] for k in month_ints if k in coords_by_month } month_to_time_idx = sorted( - list(chain.from_iterable(month_to_time_idx.values())) - ) # type: ignore + list(chain.from_iterable(month_to_time_idx.values())) # type: ignore + ) ds_new = ds.isel({f"{self.dim}": month_to_time_idx}) return ds_new + def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset: + """Drops incomplete DJF seasons within a continuous time series. + + This method assumes that the time series is continuous and removes the + leading and trailing incomplete seasons (e.g., the first January and + February of a time series that are not complete, because the December of + the previous year is missing). This method does not account for or + remove missing time steps anywhere else. + + Parameters + ---------- + dataset : xr.Dataset + The dataset with some possibly incomplete DJF seasons. + Returns + ------- + xr.Dataset + The dataset with only complete DJF seasons. + """ + # Separate the dataset into two datasets, one with and one without + # the time dimension. This is necessary because the xarray .where() + # method concatenates the time dimension to non-time dimension data + # vars, which is not a desired behavior. + ds = dataset.copy() + ds_time = ds.get([v for v in ds.data_vars if self.dim in ds[v].dims]) # type: ignore + ds_no_time = ds.get([v for v in ds.data_vars if self.dim not in ds[v].dims]) # type: ignore + + start_year, end_year = ( + ds[self.dim].dt.year.values[0], + ds[self.dim].dt.year.values[-1], + ) + incomplete_seasons = ( + f"{int(start_year):04d}-01", + f"{int(start_year):04d}-02", + f"{int(end_year):04d}-12", + ) + + for year_month in incomplete_seasons: + try: + coord_pt = ds.loc[dict(time=year_month)][self.dim][0] + ds_time = ds_time.where(ds_time[self.dim] != coord_pt, drop=True) + except (KeyError, IndexError): + continue + + ds_final = xr.merge((ds_time, ds_no_time)) + + return ds_final + def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: """Drops incomplete seasons within a continuous time series. @@ -1196,6 +1288,9 @@ def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: # without the time dimension because the xarray `.where()` method # concatenates the time dimension to non-time dimension data vars, # which is an undesired behavior. + # FIXME: Figure out if this code block is still necessary + # https://github.com/pydata/xarray/issues/1234 + # https://github.com/pydata/xarray/issues/8796#issuecomment-1974878267 ds_no_time = ds.get([v for v in ds.data_vars if self.dim not in ds[v].dims]) # type: ignore ds_time = ds.get([v for v in ds.data_vars if self.dim in ds[v].dims]) # type: ignore @@ -1226,7 +1321,7 @@ def _drop_leap_days(self, ds: xr.Dataset): ------- xr.Dataset """ - ds = ds.sel( # type: ignore + ds = ds.sel( **{self.dim: ~((ds[self.dim].dt.month == 2) & (ds[self.dim].dt.day == 29))} ) return ds @@ -1865,9 +1960,15 @@ def _add_operation_attrs(self, data_var: xr.DataArray) -> xr.DataArray: ) if self._freq == "season": - data_var.attrs["drop_incomplete_seasons"] = str( - self._season_config["drop_incomplete_seasons"] - ) + drop_incomplete_seasons = self._season_config["drop_incomplete_seasons"] + drop_incomplete_djf = self._season_config.get("drop_incomplete_djf", False) + + # TODO: Deprecate drop_incomplete_djf. This attr is only set if the + # user does not set drop_incomplete_seasons. + if drop_incomplete_seasons is False and drop_incomplete_djf is not False: + data_var.attrs["drop_incomplete_djf"] = str(drop_incomplete_djf) + else: + data_var.attrs["drop_incomplete_seasons"] = str(drop_incomplete_seasons) custom_seasons = self._season_config.get("custom_seasons") if custom_seasons is not None: From 2b47972def87763ebcf0df452b4f4cb267b3f89a Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 14 Oct 2024 13:43:42 -0700 Subject: [PATCH 13/17] Move deprecation warning to inside DJF conditional --- xcdat/temporal.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index dde7ff7d..8002dbac 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1030,15 +1030,15 @@ def _set_season_config_attr(self, season_config: SeasonConfigInput): # TODO: Deprecate incomplete_djf. drop_incomplete_djf = season_config.get("drop_incomplete_djf", False) - if drop_incomplete_djf is not False: - warnings.warn( - "The `season_config` argument 'drop_incomplete_djf' is being " - "deprecated. Please use 'drop_incomplete_seasons' instead.", - DeprecationWarning, - stacklevel=2, - ) - if dec_mode == "DJF": + if drop_incomplete_djf is not False: + warnings.warn( + "The `season_config` argument 'drop_incomplete_djf' is being " + "deprecated. Please use 'drop_incomplete_seasons' instead.", + DeprecationWarning, + stacklevel=2, + ) + self._season_config["drop_incomplete_djf"] = drop_incomplete_djf def _is_valid_reference_period(self, reference_period: Tuple[str, str]): From 05496ed1763bcee5fd12d2912cbe5cf01bf99e25 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Wed, 23 Oct 2024 14:53:39 -0700 Subject: [PATCH 14/17] Update comment about .where concatenating time din --- xcdat/temporal.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 8002dbac..0163dac1 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1286,9 +1286,8 @@ def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: elif len(indexes_to_drop) > 0: # The dataset needs to be split into a dataset with and a dataset # without the time dimension because the xarray `.where()` method - # concatenates the time dimension to non-time dimension data vars, - # which is an undesired behavior. - # FIXME: Figure out if this code block is still necessary + # adds the time dimension to non-time dimension data vars when + # broadcasting, which is a behavior we do not desire. # https://github.com/pydata/xarray/issues/1234 # https://github.com/pydata/xarray/issues/8796#issuecomment-1974878267 ds_no_time = ds.get([v for v in ds.data_vars if self.dim not in ds[v].dims]) # type: ignore From 0b6852f0659030c9d4d763c4673416a95f8a4b3c Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Thu, 24 Oct 2024 13:31:42 -0700 Subject: [PATCH 15/17] Refactor logic for shifting months to use Xarray instead of Pandas - Months are also shifted in the `_preprocess_dataset()` method now. Before months were being shifted twice, once when dropping incomplete seasons or DJF, and a second time when labeling time coordinates. --- xcdat/temporal.py | 299 ++++++++++++++++++++++------------------------ 1 file changed, 145 insertions(+), 154 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 0163dac1..810e1c63 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1101,9 +1101,12 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]] def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: """Preprocess the dataset based on averaging settings. - Preprocessing operations include: - - Drop incomplete DJF seasons (leading/trailing) - - Drop leap days + Operations include: + 1. Drop leap days for daily climatologies. + 2. Subset the dataset based on the reference period. + 3. Shift years for custom seasons spanning the calendar year. + 4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons. + 5. Drop incomplete seasons if specified. Parameters ---------- @@ -1114,6 +1117,18 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: ------- xr.Dataset """ + if ( + self._freq == "day" + and self._mode in ["climatology", "departures"] + and self.calendar in ["gregorian", "proleptic_gregorian", "standard"] + ): + ds = self._drop_leap_days(ds) + + if self._mode == "climatology" and self._reference_period is not None: + ds = ds.sel( + {self.dim: slice(self._reference_period[0], self._reference_period[1])} + ) + if ( self._freq == "season" and self._season_config.get("custom_seasons") is not None @@ -1129,12 +1144,17 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: if len(months) != 12: ds = self._subset_coords_for_custom_seasons(ds, months) - if ( - self._freq == "day" - and self._mode in ["climatology", "departures"] - and self.calendar in ["gregorian", "proleptic_gregorian", "standard"] - ): - ds = self._drop_leap_days(ds) + ds = self._shift_custom_season_years(ds) + + if self._freq == "season" and self._season_config.get("dec_mode") == "DJF": + ds = self._shift_djf_decembers(ds) + + # TODO: Deprecate incomplete_djf. + if ( + self._season_config.get("drop_incomplete_djf") is True + and self._season_config.get("drop_incomplete_seasons") is False + ): + ds = self._drop_incomplete_djf(ds) if ( self._freq == "season" @@ -1142,21 +1162,6 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: ): ds = self._drop_incomplete_seasons(ds) - # TODO: Deprecate incomplete_djf. Only run this is drop_incomplete_seasons - # is False and drop_incomplete_djf is True. - if ( - self._freq == "season" - and self._season_config.get("dec_mode") == "DJF" - and self._season_config.get("drop_incomplete_djf") is True - and self._season_config.get("drop_incomplete_seasons") is False - ): - ds = self._drop_incomplete_djf(ds) - - if self._mode == "climatology" and self._reference_period is not None: - ds = ds.sel( - {self.dim: slice(self._reference_period[0], self._reference_period[1])} - ) - return ds def _subset_coords_for_custom_seasons( @@ -1191,6 +1196,119 @@ def _subset_coords_for_custom_seasons( return ds_new + def _shift_custom_season_years(self, ds: xr.Dataset) -> xr.Dataset: + """Shifts the year for custom seasons spanning the calendar year. + + A season spans the calendar year if it includes "Jan" and "Jan" is not + the first month. For example, for + ``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]``: + - ["Nov", "Dec"] are from the previous year. + - ["Jan", "Feb", "Mar"] are from the current year. + + Therefore, ["Nov", "Dec"] need to be shifted a year forward for correct + grouping. + + Parameters + ---------- + ds : xr.Dataset + The Dataset with time coordinates. + + Returns + ------- + xr.Dataset + The Dataset with shifted time coordinates. + + Examples + -------- + + Before and after shifting months for "NDJFM" seasons: + + >>> # Before shifting months + >>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] + + >>> # After shifting months + >>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] + """ + ds_new = ds.copy() + custom_seasons = self._season_config["custom_seasons"] + + span_months: List[int] = [] + + # Identify the months that span across years in custom seasons. + # This is done by checking if "Jan" is not the first month in the + # custom season and getting all months before "Jan". + for months in custom_seasons.values(): # type: ignore + month_nums = [MONTH_STR_TO_INT[month] for month in months] + if 1 in month_nums: + jan_index = month_nums.index(1) + + if jan_index != 0: + span_months.extend(month_nums[:jan_index]) + break + + if span_months: + time_coords = ds_new[self.dim].copy() + idxs = np.where(time_coords.dt.month.isin(span_months))[0] + + if isinstance(time_coords.values[0], cftime.datetime): + for idx in idxs: + time_coords.values[idx] = time_coords.values[idx].replace( + year=time_coords.values[idx].year + 1 + ) + else: + for idx in idxs: + time_coords.values[idx] = pd.Timestamp( + time_coords.values[idx] + ) + pd.DateOffset(years=1) + + ds_new = ds_new.assign_coords({self.dim: time_coords}) + + return ds_new + + def _shift_djf_decembers(self, ds: xr.Dataset) -> xr.Dataset: + """Shifts Decembers to the next year for "DJF" seasons. + + This ensures correct grouping for "DJF" seasons by shifting Decembers + to the next year. Without this, grouping defaults to "JFD", which + is the native Xarray behavior. + + Parameters + ---------- + ds : xr.Dataset + The Dataset with time coordinates. + + Returns + ------- + xr.Dataset + The Dataset with shifted time coordinates. + + Examples + -------- + + Comparison of "JFD" and "DJF" seasons: + + >>> # "JFD" (native xarray behavior) + >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2000, "DJF", 12), + >>> (2001, "DJF", 1), (2001, "DJF", 2)] + + >>> # "DJF" (shifted Decembers) + >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12), + >>> (2001, "DJF", 1), (2001, "DJF", 2)] + """ + ds_new = ds.copy() + time_coords = ds_new[self.dim].copy() + dec_indexes = time_coords.dt.month == 12 + + time_coords.values[dec_indexes] = [ + time.replace(year=time.year + 1) for time in time_coords.values[dec_indexes] + ] + + ds_new = ds_new.assign_coords({self.dim: time_coords}) + + return ds_new + def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset: """Drops incomplete DJF seasons within a continuous time series. @@ -1612,7 +1730,9 @@ def _get_df_dt_components( elif self._mode == "group_average": df["month"] = time_coords[f"{self.dim}.month"].values - df = self._process_season_df(df) + custom_seasons = self._season_config.get("custom_seasons") + if custom_seasons is not None: + df = self._map_months_to_custom_seasons(df) if drop_obsolete_cols: df = self._drop_obsolete_columns(df) @@ -1620,33 +1740,6 @@ def _get_df_dt_components( return df - def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Processes a DataFrame of datetime components for the season frequency. - - Parameters - ---------- - df : xr.DataArray - A DataFrame of seasonal datetime components. - - Returns - ------- - pd.DataFrame - A DataFrame of seasonal datetime components. - """ - df_new = df.copy() - custom_seasons = self._season_config.get("custom_seasons") - dec_mode = self._season_config.get("dec_mode") - - if custom_seasons is not None: - df_new = self._map_months_to_custom_seasons(df_new) - df_new = self._shift_spanning_months(df_new) - else: - if dec_mode == "DJF": - df_new = self._shift_decembers(df_new) - - return df_new - def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame: """Maps the month column in the DataFrame to a custom season. @@ -1681,108 +1774,6 @@ def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame: return df_new - def _shift_spanning_months(self, df: pd.DataFrame) -> pd.DataFrame: - """Shifts months in seasons spanning the previous year to the next year. - - A season spans the previous year if it includes the month of "Jan" and - "Jan" is not the first month of the season. For example, let's say we - define ``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]`` to - represent the southern hemisphere growing seasons, "NDJFM". - - ["Nov", "Dec"] are from the previous year since they are listed - before "Jan". - - ["Jan", "Feb", "Mar"] are from the current year. - - Therefore, we need to shift ["Nov", "Dec"] a year forward in order for - xarray to group seasons correctly. Refer to the examples section below - for a visual demonstration. - - Parameters - ---------- - df : pd.Dataframe - The DataFrame of xarray datetime components produced using the - "season" frequency". - - Returns - ------- - pd.DataFrame - The DataFrame of xarray dataetime copmonents with months spanning - previous year shifted over to the next year. - - Examples - -------- - - Before and after shifting months for "NDJFM" seasons: - - >>> # Before shifting months - >>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1), - >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] - - >>> # After shifting months - >>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1), - >>> (2001, "NDJFM", 1), (2001, "NDJFM", 2)] - """ - df_new = df.copy() - custom_seasons = self._season_config["custom_seasons"] - - span_months: List[int] = [] - - # Loop over the custom seasons and get the list of months for the - # current season. Convert those months to their integer representations. - # If 1 ("Jan") is in the list of months and it is NOT the first element, - # then get all elements before it (aka the spanning months). - for months in custom_seasons.values(): # type: ignore - month_nums = [MONTH_STR_TO_INT[month] for month in months] - try: - jan_index = month_nums.index(1) - if jan_index != 0: - span_months = span_months + month_nums[:jan_index] - break - except ValueError: - continue - - if len(span_months) > 0: - df_new.loc[df_new["month"].isin(span_months), "year"] = df_new["year"] + 1 - - return df_new - - def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame: - """Shifts Decembers over to the next year for "DJF" seasons in-place. - - For "DJF" seasons, Decembers must be shifted over to the next year in - order for the xarray groupby operation to correctly label and group the - corresponding time coordinates. If the aren't shifted over, grouping is - incorrectly performed with the native xarray "DJF" season (which is - actually "JFD"). - - Parameters - ---------- - df_season : pd.DataFrame - The DataFrame of xarray datetime components produced using the - "season" frequency. - - Returns - ------- - pd.DataFrame - The DataFrame of xarray datetime components with Decembers shifted - over to the next year. - - Examples - -------- - - Comparison of "JFD" and "DJF" seasons: - - >>> # "JFD" (native xarray behavior) - >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2000, "DJF", 12), - >>> (2001, "DJF", 1), (2001, "DJF", 2)] - - >>> # "DJF" (shifted Decembers) - >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12), - >>> (2001, "DJF", 1), (2001, "DJF", 2)] - """ - df_season.loc[df_season["month"] == 12, "year"] = df_season["year"] + 1 - - return df_season - def _map_seasons_to_mid_months(self, df: pd.DataFrame) -> pd.DataFrame: """Maps the season column values to the integer of its middle month. From 8d156c2a9311d3652748a343176f93f6d1756d85 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Tue, 12 Nov 2024 12:37:17 -0800 Subject: [PATCH 16/17] Add todo comments --- xcdat/temporal.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 810e1c63..821cef18 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -335,7 +335,6 @@ def group_average( representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') - * Each month must be included once in a custom season * Order of the months in each custom season does not matter * Custom seasons can vary in length @@ -529,7 +528,6 @@ def climatology( representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') - * Each month must be included once in a custom season * Order of the months in each custom season does not matter * Custom seasons can vary in length @@ -739,7 +737,6 @@ def departures( representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') - * Each month must be included once in a custom season * Order of the months in each custom season does not matter * Custom seasons can vary in length @@ -1381,6 +1378,11 @@ def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: pd.DataFrame A DataFrame of seasonal datetime components with only complete seasons. + + Notes + ----- + TODO: Refactor this method to use pure Xarray/NumPy operations, rather + than Pandas. """ # Transform the time coords into a DataFrame of seasonal datetime # components based on the grouping mode. From f2648ffb1530a7f201f06739eca38aebf9871b90 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Wed, 20 Nov 2024 13:06:50 -0800 Subject: [PATCH 17/17] Clean up logic in various private methods - Methods include `_subset_coords_for_custom_seasons()` and `_shift_custom_season_years()` --- xcdat/temporal.py | 72 +++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 821cef18..3367e06a 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1099,11 +1099,12 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: """Preprocess the dataset based on averaging settings. Operations include: - 1. Drop leap days for daily climatologies. - 2. Subset the dataset based on the reference period. - 3. Shift years for custom seasons spanning the calendar year. - 4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons. - 5. Drop incomplete seasons if specified. + 1. Drop leap days for daily climatologies. + 2. Subset the dataset based on the reference period. + 3. Shift years for custom seasons spanning the calendar year. + 4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons, + if specified. + 5. Drop incomplete seasons if specified. Parameters ---------- @@ -1141,6 +1142,9 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: if len(months) != 12: ds = self._subset_coords_for_custom_seasons(ds, months) + # The years for time coordinates needs to be shifted by 1 for months + # that span the calendar because Xarray groups seasons by months + # in the same year, rather than the previous year. ds = self._shift_custom_season_years(ds) if self._freq == "season" and self._season_config.get("dec_mode") == "DJF": @@ -1180,16 +1184,8 @@ def _subset_coords_for_custom_seasons( The dataset with time coordinate subsetted to months used in custom seasons. """ - month_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) - - coords_by_month = ds[self.dim].groupby(f"{self.dim}.month").groups - month_to_time_idx = { - k: coords_by_month[k] for k in month_ints if k in coords_by_month - } - month_to_time_idx = sorted( - list(chain.from_iterable(month_to_time_idx.values())) # type: ignore - ) - ds_new = ds.isel({f"{self.dim}": month_to_time_idx}) + month_ints = [MONTH_STR_TO_INT[month] for month in months] + ds_new = ds.sel({self.dim: ds[self.dim].dt.month.isin(month_ints)}) return ds_new @@ -1231,34 +1227,31 @@ def _shift_custom_season_years(self, ds: xr.Dataset) -> xr.Dataset: ds_new = ds.copy() custom_seasons = self._season_config["custom_seasons"] + # Identify months that span across years in custom seasons by getting + # the months before "Jan" if "Jan" is not the first month of the season. + # Note: Only one custom season can span the calendar year. span_months: List[int] = [] - - # Identify the months that span across years in custom seasons. - # This is done by checking if "Jan" is not the first month in the - # custom season and getting all months before "Jan". for months in custom_seasons.values(): # type: ignore - month_nums = [MONTH_STR_TO_INT[month] for month in months] - if 1 in month_nums: - jan_index = month_nums.index(1) + month_ints = [MONTH_STR_TO_INT[month] for month in months] - if jan_index != 0: - span_months.extend(month_nums[:jan_index]) + if 1 in month_ints and month_ints.index(1) != 0: + span_months.extend(month_ints[: month_ints.index(1)]) break if span_months: time_coords = ds_new[self.dim].copy() - idxs = np.where(time_coords.dt.month.isin(span_months))[0] + indexes = time_coords.dt.month.isin(span_months) if isinstance(time_coords.values[0], cftime.datetime): - for idx in idxs: - time_coords.values[idx] = time_coords.values[idx].replace( - year=time_coords.values[idx].year + 1 - ) + time_coords.values[indexes] = [ + time.replace(year=time.year + 1) + for time in time_coords.values[indexes] + ] else: - for idx in idxs: - time_coords.values[idx] = pd.Timestamp( - time_coords.values[idx] - ) + pd.DateOffset(years=1) + time_coords.values[indexes] = [ + pd.Timestamp(time) + pd.DateOffset(years=1) + for time in time_coords.values[indexes] + ] ds_new = ds_new.assign_coords({self.dim: time_coords}) @@ -1298,9 +1291,16 @@ def _shift_djf_decembers(self, ds: xr.Dataset) -> xr.Dataset: time_coords = ds_new[self.dim].copy() dec_indexes = time_coords.dt.month == 12 - time_coords.values[dec_indexes] = [ - time.replace(year=time.year + 1) for time in time_coords.values[dec_indexes] - ] + if isinstance(time_coords.values[0], cftime.datetime): + time_coords.values[dec_indexes] = [ + time.replace(year=time.year + 1) + for time in time_coords.values[dec_indexes] + ] + else: + time_coords.values[dec_indexes] = [ + pd.Timestamp(time) + pd.DateOffset(years=1) + for time in time_coords.values[dec_indexes] + ] ds_new = ds_new.assign_coords({self.dim: time_coords})