Skip to content

Commit

Permalink
Add support for shifting months in seasons spanning calendar year
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Apr 4, 2024
1 parent 6c98bbd commit aada458
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 53 deletions.
154 changes: 102 additions & 52 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,35 @@ def test_weighted_seasonal_averages_with_JFD(self):

assert result.identical(expected)

def test_raises_error_with_incorrect_custom_seasons_argument(self):
# Test raises error with non-3 letter strings
with pytest.raises(ValueError):
custom_seasons = [
["J", "Feb", "Mar"],
["Apr", "May", "Jun"],
["Jul", "Aug", "Sep"],
["Oct", "Nov", "Dec"],
]
self.ds.temporal.group_average(
"ts",
"season",
season_config={"custom_seasons": custom_seasons},
)

# Test raises error if duplicate month(s) were found
with pytest.raises(ValueError):
custom_seasons = [
["Jan", "Jan", "Mar"],
["Apr", "May", "Jun"],
["Jul", "Aug", "Sep"],
["Oct", "Nov", "Dec"],
]
self.ds.temporal.group_average(
"ts",
"season",
season_config={"custom_seasons": custom_seasons},
)

def test_weighted_custom_seasonal_averages(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -789,40 +818,41 @@ def test_weighted_custom_seasonal_averages(self):

assert result.identical(expected)

def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_and_drop_incomplete_seasons(
self,
):
def test_weighted_seasonal_averages_drops_incomplete_seasons(self):
ds = self.ds.copy()
ds["time"].values[:] = np.array(
[
"2000-11-16T12:00:00.000000000",
"2000-12-16T12:00:00.000000000",
"2001-01-16T00:00:00.000000000",
"2001-02-16T00:00:00.000000000",
"2001-03-16T00:00:00.000000000",
],
dtype="datetime64[ns]",
)

custom_seasons = [
["Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov"],
["Dec", "Jan", "Feb", "Mar"],
["Nov", "Dec", "Jan", "Feb", "Mar"],
]

result = ds.temporal.group_average(
"ts",
"season",
season_config={
"custom_seasons": custom_seasons,
"drop_incomplete_seasons": True,
# "drop_incomplete_seasons": True,
},
)
expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[1.5]], [[1.0]], [[1.0]], [[2.0]]]),
data=np.array([[[1.3933333]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 2, 1),
cftime.DatetimeGregorian(2000, 5, 1),
cftime.DatetimeGregorian(2000, 8, 1),
cftime.DatetimeGregorian(2001, 2, 1),
],
),
data=np.array([cftime.datetime(2001, 1, 1)], dtype=object),
dims=["time"],
attrs={
"axis": "T",
Expand All @@ -838,50 +868,70 @@ def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_
"operation": "temporal_avg",
"mode": "group_average",
"freq": "season",
"custom_seasons": [
"JanFebMar",
"AprMayJun",
"JulAugSep",
"OctNovDec",
],
"custom_seasons": ["NovDecJanFebMar"],
"weighted": "True",
},
)

assert result.identical(expected)
xr.testing.assert_allclose(result, expected)
assert result.ts.attrs == expected.ts.attrs

@pytest.mark.xfail
def test_weighted_custom_seasonal_averages_with_single_custom_season(self):
assert 0
def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years(
self,
):
ds = self.ds.copy()
ds["time"].values[:] = np.array(
[
"2000-11-16T12:00:00.000000000",
"2000-12-16T12:00:00.000000000",
"2001-01-16T00:00:00.000000000",
"2001-02-16T00:00:00.000000000",
"2001-03-16T00:00:00.000000000",
],
dtype="datetime64[ns]",
)

def test_raises_error_with_incorrect_custom_seasons_argument(self):
# Test raises error with non-3 letter strings
with pytest.raises(ValueError):
custom_seasons = [
["J", "Feb", "Mar"],
["Apr", "May", "Jun"],
["Jul", "Aug", "Sep"],
["Oct", "Nov", "Dec"],
]
self.ds.temporal.group_average(
"ts",
"season",
season_config={"custom_seasons": custom_seasons},
)
custom_seasons = [
["Nov", "Dec", "Jan", "Feb", "Mar"],
]

# Test raises error if duplicate month(s) were found
with pytest.raises(ValueError):
custom_seasons = [
["Jan", "Jan", "Mar"],
["Apr", "May", "Jun"],
["Jul", "Aug", "Sep"],
["Oct", "Nov", "Dec"],
]
self.ds.temporal.group_average(
"ts",
"season",
season_config={"custom_seasons": custom_seasons},
)
result = ds.temporal.group_average(
"ts",
"season",
season_config={"custom_seasons": custom_seasons},
)
expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[1.3933333]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array([cftime.datetime(2001, 1, 1)], dtype=object),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "group_average",
"freq": "season",
"custom_seasons": ["NovDecJanFebMar"],
"weighted": "True",
},
)

xr.testing.assert_allclose(result, expected)
assert result.ts.attrs == expected.ts.attrs

def test_weighted_monthly_averages(self):
ds = self.ds.copy()
Expand Down
61 changes: 60 additions & 1 deletion xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
11: "Nov",
12: "Dec",
}
MONTH_STR_TO_INT = {v: k for k, v in MONTH_INT_TO_STR.items()}

# A dictionary mapping pre-defined seasons to their middle month. This
# dictionary is used during the creation of datetime objects, which don't
Expand Down Expand Up @@ -1459,6 +1460,7 @@ def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame:

if custom_seasons is not None:
df_new = self._map_months_to_custom_seasons(df_new)
df_new = self._shift_spanning_months(df_new)
else:
if dec_mode == "DJF":
df_new = self._shift_decembers(df_new)
Expand Down Expand Up @@ -1501,6 +1503,64 @@ def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame:

return df_new

def _shift_spanning_months(self, df: pd.DataFrame) -> pd.DataFrame:
"""Shifts months in seasons spanning the previous year to the next year.
A season spans the previous year if it includes the month of "Jan" and
"Jan" is not the first month of the season. For example, let's say we
define ``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]`` to
represent the southern hemisphere growing seasons, "NDJFM".
- ["Nov", "Dec"] are from the previous year since they are listed
before "Jan".
- ["Jan", "Feb", "Mar"] are from the current year.
Therefore, we need to shift ["Nov", "Dec"] a year forward in order for
xarray to group seasons correctly. Refer to the examples section below
for a visual demonstration.
Parameters
----------
df : pd.Dataframe
The DataFrame of xarray datetime components produced using the
"season" frequency".
Returns
-------
pd.DataFrame
The DataFrame of xarray dataetime copmonents with months spanning
previous year shifted over to the next year.
Examples
--------
Before and after shifting months for "NDJFM" seasons:
>>> # Before shifting months
>>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1),
>>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)]
>>> # After shifting months
>>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1),
>>> (2001, "NDJFM", 1), (2001, "NDJFM", 2)]
"""
df_new = df.copy()
custom_seasons = self._season_config["custom_seasons"]

span_months: List[int] = []
for months in custom_seasons.values(): # type: ignore
month_nums = [MONTH_STR_TO_INT[month] for month in months]
try:
jan_index = month_nums.index(1)
span_months = span_months + month_nums[:jan_index]
break
except ValueError:
continue

if len(span_months) > 0:
df_new.loc[df_new["month"].isin(span_months), "year"] = df_new["year"] + 1

return df_new

def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame:
"""Shifts Decembers over to the next year for "DJF" seasons in-place.
Expand Down Expand Up @@ -1534,7 +1594,6 @@ def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame:
>>> # "DJF" (shifted Decembers)
>>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12),
>>> (2001, "DJF", 1), (2001, "DJF", 2)]
"""
df_season.loc[df_season["month"] == 12, "year"] = df_season["year"] + 1

Expand Down

0 comments on commit aada458

Please sign in to comment.