Skip to content

Commit

Permalink
Add baseline unit tests
Browse files Browse the repository at this point in the history
- Remove logic for requiring all 12 months to be used
  • Loading branch information
tomvothecoder committed Apr 4, 2024
1 parent 472c04b commit 6c98bbd
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 18 deletions.
150 changes: 136 additions & 14 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,71 @@ def test_weighted_custom_seasonal_averages(self):

assert result.identical(expected)

def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years_and_drop_incomplete_seasons(
self,
):
ds = self.ds.copy()
custom_seasons = [
["Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov"],
["Dec", "Jan", "Feb", "Mar"],
]

result = ds.temporal.group_average(
"ts",
"season",
season_config={
"custom_seasons": custom_seasons,
"drop_incomplete_seasons": True,
},
)
expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[1.5]], [[1.0]], [[1.0]], [[2.0]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 2, 1),
cftime.DatetimeGregorian(2000, 5, 1),
cftime.DatetimeGregorian(2000, 8, 1),
cftime.DatetimeGregorian(2001, 2, 1),
],
),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "group_average",
"freq": "season",
"custom_seasons": [
"JanFebMar",
"AprMayJun",
"JulAugSep",
"OctNovDec",
],
"weighted": "True",
},
)

assert result.identical(expected)

@pytest.mark.xfail
def test_weighted_custom_seasonal_averages_with_single_custom_season(self):
assert 0

def test_raises_error_with_incorrect_custom_seasons_argument(self):
# Test raises error with non-3 letter strings
with pytest.raises(ValueError):
Expand All @@ -804,20 +869,6 @@ def test_raises_error_with_incorrect_custom_seasons_argument(self):
season_config={"custom_seasons": custom_seasons},
)

# Test raises error with missing month(s)
with pytest.raises(ValueError):
custom_seasons = [
["Feb", "Mar"],
["Apr", "May", "Jun"],
["Jul", "Aug", "Sep"],
["Oct", "Nov", "Dec"],
]
self.ds.temporal.group_average(
"ts",
"season",
season_config={"custom_seasons": custom_seasons},
)

# Test raises error if duplicate month(s) were found
with pytest.raises(ValueError):
custom_seasons = [
Expand Down Expand Up @@ -1330,6 +1381,77 @@ def test_weighted_custom_seasonal_climatology(self):

assert result.identical(expected)

@pytest.mark.xfail
def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_years_and_drop_incomplete_seasons(
self,
):
ds = self.ds.copy()

custom_seasons = [
["Jan", "Feb", "Mar"],
["Apr", "May", "Jun"],
["Jul", "Aug", "Sep"],
["Oct", "Nov", "Dec"],
]
result = ds.temporal.climatology(
"ts",
"season",
season_config={
"custom_seasons": custom_seasons,
"drop_incomplete_seasons": True,
},
)

expected = ds.copy()
expected = expected.drop_dims("time")
expected_time = xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(1, 2, 1),
cftime.DatetimeGregorian(1, 5, 1),
cftime.DatetimeGregorian(1, 8, 1),
cftime.DatetimeGregorian(1, 11, 1),
],
),
coords={
"time": np.array(
[
cftime.DatetimeGregorian(1, 2, 1),
cftime.DatetimeGregorian(1, 5, 1),
cftime.DatetimeGregorian(1, 8, 1),
cftime.DatetimeGregorian(1, 11, 1),
],
),
},
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
)

expected["ts"] = xr.DataArray(
name="ts",
data=np.ones((4, 4, 4)),
coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time},
dims=["time", "lat", "lon"],
attrs={
"operation": "temporal_avg",
"mode": "climatology",
"freq": "season",
"weighted": "True",
"custom_seasons": [
"JanFebMar",
"AprMayJun",
"JulAugSep",
"OctNovDec",
],
},
)

assert result.identical(expected)

def test_weighted_monthly_climatology(self):
result = self.ds.temporal.climatology("ts", "month")

Expand Down
4 changes: 0 additions & 4 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,10 +1022,6 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]]
predefined_months = list(MONTH_INT_TO_STR.values())
input_months = list(chain.from_iterable(custom_seasons))

if len(input_months) != len(predefined_months):
raise ValueError(
"Exactly 12 months were not passed in the list of custom seasons."
)
if len(input_months) != len(set(input_months)):
raise ValueError(
"Duplicate month(s) were found in the list of custom seasons."
Expand Down

0 comments on commit 6c98bbd

Please sign in to comment.