From 2ff2ba7fcf7e8cd7f048b56a27fcf653b734fc38 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Thu, 24 Oct 2024 13:31:42 -0700 Subject: [PATCH] Refactor logic for shifting months to use Xarray instead of Pandas - Months are also shifted in the `_preprocess_dataset()` method now. Before months were being shifted twice, once when dropping incomplete seasons or DJF, and a second time when labeling time coordinates. --- xcdat/temporal.py | 299 ++++++++++++++++++++++------------------------ 1 file changed, 145 insertions(+), 154 deletions(-) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 0163dac1..810e1c63 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1101,9 +1101,12 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]] def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: """Preprocess the dataset based on averaging settings. - Preprocessing operations include: - - Drop incomplete DJF seasons (leading/trailing) - - Drop leap days + Operations include: + 1. Drop leap days for daily climatologies. + 2. Subset the dataset based on the reference period. + 3. Shift years for custom seasons spanning the calendar year. + 4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons. + 5. Drop incomplete seasons if specified. Parameters ---------- @@ -1114,6 +1117,18 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: ------- xr.Dataset """ + if ( + self._freq == "day" + and self._mode in ["climatology", "departures"] + and self.calendar in ["gregorian", "proleptic_gregorian", "standard"] + ): + ds = self._drop_leap_days(ds) + + if self._mode == "climatology" and self._reference_period is not None: + ds = ds.sel( + {self.dim: slice(self._reference_period[0], self._reference_period[1])} + ) + if ( self._freq == "season" and self._season_config.get("custom_seasons") is not None @@ -1129,12 +1144,17 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: if len(months) != 12: ds = self._subset_coords_for_custom_seasons(ds, months) - if ( - self._freq == "day" - and self._mode in ["climatology", "departures"] - and self.calendar in ["gregorian", "proleptic_gregorian", "standard"] - ): - ds = self._drop_leap_days(ds) + ds = self._shift_custom_season_years(ds) + + if self._freq == "season" and self._season_config.get("dec_mode") == "DJF": + ds = self._shift_djf_decembers(ds) + + # TODO: Deprecate incomplete_djf. + if ( + self._season_config.get("drop_incomplete_djf") is True + and self._season_config.get("drop_incomplete_seasons") is False + ): + ds = self._drop_incomplete_djf(ds) if ( self._freq == "season" @@ -1142,21 +1162,6 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: ): ds = self._drop_incomplete_seasons(ds) - # TODO: Deprecate incomplete_djf. Only run this is drop_incomplete_seasons - # is False and drop_incomplete_djf is True. - if ( - self._freq == "season" - and self._season_config.get("dec_mode") == "DJF" - and self._season_config.get("drop_incomplete_djf") is True - and self._season_config.get("drop_incomplete_seasons") is False - ): - ds = self._drop_incomplete_djf(ds) - - if self._mode == "climatology" and self._reference_period is not None: - ds = ds.sel( - {self.dim: slice(self._reference_period[0], self._reference_period[1])} - ) - return ds def _subset_coords_for_custom_seasons( @@ -1191,6 +1196,119 @@ def _subset_coords_for_custom_seasons( return ds_new + def _shift_custom_season_years(self, ds: xr.Dataset) -> xr.Dataset: + """Shifts the year for custom seasons spanning the calendar year. + + A season spans the calendar year if it includes "Jan" and "Jan" is not + the first month. For example, for + ``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]``: + - ["Nov", "Dec"] are from the previous year. + - ["Jan", "Feb", "Mar"] are from the current year. + + Therefore, ["Nov", "Dec"] need to be shifted a year forward for correct + grouping. + + Parameters + ---------- + ds : xr.Dataset + The Dataset with time coordinates. + + Returns + ------- + xr.Dataset + The Dataset with shifted time coordinates. + + Examples + -------- + + Before and after shifting months for "NDJFM" seasons: + + >>> # Before shifting months + >>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] + + >>> # After shifting months + >>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] + """ + ds_new = ds.copy() + custom_seasons = self._season_config["custom_seasons"] + + span_months: List[int] = [] + + # Identify the months that span across years in custom seasons. + # This is done by checking if "Jan" is not the first month in the + # custom season and getting all months before "Jan". + for months in custom_seasons.values(): # type: ignore + month_nums = [MONTH_STR_TO_INT[month] for month in months] + if 1 in month_nums: + jan_index = month_nums.index(1) + + if jan_index != 0: + span_months.extend(month_nums[:jan_index]) + break + + if span_months: + time_coords = ds_new[self.dim].copy() + idxs = np.where(time_coords.dt.month.isin(span_months))[0] + + if isinstance(time_coords.values[0], cftime.datetime): + for idx in idxs: + time_coords.values[idx] = time_coords.values[idx].replace( + year=time_coords.values[idx].year + 1 + ) + else: + for idx in idxs: + time_coords.values[idx] = pd.Timestamp( + time_coords.values[idx] + ) + pd.DateOffset(years=1) + + ds_new = ds_new.assign_coords({self.dim: time_coords}) + + return ds_new + + def _shift_djf_decembers(self, ds: xr.Dataset) -> xr.Dataset: + """Shifts Decembers to the next year for "DJF" seasons. + + This ensures correct grouping for "DJF" seasons by shifting Decembers + to the next year. Without this, grouping defaults to "JFD", which + is the native Xarray behavior. + + Parameters + ---------- + ds : xr.Dataset + The Dataset with time coordinates. + + Returns + ------- + xr.Dataset + The Dataset with shifted time coordinates. + + Examples + -------- + + Comparison of "JFD" and "DJF" seasons: + + >>> # "JFD" (native xarray behavior) + >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2000, "DJF", 12), + >>> (2001, "DJF", 1), (2001, "DJF", 2)] + + >>> # "DJF" (shifted Decembers) + >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12), + >>> (2001, "DJF", 1), (2001, "DJF", 2)] + """ + ds_new = ds.copy() + time_coords = ds_new[self.dim].copy() + dec_indexes = time_coords.dt.month == 12 + + time_coords.values[dec_indexes] = [ + time.replace(year=time.year + 1) for time in time_coords.values[dec_indexes] + ] + + ds_new = ds_new.assign_coords({self.dim: time_coords}) + + return ds_new + def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset: """Drops incomplete DJF seasons within a continuous time series. @@ -1612,7 +1730,9 @@ def _get_df_dt_components( elif self._mode == "group_average": df["month"] = time_coords[f"{self.dim}.month"].values - df = self._process_season_df(df) + custom_seasons = self._season_config.get("custom_seasons") + if custom_seasons is not None: + df = self._map_months_to_custom_seasons(df) if drop_obsolete_cols: df = self._drop_obsolete_columns(df) @@ -1620,33 +1740,6 @@ def _get_df_dt_components( return df - def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Processes a DataFrame of datetime components for the season frequency. - - Parameters - ---------- - df : xr.DataArray - A DataFrame of seasonal datetime components. - - Returns - ------- - pd.DataFrame - A DataFrame of seasonal datetime components. - """ - df_new = df.copy() - custom_seasons = self._season_config.get("custom_seasons") - dec_mode = self._season_config.get("dec_mode") - - if custom_seasons is not None: - df_new = self._map_months_to_custom_seasons(df_new) - df_new = self._shift_spanning_months(df_new) - else: - if dec_mode == "DJF": - df_new = self._shift_decembers(df_new) - - return df_new - def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame: """Maps the month column in the DataFrame to a custom season. @@ -1681,108 +1774,6 @@ def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame: return df_new - def _shift_spanning_months(self, df: pd.DataFrame) -> pd.DataFrame: - """Shifts months in seasons spanning the previous year to the next year. - - A season spans the previous year if it includes the month of "Jan" and - "Jan" is not the first month of the season. For example, let's say we - define ``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]`` to - represent the southern hemisphere growing seasons, "NDJFM". - - ["Nov", "Dec"] are from the previous year since they are listed - before "Jan". - - ["Jan", "Feb", "Mar"] are from the current year. - - Therefore, we need to shift ["Nov", "Dec"] a year forward in order for - xarray to group seasons correctly. Refer to the examples section below - for a visual demonstration. - - Parameters - ---------- - df : pd.Dataframe - The DataFrame of xarray datetime components produced using the - "season" frequency". - - Returns - ------- - pd.DataFrame - The DataFrame of xarray dataetime copmonents with months spanning - previous year shifted over to the next year. - - Examples - -------- - - Before and after shifting months for "NDJFM" seasons: - - >>> # Before shifting months - >>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1), - >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] - - >>> # After shifting months - >>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1), - >>> (2001, "NDJFM", 1), (2001, "NDJFM", 2)] - """ - df_new = df.copy() - custom_seasons = self._season_config["custom_seasons"] - - span_months: List[int] = [] - - # Loop over the custom seasons and get the list of months for the - # current season. Convert those months to their integer representations. - # If 1 ("Jan") is in the list of months and it is NOT the first element, - # then get all elements before it (aka the spanning months). - for months in custom_seasons.values(): # type: ignore - month_nums = [MONTH_STR_TO_INT[month] for month in months] - try: - jan_index = month_nums.index(1) - if jan_index != 0: - span_months = span_months + month_nums[:jan_index] - break - except ValueError: - continue - - if len(span_months) > 0: - df_new.loc[df_new["month"].isin(span_months), "year"] = df_new["year"] + 1 - - return df_new - - def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame: - """Shifts Decembers over to the next year for "DJF" seasons in-place. - - For "DJF" seasons, Decembers must be shifted over to the next year in - order for the xarray groupby operation to correctly label and group the - corresponding time coordinates. If the aren't shifted over, grouping is - incorrectly performed with the native xarray "DJF" season (which is - actually "JFD"). - - Parameters - ---------- - df_season : pd.DataFrame - The DataFrame of xarray datetime components produced using the - "season" frequency. - - Returns - ------- - pd.DataFrame - The DataFrame of xarray datetime components with Decembers shifted - over to the next year. - - Examples - -------- - - Comparison of "JFD" and "DJF" seasons: - - >>> # "JFD" (native xarray behavior) - >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2000, "DJF", 12), - >>> (2001, "DJF", 1), (2001, "DJF", 2)] - - >>> # "DJF" (shifted Decembers) - >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12), - >>> (2001, "DJF", 1), (2001, "DJF", 2)] - """ - df_season.loc[df_season["month"] == 12, "year"] = df_season["year"] + 1 - - return df_season - def _map_seasons_to_mid_months(self, df: pd.DataFrame) -> pd.DataFrame: """Maps the season column values to the integer of its middle month.