From aada458d60063fbcfa645a92cc9586f3fa8b156c Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Thu, 2 Mar 2023 16:53:37 -0800 Subject: [PATCH] Add support for shifting months in seasons spanning calendar year --- tests/test_temporal.py | 154 +++++++++++++++++++++++++++-------------- xcdat/temporal.py | 61 +++++++++++++++- 2 files changed, 162 insertions(+), 53 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 12149641..06696c41 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -731,6 +731,35 @@ def test_weighted_seasonal_averages_with_JFD(self): assert result.identical(expected) + def test_raises_error_with_incorrect_custom_seasons_argument(self): + # Test raises error with non-3 letter strings + with pytest.raises(ValueError): + custom_seasons = [ + ["J", "Feb", "Mar"], + ["Apr", "May", "Jun"], + ["Jul", "Aug", "Sep"], + ["Oct", "Nov", "Dec"], + ] + self.ds.temporal.group_average( + "ts", + "season", + season_config={"custom_seasons": custom_seasons}, + ) + + # Test raises error if duplicate month(s) were found + with pytest.raises(ValueError): + custom_seasons = [ + ["Jan", "Jan", "Mar"], + ["Apr", "May", "Jun"], + ["Jul", "Aug", "Sep"], + ["Oct", "Nov", "Dec"], + ] + self.ds.temporal.group_average( + "ts", + "season", + season_config={"custom_seasons": custom_seasons}, + ) + def test_weighted_custom_seasonal_averages(self): ds = self.ds.copy() @@ -789,13 +818,21 @@ def test_weighted_custom_seasonal_averages(self): assert result.identical(expected) - def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_and_drop_incomplete_seasons( - self, - ): + def test_weighted_seasonal_averages_drops_incomplete_seasons(self): ds = self.ds.copy() + ds["time"].values[:] = np.array( + [ + "2000-11-16T12:00:00.000000000", + "2000-12-16T12:00:00.000000000", + "2001-01-16T00:00:00.000000000", + "2001-02-16T00:00:00.000000000", + "2001-03-16T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ) + custom_seasons = [ - ["Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov"], - ["Dec", "Jan", "Feb", "Mar"], + ["Nov", "Dec", "Jan", "Feb", "Mar"], ] result = ds.temporal.group_average( @@ -803,26 +840,19 @@ def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_ "season", season_config={ "custom_seasons": custom_seasons, - "drop_incomplete_seasons": True, + # "drop_incomplete_seasons": True, }, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[1.5]], [[1.0]], [[1.0]], [[2.0]]]), + data=np.array([[[1.3933333]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( - data=np.array( - [ - cftime.DatetimeGregorian(2000, 2, 1), - cftime.DatetimeGregorian(2000, 5, 1), - cftime.DatetimeGregorian(2000, 8, 1), - cftime.DatetimeGregorian(2001, 2, 1), - ], - ), + data=np.array([cftime.datetime(2001, 1, 1)], dtype=object), dims=["time"], attrs={ "axis": "T", @@ -838,50 +868,70 @@ def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_ "operation": "temporal_avg", "mode": "group_average", "freq": "season", - "custom_seasons": [ - "JanFebMar", - "AprMayJun", - "JulAugSep", - "OctNovDec", - ], + "custom_seasons": ["NovDecJanFebMar"], "weighted": "True", }, ) - assert result.identical(expected) + xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs - @pytest.mark.xfail - def test_weighted_custom_seasonal_averages_with_single_custom_season(self): - assert 0 + def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years( + self, + ): + ds = self.ds.copy() + ds["time"].values[:] = np.array( + [ + "2000-11-16T12:00:00.000000000", + "2000-12-16T12:00:00.000000000", + "2001-01-16T00:00:00.000000000", + "2001-02-16T00:00:00.000000000", + "2001-03-16T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ) - def test_raises_error_with_incorrect_custom_seasons_argument(self): - # Test raises error with non-3 letter strings - with pytest.raises(ValueError): - custom_seasons = [ - ["J", "Feb", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] - self.ds.temporal.group_average( - "ts", - "season", - season_config={"custom_seasons": custom_seasons}, - ) + custom_seasons = [ + ["Nov", "Dec", "Jan", "Feb", "Mar"], + ] - # Test raises error if duplicate month(s) were found - with pytest.raises(ValueError): - custom_seasons = [ - ["Jan", "Jan", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] - self.ds.temporal.group_average( - "ts", - "season", - season_config={"custom_seasons": custom_seasons}, - ) + result = ds.temporal.group_average( + "ts", + "season", + season_config={"custom_seasons": custom_seasons}, + ) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[1.3933333]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array([cftime.datetime(2001, 1, 1)], dtype=object), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", + "custom_seasons": ["NovDecJanFebMar"], + "weighted": "True", + }, + ) + + xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs def test_weighted_monthly_averages(self): ds = self.ds.copy() diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 109e22ff..b0f88242 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -103,6 +103,7 @@ 11: "Nov", 12: "Dec", } +MONTH_STR_TO_INT = {v: k for k, v in MONTH_INT_TO_STR.items()} # A dictionary mapping pre-defined seasons to their middle month. This # dictionary is used during the creation of datetime objects, which don't @@ -1459,6 +1460,7 @@ def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame: if custom_seasons is not None: df_new = self._map_months_to_custom_seasons(df_new) + df_new = self._shift_spanning_months(df_new) else: if dec_mode == "DJF": df_new = self._shift_decembers(df_new) @@ -1501,6 +1503,64 @@ def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame: return df_new + def _shift_spanning_months(self, df: pd.DataFrame) -> pd.DataFrame: + """Shifts months in seasons spanning the previous year to the next year. + + A season spans the previous year if it includes the month of "Jan" and + "Jan" is not the first month of the season. For example, let's say we + define ``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]`` to + represent the southern hemisphere growing seasons, "NDJFM". + - ["Nov", "Dec"] are from the previous year since they are listed + before "Jan". + - ["Jan", "Feb", "Mar"] are from the current year. + + Therefore, we need to shift ["Nov", "Dec"] a year forward in order for + xarray to group seasons correctly. Refer to the examples section below + for a visual demonstration. + + Parameters + ---------- + df : pd.Dataframe + The DataFrame of xarray datetime components produced using the + "season" frequency". + + Returns + ------- + pd.DataFrame + The DataFrame of xarray dataetime copmonents with months spanning + previous year shifted over to the next year. + + Examples + -------- + + Before and after shifting months for "NDJFM" seasons: + + >>> # Before shifting months + >>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] + + >>> # After shifting months + >>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 1), (2001, "NDJFM", 2)] + """ + df_new = df.copy() + custom_seasons = self._season_config["custom_seasons"] + + span_months: List[int] = [] + for months in custom_seasons.values(): # type: ignore + month_nums = [MONTH_STR_TO_INT[month] for month in months] + try: + jan_index = month_nums.index(1) + span_months = span_months + month_nums[:jan_index] + break + except ValueError: + continue + + if len(span_months) > 0: + df_new.loc[df_new["month"].isin(span_months), "year"] = df_new["year"] + 1 + + return df_new + def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame: """Shifts Decembers over to the next year for "DJF" seasons in-place. @@ -1534,7 +1594,6 @@ def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame: >>> # "DJF" (shifted Decembers) >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12), >>> (2001, "DJF", 1), (2001, "DJF", 2)] - """ df_season.loc[df_season["month"] == 12, "year"] = df_season["year"] + 1