Skip to content

Commit 9c4aa48

Browse files
committed
Add skipna arg to all temporal APIs
- Add unit tests
1 parent f74699c commit 9c4aa48

File tree

3 files changed

+220
-12
lines changed

3 files changed

+220
-12
lines changed

tests/test_spatial.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,28 @@ def test_spatial_average_for_lat_region(self):
180180

181181
assert result.identical(expected)
182182

183+
def test_spatial_average_for_lat_region_and_skipna(self):
184+
ds = self.ds.copy(deep=True)
185+
ds.ts[0] = np.nan
186+
187+
# Specifying axis as a str instead of list of str.
188+
result = ds.spatial.average("ts", axis=["Y"], lat_bounds=(-5.0, 5), skipna=True)
189+
190+
expected = self.ds.copy()
191+
expected["ts"] = xr.DataArray(
192+
data=np.array(
193+
[
194+
[np.nan, np.nan, np.nan, np.nan],
195+
[1.0, 1.0, 1.0, 1.0],
196+
[1.0, 1.0, 1.0, 1.0],
197+
]
198+
),
199+
coords={"time": expected.time, "lon": expected.lon},
200+
dims=["time", "lon"],
201+
)
202+
203+
assert result.identical(expected)
204+
183205
def test_spatial_average_for_domain_wrapping_p_meridian_non_cf_conventions(
184206
self,
185207
):

tests/test_temporal.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,57 @@ def test_weighted_annual_averages(self):
520520
assert result.ts.attrs == expected.ts.attrs
521521
assert result.time.attrs == expected.time.attrs
522522

523+
def test_weighted_annual_averages_and_skipna(self):
524+
ds = self.ds.copy(deep=True)
525+
ds.ts[0] = np.nan
526+
527+
result = ds.temporal.group_average("ts", "year", skipna=True)
528+
expected = ds.copy()
529+
expected = expected.drop_dims("time")
530+
expected["ts"] = xr.DataArray(
531+
name="ts",
532+
data=np.array([[[1]], [[2.0]]]),
533+
coords={
534+
"lat": expected.lat,
535+
"lon": expected.lon,
536+
"time": xr.DataArray(
537+
data=np.array(
538+
[
539+
cftime.DatetimeGregorian(2000, 1, 1),
540+
cftime.DatetimeGregorian(2001, 1, 1),
541+
],
542+
),
543+
coords={
544+
"time": np.array(
545+
[
546+
cftime.DatetimeGregorian(2000, 1, 1),
547+
cftime.DatetimeGregorian(2001, 1, 1),
548+
],
549+
)
550+
},
551+
dims=["time"],
552+
attrs={
553+
"axis": "T",
554+
"long_name": "time",
555+
"standard_name": "time",
556+
"bounds": "time_bnds",
557+
},
558+
),
559+
},
560+
dims=["time", "lat", "lon"],
561+
attrs={
562+
"test_attr": "test",
563+
"operation": "temporal_avg",
564+
"mode": "group_average",
565+
"freq": "year",
566+
"weighted": "True",
567+
},
568+
)
569+
570+
xr.testing.assert_allclose(result, expected)
571+
assert result.ts.attrs == expected.ts.attrs
572+
assert result.time.attrs == expected.time.attrs
573+
523574
@requires_dask
524575
def test_weighted_annual_averages_with_chunking(self):
525576
ds = self.ds.copy().chunk({"time": 2})
@@ -1161,6 +1212,68 @@ def test_weighted_seasonal_climatology_with_DJF(self):
11611212

11621213
xr.testing.assert_identical(result, expected)
11631214

1215+
def test_weighted_seasonal_climatology_with_DJF_and_skipna(self):
1216+
ds = self.ds.copy(deep=True)
1217+
1218+
# Replace all MAM values with np.nan.
1219+
djf_months = [3, 4, 5]
1220+
for mon in djf_months:
1221+
ds["ts"] = ds.ts.where(ds.ts.time.dt.month != mon, np.nan)
1222+
1223+
result = ds.temporal.climatology(
1224+
"ts",
1225+
"season",
1226+
season_config={"dec_mode": "DJF", "drop_incomplete_djf": True},
1227+
skipna=True,
1228+
)
1229+
1230+
expected = ds.copy()
1231+
expected = expected.drop_dims("time")
1232+
expected_time = xr.DataArray(
1233+
data=np.array(
1234+
[
1235+
cftime.DatetimeGregorian(1, 1, 1),
1236+
cftime.DatetimeGregorian(1, 4, 1),
1237+
cftime.DatetimeGregorian(1, 7, 1),
1238+
cftime.DatetimeGregorian(1, 10, 1),
1239+
],
1240+
),
1241+
coords={
1242+
"time": np.array(
1243+
[
1244+
cftime.DatetimeGregorian(1, 1, 1),
1245+
cftime.DatetimeGregorian(1, 4, 1),
1246+
cftime.DatetimeGregorian(1, 7, 1),
1247+
cftime.DatetimeGregorian(1, 10, 1),
1248+
],
1249+
),
1250+
},
1251+
attrs={
1252+
"axis": "T",
1253+
"long_name": "time",
1254+
"standard_name": "time",
1255+
"bounds": "time_bnds",
1256+
},
1257+
)
1258+
expected["ts"] = xr.DataArray(
1259+
name="ts",
1260+
data=np.ones((4, 4, 4)),
1261+
coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time},
1262+
dims=["time", "lat", "lon"],
1263+
attrs={
1264+
"operation": "temporal_avg",
1265+
"mode": "climatology",
1266+
"freq": "season",
1267+
"weighted": "True",
1268+
"dec_mode": "DJF",
1269+
"drop_incomplete_djf": "True",
1270+
},
1271+
)
1272+
expected.ts[1] = np.nan
1273+
1274+
# MAM should be np.nan
1275+
assert result.identical(expected)
1276+
11641277
@requires_dask
11651278
def test_chunked_weighted_seasonal_climatology_with_DJF(self):
11661279
ds = self.ds.copy().chunk({"time": 2})
@@ -1947,6 +2060,62 @@ def test_weighted_seasonal_departures_with_DJF(self):
19472060

19482061
xr.testing.assert_identical(result, expected)
19492062

2063+
def test_weighted_seasonal_departures_with_DJF_and_skipna(self):
2064+
ds = self.ds.copy(deep=True)
2065+
2066+
# Replace all MAM values with np.nan.
2067+
djf_months = [3, 4, 5]
2068+
for mon in djf_months:
2069+
ds["ts"] = ds.ts.where(ds.ts.time.dt.month != mon, np.nan)
2070+
2071+
result = ds.temporal.departures(
2072+
"ts",
2073+
"season",
2074+
weighted=True,
2075+
season_config={"dec_mode": "DJF", "drop_incomplete_djf": True},
2076+
skipna=True,
2077+
)
2078+
2079+
expected = ds.copy()
2080+
expected = expected.drop_dims("time")
2081+
expected["ts"] = xr.DataArray(
2082+
name="ts",
2083+
data=np.array([[[np.nan]], [[0.0]], [[0.0]], [[0.0]]]),
2084+
coords={
2085+
"lat": expected.lat,
2086+
"lon": expected.lon,
2087+
"time": xr.DataArray(
2088+
data=np.array(
2089+
[
2090+
cftime.DatetimeGregorian(2000, 4, 1),
2091+
cftime.DatetimeGregorian(2000, 7, 1),
2092+
cftime.DatetimeGregorian(2000, 10, 1),
2093+
cftime.DatetimeGregorian(2001, 1, 1),
2094+
],
2095+
),
2096+
dims=["time"],
2097+
attrs={
2098+
"axis": "T",
2099+
"long_name": "time",
2100+
"standard_name": "time",
2101+
"bounds": "time_bnds",
2102+
},
2103+
),
2104+
},
2105+
dims=["time", "lat", "lon"],
2106+
attrs={
2107+
"test_attr": "test",
2108+
"operation": "temporal_avg",
2109+
"mode": "departures",
2110+
"freq": "season",
2111+
"weighted": "True",
2112+
"dec_mode": "DJF",
2113+
"drop_incomplete_djf": "True",
2114+
},
2115+
)
2116+
2117+
assert result.identical(expected)
2118+
19502119
def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self):
19512120
ds = self.ds.copy()
19522121

xcdat/temporal.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def average(
160160
data_var: str,
161161
weighted: bool = True,
162162
keep_weights: bool = False,
163-
skipna: Union[bool, None] = None,
163+
skipna: bool | None = None,
164164
):
165165
"""
166166
Returns a Dataset with the average of a data variable and the time
@@ -202,7 +202,7 @@ def average(
202202
keep_weights : bool, optional
203203
If calculating averages using weights, keep the weights in the
204204
final dataset output, by default False.
205-
skipna : bool or None, optional
205+
skipna : bool | None, optional
206206
If True, skip missing values (as marked by NaN). By default, only
207207
skips missing values for float dtypes; other dtypes either do not
208208
have a sentinel missing value (int) or ``skipna=True`` has not been
@@ -257,6 +257,7 @@ def group_average(
257257
weighted: bool = True,
258258
keep_weights: bool = False,
259259
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
260+
skipna: bool | None = None,
260261
):
261262
"""Returns a Dataset with average of a data variable by time group.
262263
@@ -335,6 +336,11 @@ def group_average(
335336
>>> ["Jul", "Aug", "Sep"], # "JulAugSep"
336337
>>> ["Oct", "Nov", "Dec"], # "OctNovDec"
337338
>>> ]
339+
skipna : bool | None, optional
340+
If True, skip missing values (as marked by NaN). By default, only
341+
skips missing values for float dtypes; other dtypes either do not
342+
have a sentinel missing value (int) or ``skipna=True`` has not been
343+
implemented (object, datetime64 or timedelta64).
338344
339345
Returns
340346
-------
@@ -413,6 +419,7 @@ def group_average(
413419
weighted=weighted,
414420
keep_weights=keep_weights,
415421
season_config=season_config,
422+
skipna=skipna,
416423
)
417424

418425
def climatology(
@@ -423,6 +430,7 @@ def climatology(
423430
keep_weights: bool = False,
424431
reference_period: Optional[Tuple[str, str]] = None,
425432
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
433+
skipna: bool | None = None,
426434
):
427435
"""Returns a Dataset with the climatology of a data variable.
428436
@@ -510,6 +518,11 @@ def climatology(
510518
>>> ["Jul", "Aug", "Sep"], # "JulAugSep"
511519
>>> ["Oct", "Nov", "Dec"], # "OctNovDec"
512520
>>> ]
521+
skipna : bool | None, optional
522+
If True, skip missing values (as marked by NaN). By default, only
523+
skips missing values for float dtypes; other dtypes either do not
524+
have a sentinel missing value (int) or ``skipna=True`` has not been
525+
implemented (object, datetime64 or timedelta64).
513526
514527
Returns
515528
-------
@@ -593,6 +606,7 @@ def climatology(
593606
keep_weights,
594607
reference_period,
595608
season_config,
609+
skipna,
596610
)
597611

598612
def departures(
@@ -603,6 +617,7 @@ def departures(
603617
keep_weights: bool = False,
604618
reference_period: Optional[Tuple[str, str]] = None,
605619
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
620+
skipna: bool | None = None,
606621
) -> xr.Dataset:
607622
"""
608623
Returns a Dataset with the climatological departures (anomalies) for a
@@ -697,6 +712,11 @@ def departures(
697712
>>> ["Jul", "Aug", "Sep"], # "JulAugSep"
698713
>>> ["Oct", "Nov", "Dec"], # "OctNovDec"
699714
>>> ]
715+
skipna : bool | None, optional
716+
If True, skip missing values (as marked by NaN). By default, only
717+
skips missing values for float dtypes; other dtypes either do not
718+
have a sentinel missing value (int) or ``skipna=True`` has not been
719+
implemented (object, datetime64 or timedelta64).
700720
701721
Returns
702722
-------
@@ -777,11 +797,7 @@ def departures(
777797
inferred_freq = _infer_freq(ds[self.dim])
778798
if inferred_freq != freq:
779799
ds_obs = ds_obs.temporal.group_average(
780-
data_var,
781-
freq,
782-
weighted,
783-
keep_weights,
784-
season_config,
800+
data_var, freq, weighted, keep_weights, season_config, skipna
785801
)
786802

787803
# 4. Calculate the climatology of the data variable.
@@ -794,6 +810,7 @@ def departures(
794810
keep_weights,
795811
reference_period,
796812
season_config,
813+
skipna,
797814
)
798815

799816
# 5. Calculate the departures for the data variable.
@@ -815,7 +832,7 @@ def _averager(
815832
keep_weights: bool = False,
816833
reference_period: Optional[Tuple[str, str]] = None,
817834
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
818-
skipna: Union[bool, None] = None,
835+
skipna: bool | None = None,
819836
) -> xr.Dataset:
820837
"""Averages a data variable based on the averaging mode and frequency."""
821838
ds = self._dataset.copy()
@@ -1141,7 +1158,7 @@ def _drop_leap_days(self, ds: xr.Dataset):
11411158
return ds
11421159

11431160
def _average(
1144-
self, ds: xr.Dataset, data_var: str, skipna: Union[bool, None] = None
1161+
self, ds: xr.Dataset, data_var: str, skipna: bool | None = None
11451162
) -> xr.DataArray:
11461163
"""Averages a data variable with the time dimension removed.
11471164
@@ -1151,7 +1168,7 @@ def _average(
11511168
The dataset.
11521169
data_var : str
11531170
The key of the data variable.
1154-
skipna : bool or None, optional
1171+
skipna : bool | None, optional
11551172
If True, skip missing values (as marked by NaN). By default, only
11561173
skips missing values for float dtypes; other dtypes either do not
11571174
have a sentinel missing value (int) or ``skipna=True`` has not been
@@ -1178,7 +1195,7 @@ def _average(
11781195
return dv
11791196

11801197
def _group_average(
1181-
self, ds: xr.Dataset, data_var: str, skipna: Union[bool, None] = None
1198+
self, ds: xr.Dataset, data_var: str, skipna: bool | None = None
11821199
) -> xr.DataArray:
11831200
"""Averages a data variable by time group.
11841201
@@ -1188,7 +1205,7 @@ def _group_average(
11881205
The dataset.
11891206
data_var : str
11901207
The key of the data variable.
1191-
skipna : bool or None, optional
1208+
skipna : bool | None, optional
11921209
If True, skip missing values (as marked by NaN). By default, only
11931210
skips missing values for float dtypes; other dtypes either do not
11941211
have a sentinel missing value (int) or ``skipna=True`` has not been

0 commit comments

Comments
 (0)