diff --git a/pyproject.toml b/pyproject.toml index 108a79d..455fdf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -309,7 +309,10 @@ ignore = [ "N806", "PTH123", "S310", - "PERF401" # don't force list comprehensions + "PERF401", # don't force list comprehensions + "PERF203", # allow try/except in loop + "E501", # line too long + "W505" # doc line too long ] preview = true select = [ diff --git a/src/xsdba/__init__.py b/src/xsdba/__init__.py index 9554388..20483b4 100644 --- a/src/xsdba/__init__.py +++ b/src/xsdba/__init__.py @@ -20,7 +20,16 @@ from __future__ import annotations -from . import adjustment, base, detrending, processing, testing, units, utils +from . import ( + adjustment, + base, + detrending, + processing, + properties, + testing, + units, + utils, +) # , adjustment # from . import adjustment, base, detrending, measures, processing, properties, utils diff --git a/src/xsdba/base.py b/src/xsdba/base.py index fce999a..293876f 100644 --- a/src/xsdba/base.py +++ b/src/xsdba/base.py @@ -8,7 +8,6 @@ import datetime as pydt import itertools from collections.abc import Sequence -from enum import IntEnum from inspect import _empty, signature from typing import Any, Callable, NewType, TypeVar @@ -19,19 +18,10 @@ import pandas as pd import xarray as xr from boltons.funcutils import wraps -from pint import Quantity from xsdba.options import OPTIONS, SDBA_ENCODE_CF -# XC: -#: Type annotation for strings representing full dates (YYYY-MM-DD), may include time. -DateStr = NewType("DateStr", str) - -#: Type annotation for strings representing dates without a year (MM-DD). -DayOfYearStr = NewType("DayOfYearStr", str) - -#: Type annotation for thresholds and other not-exactly-a-variable quantities -Quantified = TypeVar("Quantified", xr.DataArray, str, Quantity) +from .typing import InputKind # ## Base class for the sdba module @@ -116,112 +106,6 @@ def set_dataset(self, ds: xr.Dataset) -> None: self.ds.attrs[self._attribute] = jsonpickle.encode(self) -# XC - - -class InputKind(IntEnum): - """Constants for input parameter kinds. - - For use by external parses to determine what kind of data the indicator expects. - On the creation of an indicator, the appropriate constant is stored in - :py:attr:`xclim.core.indicator.Indicator.parameters`. The integer value is what gets stored in the output - of :py:meth:`xclim.core.indicator.Indicator.json`. - - For developers : for each constant, the docstring specifies the annotation a parameter of an indice function - should use in order to be picked up by the indicator constructor. Notice that we are using the annotation format - as described in `PEP 604 `_, i.e. with '|' indicating a union and without import - objects from `typing`. - """ - - VARIABLE = 0 - """A data variable (DataArray or variable name). - - Annotation : ``xr.DataArray``. - """ - OPTIONAL_VARIABLE = 1 - """An optional data variable (DataArray or variable name). - - Annotation : ``xr.DataArray | None``. The default should be None. - """ - QUANTIFIED = 2 - """A quantity with units, either as a string (scalar), a pint.Quantity (scalar) or a DataArray (with units set). - - Annotation : ``xclim.core.utils.Quantified`` and an entry in the :py:func:`xclim.core.units.declare_units` - decorator. "Quantified" translates to ``str | xr.DataArray | pint.util.Quantity``. - """ - FREQ_STR = 3 - """A string representing an "offset alias", as defined by pandas. - - See the Pandas documentation on :ref:`timeseries.offset_aliases` for a list of valid aliases. - - Annotation : ``str`` + ``freq`` as the parameter name. - """ - NUMBER = 4 - """A number. - - Annotation : ``int``, ``float`` and unions thereof, potentially optional. - """ - STRING = 5 - """A simple string. - - Annotation : ``str`` or ``str | None``. In most cases, this kind of parameter makes sense - with choices indicated in the docstring's version of the annotation with curly braces. - See :ref:`notebooks/extendxclim:Defining new indices`. - """ - DAY_OF_YEAR = 6 - """A date, but without a year, in the MM-DD format. - - Annotation : :py:obj:`xclim.core.utils.DayOfYearStr` (may be optional). - """ - DATE = 7 - """A date in the YYYY-MM-DD format, may include a time. - - Annotation : :py:obj:`xclim.core.utils.DateStr` (may be optional). - """ - NUMBER_SEQUENCE = 8 - """A sequence of numbers - - Annotation : ``Sequence[int]``, ``Sequence[float]`` and unions thereof, may include single ``int`` and ``float``, - may be optional. - """ - BOOL = 9 - """A boolean flag. - - Annotation : ``bool``, may be optional. - """ - DICT = 10 - """A dictionary. - - Annotation : ``dict`` or ``dict | None``, may be optional. - """ - KWARGS = 50 - """A mapping from argument name to value. - - Developers : maps the ``**kwargs``. Please use as little as possible. - """ - DATASET = 70 - """An xarray dataset. - - Developers : as indices only accept DataArrays, this should only be added on the indicator's constructor. - """ - OTHER_PARAMETER = 99 - """An object that fits None of the previous kinds. - - Developers : This is the fallback kind, it will raise an error in xclim's unit tests if used. - """ - - -# XC -def copy_all_attrs(ds: xr.Dataset | xr.DataArray, ref: xr.Dataset | xr.DataArray): - """Copy all attributes of ds to ref, including attributes of shared coordinates, and variables in the case of Datasets.""" - ds.attrs.update(ref.attrs) - extras = ds.variables if isinstance(ds, xr.Dataset) else ds.coords - others = ref.variables if isinstance(ref, xr.Dataset) else ref.coords - for name, var in extras.items(): - if name in others: - var.attrs.update(ref[name].attrs) - - # XC put here to avoid circular import def uses_dask(*das: xr.DataArray | xr.Dataset) -> bool: r"""Evaluate whether dask is installed and array is loaded as a dask array. @@ -1020,3 +904,65 @@ def _apply_on_group(dsblock, **kwargs): return wrapper return _decorator + + +def infer_kind_from_parameter(param) -> InputKind: + """Return the appropriate InputKind constant from an ``inspect.Parameter`` object. + + Parameters + ---------- + param : Parameter + + Notes + ----- + The correspondence between parameters and kinds is documented in :py:class:`xclim.core.utils.InputKind`. + """ + if param.annotation is not _empty: + annot = set( + param.annotation.replace("xarray.", "").replace("xr.", "").split(" | ") + ) + else: + annot = {"no_annotation"} + + if "DataArray" in annot and "None" not in annot and param.default is not None: + return InputKind.VARIABLE + + annot = annot - {"None"} + + if "DataArray" in annot: + return InputKind.OPTIONAL_VARIABLE + + if param.name == "freq": + return InputKind.FREQ_STR + + if param.kind == param.VAR_KEYWORD: + return InputKind.KWARGS + + if annot == {"Quantified"}: + return InputKind.QUANTIFIED + + if "DayOfYearStr" in annot: + return InputKind.DAY_OF_YEAR + + if annot.issubset({"int", "float"}): + return InputKind.NUMBER + + if annot.issubset({"int", "float", "Sequence[int]", "Sequence[float]"}): + return InputKind.NUMBER_SEQUENCE + + if annot.issuperset({"str"}): + return InputKind.STRING + + if annot == {"DateStr"}: + return InputKind.DATE + + if annot == {"bool"}: + return InputKind.BOOL + + if annot == {"dict"}: + return InputKind.DICT + + if annot == {"Dataset"}: + return InputKind.DATASET + + return InputKind.OTHER_PARAMETER diff --git a/src/xsdba/calendar.py b/src/xsdba/calendar.py new file mode 100644 index 0000000..e32cdcd --- /dev/null +++ b/src/xsdba/calendar.py @@ -0,0 +1,1663 @@ +""" +Calendar Handling Utilities +=========================== + +Helper function to handle dates, times and different calendars with xarray. +""" + +from __future__ import annotations + +import datetime as pydt +from collections.abc import Sequence +from typing import Any, TypeVar +from warnings import warn + +import cftime +import numpy as np +import pandas as pd +import xarray as xr +from xarray.coding.cftime_offsets import to_cftime_datetime +from xarray.coding.cftimeindex import CFTimeIndex +from xarray.core import dtypes +from xarray.core.resample import DataArrayResample, DatasetResample + +from .base import uses_dask +from .formatting import update_xsdba_history +from .typing import DayOfYearStr + +__all__ = [ + "DayOfYearStr", + "adjust_doy_calendar", + "build_climatology_bounds", + "climatological_mean_doy", + "common_calendar", + "compare_offsets", + "construct_offset", + "convert_calendar", + "convert_doy", + "date_range", + "date_range_like", + "datetime_to_decimal_year", + "days_in_year", + "days_since_to_doy", + "doy_from_string", + "doy_to_days_since", + "ensure_cftime_array", + "get_calendar", + "interp_calendar", + "is_offset_divisor", + "max_doy", + "parse_offset", + "percentile_doy", + "resample_doy", + "select_time", + "stack_periods", + "time_bnds", + "uniform_calendars", + "unstack_periods", + "within_bnds_doy", +] + +# Maximum day of year in each calendar. +max_doy = { + "standard": 366, + "gregorian": 366, + "proleptic_gregorian": 366, + "julian": 366, + "noleap": 365, + "365_day": 365, + "all_leap": 366, + "366_day": 366, + "360_day": 360, +} + +# Some xclim.core.utils functions made accessible here for backwards compatibility reasons. +datetime_classes = cftime._cftime.DATE_TYPES + +# Names of calendars that have the same number of days for all years +uniform_calendars = ("noleap", "all_leap", "365_day", "366_day", "360_day") + + +DataType = TypeVar("DataType", xr.DataArray, xr.Dataset) + + +def _get_usecf_and_warn(calendar: str, xcfunc: str, xrfunc: str): + if calendar == "default": + calendar = "standard" + use_cftime = False + msg = " and use use_cftime=False instead of calendar='default' to get numpy objects." + else: + use_cftime = None + msg = "" + warn( + f"`xclim` function {xcfunc} is deprecated in favour of {xrfunc} and will be removed in v0.51.0. Please adjust your script{msg}.", + FutureWarning, + ) + return calendar, use_cftime + + +def days_in_year(year: int, calendar: str = "proleptic_gregorian") -> int: + """Deprecated : use :py:func:`xarray.coding.calendar_ops._days_in_year` instead. Passing use_cftime=False instead of calendar='default'. + + Return the number of days in the input year according to the input calendar. + """ + calendar, usecf = _get_usecf_and_warn( + calendar, "days_in_year", "xarray.coding.calendar_ops._days_in_year" + ) + return xr.coding.calendar_ops._days_in_year(year, calendar, use_cftime=usecf) + + +def doy_from_string(doy: DayOfYearStr, year: int, calendar: str) -> int: + """Return the day-of-year corresponding to a "MM-DD" string for a given year and calendar.""" + MM, DD = doy.split("-") + return datetime_classes[calendar](year, int(MM), int(DD)).timetuple().tm_yday + + +def date_range(*args, **kwargs) -> pd.DatetimeIndex | CFTimeIndex: + """Deprecated : use :py:func:`xarray.date_range` instead. Passing use_cftime=False instead of calendar='default'. + + Wrap a Pandas date_range object. + + Uses pd.date_range (if calendar == 'default') or xr.cftime_range (otherwise). + """ + calendar, usecf = _get_usecf_and_warn( + kwargs.pop("calendar", "default"), "date_range", "xarray.date_range" + ) + return xr.date_range(*args, calendar=calendar, use_cftime=usecf, **kwargs) + + +def get_calendar(obj: Any, dim: str = "time") -> str: + """Return the calendar of an object. + + Parameters + ---------- + obj : Any + An object defining some date. + If `obj` is an array/dataset with a datetime coordinate, use `dim` to specify its name. + Values must have either a datetime64 dtype or a cftime dtype. + `obj` can also be a python datetime.datetime, a cftime object or a pandas Timestamp + or an iterable of those, in which case the calendar is inferred from the first value. + dim : str + Name of the coordinate to check (if `obj` is a DataArray or Dataset). + + Raises + ------ + ValueError + If no calendar could be inferred. + + Returns + ------- + str + The Climate and Forecasting (CF) calendar name. + Will always return "standard" instead of "gregorian", following CF conventions 1.9. + """ + if isinstance(obj, (xr.DataArray, xr.Dataset)): + return obj[dim].dt.calendar + elif isinstance(obj, xr.CFTimeIndex): + obj = obj.values[0] + else: + obj = np.take(obj, 0) + # Take zeroth element, overcome cases when arrays or lists are passed. + if isinstance(obj, pydt.datetime): # Also covers pandas Timestamp + return "standard" + if isinstance(obj, cftime.datetime): + if obj.calendar == "gregorian": + return "standard" + return obj.calendar + + raise ValueError(f"Calendar could not be inferred from object of type {type(obj)}.") + + +def common_calendar(calendars: Sequence[str], join="outer") -> str: + """Return a calendar common to all calendars from a list. + + Uses the hierarchy: 360_day < noleap < standard < all_leap. + Returns "default" only if all calendars are "default." + + Parameters + ---------- + calendars: Sequence of string + List of calendar names. + join : {'inner', 'outer'} + The criterion for the common calendar. + + - 'outer': the common calendar is the smallest calendar (in number of days by year) that will include all the + dates of the other calendars. + When converting the data to this calendar, no timeseries will lose elements, but some + might be missing (gaps or NaNs in the series). + - 'inner': the common calendar is the smallest calendar of the list. + When converting the data to this calendar, no timeseries will have missing elements (no gaps or NaNs), + but some might be dropped. + + Examples + -------- + >>> common_calendar(["360_day", "noleap", "default"], join="outer") + 'standard' + >>> common_calendar(["360_day", "noleap", "default"], join="inner") + '360_day' + """ + if all(cal == "default" for cal in calendars): + return "default" + + trans = { + "proleptic_gregorian": "standard", + "gregorian": "standard", + "default": "standard", + "366_day": "all_leap", + "365_day": "noleap", + "julian": "standard", + } + ranks = {"360_day": 0, "noleap": 1, "standard": 2, "all_leap": 3} + calendars = sorted([trans.get(cal, cal) for cal in calendars], key=ranks.get) + + if join == "outer": + return calendars[-1] + if join == "inner": + return calendars[0] + raise NotImplementedError(f"Unknown join criterion `{join}`.") + + +def _convert_doy_date(doy: int, year: int, src, tgt): + fracpart = doy - int(doy) + date = src(year, 1, 1) + pydt.timedelta(days=int(doy - 1)) + try: + same_date = tgt(date.year, date.month, date.day) + except ValueError: + return np.nan + else: + if tgt is pydt.datetime: + return float(same_date.timetuple().tm_yday) + fracpart + return float(same_date.dayofyr) + fracpart + + +def convert_doy( + source: xr.DataArray | xr.Dataset, + target_cal: str, + source_cal: str | None = None, + align_on: str = "year", + missing: Any = np.nan, + dim: str = "time", +) -> xr.DataArray: + """Convert the calendar of day of year (doy) data. + + Parameters + ---------- + source : xr.DataArray or xr.Dataset + Day of year data (range [1, 366], max depending on the calendar). + If a Dataset, the function is mapped to each variables with attribute `is_day_of_year == 1`. + target_cal : str + Name of the calendar to convert to. + source_cal : str, optional + Calendar the doys are in. If not given, uses the "calendar" attribute of `source` or, + if absent, the calendar of its `dim` axis. + align_on : {'date', 'year'} + If 'year' (default), the doy is seen as a "percentage" of the year and is simply rescaled unto the new doy range. + This always result in floating point data, changing the decimal part of the value. + if 'date', the doy is seen as a specific date. See notes. This never changes the decimal part of the value. + missing : Any + If `align_on` is "date" and the new doy doesn't exist in the new calendar, this value is used. + dim : str + Name of the temporal dimension. + """ + if isinstance(source, xr.Dataset): + return source.map( + lambda da: ( + da + if da.attrs.get("is_dayofyear") != 1 + else convert_doy( + da, + target_cal, + source_cal=source_cal, + align_on=align_on, + missing=missing, + dim=dim, + ) + ) + ) + + source_cal = source_cal or source.attrs.get("calendar", get_calendar(source[dim])) + is_calyear = xr.infer_freq(source[dim]) in ("YS-JAN", "Y-DEC", "YE-DEC") + + if is_calyear: # Fast path + year_of_the_doy = source[dim].dt.year + else: # Doy might refer to a date from the year after the timestamp. + year_of_the_doy = source[dim].dt.year + 1 * (source < source[dim].dt.dayofyear) + + if align_on == "year": + if source_cal in ["noleap", "all_leap", "360_day"]: + max_doy_src = max_doy[source_cal] + else: + max_doy_src = xr.apply_ufunc( + xr.coding.calendar_ops._days_in_year, + year_of_the_doy, + vectorize=True, + dask="parallelized", + kwargs={"calendar": source_cal}, + ) + if target_cal in ["noleap", "all_leap", "360_day"]: + max_doy_tgt = max_doy[target_cal] + else: + max_doy_tgt = xr.apply_ufunc( + xr.coding.calendar_ops._days_in_year, + year_of_the_doy, + vectorize=True, + dask="parallelized", + kwargs={"calendar": target_cal}, + ) + new_doy = source.copy(data=source * max_doy_tgt / max_doy_src) + elif align_on == "date": + new_doy = xr.apply_ufunc( + _convert_doy_date, + source, + year_of_the_doy, + vectorize=True, + dask="parallelized", + kwargs={ + "src": datetime_classes[source_cal], + "tgt": datetime_classes[target_cal], + }, + ) + else: + raise NotImplementedError('"align_on" must be one of "date" or "year".') + return new_doy.assign_attrs(is_dayofyear=np.int32(1), calendar=target_cal) + + +def convert_calendar( + source: xr.DataArray | xr.Dataset, + target: xr.DataArray | str, + align_on: str | None = None, + missing: Any | None = None, + doy: bool | str = False, + dim: str = "time", +) -> DataType: + """Deprecated : use :py:meth:`xarray.Dataset.convert_calendar` or :py:meth:`xarray.DataArray.convert_calendar` + or :py:func:`xarray.coding.calendar_ops.convert_calendar` instead. Passing use_cftime=False instead of calendar='default'. + + Convert a DataArray/Dataset to another calendar using the specified method. + """ + if isinstance(target, xr.DataArray): + raise NotImplementedError( + "In `xclim` v0.50.0, `convert_calendar` is a direct copy of `xarray.coding.calendar_ops.convert_calendar`. " + "To retrieve the previous behaviour with target as a DataArray, convert the source first then reindex to the target." + ) + if doy is not False: + raise NotImplementedError( + "In `xclim` v0.50.0, `convert_calendar` is a direct copy of `xarray.coding.calendar_ops.convert_calendar`. " + "To retrieve the previous behaviour of doy=True, do convert_doy(obj, target_cal).convert_cal(target_cal)." + ) + target, _usecf = _get_usecf_and_warn( + target, + "convert_calendar", + "xarray.coding.calendar_ops.convert_calendar or obj.convert_calendar", + ) + return xr.coding.calendar_ops.convert_calendar( + source, target, dim=dim, align_on=align_on, missing=missing + ) + + +def interp_calendar( + source: xr.DataArray | xr.Dataset, + target: xr.DataArray, + dim: str = "time", +) -> xr.DataArray | xr.Dataset: + """Deprecated : use :py:func:`xarray.coding.calendar_ops.interp_calendar` instead. + + Interpolates a DataArray/Dataset to another calendar based on decimal year measure. + """ + _, _ = _get_usecf_and_warn( + "standard", "interp_calendar", "xarray.coding.calendar_ops.interp_calendar" + ) + return xr.coding.calendar_ops.interp_calendar(source, target, dim=dim) + + +def ensure_cftime_array(time: Sequence) -> np.ndarray | Sequence[cftime.datetime]: + """Convert an input 1D array to a numpy array of cftime objects. + + Python's datetime are converted to cftime.DatetimeGregorian ("standard" calendar). + + Parameters + ---------- + time : sequence + A 1D array of datetime-like objects. + + Returns + ------- + np.ndarray + + Raises + ------ + ValueError: When unable to cast the input. + """ + if isinstance(time, xr.DataArray): + time = time.indexes["time"] + elif isinstance(time, np.ndarray): + time = pd.DatetimeIndex(time) + if isinstance(time, xr.CFTimeIndex): + return time.values + if isinstance(time[0], cftime.datetime): + return time + if isinstance(time[0], pydt.datetime): + return np.array( + [cftime.DatetimeGregorian(*ele.timetuple()[:6]) for ele in time] + ) + raise ValueError("Unable to cast array to cftime dtype") + + +def datetime_to_decimal_year(times: xr.DataArray, calendar: str = "") -> xr.DataArray: + """Deprecated : use :py:func:`xarray.coding.calendar_ops_datetime_to_decimal_year` instead. + + Convert a datetime xr.DataArray to decimal years according to its calendar or the given one. + """ + _, _ = _get_usecf_and_warn( + "standard", + "datetime_to_decimal_year", + "xarray.coding.calendar_ops._datetime_to_decimal_year", + ) + return xr.coding.calendar_ops._datetime_to_decimal_year( + times, dim="time", calendar=calendar + ) + + +@update_xsdba_history +def percentile_doy( + arr: xr.DataArray, + window: int = 5, + per: float | Sequence[float] = 10.0, + alpha: float = 1.0 / 3.0, + beta: float = 1.0 / 3.0, + copy: bool = True, +) -> xr.DataArray: + """Percentile value for each day of the year. + + Return the climatological percentile over a moving window around each day of the year. Different quantile estimators + can be used by specifying `alpha` and `beta` according to specifications given by :cite:t:`hyndman_sample_1996`. + The default definition corresponds to method 8, which meets multiple desirable statistical properties for sample + quantiles. Note that `numpy.percentile` corresponds to method 7, with alpha and beta set to 1. + + Parameters + ---------- + arr : xr.DataArray + Input data, a daily frequency (or coarser) is required. + window : int + Number of time-steps around each day of the year to include in the calculation. + per : float or sequence of floats + Percentile(s) between [0, 100] + alpha : float + Plotting position parameter. + beta : float + Plotting position parameter. + copy : bool + If True (default) the input array will be deep-copied. It's a necessary step + to keep the data integrity, but it can be costly. + If False, no copy is made of the input array. It will be mutated and rendered + unusable but performances may significantly improve. + Put this flag to False only if you understand the consequences. + + Returns + ------- + xr.DataArray + The percentiles indexed by the day of the year. + For calendars with 366 days, percentiles of doys 1-365 are interpolated to the 1-366 range. + + References + ---------- + :cite:cts:`hyndman_sample_1996` + """ + from .utils import calc_perc # pylint: disable=import-outside-toplevel + + # Ensure arr sampling frequency is daily or coarser + # but cowardly escape the non-inferrable case. + if compare_offsets(xr.infer_freq(arr.time) or "D", "<", "D"): + raise ValueError("input data should have daily or coarser frequency") + + rr = arr.rolling(min_periods=1, center=True, time=window).construct("window") + + crd = xr.Coordinates.from_pandas_multiindex( + pd.MultiIndex.from_arrays( + (rr.time.dt.year.values, rr.time.dt.dayofyear.values), + names=("year", "dayofyear"), + ), + "time", + ) + rr = rr.drop_vars("time").assign_coords(crd) + rrr = rr.unstack("time").stack(stack_dim=("year", "window")) + + if rrr.chunks is not None and len(rrr.chunks[rrr.get_axis_num("stack_dim")]) > 1: + # Preserve chunk size + time_chunks_count = len(arr.chunks[arr.get_axis_num("time")]) + doy_chunk_size = np.ceil(len(rrr.dayofyear) / (window * time_chunks_count)) + rrr = rrr.chunk(dict(stack_dim=-1, dayofyear=doy_chunk_size)) + + if np.isscalar(per): + per = [per] + + p = xr.apply_ufunc( + calc_perc, + rrr, + input_core_dims=[["stack_dim"]], + output_core_dims=[["percentiles"]], + keep_attrs=True, + kwargs=dict(percentiles=per, alpha=alpha, beta=beta, copy=copy), + dask="parallelized", + output_dtypes=[rrr.dtype], + dask_gufunc_kwargs=dict(output_sizes={"percentiles": len(per)}), + ) + p = p.assign_coords(percentiles=xr.DataArray(per, dims=("percentiles",))) + + # The percentile for the 366th day has a sample size of 1/4 of the other days. + # To have the same sample size, we interpolate the percentile from 1-365 doy range to 1-366 + if p.dayofyear.max() == 366: + p = adjust_doy_calendar(p.sel(dayofyear=(p.dayofyear < 366)), arr) + + p.attrs.update(arr.attrs.copy()) + + # Saving percentile attributes + p.attrs["climatology_bounds"] = build_climatology_bounds(arr) + p.attrs["window"] = window + p.attrs["alpha"] = alpha + p.attrs["beta"] = beta + return p.rename("per") + + +def build_climatology_bounds(da: xr.DataArray) -> list[str]: + """Build the climatology_bounds property with the start and end dates of input data. + + Parameters + ---------- + da : xr.DataArray + The input data. + Must have a time dimension. + """ + n = len(da.time) + return da.time[0 :: n - 1].dt.strftime("%Y-%m-%d").values.tolist() + + +def compare_offsets(freqA: str, op: str, freqB: str) -> bool: + """Compare offsets string based on their approximate length, according to a given operator. + + Offset are compared based on their length approximated for a period starting + after 1970-01-01 00:00:00. If the offsets are from the same category (same first letter), + only the multiplier prefix is compared (QS-DEC == QS-JAN, MS < 2MS). + "Business" offsets are not implemented. + + Parameters + ---------- + freqA : str + RHS Date offset string ('YS', '1D', 'QS-DEC', ...) + op : {'<', '<=', '==', '>', '>=', '!='} + Operator to use. + freqB : str + LHS Date offset string ('YS', '1D', 'QS-DEC', ...) + + Returns + ------- + bool + freqA op freqB + """ + from ..indices.generic import get_op # pylint: disable=import-outside-toplevel + + # Get multiplier and base frequency + t_a, b_a, _, _ = parse_offset(freqA) + t_b, b_b, _, _ = parse_offset(freqB) + + if b_a != b_b: + # Different base freq, compare length of first period after beginning of time. + t = pd.date_range("1970-01-01T00:00:00.000", periods=2, freq=freqA) + t_a = (t[1] - t[0]).total_seconds() + t = pd.date_range("1970-01-01T00:00:00.000", periods=2, freq=freqB) + t_b = (t[1] - t[0]).total_seconds() + # else Same base freq, compare multiplier only. + + return get_op(op)(t_a, t_b) + + +def parse_offset(freq: str) -> tuple[int, str, bool, str | None]: + """Parse an offset string. + + Parse a frequency offset and, if needed, convert to cftime-compatible components. + + Parameters + ---------- + freq : str + Frequency offset. + + Returns + ------- + multiplier : int + Multiplier of the base frequency. "[n]W" is always replaced with "[7n]D", + as xarray doesn't support "W" for cftime indexes. + offset_base : str + Base frequency. + is_start_anchored : bool + Whether coordinates of this frequency should correspond to the beginning of the period (`True`) + or its end (`False`). Can only be False when base is Y, Q or M; in other words, xclim assumes frequencies finer + than monthly are all start-anchored. + anchor : str, optional + Anchor date for bases Y or Q. As xarray doesn't support "W", + neither does xclim (anchor information is lost when given). + + """ + # Useful to raise on invalid freqs, convert Y to A and get default anchor (A, Q) + offset = pd.tseries.frequencies.to_offset(freq) + base, *anchor = offset.name.split("-") + anchor = anchor[0] if len(anchor) > 0 else None + start = ("S" in base) or (base[0] not in "AYQM") + if base.endswith("S") or base.endswith("E"): + base = base[:-1] + mult = offset.n + if base == "W": + mult = 7 * mult + base = "D" + anchor = None + return mult, base, start, anchor + + +def construct_offset(mult: int, base: str, start_anchored: bool, anchor: str | None): + """Reconstruct an offset string from its parts. + + Parameters + ---------- + mult : int + The period multiplier (>= 1). + base : str + The base period string (one char). + start_anchored : bool + If True and base in [Y, Q, M], adds the "S" flag, False add "E". + anchor : str, optional + The month anchor of the offset. Defaults to JAN for bases YS and QS and to DEC for bases YE and QE. + + Returns + ------- + str + An offset string, conformant to pandas-like naming conventions. + + Notes + ----- + This provides the mirror opposite functionality of :py:func:`parse_offset`. + """ + start = ("S" if start_anchored else "E") if base in "YAQM" else "" + if anchor is None and base in "AQY": + anchor = "JAN" if start_anchored else "DEC" + return ( + f"{mult if mult > 1 else ''}{base}{start}{'-' if anchor else ''}{anchor or ''}" + ) + + +def is_offset_divisor(divisor: str, offset: str): + """Check that divisor is a divisor of offset. + + A frequency is a "divisor" of another if a whole number of periods of the + former fit within a single period of the latter. + + Parameters + ---------- + divisor : str + The divisor frequency. + offset: str + The large frequency. + + Returns + ------- + bool + + Examples + -------- + >>> is_offset_divisor("QS-Jan", "YS") + True + >>> is_offset_divisor("QS-DEC", "YS-JUL") + False + >>> is_offset_divisor("D", "M") + True + """ + if compare_offsets(divisor, ">", offset): + return False + # Reconstruct offsets anchored at the start of the period + # to have comparable quantities, also get "offset" objects + mA, bA, _sA, aA = parse_offset(divisor) + offAs = pd.tseries.frequencies.to_offset(construct_offset(mA, bA, True, aA)) + + mB, bB, _sB, aB = parse_offset(offset) + offBs = pd.tseries.frequencies.to_offset(construct_offset(mB, bB, True, aB)) + tB = pd.date_range("1970-01-01T00:00:00", freq=offBs, periods=13) + + if bA in ["W", "D", "h", "min", "s", "ms", "us", "ms"] or bB in [ + "W", + "D", + "h", + "min", + "s", + "ms", + "us", + "ms", + ]: + # Simple length comparison is sufficient for submonthly freqs + # In case one of bA or bB is > W, we test many to be sure. + tA = pd.date_range("1970-01-01T00:00:00", freq=offAs, periods=13) + return np.all( + (np.diff(tB)[:, np.newaxis] / np.diff(tA)[np.newaxis, :]) % 1 == 0 + ) + + # else, we test alignment with some real dates + # If both fall on offAs, then is means divisor is aligned with offset at those dates + # if N=13 is True, then it is always True + # As divisor <= offset, this means divisor is a "divisor" of offset. + return all(offAs.is_on_offset(d) for d in tB) + + +def _interpolate_doy_calendar( + source: xr.DataArray, doy_max: int, doy_min: int = 1 +) -> xr.DataArray: + """Interpolate from one set of dayofyear range to another. + + Interpolate an array defined over a `dayofyear` range (say 1 to 360) to another `dayofyear` range (say 1 + to 365). + + Parameters + ---------- + source : xr.DataArray + Array with `dayofyear` coordinates. + doy_max : int + The largest day of the year allowed by calendar. + doy_min : int + The smallest day of the year in the output. + This parameter is necessary when the target time series does not span over a full year (e.g. JJA season). + Default is 1. + + Returns + ------- + xr.DataArray + Interpolated source array over coordinates spanning the target `dayofyear` range. + """ + if "dayofyear" not in source.coords.keys(): + raise AttributeError("Source should have `dayofyear` coordinates.") + + # Interpolate to fill na values + da = source + if uses_dask(source): + # interpolate_na cannot run on chunked dayofyear. + da = source.chunk(dict(dayofyear=-1)) + filled_na = da.interpolate_na(dim="dayofyear") + + # Interpolate to target dayofyear range + filled_na.coords["dayofyear"] = np.linspace( + start=doy_min, stop=doy_max, num=len(filled_na.coords["dayofyear"]) + ) + + return filled_na.interp(dayofyear=range(doy_min, doy_max + 1)) + + +def adjust_doy_calendar( + source: xr.DataArray, target: xr.DataArray | xr.Dataset +) -> xr.DataArray: + """Interpolate from one set of dayofyear range to another calendar. + + Interpolate an array defined over a `dayofyear` range (say 1 to 360) to another `dayofyear` range (say 1 to 365). + + Parameters + ---------- + source : xr.DataArray + Array with `dayofyear` coordinate. + target : xr.DataArray or xr.Dataset + Array with `time` coordinate. + + Returns + ------- + xr.DataArray + Interpolated source array over coordinates spanning the target `dayofyear` range. + """ + max_target_doy = int(target.time.dt.dayofyear.max()) + min_target_doy = int(target.time.dt.dayofyear.min()) + + def has_same_calendar(): + # case of full year (doys between 1 and 360|365|366) + return source.dayofyear.max() == max_doy[get_calendar(target)] + + def has_similar_doys(): + # case of partial year (e.g. JJA, doys between 152|153 and 243|244) + return ( + source.dayofyear.min == min_target_doy + and source.dayofyear.max == max_target_doy + ) + + if has_same_calendar() or has_similar_doys(): + return source + return _interpolate_doy_calendar(source, max_target_doy, min_target_doy) + + +def resample_doy(doy: xr.DataArray, arr: xr.DataArray | xr.Dataset) -> xr.DataArray: + """Create a temporal DataArray where each day takes the value defined by the day-of-year. + + Parameters + ---------- + doy : xr.DataArray + Array with `dayofyear` coordinate. + arr : xr.DataArray or xr.Dataset + Array with `time` coordinate. + + Returns + ------- + xr.DataArray + An array with the same dimensions as `doy`, except for `dayofyear`, which is + replaced by the `time` dimension of `arr`. Values are filled according to the + day of year value in `doy`. + """ + if "dayofyear" not in doy.coords: + raise AttributeError("Source should have `dayofyear` coordinates.") + + # Adjust calendar + adoy = adjust_doy_calendar(doy, arr) + + out = adoy.rename(dayofyear="time").reindex(time=arr.time.dt.dayofyear) + out["time"] = arr.time + + return out + + +def time_bnds( # noqa: C901 + time: ( + xr.DataArray + | xr.Dataset + | CFTimeIndex + | pd.DatetimeIndex + | DataArrayResample + | DatasetResample + ), + freq: str | None = None, + precision: str | None = None, +): + """Find the time bounds for a datetime index. + + As we are using datetime indices to stand in for period indices, assumptions regarding the period + are made based on the given freq. + + Parameters + ---------- + time : DataArray, Dataset, CFTimeIndex, DatetimeIndex, DataArrayResample or DatasetResample + Object which contains a time index as a proxy representation for a period index. + freq : str, optional + String specifying the frequency/offset such as 'MS', '2D', or '3min' + If not given, it is inferred from the time index, which means that index must + have at least three elements. + precision : str, optional + A timedelta representation that :py:class:`pandas.Timedelta` understands. + The time bounds will be correct up to that precision. If not given, + 1 ms ("1U") is used for CFtime indexes and 1 ns ("1N") for numpy datetime64 indexes. + + Returns + ------- + DataArray + The time bounds: start and end times of the periods inferred from the time index and a frequency. + It has the original time index along it's `time` coordinate and a new `bnds` coordinate. + The dtype and calendar of the array are the same as the index. + + Notes + ----- + xclim assumes that indexes for greater-than-day frequencies are "floored" down to a daily resolution. + For example, the coordinate "2000-01-31 00:00:00" with a "ME" frequency is assumed to mean a period + going from "2000-01-01 00:00:00" to "2000-01-31 23:59:59.999999". + + Similarly, it assumes that daily and finer frequencies yield indexes pointing to the period's start. + So "2000-01-31 00:00:00" with a "3h" frequency, means a period going from "2000-01-31 00:00:00" to + "2000-01-31 02:59:59.999999". + """ + if isinstance(time, (xr.DataArray, xr.Dataset)): + time = time.indexes[time.name] + elif isinstance(time, (DataArrayResample, DatasetResample)): + for grouper in time.groupers: + if "time" in grouper.dims: + datetime = grouper.unique_coord.data + freq = freq or grouper.grouper.freq + if datetime.dtype == "O": + time = xr.CFTimeIndex(datetime) + else: + time = pd.DatetimeIndex(datetime) + break + + else: + raise ValueError( + 'Got object resampled along another dimension than "time".' + ) + + if freq is None and hasattr(time, "freq"): + freq = time.freq + if freq is None: + freq = xr.infer_freq(time) + elif hasattr(freq, "freqstr"): + # When freq is a Offset + freq = freq.freqstr + + freq_base, freq_is_start = parse_offset(freq)[1:3] + + # Normalizing without using `.normalize` because cftime doesn't have it + floor = {"hour": 0, "minute": 0, "second": 0, "microsecond": 0, "nanosecond": 0} + if freq_base in ["h", "min", "s", "ms", "us", "ns"]: + floor.pop("hour") + if freq_base in ["min", "s", "ms", "us", "ns"]: + floor.pop("minute") + if freq_base in ["s", "ms", "us", "ns"]: + floor.pop("second") + if freq_base in ["us", "ns"]: + floor.pop("microsecond") + if freq_base == "ns": + floor.pop("nanosecond") + + if isinstance(time, xr.CFTimeIndex): + period = xr.coding.cftime_offsets.to_offset(freq) + is_on_offset = period.onOffset + eps = pd.Timedelta(precision or "1us").to_pytimedelta() + day = pd.Timedelta("1D").to_pytimedelta() + floor.pop("nanosecond") # unsupported by cftime + else: + period = pd.tseries.frequencies.to_offset(freq) + is_on_offset = period.is_on_offset + eps = pd.Timedelta(precision or "1ns") + day = pd.Timedelta("1D") + + def shift_time(t): + if not is_on_offset(t): + if freq_is_start: + t = period.rollback(t) + else: + t = period.rollforward(t) + return t.replace(**floor) + + time_real = list(map(shift_time, time)) + + cls = time.__class__ + if freq_is_start: + tbnds = [cls(time_real), cls([t + period - eps for t in time_real])] + else: + tbnds = [ + cls([t - period + day for t in time_real]), + cls([t + day - eps for t in time_real]), + ] + return xr.DataArray( + tbnds, dims=("bnds", "time"), coords={"time": time}, name="time_bnds" + ).transpose() + + +def climatological_mean_doy( + arr: xr.DataArray, window: int = 5 +) -> tuple[xr.DataArray, xr.DataArray]: + """Calculate the climatological mean and standard deviation for each day of the year. + + Parameters + ---------- + arr : xarray.DataArray + Input array. + window : int + Window size in days. + + Returns + ------- + xarray.DataArray, xarray.DataArray + Mean and standard deviation. + """ + rr = arr.rolling(min_periods=1, center=True, time=window).construct("window") + + # Create empty percentile array + g = rr.groupby("time.dayofyear") + + m = g.mean(["time", "window"]) + s = g.std(["time", "window"]) + + return m, s + + +def within_bnds_doy( + arr: xr.DataArray, *, low: xr.DataArray, high: xr.DataArray +) -> xr.DataArray: + """Return whether array values are within bounds for each day of the year. + + Parameters + ---------- + arr : xarray.DataArray + Input array. + low : xarray.DataArray + Low bound with dayofyear coordinate. + high : xarray.DataArray + High bound with dayofyear coordinate. + + Returns + ------- + xarray.DataArray + """ + low = resample_doy(low, arr) + high = resample_doy(high, arr) + return (low < arr) * (arr < high) + + +def _doy_days_since_doys( + base: xr.DataArray, start: DayOfYearStr | None = None +) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray]: + """Calculate dayofyear to days since, or the inverse. + + Parameters + ---------- + base : xr.DataArray + 1D time coordinate. + start : DayOfYearStr, optional + A date to compute the offset relative to. If note given, start_doy is the same as base_doy. + + Returns + ------- + base_doy : xr.DataArray + Day of year for each element in base. + start_doy : xr.DataArray + Day of year of the "start" date. + The year used is the one the start date would take as a doy for the corresponding base element. + doy_max : xr.DataArray + Number of days (maximum doy) for the year of each value in base. + """ + calendar = get_calendar(base) + + base_doy = base.dt.dayofyear + + doy_max = xr.apply_ufunc( + xr.coding.calendar_ops._days_in_year, + base.dt.year, + vectorize=True, + kwargs={"calendar": calendar}, + ) + + if start is not None: + mm, dd = map(int, start.split("-")) + starts = xr.apply_ufunc( + lambda y: datetime_classes[calendar](y, mm, dd), + base.dt.year, + vectorize=True, + ) + start_doy = starts.dt.dayofyear + start_doy = start_doy.where(start_doy >= base_doy, start_doy + doy_max) + else: + start_doy = base_doy + + return base_doy, start_doy, doy_max + + +def doy_to_days_since( + da: xr.DataArray, + start: DayOfYearStr | None = None, + calendar: str | None = None, +) -> xr.DataArray: + """Convert day-of-year data to days since a given date. + + This is useful for computing meaningful statistics on doy data. + + Parameters + ---------- + da : xr.DataArray + Array of "day-of-year", usually int dtype, must have a `time` dimension. + Sampling frequency should be finer or similar to yearly and coarser than daily. + start : date of year str, optional + A date in "MM-DD" format, the base day of the new array. If None (default), the `time` axis is used. + Passing `start` only makes sense if `da` has a yearly sampling frequency. + calendar : str, optional + The calendar to use when computing the new interval. + If None (default), the calendar attribute of the data or of its `time` axis is used. + All time coordinates of `da` must exist in this calendar. + No check is done to ensure doy values exist in this calendar. + + Returns + ------- + xr.DataArray + Same shape as `da`, int dtype, day-of-year data translated to a number of days since a given date. + If start is not None, there might be negative values. + + Notes + ----- + The time coordinates of `da` are considered as the START of the period. For example, a doy value of + 350 with a timestamp of '2020-12-31' is understood as '2021-12-16' (the 350th day of 2021). + Passing `start=None`, will use the time coordinate as the base, so in this case the converted value + will be 350 "days since time coordinate". + + Examples + -------- + >>> from xarray import DataArray + >>> time = date_range("2020-07-01", "2021-07-01", freq="AS-JUL") + >>> # July 8th 2020 and Jan 2nd 2022 + >>> da = DataArray([190, 2], dims=("time",), coords={"time": time}) + >>> # Convert to days since Oct. 2nd, of the data's year. + >>> doy_to_days_since(da, start="10-02").values + array([-86, 92]) + """ + base_calendar = get_calendar(da) + calendar = calendar or da.attrs.get("calendar", base_calendar) + dac = da.convert_calendar(calendar) + + base_doy, start_doy, doy_max = _doy_days_since_doys(dac.time, start) + + # 2cases: + # val is a day in the same year as its index : da - offset + # val is a day in the next year : da + doy_max - offset + out = xr.where(dac > base_doy, dac, dac + doy_max) - start_doy + out.attrs.update(da.attrs) + if start is not None: + out.attrs.update(units=f"days after {start}") + else: + starts = np.unique(out.time.dt.strftime("%m-%d")) + if len(starts) == 1: + out.attrs.update(units=f"days after {starts[0]}") + else: + out.attrs.update(units="days after time coordinate") + + out.attrs.pop("is_dayofyear", None) + out.attrs.update(calendar=calendar) + return out.convert_calendar(base_calendar).rename(da.name) + + +def days_since_to_doy( + da: xr.DataArray, + start: DayOfYearStr | None = None, + calendar: str | None = None, +) -> xr.DataArray: + """Reverse the conversion made by :py:func:`doy_to_days_since`. + + Converts data given in days since a specific date to day-of-year. + + Parameters + ---------- + da : xr.DataArray + The result of :py:func:`doy_to_days_since`. + start : DateOfYearStr, optional + `da` is considered as days since that start date (in the year of the time index). + If None (default), it is read from the attributes. + calendar : str, optional + Calendar the "days since" were computed in. + If None (default), it is read from the attributes. + + Returns + ------- + xr.DataArray + Same shape as `da`, values as `day of year`. + + Examples + -------- + >>> from xarray import DataArray + >>> time = date_range("2020-07-01", "2021-07-01", freq="AS-JUL") + >>> da = DataArray( + ... [-86, 92], + ... dims=("time",), + ... coords={"time": time}, + ... attrs={"units": "days since 10-02"}, + ... ) + >>> days_since_to_doy(da).values + array([190, 2]) + """ + if start is None: + unitstr = da.attrs.get("units", " time coordinate").split(" ", maxsplit=2)[-1] + if unitstr != "time coordinate": + start = unitstr + + base_calendar = get_calendar(da) + calendar = calendar or da.attrs.get("calendar", base_calendar) + + dac = da.convert_calendar(calendar) + + _, start_doy, doy_max = _doy_days_since_doys(dac.time, start) + + # 2cases: + # val is a day in the same year as its index : da + offset + # val is a day in the next year : da + offset - doy_max + out = dac + start_doy + out = xr.where(out > doy_max, out - doy_max, out) + + out.attrs.update( + {k: v for k, v in da.attrs.items() if k not in ["units", "calendar"]} + ) + out.attrs.update(calendar=calendar, is_dayofyear=1) + return out.convert_calendar(base_calendar).rename(da.name) + + +def date_range_like(source: xr.DataArray, calendar: str) -> xr.DataArray: + """Deprecated : use :py:func:`xarray.date_range_like` instead. Passing use_cftime=False instead of calendar='default'. + + Generate a datetime array with the same frequency, start and end as another one, but in a different calendar. + """ + calendar, usecf = _get_usecf_and_warn( + calendar, "date_range_like", "xarray.date_range_like" + ) + return xr.coding.calendar_ops.date_range_like( + source=source, calendar=calendar, use_cftime=usecf + ) + + +def select_time( + da: xr.DataArray | xr.Dataset, + drop: bool = False, + season: str | Sequence[str] | None = None, + month: int | Sequence[int] | None = None, + doy_bounds: tuple[int, int] | None = None, + date_bounds: tuple[str, str] | None = None, + include_bounds: bool | tuple[bool, bool] = True, +) -> DataType: + """Select entries according to a time period. + + This conveniently improves xarray's :py:meth:`xarray.DataArray.where` and + :py:meth:`xarray.DataArray.sel` with fancier ways of indexing over time elements. + In addition to the data `da` and argument `drop`, only one of `season`, `month`, + `doy_bounds` or `date_bounds` may be passed. + + Parameters + ---------- + da : xr.DataArray or xr.Dataset + Input data. + drop : bool + Whether to drop elements outside the period of interest or to simply mask them (default). + season : string or sequence of strings, optional + One or more of 'DJF', 'MAM', 'JJA' and 'SON'. + month : integer or sequence of integers, optional + Sequence of month numbers (January = 1 ... December = 12) + doy_bounds : 2-tuple of integers, optional + The bounds as (start, end) of the period of interest expressed in day-of-year, integers going from + 1 (January 1st) to 365 or 366 (December 31st). + If calendar awareness is needed, consider using ``date_bounds`` instead. + date_bounds : 2-tuple of strings, optional + The bounds as (start, end) of the period of interest expressed as dates in the month-day (%m-%d) format. + include_bounds : bool or 2-tuple of booleans + Whether the bounds of `doy_bounds` or `date_bounds` should be inclusive or not. + Either one value for both or a tuple. Default is True, meaning bounds are inclusive. + + Returns + ------- + xr.DataArray or xr.Dataset + Selected input values. If ``drop=False``, this has the same length as ``da`` (along dimension 'time'), + but with masked (NaN) values outside the period of interest. + + Examples + -------- + Keep only the values of fall and spring. + + >>> ds = open_dataset("ERA5/daily_surface_cancities_1990-1993.nc") + >>> ds.time.size + 1461 + >>> out = select_time(ds, drop=True, season=["MAM", "SON"]) + >>> out.time.size + 732 + + Or all values between two dates (included). + + >>> out = select_time(ds, drop=True, date_bounds=("02-29", "03-02")) + >>> out.time.values + array(['1990-03-01T00:00:00.000000000', '1990-03-02T00:00:00.000000000', + '1991-03-01T00:00:00.000000000', '1991-03-02T00:00:00.000000000', + '1992-02-29T00:00:00.000000000', '1992-03-01T00:00:00.000000000', + '1992-03-02T00:00:00.000000000', '1993-03-01T00:00:00.000000000', + '1993-03-02T00:00:00.000000000'], dtype='datetime64[ns]') + """ + N = sum(arg is not None for arg in [season, month, doy_bounds, date_bounds]) + if N > 1: + raise ValueError(f"Only one method of indexing may be given, got {N}.") + + if N == 0: + return da + + def _get_doys(_start, _end, _inclusive): + if _start <= _end: + _doys = np.arange(_start, _end + 1) + else: + _doys = np.concatenate((np.arange(_start, 367), np.arange(0, _end + 1))) + if not _inclusive[0]: + _doys = _doys[1:] + if not _inclusive[1]: + _doys = _doys[:-1] + return _doys + + if isinstance(include_bounds, bool): + include_bounds = (include_bounds, include_bounds) + + if season is not None: + if isinstance(season, str): + season = [season] + mask = da.time.dt.season.isin(season) + + elif month is not None: + if isinstance(month, int): + month = [month] + mask = da.time.dt.month.isin(month) + + elif doy_bounds is not None: + mask = da.time.dt.dayofyear.isin(_get_doys(*doy_bounds, include_bounds)) + + elif date_bounds is not None: + # This one is a bit trickier. + start, end = date_bounds + time = da.time + calendar = get_calendar(time) + if calendar not in uniform_calendars: + # For non-uniform calendars, we can't simply convert dates to doys + # conversion to all_leap is safe for all non-uniform calendar as it doesn't remove any date. + time = time.convert_calendar("all_leap") + # values of time are the _old_ calendar + # and the new calendar is in the coordinate + calendar = "all_leap" + + # Get doy of date, this is now safe because the calendar is uniform. + doys = _get_doys( + to_cftime_datetime(f"2000-{start}", calendar).dayofyr, + to_cftime_datetime(f"2000-{end}", calendar).dayofyr, + include_bounds, + ) + mask = time.time.dt.dayofyear.isin(doys) + # Needed if we converted calendar, this puts back the correct coord + mask["time"] = da.time + + else: + raise ValueError( + "Must provide either `season`, `month`, `doy_bounds` or `date_bounds`." + ) + + return da.where(mask, drop=drop) + + +def _month_is_first_period_month(time, freq): + """Returns True if the given time is from the first month of freq.""" + if isinstance(time, cftime.datetime): + frq_monthly = xr.coding.cftime_offsets.to_offset("MS") + frq = xr.coding.cftime_offsets.to_offset(freq) + if frq_monthly.onOffset(time): + return frq.onOffset(time) + return frq.onOffset(frq_monthly.rollback(time)) + # Pandas + time = pd.Timestamp(time) + frq_monthly = pd.tseries.frequencies.to_offset("MS") + frq = pd.tseries.frequencies.to_offset(freq) + if frq_monthly.is_on_offset(time): + return frq.is_on_offset(time) + return frq.is_on_offset(frq_monthly.rollback(time)) + + +def stack_periods( + da: xr.Dataset | xr.DataArray, + window: int = 30, + stride: int | None = None, + min_length: int | None = None, + freq: str = "YS", + dim: str = "period", + start: str = "1970-01-01", + align_days: bool = True, + pad_value=dtypes.NA, +): + """Construct a multi-period array. + + Stack different equal-length periods of `da` into a new 'period' dimension. + + This is similar to ``da.rolling(time=window).construct(dim, stride=stride)``, but adapted for arguments + in terms of a base temporal frequency that might be non-uniform (years, months, etc.). + It is reversible for some cases (see `stride`). + A rolling-construct method will be much more performant for uniform periods (days, weeks). + + Parameters + ---------- + da : xr.Dataset or xr.DataArray + An xarray object with a `time` dimension. + Must have a uniform timestep length. + Output might be strange if this does not use a uniform calendar (noleap, 360_day, all_leap). + window : int + The length of the moving window as a multiple of ``freq``. + stride : int, optional + At which interval to take the windows, as a multiple of ``freq``. + For the operation to be reversible with :py:func:`unstack_periods`, it must divide `window` into an odd number of parts. + Default is `window` (no overlap between periods). + min_length : int, optional + Windows shorter than this are not included in the output. + Given as a multiple of ``freq``. Default is ``window`` (every window must be complete). + Similar to the ``min_periods`` argument of ``da.rolling``. + If ``freq`` is annual or quarterly and ``min_length == ``window``, the first period is considered complete + if the first timestep is in the first month of the period. + freq : str + Units of ``window``, ``stride`` and ``min_length``, as a frequency string. + Must be larger or equal to the data's sampling frequency. + Note that this function offers an easier interface for non-uniform period (like years or months) + but is much slower than a rolling-construct method. + dim : str + The new dimension name. + start : str + The `start` argument passed to :py:func:`xarray.date_range` to generate the new placeholder + time coordinate. + align_days : bool + When True (default), an error is raised if the output would have unaligned days across periods. + If `freq = 'YS'`, day-of-year alignment is checked and if `freq` is "MS" or "QS", we check day-in-month. + Only uniform-calendar will pass the test for `freq='YS'`. + For other frequencies, only the `360_day` calendar will work. + This check is ignored if the sampling rate of the data is coarser than "D". + pad_value : Any + When some periods are shorter than others, this value is used to pad them at the end. + Passed directly as argument ``fill_value`` to :py:func:`xarray.concat`, + the default is the same as on that function. + + Return + ------ + xr.DataArray + A DataArray with a new `period` dimension and a `time` dimension with the length of the longest window. + The new time coordinate has the same frequency as the input data but is generated using + :py:func:`xarray.date_range` with the given `start` value. + That coordinate is the same for all periods, depending on the choice of ``window`` and ``freq``, it might make sense. + But for unequal periods or non-uniform calendars, it will certainly not. + If ``stride`` is a divisor of ``window``, the correct timeseries can be reconstructed with :py:func:`unstack_periods`. + The coordinate of `period` is the first timestep of each window. + """ + from xsdba.units import ( # Import in function to avoid cyclical imports; ensure_cf_units, + infer_sampling_units, + ) + + stride = stride or window + min_length = min_length or window + if stride > window: + raise ValueError( + f"Stride must be less than or equal to window. Got {stride} > {window}." + ) + + srcfreq = xr.infer_freq(da.time) + cal = da.time.dt.calendar + use_cftime = da.time.dtype == "O" + + if ( + compare_offsets(srcfreq, "<=", "D") + and align_days + and ( + (freq.startswith(("Y", "A")) and cal not in uniform_calendars) + or (freq.startswith(("Q", "M")) and window > 1 and cal != "360_day") + ) + ): + if freq.startswith(("Y", "A")): + u = "year" + else: + u = "month" + raise ValueError( + f"Stacking {window}{freq} periods will result in unaligned day-of-{u}. " + f"Consider converting the calendar of your data to one with uniform {u} lengths, " + "or pass `align_days=False` to disable this check." + ) + + # Convert integer inputs to freq strings + mult, *args = parse_offset(freq) + win_frq = construct_offset(mult * window, *args) + strd_frq = construct_offset(mult * stride, *args) + minl_frq = construct_offset(mult * min_length, *args) + + # The same time coord as da, but with one extra element. + # This way, the last window's last index is not returned as None by xarray's grouper. + time2 = xr.DataArray( + xr.date_range( + da.time[0].item(), + freq=srcfreq, + calendar=cal, + periods=da.time.size + 1, + use_cftime=use_cftime, + ), + dims=("time",), + name="time", + ) + + periods = [] + # longest = 0 + # Iterate over strides, but recompute the full window for each stride start + for strd_slc in da.resample(time=strd_frq).groups.values(): + win_resamp = time2.isel(time=slice(strd_slc.start, None)).resample(time=win_frq) + # Get slice for first group + win_slc = win_resamp._group_indices[0] + if min_length < window: + # If we ask for a min_length period instead is it complete ? + min_resamp = time2.isel(time=slice(strd_slc.start, None)).resample( + time=minl_frq + ) + min_slc = min_resamp._group_indices[0] + open_ended = min_slc.stop is None + else: + # The end of the group slice is None if no outside-group value was found after the last element + # As we added an extra step to time2, we avoid the case where a group ends exactly on the last element of ds + open_ended = win_slc.stop is None + if open_ended: + # Too short, we got to the end + break + if ( + strd_slc.start == 0 + and parse_offset(freq)[1] in "YAQ" + and min_length == window + and not _month_is_first_period_month(da.time[0].item(), freq) + ): + # For annual or quarterly frequencies (which can be anchor-based), + # if the first time is not in the first month of the first period, + # then the first period is incomplete but by a fractional amount. + continue + periods.append( + slice( + strd_slc.start + win_slc.start, + ( + (strd_slc.start + win_slc.stop) + if win_slc.stop is not None + else da.time.size + ), + ) + ) + + # Make coordinates + lengths = xr.DataArray( + [slc.stop - slc.start for slc in periods], + dims=(dim,), + attrs={"long_name": "Length of each period"}, + ) + longest = lengths.max().item() + # Length as a pint-ready array : with proper units, but values are not usable as indexes anymore + m, u = infer_sampling_units(da) + lengths = lengths * m + # ADAPT: cf-agnostic + # lengths.attrs["units"] = ensure_cf_units(u) + + # Start points for each period and remember parameters for unstacking + starts = xr.DataArray( + [da.time[slc.start].item() for slc in periods], + dims=(dim,), + attrs={ + "long_name": "Start of the period", + # Save parameters so that we can unstack. + "window": window, + "stride": stride, + "freq": freq, + "unequal_lengths": int(len(np.unique(lengths)) > 1), + }, + ) + # The "fake" axis that all periods share + fake_time = xr.date_range( + start, periods=longest, freq=srcfreq, calendar=cal, use_cftime=use_cftime + ) + # Slice and concat along new dim. We drop the index and add a new one so that xarray can concat them together. + out = xr.concat( + [ + da.isel(time=slc) + .drop_vars("time") + .assign_coords(time=np.arange(slc.stop - slc.start)) + for slc in periods + ], + dim, + join="outer", + fill_value=pad_value, + ) + out = out.assign_coords( + time=(("time",), fake_time, da.time.attrs.copy()), + **{f"{dim}_length": lengths, dim: starts}, + ) + out.time.attrs.update(long_name="Placeholder time axis") + return out + + +def unstack_periods(da: xr.DataArray | xr.Dataset, dim: str = "period"): + """Unstack an array constructed with :py:func:`stack_periods`. + + Can only work with periods stacked with a ``stride`` that divides ``window`` in an odd number of sections. + When ``stride`` is smaller than ``window``, only the center-most stride of each window is kept, + except for the beginning and end which are taken from the first and last windows. + + Parameters + ---------- + da : xr.DataArray + As constructed by :py:func:`stack_periods`, attributes of the period coordinates must have been preserved. + dim : str + The period dimension name. + + Notes + ----- + The following table shows which strides are included (``o``) in the unstacked output. + + In this example, ``stride`` was a fifth of ``window`` and ``min_length`` was four (4) times ``stride``. + The row index ``i`` the period index in the stacked dataset, + columns are the stride-long section of the original timeseries. + + .. table:: Unstacking example with ``stride < window``. + + === === === === === === === === + i 0 1 2 3 4 5 6 + === === === === === === === === + 3 x x o o + 2 x x o x x + 1 x x o x x + 0 o o o x x + === === === === === === === === + """ + from xclim.core.units import infer_sampling_units + + try: + starts = da[dim] + window = starts.attrs["window"] + stride = starts.attrs["stride"] + freq = starts.attrs["freq"] + unequal_lengths = bool(starts.attrs["unequal_lengths"]) + except (AttributeError, KeyError) as err: + raise ValueError( + f"`unstack_periods` can't find the window, stride and freq attributes on the {dim} coordinates." + ) from err + + if unequal_lengths: + try: + lengths = da[f"{dim}_length"] + except KeyError as err: + raise ValueError( + f"`unstack_periods` can't find the `{dim}_length` coordinate." + ) from err + # Get length as number of points + m, _ = infer_sampling_units(da.time) + lengths = lengths // m + else: + # It is acceptable to lose "{dim}_length" if they were all equal + lengths = xr.DataArray([da.time.size] * da[dim].size, dims=(dim,)) + + # Convert from the fake axis to the real one + time_as_delta = da.time - da.time[0] + if da.time.dtype == "O": + # cftime can't add with np.timedelta64 (restriction comes from numpy which refuses to add O with m8) + time_as_delta = pd.TimedeltaIndex( + time_as_delta + ).to_pytimedelta() # this array is O, numpy complies + else: + # Xarray will return int when iterating over datetime values, this returns timestamps + starts = pd.DatetimeIndex(starts) + + def _reconstruct_time(_time_as_delta, _start): + times = _time_as_delta + _start + return xr.DataArray(times, dims=("time",), coords={"time": times}, name="time") + + # Easy case: + if window == stride: + # just concat them all + periods = [] + for i, (start, length) in enumerate(zip(starts.values, lengths.values)): + real_time = _reconstruct_time(time_as_delta, start) + periods.append( + da.isel(**{dim: i}, drop=True) + .isel(time=slice(0, length)) + .assign_coords(time=real_time.isel(time=slice(0, length))) + ) + return xr.concat(periods, "time") + + # Difficult and ambiguous case + if (window / stride) % 2 != 1: + raise NotImplementedError( + "`unstack_periods` can't work with strides that do not divide the window into an odd number of parts." + f"Got {window} / {stride} which is not an odd integer." + ) + + # Non-ambiguous overlapping case + Nwin = window // stride + mid = (Nwin - 1) // 2 # index of the center window + + mult, *args = parse_offset(freq) + strd_frq = construct_offset(mult * stride, *args) + + periods = [] + for i, (start, length) in enumerate(zip(starts.values, lengths.values)): + real_time = _reconstruct_time(time_as_delta, start) + slices = real_time.resample(time=strd_frq)._group_indices + if i == 0: + slc = slice(slices[0].start, min(slices[mid].stop, length)) + elif i == da.period.size - 1: + slc = slice(slices[mid].start, min(slices[Nwin - 1].stop or length, length)) + else: + slc = slice(slices[mid].start, min(slices[mid].stop, length)) + periods.append( + da.isel(**{dim: i}, drop=True) + .isel(time=slc) + .assign_coords(time=real_time.isel(time=slc)) + ) + + return xr.concat(periods, "time") diff --git a/src/xsdba/datachecks.py b/src/xsdba/datachecks.py new file mode 100644 index 0000000..9269046 --- /dev/null +++ b/src/xsdba/datachecks.py @@ -0,0 +1,123 @@ +""" +Data Checks +=========== + +Utilities designed to check the validity of data inputs. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import xarray as xr + +from .calendar import compare_offsets, parse_offset +from .logging import ValidationError +from .options import datacheck + + +@datacheck +def check_freq(var: xr.DataArray, freq: str | Sequence[str], strict: bool = True): + """Raise an error if not series has not the expected temporal frequency or is not monotonically increasing. + + Parameters + ---------- + var : xr.DataArray + Input array. + freq : str or sequence of str + The expected temporal frequencies, using Pandas frequency terminology ({'Y', 'M', 'D', 'h', 'min', 's', 'ms', 'us'}) + and multiples thereof. To test strictly for 'W', pass '7D' with `strict=True`. + This ignores the start/end flag and the anchor (ex: 'YS-JUL' will validate against 'Y'). + strict : bool + Whether multiples of the frequencies are considered invalid or not. With `strict` set to False, a '3h' series + will not raise an error if freq is set to 'h'. + + Raises + ------ + ValidationError + - If the frequency of `var` is not inferrable. + - If the frequency of `var` does not match the requested `freq`. + """ + if isinstance(freq, str): + freq = [freq] + exp_base = [parse_offset(frq)[1] for frq in freq] + v_freq = xr.infer_freq(var.time) + if v_freq is None: + raise ValidationError( + "Unable to infer the frequency of the time series. " + "To mute this, set xclim's option data_validation='log'." + ) + v_base = parse_offset(v_freq)[1] + if v_base not in exp_base or ( + strict and all(compare_offsets(v_freq, "!=", frq) for frq in freq) + ): + raise ValidationError( + f"Frequency of time series not {'strictly' if strict else ''} in {freq}. " + "To mute this, set xclim's option data_validation='log'." + ) + + +def check_daily(var: xr.DataArray): + """Raise an error if not series has a frequency other that daily, or is not monotonically increasing. + + Notes + ----- + This does not check for gaps in series. + """ + return check_freq(var, "D") + + +@datacheck +def check_common_time(inputs: Sequence[xr.DataArray]): + """Raise an error if the list of inputs doesn't have a single common frequency. + + Raises + ------ + ValidationError + - if the frequency of any input can't be inferred + - if inputs have different frequencies + - if inputs have a daily or hourly frequency, but they are not given at the same time of day. + + Parameters + ---------- + inputs : Sequence of xr.DataArray + Input arrays. + """ + # Check all have the same freq + freqs = [xr.infer_freq(da.time) for da in inputs] + if None in freqs: + raise ValidationError( + "Unable to infer the frequency of the time series. " + "To mute this, set xclim's option data_validation='log'." + ) + if len(set(freqs)) != 1: + raise ValidationError( + f"Inputs have different frequencies. Got : {freqs}." + "To mute this, set xclim's option data_validation='log'." + ) + + # Check if anchor is the same + freq = freqs[0] + base = parse_offset(freq)[1] + fmt = {"h": ":%M", "D": "%H:%M"} + if base in fmt: + outs = {da.indexes["time"][0].strftime(fmt[base]) for da in inputs} + if len(outs) > 1: + raise ValidationError( + f"All inputs have the same frequency ({freq}), but they are not anchored on the same minutes (got {outs}). " + f"xarray's alignment would silently fail. You can try to fix this with `da.resample('{freq}').mean()`." + "To mute this, set xclim's option data_validation='log'." + ) + + +def is_percentile_dataarray(source: xr.DataArray) -> bool: + """Evaluate whether a DataArray is a Percentile. + + A percentile dataarray must have climatology_bounds attributes and either a + quantile or percentiles coordinate, the window is not mandatory. + """ + return ( + isinstance(source, xr.DataArray) + and source.attrs.get("climatology_bounds", None) is not None + and ("quantile" in source.coords or "percentiles" in source.coords) + ) diff --git a/src/xsdba/formatting.py b/src/xsdba/formatting.py index de5bbfb..77c1c08 100644 --- a/src/xsdba/formatting.py +++ b/src/xsdba/formatting.py @@ -7,11 +7,342 @@ import datetime as dt import itertools -from inspect import signature +import re +import string +import warnings +from ast import literal_eval +from collections.abc import Sequence +from fnmatch import fnmatch +from inspect import _empty, signature +from typing import Any, Callable import xarray as xr from boltons.funcutils import wraps +from .typing import KIND_ANNOTATION, InputKind + + +class AttrFormatter(string.Formatter): + """A formatter for frequently used attribute values. + + See the doc of format_field() for more details. + """ + + def __init__( + self, + mapping: dict[str, Sequence[str]], + modifiers: Sequence[str], + ) -> None: + """Initialize the formatter. + + Parameters + ---------- + mapping : dict[str, Sequence[str]] + A mapping from values to their possible variations. + modifiers : Sequence[str] + The list of modifiers, must be the as long as the longest value of `mapping`. + Cannot include reserved modifier 'r'. + """ + super().__init__() + if "r" in modifiers: + raise ValueError("Modifier 'r' is reserved for default raw formatting.") + self.modifiers = modifiers + self.mapping = mapping + + def format(self, format_string: str, /, *args: Any, **kwargs: Any) -> str: + r"""Format a string. + + Parameters + ---------- + format_string: str + \*args: Any + \*\*kwargs: Any + + Returns + ------- + str + """ + # ADAPT: THIS IS VERY CLIMATE, WILL BE REMOVED + # for k, v in DEFAULT_FORMAT_PARAMS.items(): + # if k not in kwargs: + # kwargs.update({k: v}) + return super().format(format_string, *args, **kwargs) + + def format_field(self, value, format_spec): + """Format a value given a formatting spec. + + If `format_spec` is in this Formatter's modifiers, the corresponding variation + of value is given. If `format_spec` is 'r' (raw), the value is returned unmodified. + If `format_spec` is not specified but `value` is in the mapping, the first variation is returned. + + Examples + -------- + Let's say the string "The dog is {adj1}, the goose is {adj2}" is to be translated + to French and that we know that possible values of `adj` are `nice` and `evil`. + In French, the genre of the noun changes the adjective (cat = chat is masculine, + and goose = oie is feminine) so we initialize the formatter as: + + >>> fmt = AttrFormatter( + ... { + ... "nice": ["beau", "belle"], + ... "evil": ["méchant", "méchante"], + ... "smart": ["intelligent", "intelligente"], + ... }, + ... ["m", "f"], + ... ) + >>> fmt.format( + ... "Le chien est {adj1:m}, l'oie est {adj2:f}, le gecko est {adj3:r}", + ... adj1="nice", + ... adj2="evil", + ... adj3="smart", + ... ) + "Le chien est beau, l'oie est méchante, le gecko est smart" + + The base values may be given using unix shell-like patterns: + + >>> fmt = AttrFormatter( + ... {"YS-*": ["annuel", "annuelle"], "MS": ["mensuel", "mensuelle"]}, + ... ["m", "f"], + ... ) + >>> fmt.format( + ... "La moyenne {freq:f} est faite sur un échantillon {src_timestep:m}", + ... freq="YS-JUL", + ... src_timestep="MS", + ... ) + 'La moyenne annuelle est faite sur un échantillon mensuel' + """ + baseval = self._match_value(value) + if baseval is None: # Not something we know how to translate + if format_spec in self.modifiers + [ + "r" + ]: # Woops, however a known format spec was asked + warnings.warn( + f"Requested formatting `{format_spec}` for unknown string `{value}`." + ) + format_spec = "" + return super().format_field(value, format_spec) + # Thus, known value + + if not format_spec: # (None or '') No modifiers, return first + return self.mapping[baseval][0] + + if format_spec == "r": # Raw modifier + return super().format_field(value, "") + + if format_spec in self.modifiers: # Known modifier + if len(self.mapping[baseval]) == 1: # But unmodifiable entry + return self.mapping[baseval][0] + # Known modifier, modifiable entry + return self.mapping[baseval][self.modifiers.index(format_spec)] + # Known value but unknown modifier, must be a built-in one, only works for the default val... + return super().format_field(self.mapping[baseval][0], format_spec) + + def _match_value(self, value): + if isinstance(value, str): + for mapval in self.mapping.keys(): + if fnmatch(value, mapval): + return mapval + return None + + +# Tag mappings between keyword arguments and long-form text. +default_formatter = AttrFormatter( + { + # Arguments to "freq" + "D": ["daily", "days"], + "YS": ["annual", "years"], + "YS-*": ["annual", "years"], + "MS": ["monthly", "months"], + "QS-*": ["seasonal", "seasons"], + # Arguments to "indexer" + "DJF": ["winter"], + "MAM": ["spring"], + "JJA": ["summer"], + "SON": ["fall"], + "norm": ["Normal"], + "m1": ["january"], + "m2": ["february"], + "m3": ["march"], + "m4": ["april"], + "m5": ["may"], + "m6": ["june"], + "m7": ["july"], + "m8": ["august"], + "m9": ["september"], + "m10": ["october"], + "m11": ["november"], + "m12": ["december"], + # Arguments to "op / reducer / stat" (for example for generic.stats) + "integral": ["integrated", "integral"], + "count": ["count"], + "doymin": ["day of minimum"], + "doymax": ["day of maximum"], + "mean": ["average"], + "max": ["maximal", "maximum"], + "min": ["minimal", "minimum"], + "sum": ["total", "sum"], + "std": ["standard deviation"], + "var": ["variance"], + "absamp": ["absolute amplitude"], + "relamp": ["relative amplitude"], + # For when we are formatting indicator classes with empty options + "": [""], + }, + ["adj", "noun"], +) + + +def parse_doc(doc: str) -> dict[str, str]: + """Crude regex parsing reading an indice docstring and extracting information needed in indicator construction. + + The appropriate docstring syntax is detailed in :ref:`notebooks/extendxclim:Defining new indices`. + + Parameters + ---------- + doc : str + The docstring of an indice function. + + Returns + ------- + dict + A dictionary with all parsed sections. + """ + if doc is None: + return {} + + out = {} + + sections = re.split(r"(\w+\s?\w+)\n\s+-{3,50}", doc) # obj.__doc__.split('\n\n') + intro = sections.pop(0) + if intro: + intro_content = list(map(str.strip, intro.strip().split("\n\n"))) + if len(intro_content) == 1: + out["title"] = intro_content[0] + elif len(intro_content) >= 2: + out["title"], abstract = intro_content[:2] + out["abstract"] = " ".join(map(str.strip, abstract.splitlines())) + + for i in range(0, len(sections), 2): + header, content = sections[i : i + 2] + + if header in ["Notes", "References"]: + out[header.lower()] = content.replace("\n ", "\n").strip() + elif header == "Parameters": + out["parameters"] = _parse_parameters(content) + elif header == "Returns": + rets = _parse_returns(content) + if rets: + meta = list(rets.values())[0] + if "long_name" in meta: + out["long_name"] = meta["long_name"] + return out + + +def _parse_parameters(section): + """Parse the 'parameters' section of a docstring into a dictionary. + + Works by mapping the parameter name to its description and, potentially, to its set of choices. + The type annotation are not parsed, except for fixed sets of values (listed as "{'a', 'b', 'c'}"). + The annotation parsing only accepts strings, numbers, `None` and `nan` (to represent `numpy.nan`). + """ + curr_key = None + params = {} + for line in section.split("\n"): + if line.startswith(" " * 6): # description + s = " " if params[curr_key]["description"] else "" + params[curr_key]["description"] += s + line.strip() + elif line.startswith(" " * 4) and ":" in line: # param title + name, annot = line.split(":", maxsplit=1) + curr_key = name.strip() + params[curr_key] = {"description": ""} + match = re.search(r".*(\{.*\}).*", annot) + if match: + try: + choices = literal_eval(match.groups()[0]) + params[curr_key]["choices"] = choices + except ValueError: # noqa: S110 + # If the literal_eval fails, we just ignore the choices. + pass + return params + + +def _parse_returns(section): + """Parse the returns section of a docstring into a dictionary mapping the parameter name to its description.""" + curr_key = None + params = {} + for line in section.split("\n"): + if line.strip(): + if line.startswith(" " * 6): # long_name + s = " " if params[curr_key]["long_name"] else "" + params[curr_key]["long_name"] += s + line.strip() + elif line.startswith(" " * 4): # param title + annot, *name = reversed(line.split(":", maxsplit=1)) + if name: + curr_key = name[0].strip() + else: + curr_key = None + params[curr_key] = {"long_name": ""} + annot, *unit = annot.split(",", maxsplit=1) + if unit: + params[curr_key]["units"] = unit[0].strip() + return params + + +# XC +def prefix_attrs(source: dict, keys: Sequence, prefix: str) -> dict: + """Rename some keys of a dictionary by adding a prefix. + + Parameters + ---------- + source : dict + Source dictionary, for example data attributes. + keys : sequence + Names of keys to prefix. + prefix : str + Prefix to prepend to keys. + + Returns + ------- + dict + Dictionary of attributes with some keys prefixed. + """ + out = {} + for key, val in source.items(): + if key in keys: + out[f"{prefix}{key}"] = val + else: + out[key] = val + return out + + +# XC +def unprefix_attrs(source: dict, keys: Sequence, prefix: str) -> dict: + """Remove prefix from keys in a dictionary. + + Parameters + ---------- + source : dict + Source dictionary, for example data attributes. + keys : sequence + Names of original keys for which prefix should be removed. + prefix : str + Prefix to remove from keys. + + Returns + ------- + dict + Dictionary of attributes whose keys were prefixed, with prefix removed. + """ + out = {} + n = len(prefix) + for key, val in source.items(): + k = key[n:] + if (k in keys) and key.startswith(prefix): + out[k] = val + elif key not in out: + out[key] = val + return out + # XC def merge_attributes( @@ -205,3 +536,163 @@ def gen_call_string( elements.append(rep) return f"{funcname}({', '.join(elements)})" + + +# XC +def _gen_parameters_section( + parameters: dict[str, dict[str, Any]], allowed_periods: list[str] | None = None +) -> str: + """Generate the "parameters" section of the indicator docstring. + + Parameters + ---------- + parameters : dict + Parameters dictionary (`Ind.parameters`). + allowed_periods : list of str, optional + Restrict parameters to specific periods. Default: None. + + Returns + ------- + str + """ + section = "Parameters\n----------\n" + for name, param in parameters.items(): + desc_str = param.description + if param.kind == InputKind.FREQ_STR: + desc_str += ( + " See https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset" + "-aliases for available options." + ) + if allowed_periods is not None: + desc_str += ( + f" Restricted to frequencies equivalent to one of {allowed_periods}" + ) + if param.kind == InputKind.VARIABLE: + defstr = f"Default : `ds.{param.default}`. " + elif param.kind == InputKind.OPTIONAL_VARIABLE: + defstr = "" + elif param.default is not _empty: + defstr = f"Default : {param.default}. " + else: + defstr = "Required. " + if "choices" in param: + annotstr = str(param.choices) + else: + annotstr = KIND_ANNOTATION[param.kind] + if "units" in param and param.units is not None: + unitstr = f"[Required units : {param.units}]" + else: + unitstr = "" + section += f"{name} {': ' if annotstr else ''}{annotstr}\n {desc_str}\n {defstr}{unitstr}\n" + return section + + +def _gen_returns_section(cf_attrs: Sequence[dict[str, Any]]) -> str: + """Generate the "Returns" section of an indicator's docstring. + + Parameters + ---------- + cf_attrs : Sequence[Dict[str, Any]] + The list of attributes, usually Indicator.cf_attrs. + + Returns + ------- + str + """ + section = "Returns\n-------\n" + for attrs in cf_attrs: + if not section.endswith("\n"): + section += "\n" + section += f"{attrs['var_name']} : DataArray\n" + section += f" {attrs.get('long_name', '')}" + if "standard_name" in attrs: + section += f" ({attrs['standard_name']})" + if "units" in attrs: + section += f" [{attrs['units']}]" + added_section = "" + for key, attr in attrs.items(): + if key not in ["long_name", "standard_name", "units", "var_name"]: + if callable(attr): + attr = "" + added_section += f" **{key}**: {attr};" + if added_section: + section = f"{section}, with additional attributes:{added_section[:-1]}" + section += "\n" + return section + + +def generate_indicator_docstring(ind) -> str: + """Generate an indicator's docstring from keywords. + + Parameters + ---------- + ind + Indicator instance + + Returns + ------- + str + """ + header = f"{ind.title} (realm: {ind.realm})\n\n{ind.abstract}\n" + + special = "" + + if hasattr(ind, "missing"): # Only ResamplingIndicators + special += f'This indicator will check for missing values according to the method "{ind.missing}".\n' + if hasattr(ind.compute, "__module__"): + special += f"Based on indice :py:func:`~{ind.compute.__module__}.{ind.compute.__name__}`.\n" + if ind.injected_parameters: + special += "With injected parameters: " + special += ", ".join( + [f"{k}={v}" for k, v in ind.injected_parameters.items()] + ) + special += ".\n" + if ind.keywords: + special += f"Keywords : {ind.keywords}.\n" + + parameters = _gen_parameters_section( + ind.parameters, getattr(ind, "allowed_periods", None) + ) + + returns = _gen_returns_section(ind.cf_attrs) + + extras = "" + for section in ["notes", "references"]: + if getattr(ind, section): + extras += f"{section.capitalize()}\n{'-' * len(section)}\n{getattr(ind, section)}\n\n" + + doc = f"{header}\n{special}\n{parameters}\n{returns}\n{extras}" + return doc + + +def get_percentile_metadata(data: xr.DataArray, prefix: str) -> dict[str, str]: + """Get the metadata related to percentiles from the given DataArray as a dictionary. + + Parameters + ---------- + data : xr.DataArray + Must be a percentile DataArray, this means the necessary metadata + must be available in its attributes and coordinates. + prefix : str + The prefix to be used in the metadata key. + Usually this takes the form of "tasmin_per" or equivalent. + + Returns + ------- + dict + A mapping of the configuration used to compute these percentiles. + """ + # handle case where da was created with `quantile()` method + if "quantile" in data.coords: + percs = data.coords["quantile"].values * 100 + elif "percentiles" in data.coords: + percs = data.coords["percentiles"].values + else: + percs = "" + clim_bounds = data.attrs.get("climatology_bounds", "") + + return { + f"{prefix}_thresh": percs, + f"{prefix}_window": data.attrs.get("window", ""), + f"{prefix}_period": clim_bounds, + } diff --git a/src/xsdba/locales.py b/src/xsdba/locales.py new file mode 100644 index 0000000..1eba691 --- /dev/null +++ b/src/xsdba/locales.py @@ -0,0 +1,331 @@ +""" +Internationalization +==================== + +This module defines methods and object to help the internationalization of metadata for +climate indicators computed by xclim. Go to :ref:`notebooks/customize:Adding translated metadata` to see +how to use this feature. + +All the methods and objects in this module use localization data given in JSON files. +These files are expected to be defined as in this example for French: + +.. code-block:: + + { + "attrs_mapping": { + "modifiers": ["", "f", "mpl", "fpl"], + "YS": ["annuel", "annuelle", "annuels", "annuelles"], + "YS-*": ["annuel", "annuelle", "annuels", "annuelles"], + # ... and so on for other frequent parameters translation... + }, + "DTRVAR": { + "long_name": "Variabilité de l'amplitude de la température diurne", + "description": "Variabilité {freq:f} de l'amplitude de la température diurne (définie comme la moyenne de la variation journalière de l'amplitude de température sur une période donnée)", + "title": "Variation quotidienne absolue moyenne de l'amplitude de la température diurne", + "comment": "", + "abstract": "La valeur absolue de la moyenne de l'amplitude de la température diurne.", + }, + # ... and so on for other indicators... + } + +Indicators are named by subclass identifier, the same as in the indicator registry (`xclim.core.indicators.registry`), +but which can differ from the callable name. In this case, the indicator is called through +`atmos.daily_temperature_range_variability`, but its identifier is `DTRVAR`. +Use the `ind.__class__.__name__` accessor to get its registry name. + +Here, the usual parameter passed to the formatting of "description" is "freq" and is usually translated from "YS" +to "annual". However, in French and in this sentence, the feminine form should be used, so the "f" modifier is added +by the translator so that the formatting function knows which translation to use. Acceptable entries for the mappings +are limited to what is already defined in `xclim.core.indicators.utils.default_formatter`. + +For user-provided internationalization dictionaries, only the "attrs_mapping" and its "modifiers" key are mandatory, +all other entries (translations of frequent parameters and all indicator entries) are optional. +For xclim-provided translations (for now only French), all indicators must have en entry and the "attrs_mapping" +entries must match exactly the default formatter. +Those default translations are found in the `xclim/locales` folder. +""" + +from __future__ import annotations + +import json +import warnings +from collections.abc import Sequence +from copy import deepcopy +from pathlib import Path + +from .formatting import AttrFormatter, default_formatter + +TRANSLATABLE_ATTRS = [ + "long_name", + "description", + "comment", + "title", + "abstract", + "keywords", +] +""" +List of attributes to consider translatable when generating locale dictionaries. +""" + +_LOCALES = {} + + +def list_locales(): + """List of loaded locales. Includes all loaded locales, no matter how complete the translations are.""" + return list(_LOCALES.keys()) + + +def _valid_locales(locales): + """Check if the locales are valid.""" + if isinstance(locales, str): + return True + return all( + [ + # A locale is valid if it is a string from the list + (isinstance(locale, str) and locale in _LOCALES) + or ( + # Or if it is a tuple of a string and either a file or a dict. + not isinstance(locale, str) + and isinstance(locale[0], str) + and (isinstance(locale[1], dict) or Path(locale[1]).is_file()) + ) + for locale in locales + ] + ) + + +def get_local_dict(locale: str | Sequence[str] | tuple[str, dict]) -> tuple[str, dict]: + """Return all translated metadata for a given locale. + + Parameters + ---------- + locale: str or sequence of str + IETF language tag or a tuple of the language tag and a translation dict, or a tuple of the language + tag and a path to a json file defining translation of attributes. + + Raises + ------ + UnavailableLocaleError + If the given locale is not available. + + Returns + ------- + str + The best fitting locale string + dict + The available translations in this locale. + """ + _valid_locales([locale]) + + if isinstance(locale, str): + if locale not in _LOCALES: + raise UnavailableLocaleError(locale) + + return locale, deepcopy(_LOCALES[locale]) + + if isinstance(locale[1], dict): + trans = locale[1] + else: + # Thus, a string pointing to a json file + trans = read_locale_file(locale[1]) + + if locale[0] in _LOCALES: + loaded_trans = deepcopy(_LOCALES[locale[0]]) + # Passed translations have priority + loaded_trans.update(trans) + trans = loaded_trans + return locale[0], trans + + +def get_local_attrs( + indicator: str | Sequence[str], + *locales: str | Sequence[str] | tuple[str, dict], + names: Sequence[str] | None = None, + append_locale_name: bool = True, +) -> dict: + """Get all attributes of an indicator in the requested locales. + + Parameters + ---------- + indicator : str or sequence of strings + Indicator's class name, usually the same as in `xc.core.indicator.registry`. + If multiple names are passed, the attrs from each indicator are merged, + with the highest priority set to the first name. + locales : str or tuple of str + IETF language tag or a tuple of the language tag and a translation dict, or a tuple of the language tag + and a path to a json file defining translation of attributes. + names : sequence of str, optional + If given, only returns translations of attributes in this list. + append_locale_name : bool + If True (default), append the language tag (as "{attr_name}_{locale}") to the returned attributes. + + Raises + ------ + ValueError + If `append_locale_name` is False and multiple `locales` are requested. + + Returns + ------- + dict + All CF attributes available for given indicator and locales. + Warns and returns an empty dict if none were available. + """ + if isinstance(indicator, str): + indicator = [indicator] + + if not append_locale_name and len(locales) > 1: + raise ValueError( + "`append_locale_name` cannot be False if multiple locales are requested." + ) + + attrs = {} + for locale in locales: + loc_name, loc_dict = get_local_dict(locale) + loc_name = f"_{loc_name}" if append_locale_name else "" + local_attrs = loc_dict.get(indicator[-1], {}) + for other_ind in indicator[-2::-1]: + local_attrs.update(loc_dict.get(other_ind, {})) + if not local_attrs: + warnings.warn( + f"Attributes of indicator {', '.join(indicator)} in language {locale} " + "were requested, but none were found." + ) + else: + for name in TRANSLATABLE_ATTRS: + if (names is None or name in names) and name in local_attrs: + attrs[f"{name}{loc_name}"] = local_attrs[name] + return attrs + + +def get_local_formatter( + locale: str | Sequence[str] | tuple[str, dict] +) -> AttrFormatter: + """Return an AttrFormatter instance for the given locale. + + Parameters + ---------- + locale: str or tuple of str + IETF language tag or a tuple of the language tag and a translation dict, or a tuple of the language tag + and a path to a json file defining translation of attributes. + """ + _, loc_dict = get_local_dict(locale) + if "attrs_mapping" in loc_dict: + attrs_mapping = loc_dict["attrs_mapping"].copy() + mods = attrs_mapping.pop("modifiers") + return AttrFormatter(attrs_mapping, mods) + + warnings.warn( + "No `attrs_mapping` entry found for locale {loc_name}, using default (english) formatter." + ) + return default_formatter + + +class UnavailableLocaleError(ValueError): + """Error raised when a locale is requested but doesn't exist.""" + + def __init__(self, locale): + super().__init__( + f"Locale {locale} not available. Use `xclim.core.locales.list_locales()` to see available languages." + ) + + +def read_locale_file( + filename, module: str | None = None, encoding: str = "UTF8" +) -> dict[str, dict]: + """Read a locale file (.json) and return its dictionary. + + Parameters + ---------- + filename : PathLike + The file to read. + module : str, optional + If module is a string, this module name is added to all identifiers translated in this file. + Defaults to None, and no module name is added (as if the indicator was an official xclim indicator). + encoding : str + The encoding to use when reading the file. + Defaults to UTF-8, overriding python's default mechanism which is machine dependent. + """ + locdict: dict[str, dict] + with open(filename, encoding=encoding) as f: + locdict = json.load(f) + + if module is not None: + locdict = { + (k if k == "attrs_mapping" else f"{module}.{k}"): v + for k, v in locdict.items() + } + return locdict + + +def load_locale(locdata: str | Path | dict[str, dict], locale: str): + """Load translations from a json file into xclim. + + Parameters + ---------- + locdata : str or Path or dictionary + Either a loaded locale dictionary or a path to a json file. + locale : str + The locale name (IETF tag). + """ + if isinstance(locdata, (str, Path)): + filename = Path(locdata) + locdata = read_locale_file(filename) + + if locale in _LOCALES: + _LOCALES[locale].update(locdata) + else: + _LOCALES[locale] = locdata + + +def generate_local_dict(locale: str, init_english: bool = False) -> dict: + """Generate a dictionary with keys for each indicator and translatable attributes. + + Parameters + ---------- + locale : str + Locale in the IETF format + init_english : bool + If True, fills the initial dictionary with the english versions of the attributes. + Defaults to False. + """ + from ..core.indicator import registry # pylint: disable=import-outside-toplevel + + if locale in _LOCALES: + _, attrs = get_local_dict(locale) + for ind_name in attrs.copy().keys(): + if ind_name != "attrs_mapping" and ind_name not in registry: + attrs.pop(ind_name) + else: + attrs = {} + + attrs_mapping = attrs.setdefault("attrs_mapping", {}) + attrs_mapping.setdefault("modifiers", [""]) + for key, value in default_formatter.mapping.items(): + attrs_mapping.setdefault(key, [value[0]]) + + eng_attr = "" + for ind_name, indicator in registry.items(): + ind_attrs = attrs.setdefault(ind_name, {}) + for translatable_attr in set(TRANSLATABLE_ATTRS).difference( + set(indicator._cf_names) + ): + if init_english: + eng_attr = getattr(indicator, translatable_attr) + if not isinstance(eng_attr, str): + eng_attr = "" + ind_attrs.setdefault(f"{translatable_attr}", eng_attr) + + for cf_attrs in indicator.cf_attrs: + # In the case of single output, put var attrs in main dict + if len(indicator.cf_attrs) > 1: + ind_attrs = attrs.setdefault(f"{ind_name}.{cf_attrs['var_name']}", {}) + + for translatable_attr in set(TRANSLATABLE_ATTRS).intersection( + set(indicator._cf_names) + ): + if init_english: + eng_attr = cf_attrs.get(translatable_attr) + if not isinstance(eng_attr, str): + eng_attr = "" + ind_attrs.setdefault(f"{translatable_attr}", eng_attr) + return attrs diff --git a/src/xsdba/logging.py b/src/xsdba/logging.py index 79ee33c..14ae2c3 100644 --- a/src/xsdba/logging.py +++ b/src/xsdba/logging.py @@ -54,3 +54,7 @@ def raise_warn_or_log( warnings.warn(message, stacklevel=stacklevel + 1) else: # mode == "raise" raise err from err_type(message) + + +class MissingVariableError(ValueError): + """Error raised when a dataset is passed to an indicator but one of the needed variable is missing.""" diff --git a/src/xsdba/options.py b/src/xsdba/options.py index ef48f11..cd34814 100644 --- a/src/xsdba/options.py +++ b/src/xsdba/options.py @@ -11,14 +11,15 @@ from boltons.funcutils import wraps -# from .locales import _valid_locales # from XC, not reproduced for now +from .locales import _valid_locales from .logging import ValidationError, raise_warn_or_log -# METADATA_LOCALES = "metadata_locales" +METADATA_LOCALES = "metadata_locales" DATA_VALIDATION = "data_validation" CF_COMPLIANCE = "cf_compliance" CHECK_MISSING = "check_missing" MISSING_OPTIONS = "missing_options" +RUN_LENGTH_UFUNC = "run_length_ufunc" SDBA_EXTRA_OUTPUT = "sdba_extra_output" SDBA_ENCODE_CF = "sdba_encode_cf" KEEP_ATTRS = "keep_attrs" @@ -27,11 +28,12 @@ MISSING_METHODS: dict[str, Callable] = {} OPTIONS = { - # METADATA_LOCALES: [], + METADATA_LOCALES: [], DATA_VALIDATION: "raise", CF_COMPLIANCE: "warn", CHECK_MISSING: "any", MISSING_OPTIONS: {}, + RUN_LENGTH_UFUNC: "auto", SDBA_EXTRA_OUTPUT: False, SDBA_ENCODE_CF: False, KEEP_ATTRS: "xarray", @@ -39,6 +41,7 @@ } _LOUDNESS_OPTIONS = frozenset(["log", "warn", "raise"]) +_RUN_LENGTH_UFUNC_OPTIONS = frozenset(["auto", True, False]) _KEEP_ATTRS_OPTIONS = frozenset(["xarray", True, False]) @@ -57,11 +60,12 @@ def _valid_missing_options(mopts): _VALIDATORS = { - # METADATA_LOCALES: _valid_locales, + METADATA_LOCALES: _valid_locales, DATA_VALIDATION: _LOUDNESS_OPTIONS.__contains__, CF_COMPLIANCE: _LOUDNESS_OPTIONS.__contains__, CHECK_MISSING: lambda meth: meth != "from_context" and meth in MISSING_METHODS, MISSING_OPTIONS: _valid_missing_options, + RUN_LENGTH_UFUNC: _RUN_LENGTH_UFUNC_OPTIONS.__contains__, SDBA_EXTRA_OUTPUT: lambda opt: isinstance(opt, bool), SDBA_ENCODE_CF: lambda opt: isinstance(opt, bool), KEEP_ATTRS: _KEEP_ATTRS_OPTIONS.__contains__, @@ -74,16 +78,16 @@ def _set_missing_options(mopts): OPTIONS[MISSING_OPTIONS][meth].update(opts) -# def _set_metadata_locales(locales): -# if isinstance(locales, str): -# OPTIONS[METADATA_LOCALES] = [locales] -# else: -# OPTIONS[METADATA_LOCALES] = locales +def _set_metadata_locales(locales): + if isinstance(locales, str): + OPTIONS[METADATA_LOCALES] = [locales] + else: + OPTIONS[METADATA_LOCALES] = locales _SETTERS = { MISSING_OPTIONS: _set_missing_options, - # METADATA_LOCALES: _set_metadata_locales, + METADATA_LOCALES: _set_metadata_locales, } diff --git a/src/xsdba/processing.py b/src/xsdba/processing.py index 992b198..4e37cb3 100644 --- a/src/xsdba/processing.py +++ b/src/xsdba/processing.py @@ -14,7 +14,7 @@ import xarray as xr from xarray.core.utils import get_temp_dimname -from xsdba.base import get_calendar, max_doy, parse_offset, uses_dask +from xsdba.calendar import get_calendar, max_doy, parse_offset, uses_dask from xsdba.formatting import update_xsdba_history from ._processing import _adapt_freq, _normalize, _reordering @@ -388,7 +388,7 @@ def escore( tgt: xr.DataArray, sim: xr.DataArray, dims: Sequence[str] = ("variables", "time"), - N: int = 0, # noqa + N: int = 0, scale: bool = False, ) -> xr.DataArray: r"""Energy score, or energy dissimilarity metric, based on :cite:t:`sdba-szekely_testing_2004` and :cite:t:`sdba-cannon_multivariate_2018`. diff --git a/src/xsdba/properties.py b/src/xsdba/properties.py new file mode 100644 index 0000000..648cc51 --- /dev/null +++ b/src/xsdba/properties.py @@ -0,0 +1,1577 @@ +# pylint: disable=missing-kwoa +""" +Properties Submodule +==================== +SDBA diagnostic tests are made up of statistical properties and measures. Properties are calculated on both simulation +and reference datasets. They collapse the time dimension to one value. + +This framework for the diagnostic tests was inspired by the `VALUE `_ project. +Statistical Properties is the xclim term for 'indices' in the VALUE project. + +""" +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np +import xarray as xr +import xclim as xc +from scipy import stats +from statsmodels.tsa import stattools + +import xsdba.xclim_submodules.run_length as rl +from xsdba.indicator import Indicator, base_registry +from xsdba.units import ( + convert_units_to, + ensure_delta, + pint2str, + to_agg_units, + units2pint, +) +from xsdba.utils import uses_dask +from xsdba.xclim_submodules.generic import compare, select_resample_op +from xsdba.xclim_submodules.stats import fit, parametric_quantile + +from .base import Grouper, map_groups, parse_group, parse_offset +from .nbutils import _pairwise_haversine_and_bins +from .utils import _pairwise_spearman, copy_all_attrs + + +class StatisticalProperty(Indicator): + """Base indicator class for statistical properties used for validating bias-adjusted outputs. + + Statistical properties reduce the time dimension, sometimes adding a grouping dimension + according to the passed value of `group` (e.g.: group='time.month' means the loss of the + time dimension and the addition of a month one). + + Statistical properties are generally unit-generic. To use those indicator in a workflow, it + is recommended to wrap them with a virtual submodule, creating one specific indicator for + each variable input (or at least for each possible dimensionality). + + Statistical properties may restrict the sampling frequency of the input, they usually take in a + single variable (named "da" in unit-generic instances). + + """ + + aspect = None + """The aspect the statistical property studies: marginal, temporal, multivariate or spatial.""" + + measure = "xclim.sdba.measures.BIAS" + """The default measure to use when comparing the properties of two datasets. + This gives the registry id. See :py:meth:`get_measure`.""" + + allowed_groups = None + """A list of allowed groupings. A subset of dayofyear, week, month, season or group. + The latter stands for no temporal grouping.""" + + realm = "generic" + + @classmethod + def _ensure_correct_parameters(cls, parameters): + if "group" not in parameters: + raise ValueError( + f"{cls.__name__} require a 'group' argument, use the base Indicator" + " class if your computation doesn't perform any regrouping." + ) + return super()._ensure_correct_parameters(parameters) + + def _preprocess_and_checks(self, das, params): + """Check if group is allowed.""" + # Convert grouping and check if allowed: + if isinstance(params["group"], str): + params["group"] = Grouper(params["group"]) + + if self.allowed_groups is not None: + if params["group"].prop not in self.allowed_groups: + raise ValueError( + f"Grouping period {params['group'].prop_name} is not allowed for property " + f"{self.identifier} (needs something in " + f"{map(lambda g: '.' + g.replace('group', ''), self.allowed_groups)})." + ) + + return das, params + + def _postprocess(self, outs, das, params): + """Squeeze `group` dim if needed.""" + outs = super()._postprocess(outs, das, params) + + for i in range(len(outs)): + if "group" in outs[i].dims: + outs[i] = outs[i].squeeze("group", drop=True) + + return outs + + def get_measure(self): + """Get the statistical measure indicator that is best used with this statistical property.""" + from xclim.core.indicator import registry + + return registry[self.measure].get_instance() + + +base_registry["StatisticalProperty"] = StatisticalProperty + + +@parse_group +def _mean(da: xr.DataArray, *, group: str | Grouper = "time") -> xr.DataArray: + """Mean. + + Mean over all years at the time resolution. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. If 'time.month', the temporal average is performed separately for each month. + + Returns + ------- + xr.DataArray, [same as input] + Mean of the variable. + """ + units = da.units + if group.prop != "group": + da = da.groupby(group.name) + out = da.mean(dim=group.dim) + return out.assign_attrs(units=units) + + +mean = StatisticalProperty( + identifier="mean", + aspect="marginal", + cell_methods="time: mean", + compute=_mean, +) + + +@parse_group +def _var(da: xr.DataArray, *, group: str | Grouper = "time") -> xr.DataArray: + """Variance. + + Variance of the variable over all years at the time resolution. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. If 'time.month', the variance is performed separately for each month. + + Returns + ------- + xr.DataArray, [square of the input units] + Variance of the variable. + """ + units = da.units + if group.prop != "group": + da = da.groupby(group.name) + out = da.var(dim=group.dim) + u2 = units2pint(units) ** 2 + out.attrs["units"] = pint2str(u2) + return out + + +var = StatisticalProperty( + identifier="var", + aspect="marginal", + cell_methods="time: var", + compute=_var, + measure="xclim.sdba.measures.RATIO", +) + + +@parse_group +def _std(da: xr.DataArray, *, group: str | Grouper = "time") -> xr.DataArray: + """Standard Deviation. + + Standard deviation of the variable over all years at the time resolution. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. If 'time.month', the standard deviation is performed separately for each month. + + Returns + ------- + xr.DataArray, + Standard deviation of the variable. + """ + units = da.units + if group.prop != "group": + da = da.groupby(group.name) + out = da.std(dim=group.dim) + out.attrs["units"] = units + return out + + +std = StatisticalProperty( + identifier="std", + aspect="marginal", + cell_methods="time: std", + compute=_std, + measure="xclim.sdba.measures.RATIO", +) + + +@parse_group +def _skewness(da: xr.DataArray, *, group: str | Grouper = "time") -> xr.DataArray: + """Skewness. + + Skewness of the distribution of the variable over all years at the time resolution. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. If 'time.month', the skewness is performed separately for each month. + + Returns + ------- + xr.DataArray, [dimensionless] + Skewness of the variable. + + See Also + -------- + scipy.stats.skew + """ + if group.prop != "group": + da = da.groupby(group.name) + out = xr.apply_ufunc( + stats.skew, + da, + input_core_dims=[[group.dim]], + vectorize=True, + dask="parallelized", + ) + out.attrs["units"] = "" + return out + + +skewness = StatisticalProperty( + identifier="skewness", aspect="marginal", compute=_skewness, units="" +) + + +@parse_group +def _quantile( + da: xr.DataArray, *, q: float = 0.98, group: str | Grouper = "time" +) -> xr.DataArray: + """Quantile. + + Returns the quantile q of the distribution of the variable over all years at the time resolution. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + q : float + Quantile to be calculated. Should be between 0 and 1. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. If 'time.month', the quantile is computed separately for each month. + + Returns + ------- + xr.DataArray, [same as input] + Quantile {q} of the variable. + """ + units = da.units + if group.prop != "group": + da = da.groupby(group.name) + out = da.quantile(q, dim=group.dim, keep_attrs=True).drop_vars("quantile") + return out.assign_attrs(units=units) + + +quantile = StatisticalProperty( + identifier="quantile", aspect="marginal", compute=_quantile +) + + +def _spell_length_distribution( + da: xr.DataArray, + *, + method: str = "amount", + op: str = ">=", + thresh: str = "1 mm d-1", + window: int = 1, + stat: str = "mean", + stat_resample: str | None = None, + group: str | Grouper = "time", + resample_before_rl: bool = True, +) -> xr.DataArray: + """Spell length distribution. + + Statistic of spell length distribution when the variable respects a condition (defined by an operation, a method and + a threshold). + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + method: {'amount', 'quantile'} + Method to choose the threshold. + 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}. + 'quantile': The threshold is calculated as the quantile {thresh} of the distribution. + op : {">", "<", ">=", "<="} + Operation to verify the condition for a spell. + The condition for a spell is variable {op} threshold. + thresh : str or float + Threshold on which to evaluate the condition to have a spell. + String with units if the method is "amount". + Float of the quantile if the method is "quantile". + window : int + Number of consecutive days respecting the constraint in order to begin a spell. + Default is 1, which is equivalent to `_threshold_count` + stat : {'mean', 'sum', 'max','min'} + Statistics to apply to the remaining time dimension after resampling (e.g. Jan 1980-2010) + stat_resample : {'mean', 'sum', 'max','min'}, optional + Statistics to apply to the resampled input at the {group} (e.g. 1-31 Jan 1980). + If `None`, the same method as `stat` will be used. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. If 'time.month', the spell lengths are computed separately for each month. + resample_before_rl : bool + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. + + Returns + ------- + xr.DataArray, [units of the sampling frequency] + {stat} of spell length distribution when the variable is {op} the {method} {thresh} for {window} consecutive day(s). + """ + group = group if isinstance(group, Grouper) else Grouper(group) + + ops = {">": np.greater, "<": np.less, ">=": np.greater_equal, "<=": np.less_equal} + + @map_groups(out=[Grouper.PROP], main_only=True) + def _spell_stats( + ds, + *, + dim, + method, + thresh, + window, + op, + freq, + resample_before_rl, + stat, + stat_resample, + ): + # PB: This prevents an import error in the distributed dask scheduler, but I don't know why. + import xarray.core.resample_cftime # noqa: F401, pylint: disable=unused-import + + da = ds.data + mask = ~(da.isel({dim: 0}).isnull()).drop_vars( + dim + ) # mask of the ocean with NaNs + if method == "quantile": + thresh = da.quantile(thresh, dim=dim).drop_vars("quantile") + + cond = op(da, thresh) + out = rl.resample_and_rl( + cond, + resample_before_rl, + rl.rle_statistics, + reducer=stat_resample, + window=window, + dim=dim, + freq=freq, + ) + out = getattr(out, stat)(dim=dim) + out = out.where(mask) + return out.rename("out").to_dataset() + + # threshold is an amount that will be converted to the right units + if method == "amount": + thresh = convert_units_to(thresh, da) # , context="infer") + elif method != "quantile": + raise ValueError( + f"{method} is not a valid method. Choose 'amount' or 'quantile'." + ) + + out = _spell_stats( + da.rename("data").to_dataset(), + group=group, + method=method, + thresh=thresh, + window=window, + op=ops[op], + freq=group.freq, + resample_before_rl=resample_before_rl, + stat=stat, + stat_resample=stat_resample or stat, + ).out + return to_agg_units(out, da, op="count") + + +spell_length_distribution = StatisticalProperty( + identifier="spell_length_distribution", + aspect="temporal", + compute=_spell_length_distribution, +) + + +@parse_group +def _threshold_count( + da: xr.DataArray, + *, + method: str = "amount", + op: str = ">=", + thresh: str = "1 mm d-1", + stat: str = "mean", + stat_resample: str | None = None, + group: str | Grouper = "time", +) -> xr.DataArray: + r"""Correlation between two variables. + + Spearman or Pearson correlation coefficient between two variables at the time resolution. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + method : {'amount', 'quantile'} + Method to choose the threshold. + 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}. + 'quantile': The threshold is calculated as the quantile {thresh} of the distribution. + op : {">", "<", ">=", "<="} + Operation to verify the condition for a spell. + The condition for a spell is variable {op} threshold. + thresh : str or float + Threshold on which to evaluate the condition to have a spell. + String with units if the method is "amount". + Float of the quantile if the method is "quantile". + stat : {'mean', 'sum', 'max','min'} + Statistics to apply to the remaining time dimension after resampling (e.g. Jan 1980-2010) + stat_resample : {'mean', 'sum', 'max','min'}, optional + Statistics to apply to the resampled input at the {group} (e.g. 1-31 Jan 1980). If `None`, the same method as `stat` will be used. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. For 'time.month', the correlation would be calculated on each month separately, + but with all the years together. + + Returns + ------- + xr.DataArray, [dimensionless] + {stat} number of days when the variable is {op} the {method} {thresh}. + + Notes + ----- + This corresponds to ``xclim.sdba.properties._spell_length_distribution`` with `window=1`. + """ + return _spell_length_distribution( + da, + method=method, + op=op, + thresh=thresh, + stat=stat, + stat_resample=stat_resample, + group=group, + window=1, + ) + + +threshold_count = StatisticalProperty( + identifier="threshold_count", aspect="temporal", compute=_threshold_count +) + + +@parse_group +def _acf( + da: xr.DataArray, *, lag: int = 1, group: str | Grouper = "time.season" +) -> xr.DataArray: + """Autocorrelation. + + Autocorrelation with a lag over a time resolution and averaged over all years. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + lag : int + Lag. + group : {'time.season', 'time.month'} + Grouping of the output. + e.g. If 'time.month', the autocorrelation is calculated over each month separately for all years. + Then, the autocorrelation for all Jan/Feb/... is averaged over all years, giving 12 outputs for each grid point. + + Returns + ------- + xr.DataArray, [dimensionless] + Lag-{lag} autocorrelation of the variable over a {group.prop} and averaged over all years. + + See Also + -------- + statsmodels.tsa.stattools.acf + + References + ---------- + :cite:cts:`alavoine_distinct_2022` + """ + + def acf_last(x, nlags): + """Statsmodels acf calculates acf for lag 0 to nlags, this return only the last one.""" + # As we resample + group, timeseries are quite short and fft=False seems more performant + out_last = stattools.acf(x, nlags=nlags, fft=False) + return out_last[-1] + + @map_groups(out=[Grouper.PROP], main_only=True) + def __acf(ds, *, dim, lag, freq): + out = xr.apply_ufunc( + acf_last, + ds.data.resample({dim: freq}), + input_core_dims=[[dim]], + vectorize=True, + kwargs={"nlags": lag}, + ) + out = out.mean("__resample_dim__") + return out.rename("out").to_dataset() + + out = __acf( + da.rename("data").to_dataset(), group=group, lag=lag, freq=group.freq + ).out + out.attrs["units"] = "" + return out + + +acf = StatisticalProperty( + identifier="acf", + aspect="temporal", + allowed_groups=["season", "month"], + compute=_acf, +) + + +# group was kept even though "time" is the only acceptable arg to keep the signature similar to other properties +# @parse_group doesn't work well here because of `window` +def _annual_cycle( + da: xr.DataArray, + *, + stat: str = "absamp", + window: int = 31, + group: str | Grouper = "time", +) -> xr.DataArray: + r"""Annual cycle statistics. + + A daily climatology is calculated and optionally smoothed with a (circular) moving average. + The requested statistic is returned. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + stat : {'absamp','relamp', 'phase', 'min', 'max', 'asymmetry'} + - 'absamp' is the peak-to-peak amplitude. (max - min). In the same units as the input. + - 'relamp' is a relative percentage. 100 * (max - min) / mean (Recommended for precipitation). Dimensionless. + - 'phase' is the day of year of the maximum. + - 'max' is the maximum. Same units as the input. + - 'min' is the minimum. Same units as the input. + - 'asymmetry' is the length of the period going from the minimum to the maximum. In years between 0 and 1. + window : int + Size of the window for the moving average filtering. Deactivate this feature by passing window = 1. + + Returns + ------- + xr.DataArray, [same units as input or dimensionless or time] + {stat} of the annual cycle. + """ + group = group if isinstance(group, Grouper) else Grouper(group) + units = da.units + + ac = da.groupby("time.dayofyear").mean() + if window > 1: # smooth the cycle + # We want the rolling mean to be circular. There's no built-in method to do this in xarray, + # we'll pad the array and extract the meaningful part. + ac = ( + ac.pad(dayofyear=(window // 2), mode="wrap") + .rolling(dayofyear=window, center=True) + .mean() + .isel(dayofyear=slice(window // 2, -(window // 2))) + ) + # TODO: In April 2024, use a match-case. + if stat == "absamp": + out = ac.max("dayofyear") - ac.min("dayofyear") + out.attrs["units"] = xc.core.units.ensure_delta(units) + elif stat == "relamp": + out = (ac.max("dayofyear") - ac.min("dayofyear")) * 100 / ac.mean("dayofyear") + out.attrs["units"] = "%" + elif stat == "phase": + out = ac.idxmax("dayofyear") + out.attrs.update(units="", is_dayofyear=np.int32(1)) + elif stat == "min": + out = ac.min("dayofyear") + out.attrs["units"] = units + elif stat == "max": + out = ac.max("dayofyear") + out.attrs["units"] = units + elif stat == "asymmetry": + out = (ac.idxmax("dayofyear") - ac.idxmin("dayofyear")) % 365 / 365 + out.attrs["units"] = "yr" + else: + raise NotImplementedError(f"{stat} is not a valid annual cycle statistic.") + return out + + +annual_cycle_amplitude = StatisticalProperty( + identifier="annual_cycle_amplitude", + aspect="temporal", + compute=_annual_cycle, + parameters={"stat": "absamp"}, + allowed_groups=["group"], + cell_methods="time: mean time: range", +) + +relative_annual_cycle_amplitude = StatisticalProperty( + identifier="relative_annual_cycle_amplitude", + aspect="temporal", + compute=_annual_cycle, + units="%", + parameters={"stat": "relamp"}, + allowed_groups=["group"], + cell_methods="time: mean time: range", + measure="xclim.sdba.measures.RATIO", +) + +annual_cycle_phase = StatisticalProperty( + identifier="annual_cycle_phase", + aspect="temporal", + units="", + compute=_annual_cycle, + parameters={"stat": "phase"}, + cell_methods="time: range", + allowed_groups=["group"], + measure="xclim.sdba.measures.CIRCULAR_BIAS", +) + +annual_cycle_asymmetry = StatisticalProperty( + identifier="annual_cycle_asymmetry", + aspect="temporal", + compute=_annual_cycle, + parameters={"stat": "asymmetry"}, + allowed_groups=["group"], + units="yr", +) + +annual_cycle_minimum = StatisticalProperty( + identifier="annual_cycle_minimum", + aspect="temporal", + units="", + compute=_annual_cycle, + parameters={"stat": "min"}, + cell_methods="time: mean time: min", + allowed_groups=["group"], +) + +annual_cycle_maximum = StatisticalProperty( + identifier="annual_cycle_maximum", + aspect="temporal", + compute=_annual_cycle, + parameters={"stat": "max"}, + cell_methods="time: mean time: max", + allowed_groups=["group"], +) + + +# @parse_group +def _annual_statistic( + da: xr.DataArray, + *, + stat: str = "absamp", + window: int = 31, + group: str | Grouper = "time", +): + """Annual range statistics. + + Compute a statistic on each year of data and return the interannual average. This is similar + to the annual cycle, but with the statistic and average operations inverted. + + Parameters + ---------- + da: xr.DataArray + Data. + stat : {'absamp', 'relamp', 'phase'} + The statistic to return. + window : int + Size of the window for the moving average filtering. Deactivate this feature by passing window = 1. + + Returns + ------- + xr.DataArray, [same units as input or dimensionless] + Average annual {stat}. + """ + units = da.units + + if window > 1: + da = da.rolling(time=window, center=True).mean() + + yrs = da.resample(time="YS") + + if stat == "absamp": + out = yrs.max() - yrs.min() + out.attrs["units"] = ensure_delta(units) + elif stat == "relamp": + out = (yrs.max() - yrs.min()) * 100 / yrs.mean() + out.attrs["units"] = "%" + elif stat == "phase": + out = yrs.map(xr.DataArray.idxmax).dt.dayofyear + out.attrs.update(units="", is_dayofyear=np.int32(1)) + else: + raise NotImplementedError(f"{stat} is not a valid annual cycle statistic.") + return out.mean("time", keep_attrs=True) + + +mean_annual_range = StatisticalProperty( + identifier="mean_annual_range", + aspect="temporal", + compute=_annual_statistic, + parameters={"stat": "absamp"}, + allowed_groups=["group"], +) + +mean_annual_relative_range = StatisticalProperty( + identifier="mean_annual_relative_range", + aspect="temporal", + compute=_annual_statistic, + parameters={"stat": "relamp"}, + allowed_groups=["group"], + units="%", + measure="xclim.sdba.measures.RATIO", +) + +mean_annual_phase = StatisticalProperty( + identifier="mean_annual_phase", + aspect="temporal", + compute=_annual_statistic, + parameters={"stat": "phase"}, + allowed_groups=["group"], + units="", + measure="xclim.sdba.measures.CIRCULAR_BIAS", +) + + +@parse_group +def _corr_btw_var( + da1: xr.DataArray, + da2: xr.DataArray, + *, + corr_type: str = "Spearman", + group: str | Grouper = "time", + output: str = "correlation", +) -> xr.DataArray: + r"""Correlation between two variables. + + Spearman or Pearson correlation coefficient between two variables at the time resolution. + + Parameters + ---------- + da1 : xr.DataArray + First variable on which to calculate the diagnostic. + da2 : xr.DataArray + Second variable on which to calculate the diagnostic. + corr_type: {'Pearson','Spearman'} + Type of correlation to calculate. + output: {'correlation', 'pvalue'} + Whether to return the correlation coefficient or the p-value. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. For 'time.month', the correlation would be calculated on each month separately, + but with all the years together. + + Returns + ------- + xr.DataArray, [dimensionless] + {corr_type} correlation coefficient + """ + if corr_type.lower() not in {"pearson", "spearman"}: + raise ValueError( + f"{corr_type} is not a valid type. Choose 'Pearson' or 'Spearman'." + ) + + index = {"correlation": 0, "pvalue": 1}[output] + + def _first_output_1d(a, b, index, corr_type): + """Only keep the correlation (first output) from the scipy function.""" + # for points in the water with NaNs + if np.isnan(a).all(): + return np.nan + aok = ~np.isnan(a) + bok = ~np.isnan(b) + if corr_type == "Pearson": + return stats.pearsonr(a[aok & bok], b[aok & bok])[index] + return stats.spearmanr(a[aok & bok], b[aok & bok])[index] + + @map_groups(out=[Grouper.PROP], main_only=True) + def _first_output(ds, *, dim, index, corr_type): + out = xr.apply_ufunc( + _first_output_1d, + ds.a, + ds.b, + input_core_dims=[[dim], [dim]], + vectorize=True, + dask="parallelized", + kwargs={"index": index, "corr_type": corr_type}, + ) + return out.rename("out").to_dataset() + + out = _first_output( + xr.Dataset({"a": da1, "b": da2}), group=group, index=index, corr_type=corr_type + ).out + out.attrs["units"] = "" + return out + + +corr_btw_var = StatisticalProperty( + identifier="corr_btw_var", aspect="multivariate", compute=_corr_btw_var +) + + +def _bivariate_spell_length_distribution( + da1: xr.DataArray, + da2: xr.DataArray, + *, + method1: str = "amount", + method2: str = "amount", + op1: str = ">=", + op2: str = ">=", + thresh1: str = "1 mm d-1", + thresh2: str = "1 mm d-1", + window: int = 1, + stat: str = "mean", + stat_resample: str | None = None, + group: str | Grouper = "time", + resample_before_rl: bool = True, +) -> xr.DataArray: + """Spell length distribution with bivariate condition. + + Statistic of spell length distribution when two variables respect individual conditions (defined by an operation, a method, + and a threshold). + + Parameters + ---------- + da1 : xr.DataArray + First variable on which to calculate the diagnostic. + da2 : xr.DataArray + Second variable on which to calculate the diagnostic. + method1 : {'amount', 'quantile'} + Method to choose the threshold. + 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}. + 'quantile': The threshold is calculated as the quantile {thresh} of the distribution. + method2 : {'amount', 'quantile'} + Method to choose the threshold. + 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}. + 'quantile': The threshold is calculated as the quantile {thresh} of the distribution. + op1 : {">", "<", ">=", "<="} + Operation to verify the condition for a spell. + The condition for a spell is variable {op1} threshold. + op2 : {">", "<", ">=", "<="} + Operation to verify the condition for a spell. + The condition for a spell is variable {op2} threshold. + thresh1 : str or float + Threshold on which to evaluate the condition to have a spell. + String with units if the method is "amount". + Float of the quantile if the method is "quantile". + thresh2 : str or float + Threshold on which to evaluate the condition to have a spell. + String with units if the method is "amount". + Float of the quantile if the method is "quantile". + window : int + Number of consecutive days respecting the constraint in order to begin a spell. + Default is 1, which is equivalent to `_bivariate_threshold_count` + stat : {'mean', 'sum', 'max','min'} + Statistics to apply to the remaining time dimension after resampling (e.g. Jan 1980-2010) + stat_resample : {'mean', 'sum', 'max','min'}, optional + Statistics to apply to the resampled input at the {group} (e.g. 1-31 Jan 1980). If `None`, the same method as `stat` will be used. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. If 'time.month', the spell lengths are computed separately for each month. + resample_before_rl : bool + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. + + Returns + ------- + xr.DataArray, [units of the sampling frequency] + {stat} of spell length distribution when the first variable is {op1} the {method1} {thresh1} + and the second variable is {op2} the {method2} {thresh2} for {window} consecutive day(s). + """ + group = group if isinstance(group, Grouper) else Grouper(group) + ops = { + ">": np.greater, + "<": np.less, + ">=": np.greater_equal, + "<=": np.less_equal, + } + + @map_groups(out=[Grouper.PROP], main_only=True) + def _bivariate_spell_stats( + ds, + *, + dim, + methods, + threshs, + opss, + freq, + window, + resample_before_rl, + stat, + stat_resample, + ): + # PB: This prevents an import error in the distributed dask scheduler, but I don't know why. + import xarray.core.resample_cftime # noqa: F401, pylint: disable=unused-import + + conds = [] + masks = [] + for da, thresh, op, method in zip([ds.da1, ds.da2], threshs, opss, methods): + masks.append( + ~(da.isel({dim: 0}).isnull()).drop_vars(dim) + ) # mask of the ocean with NaNs + if method == "quantile": + thresh = da.quantile(thresh, dim=dim).drop_vars("quantile") + conds.append(op(da, thresh)) + mask = masks[0] & masks[1] + cond = conds[0] & conds[1] + out = rl.resample_and_rl( + cond, + resample_before_rl, + rl.rle_statistics, + reducer=stat_resample, + window=window, + dim=dim, + freq=freq, + ) + out = getattr(out, stat)(dim=dim) + out = out.where(mask) + return out.rename("out").to_dataset() + + # threshold is an amount that will be converted to the right units + methods = [method1, method2] + threshs = [thresh1, thresh2] + for i, da in enumerate([da1, da2]): + if methods[i] == "amount": + # ADAPT: will this work? + threshs[i] = convert_units_to(threshs[i], da) # , context="infer") + elif methods[i] != "quantile": + raise ValueError( + f"{methods[i]} is not a valid method. Choose 'amount' or 'quantile'." + ) + + out = _bivariate_spell_stats( + xr.Dataset({"da1": da1, "da2": da2}), + group=group, + threshs=threshs, + methods=methods, + opss=[ops[op1], ops[op2]], + window=window, + freq=group.freq, + resample_before_rl=resample_before_rl, + stat=stat, + stat_resample=stat_resample or stat, + ).out + return to_agg_units(out, da1, op="count") + + +bivariate_spell_length_distribution = StatisticalProperty( + identifier="bivariate_spell_length_distribution", + aspect="temporal", + compute=_bivariate_spell_length_distribution, +) + + +@parse_group +def _bivariate_threshold_count( + da1: xr.DataArray, + da2: xr.DataArray, + *, + method1: str = "amount", + method2: str = "amount", + op1: str = ">=", + op2: str = ">=", + thresh1: str = "1 mm d-1", + thresh2: str = "1 mm d-1", + stat: str = "mean", + stat_resample: str | None = None, + group: str | Grouper = "time", +) -> xr.DataArray: + """Count the number of time steps where two variables respect given conditions. + + Statistic of number of time steps when two variables respect individual conditions (defined by an operation, a method, + and a threshold). + + Parameters + ---------- + da1 : xr.DataArray + First variable on which to calculate the diagnostic. + da2 : xr.DataArray + Second variable on which to calculate the diagnostic. + method1 : {'amount', 'quantile'} + Method to choose the threshold. + 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}. + 'quantile': The threshold is calculated as the quantile {thresh} of the distribution. + method2 : {'amount', 'quantile'} + Method to choose the threshold. + 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}. + 'quantile': The threshold is calculated as the quantile {thresh} of the distribution. + op1 : {">", "<", ">=", "<="} + Operation to verify the condition for a spell. + The condition for a spell is variable {op} threshold. + op2 : {">", "<", ">=", "<="} + Operation to verify the condition for a spell. + The condition for a spell is variable {op} threshold. + thresh1 : str or float + Threshold on which to evaluate the condition to have a spell. + String with units if the method is "amount". + Float of the quantile if the method is "quantile". + thresh2 : str or float + Threshold on which to evaluate the condition to have a spell. + String with units if the method is "amount". + Float of the quantile if the method is "quantile". + stat : {'mean', 'sum', 'max','min'} + Statistics to apply to the remaining time dimension after resampling (e.g. Jan 1980-2010) + stat_resample : {'mean', 'sum', 'max','min'}, optional + Statistics to apply to the resampled input at the {group} (e.g. 1-31 Jan 1980). + If `None`, the same method as `stat` will be used. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. + e.g. For 'time.month', the correlation would be calculated on each month separately, + but with all the years together. + + Returns + ------- + xr.DataArray, [dimensionless] + {stat} number of days when the first variable is {op1} the {method1} {thresh1} + and the second variable is {op2} the {method2} {thresh2} for {window} consecutive day(s). + + Notes + ----- + This corresponds to ``xclim.sdba.properties._bivariate_spell_length_distribution`` with `window=1`. + """ + return _bivariate_spell_length_distribution( + da1, + da2, + method1=method1, + method2=method2, + op1=op1, + op2=op2, + thresh1=thresh1, + thresh2=thresh2, + window=1, + stat=stat, + stat_resample=stat_resample, + group=group, + ) + + +bivariate_threshold_count = StatisticalProperty( + identifier="bivariate_threshold_count", + aspect="multivariate", + compute=_bivariate_threshold_count, +) + + +@parse_group +def _relative_frequency( + da: xr.DataArray, + *, + op: str = ">=", + thresh: str = "1 mm d-1", + group: str | Grouper = "time", +) -> xr.DataArray: + """Relative Frequency. + + Relative Frequency of days with variable respecting a condition (defined by an operation and a threshold) at the + time resolution. The relative frequency is the number of days that satisfy the condition divided by the total number + of days. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + op : {">", "<", ">=", "<="} + Operation to verify the condition. + The condition is variable {op} threshold. + thresh : str + Threshold on which to evaluate the condition. + group : {'time', 'time.season', 'time.month'} + Grouping on the output. + e.g. For 'time.month', the relative frequency would be calculated on each month, with all years included. + + Returns + ------- + xr.DataArray, [dimensionless] + Relative frequency of values {op} {thresh}. + """ + # mask of the ocean with NaNs + mask = ~(da.isel({group.dim: 0}).isnull()).drop_vars(group.dim) + ops: dict[str, np.ufunc] = { + ">": np.greater, + "<": np.less, + ">=": np.greater_equal, + "<=": np.less_equal, + } + t = convert_units_to(thresh, da) # , context="infer") + length = da.sizes[group.dim] + cond = ops[op](da, t) + if group.prop != "group": # change the time resolution if necessary + cond = cond.groupby(group.name) + # length of the groupBy groups + length = np.array([len(v) for k, v in cond.groups.items()]) + for _ in range(da.ndim - 1): # add empty dimension(s) to match input + length = np.expand_dims(length, axis=-1) + # count days with the condition and divide by total nb of days + out = cond.sum(dim=group.dim, skipna=False) / length + out = out.where(mask, np.nan) + out.attrs["units"] = "" + return out + + +relative_frequency = StatisticalProperty( + identifier="relative_frequency", aspect="temporal", compute=_relative_frequency +) + + +@parse_group +def _transition_probability( + da: xr.DataArray, + *, + initial_op: str = ">=", + final_op: str = ">=", + thresh: str = "1 mm d-1", + group: str | Grouper = "time", +) -> xr.DataArray: + """Transition probability. + + Probability of transition from the initial state to the final state. The states are + booleans comparing the value of the day to the threshold with the operator. + + The transition occurs when consecutive days are both in the given states. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + initial_op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Operation to verify the condition for the initial state. + The condition is variable {op} threshold. + final_op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Operation to verify the condition for the final state. + The condition is variable {op} threshold. + thresh : str + Threshold on which to evaluate the condition. + group : {"time", "time.season", "time.month"} + Grouping on the output. + e.g. For "time.month", the transition probability would be calculated on each month, with all years included. + + Returns + ------- + xr.DataArray, [dimensionless] + Transition probability of values {initial_op} {thresh} to values {final_op} {thresh}. + """ + # mask of the ocean with NaNs + mask = ~(da.isel({group.dim: 0}).isnull()).drop_vars(group.dim) + + today = da.isel(time=slice(0, -1)) + tomorrow = da.shift(time=-1).isel(time=slice(0, -1)) + + t = convert_units_to(thresh, da) # , context="infer") + cond = compare(today, initial_op, t) * compare(tomorrow, final_op, t) + out = group.apply("mean", cond) + out = out.where(mask, np.nan) + out.attrs["units"] = "" + return out + + +transition_probability = StatisticalProperty( + identifier="transition_probability", + aspect="temporal", + compute=_transition_probability, +) + + +@parse_group +def _trend( + da: xr.DataArray, + *, + group: str | Grouper = "time", + output: str = "slope", +) -> xr.DataArray: + """Linear Trend. + + The data is averaged over each time resolution and the inter-annual trend is returned. + This function will rechunk along the grouping dimension. + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + output : {'slope', 'intercept', 'rvalue', 'pvalue', 'stderr', 'intercept_stderr'} + The attributes of the linear regression to return, as defined in scipy.stats.linregress: + 'slope' is the slope of the regression line. + 'intercept' is the intercept of the regression line. + 'rvalue' is The Pearson correlation coefficient. + The square of rvalue is equal to the coefficient of determination. + 'pvalue' is the p-value for a hypothesis test whose null hypothesis is that the slope is zero, + using Wald Test with t-distribution of the test statistic. + 'stderr' is the standard error of the estimated slope (gradient), under the assumption of residual normality. + 'intercept_stderr' is the standard error of the estimated intercept, under the assumption of residual normality. + group : {'time', 'time.season', 'time.month'} + Grouping on the output. + + Returns + ------- + xr.DataArray, [units of input per year or dimensionless] + {output} of the interannual linear trend. + + See Also + -------- + scipy.stats.linregress + + numpy.polyfit + """ + units = da.units + + da = da.resample({group.dim: group.freq}) # separate all the {group} + da_mean = da.mean(dim=group.dim) # avg over all {group} + if uses_dask(da_mean): + da_mean = da_mean.chunk({group.dim: -1}) + if group.prop != "group": + da_mean = da_mean.groupby(group.name) # group all month/season together + + def modified_lr( + x, + ): # modify linregress to fit into apply_ufunc and only return slope + return getattr(stats.linregress(list(range(len(x))), x), output) + + out = xr.apply_ufunc( + modified_lr, + da_mean, + input_core_dims=[[group.dim]], + vectorize=True, + dask="parallelized", + ) + out.attrs["units"] = f"{units}/year" + return out + + +trend = StatisticalProperty(identifier="trend", aspect="temporal", compute=_trend) + + +@parse_group +def _return_value( + da: xr.DataArray, + *, + period: int = 20, + op: str = "max", + method: str = "ML", + group: str | Grouper = "time", +) -> xr.DataArray: + r"""Return value. + + Return the value corresponding to a return period. On average, the return value will be exceeded + (or not exceed for op='min') every return period (e.g. 20 years). The return value is computed by first extracting + the variable annual maxima/minima, fitting a statistical distribution to the maxima/minima, + then estimating the percentile associated with the return period (eg. 95th percentile (1/20) for 20 years) + + Parameters + ---------- + da : xr.DataArray + Variable on which to calculate the diagnostic. + period : int + Return period. Number of years over which to check if the value is exceeded (or not for op='min'). + op : {'max','min'} + Whether we are looking for a probability of exceedance ('max', right side of the distribution) + or a probability of non-exceedance (min, left side of the distribution). + method : {"ML", "PWM"} + Fitting method, either maximum likelihood (ML) or probability weighted moments (PWM), also called L-Moments. + The PWM method is usually more robust to outliers. + group : {'time', 'time.season', 'time.month'} + Grouping of the output. A distribution of the extremes is done for each group. + + Returns + ------- + xr.DataArray, [same as input] + {period}-{group.prop_name} {op} return level of the variable. + """ + + @map_groups(out=[Grouper.PROP], main_only=True) + def frequency_analysis_method(ds, *, dim, method): + sub = select_resample_op(ds.x, op=op) + params = fit(sub, dist="genextreme", method=method) + out = parametric_quantile(params, q=1 - 1.0 / period) + return out.isel(quantile=0, drop=True).rename("out").to_dataset() + + out = frequency_analysis_method( + da.rename("x").to_dataset(), method=method, group=group + ).out + return out.assign_attrs(units=da.units) + + +return_value = StatisticalProperty( + identifier="return_value", aspect="temporal", compute=_return_value +) + + +@parse_group +def _spatial_correlogram( + da: xr.DataArray, + *, + dims: Sequence[str] | None = None, + bins: int = 100, + group: str = "time", + method: int = 1, +): + """Spatial correlogram. + + Compute the pairwise spatial correlations (Spearman) and averages them based on the pairwise distances. + This collapses the spatial and temporal dimensions and returns a distance bins dimension. + Needs coordinates for longitude and latitude. This property is heavy to compute, and it will + need to create a NxN array in memory (outside of dask), where N is the number of spatial points. + There are shortcuts for all-nan time-slices or spatial points, but scipy's nan-omitting algorithm + is extremely slow, so the presence of any lone NaN will increase the computation time. Based on an idea + from :cite:p:`francois_multivariate_2020`. + + Parameters + ---------- + da : xr.DataArray + Data. + dims : sequence of strings, optional + Name of the spatial dimensions. Once these are stacked, the longitude and latitude coordinates must be 1D. + bins : int + Same as argument `bins` from :py:meth:`xarray.DataArray.groupby_bins`. + If given as a scalar, the equal-width bin limits are generated here + (instead of letting xarray do it) to improve performance. + group : str + Useless for now. + + Returns + ------- + xr.DataArray, [dimensionless] + Inter-site correlogram as a function of distance. + """ + if dims is None: + dims = [d for d in da.dims if d != "time"] + + corr = _pairwise_spearman(da, dims) + dists, mn, mx = _pairwise_haversine_and_bins( + corr.cf["longitude"].values, corr.cf["latitude"].values + ) + dists = xr.DataArray(dists, dims=corr.dims, coords=corr.coords, name="distance") + if np.isscalar(bins): + bins = np.linspace(mn * 0.9999, mx * 1.0001, bins + 1) + if uses_dask(corr): + dists = dists.chunk() + + w = np.diff(bins) + centers = xr.DataArray( + bins[:-1] + w / 2, + dims=("distance_bins",), + attrs={ + "units": "km", + "long_name": f"Centers of the intersite distance bins (width of {w[0]:.3f} km)", + }, + ) + + dists = dists.where(corr.notnull()) + + def _bin_corr(corr, distance): + """Bin and mean.""" + return stats.binned_statistic( + distance.flatten(), corr.flatten(), statistic="mean", bins=bins + ).statistic + + # (_spatial, _spatial2) -> (_spatial, distance_bins) + binned = xr.apply_ufunc( + _bin_corr, + corr, + dists, + input_core_dims=[["_spatial", "_spatial2"], ["_spatial", "_spatial2"]], + output_core_dims=[["distance_bins"]], + dask="parallelized", + vectorize=True, + output_dtypes=[float], + dask_gufunc_kwargs={ + "allow_rechunk": True, + "output_sizes": {"distance_bins": bins}, + }, + ) + binned = ( + binned.assign_coords(distance_bins=centers) + .rename(distance_bins="distance") + .assign_attrs(units="") + .rename("corr") + ) + return binned + + +spatial_correlogram = StatisticalProperty( + identifier="spatial_correlogram", + aspect="spatial", + compute=_spatial_correlogram, + allowed_groups=["group"], +) + + +def _decorrelation_length( + da: xr.DataArray, + *, + radius: int | float = 300, + thresh: float = 0.50, + dims: Sequence[str] | None = None, + bins: int = 100, + group: xr.Coordinate | str | None = "time", # FIXME: this needs to be clarified +): + """Decorrelation length. + + Distance from a grid cell where the correlation with its neighbours goes below the threshold. + A correlogram is calculated for each grid cell following the method from + ``xclim.sdba.properties.spatial_correlogram``. Then, we find the first bin closest to the correlation threshold. + + Parameters + ---------- + da : xr.DataArray + Data. + radius : float + Radius (in km) defining the region where correlations will be calculated between a point and its neighbours. + thresh : float + Threshold correlation defining decorrelation. + The decorrelation length is defined as the center of the distance bin that has a correlation closest + to this threshold. + dims : sequence of strings + Name of the spatial dimensions. Once these are stacked, the longitude and latitude coordinates must be 1D. + bins : int + Same as argument `bins` from :py:meth:`scipy.stats.binned_statistic`. + If given as a scalar, the equal-width bin limits from 0 to radius are generated here + (instead of letting scipy do it) to improve performance. + group : xarray.Coordinate or str, optional + Useless for now. + + Returns + ------- + xr.DataArray, [km] + Decorrelation length. + + Notes + ----- + Calculating this property requires a lot of memory. It will not work with large datasets. + """ + if dims is None and group is not None: + dims = [d for d in da.dims if d != group.dim] + + corr = _pairwise_spearman(da, dims) + + dists, _, _ = _pairwise_haversine_and_bins( + corr.cf["longitude"].values, corr.cf["latitude"].values, transpose=True + ) + + dists = xr.DataArray(dists, dims=corr.dims, coords=corr.coords, name="distance") + + trans_dists = xr.DataArray( + dists.T, dims=corr.dims, coords=corr.coords, name="distance" + ) + + if np.isscalar(bins): + bin_array = np.linspace(0, radius, bins + 1) + elif isinstance(bins, np.ndarray): + bin_array = bins + else: + raise ValueError("bins must be a scalar or a numpy array.") + + if uses_dask(corr): + dists = dists.chunk() + trans_dists = trans_dists.chunk() + + w = np.diff(bin_array) + centers = xr.DataArray( + bin_array[:-1] + w / 2, + dims=("distance_bins",), + attrs={ + "units": "km", + "long_name": f"Centers of the intersite distance bins (width of {w[0]:.3f} km)", + }, + ) + ds = xr.Dataset({"corr": corr, "distance": dists, "distance2": trans_dists}) + + # only keep points inside the radius + ds = ds.where(ds.distance < radius) + ds = ds.where(ds.distance2 < radius) + + def _bin_corr(_corr, _distance): + """Bin and mean.""" + mask_nan = ~np.isnan(_corr) + binned_corr = stats.binned_statistic( + _distance[mask_nan], _corr[mask_nan], statistic="mean", bins=bin_array + ) + stat = binned_corr.statistic + return stat + + # (_spatial, _spatial2) -> (_spatial, distance_bins) + binned = ( + xr.apply_ufunc( + _bin_corr, + ds.corr, + ds.distance, + input_core_dims=[["_spatial2"], ["_spatial2"]], + output_core_dims=[["distance_bins"]], + dask="parallelized", + vectorize=True, + output_dtypes=[float], + dask_gufunc_kwargs={ + "allow_rechunk": True, + "output_sizes": {"distance_bins": len(bin_array)}, + }, + ) + .rename("corr") + .to_dataset() + ) + + binned = ( + binned.assign_coords(distance_bins=centers) + .rename(distance_bins="distance") + .assign_attrs(units="") + ) + + closest = abs(binned.corr - thresh).idxmin(dim="distance") + binned["decorrelation_length"] = closest + + # get back to 2d lat and lon + # if 'lat' in dims and 'lon' in dims: + if len(dims) > 1: + binned = binned.set_index({"_spatial": dims}) + out = binned.decorrelation_length.unstack() + else: + out = binned.swap_dims({"_spatial": dims[0]}).decorrelation_length + + copy_all_attrs(out, da) + + out.attrs["units"] = "km" + return out + + +decorrelation_length = StatisticalProperty( + identifier="decorrelation_length", + aspect="spatial", + compute=_decorrelation_length, + allowed_groups=["group"], +) + + +def first_eof(): + """EOF Statistical Property (function removed). + + Warnings + -------- + Due to a licensing issue, eofs-based functionality has been permanently removed. + Please excuse the inconvenience. + For more information, see: https://github.com/Ouranosinc/xclim/issues/1620 + """ + raise RuntimeError( + "Due to a licensing issue, eofs-based functionality has been permanently removed. " + "Please excuse the inconvenience. " + "For more information, see: https://github.com/Ouranosinc/xclim/issues/1620" + ) diff --git a/src/xsdba/typing.py b/src/xsdba/typing.py new file mode 100644 index 0000000..ac96ad2 --- /dev/null +++ b/src/xsdba/typing.py @@ -0,0 +1,133 @@ +"""# noqa: SS01 +Typing Utilities +=================================== +""" + +from __future__ import annotations + +from enum import IntEnum +from typing import NewType, TypeVar + +import xarray as xr +from pint import Quantity + +# XC: +#: Type annotation for strings representing full dates (YYYY-MM-DD), may include time. +DateStr = NewType("DateStr", str) + +#: Type annotation for strings representing dates without a year (MM-DD). +DayOfYearStr = NewType("DayOfYearStr", str) + +#: Type annotation for thresholds and other not-exactly-a-variable quantities +Quantified = TypeVar("Quantified", xr.DataArray, str, Quantity) + + +# XC +class InputKind(IntEnum): + """Constants for input parameter kinds. + + For use by external parses to determine what kind of data the indicator expects. + On the creation of an indicator, the appropriate constant is stored in + :py:attr:`xclim.core.indicator.Indicator.parameters`. The integer value is what gets stored in the output + of :py:meth:`xclim.core.indicator.Indicator.json`. + + For developers : for each constant, the docstring specifies the annotation a parameter of an indice function + should use in order to be picked up by the indicator constructor. Notice that we are using the annotation format + as described in `PEP 604 `_, i.e. with '|' indicating a union and without import + objects from `typing`. + """ + + VARIABLE = 0 + """A data variable (DataArray or variable name). + + Annotation : ``xr.DataArray``. + """ + OPTIONAL_VARIABLE = 1 + """An optional data variable (DataArray or variable name). + + Annotation : ``xr.DataArray | None``. The default should be None. + """ + QUANTIFIED = 2 + """A quantity with units, either as a string (scalar), a pint.Quantity (scalar) or a DataArray (with units set). + + Annotation : ``xclim.core.utils.Quantified`` and an entry in the :py:func:`xclim.core.units.declare_units` + decorator. "Quantified" translates to ``str | xr.DataArray | pint.util.Quantity``. + """ + FREQ_STR = 3 + """A string representing an "offset alias", as defined by pandas. + + See the Pandas documentation on :ref:`timeseries.offset_aliases` for a list of valid aliases. + + Annotation : ``str`` + ``freq`` as the parameter name. + """ + NUMBER = 4 + """A number. + + Annotation : ``int``, ``float`` and unions thereof, potentially optional. + """ + STRING = 5 + """A simple string. + + Annotation : ``str`` or ``str | None``. In most cases, this kind of parameter makes sense + with choices indicated in the docstring's version of the annotation with curly braces. + See :ref:`notebooks/extendxclim:Defining new indices`. + """ + DAY_OF_YEAR = 6 + """A date, but without a year, in the MM-DD format. + + Annotation : :py:obj:`xclim.core.utils.DayOfYearStr` (may be optional). + """ + DATE = 7 + """A date in the YYYY-MM-DD format, may include a time. + + Annotation : :py:obj:`xclim.core.utils.DateStr` (may be optional). + """ + NUMBER_SEQUENCE = 8 + """A sequence of numbers + + Annotation : ``Sequence[int]``, ``Sequence[float]`` and unions thereof, may include single ``int`` and ``float``, + may be optional. + """ + BOOL = 9 + """A boolean flag. + + Annotation : ``bool``, may be optional. + """ + DICT = 10 + """A dictionary. + + Annotation : ``dict`` or ``dict | None``, may be optional. + """ + KWARGS = 50 + """A mapping from argument name to value. + + Developers : maps the ``**kwargs``. Please use as little as possible. + """ + DATASET = 70 + """An xarray dataset. + + Developers : as indices only accept DataArrays, this should only be added on the indicator's constructor. + """ + OTHER_PARAMETER = 99 + """An object that fits None of the previous kinds. + + Developers : This is the fallback kind, it will raise an error in xclim's unit tests if used. + """ + + +KIND_ANNOTATION = { + InputKind.VARIABLE: "str or DataArray", + InputKind.OPTIONAL_VARIABLE: "str or DataArray, optional", + InputKind.QUANTIFIED: "quantity (string or DataArray, with units)", + InputKind.FREQ_STR: "offset alias (string)", + InputKind.NUMBER: "number", + InputKind.NUMBER_SEQUENCE: "number or sequence of numbers", + InputKind.STRING: "str", + InputKind.DAY_OF_YEAR: "date (string, MM-DD)", + InputKind.DATE: "date (string, YYYY-MM-DD)", + InputKind.BOOL: "boolean", + InputKind.DICT: "dict", + InputKind.DATASET: "Dataset, optional", + InputKind.KWARGS: "", + InputKind.OTHER_PARAMETER: "Any", +} diff --git a/src/xsdba/units.py b/src/xsdba/units.py index 634d9a0..484182f 100644 --- a/src/xsdba/units.py +++ b/src/xsdba/units.py @@ -19,14 +19,75 @@ except ImportError: # noqa: S110 # cf-xarray is not installed, this will not be used pass +import warnings + import numpy as np import xarray as xr -from .base import Quantified, copy_all_attrs +from .calendar import parse_offset +from .typing import Quantified +from .utils import copy_all_attrs units = pint.get_application_registry() +FREQ_UNITS = { + "D": "d", + "W": "week", +} +""" +Resampling frequency units for :py:func:`xclim.core.units.infer_sampling_units`. + +Mapping from offset base to CF-compliant unit. Only constant-length frequencies are included. +""" + + +def infer_sampling_units( + da: xr.DataArray, + deffreq: str | None = "D", + dim: str = "time", +) -> tuple[int, str]: + """Infer a multiplier and the units corresponding to one sampling period. + + Parameters + ---------- + da : xr.DataArray + A DataArray from which to take coordinate `dim`. + deffreq : str, optional + If no frequency is inferred from `da[dim]`, take this one. + dim : str + Dimension from which to infer the frequency. + + Raises + ------ + ValueError + If the frequency has no exact corresponding units. + + Returns + ------- + int + The magnitude (number of base periods per period) + str + Units as a string, understandable by pint. + """ + dimmed = getattr(da, dim) + freq = xr.infer_freq(dimmed) + if freq is None: + freq = deffreq + + multi, base, _, _ = parse_offset(freq) + try: + out = multi, FREQ_UNITS.get(base, base) + except KeyError as err: + raise ValueError( + f"Sampling frequency {freq} has no corresponding units." + ) from err + if out == (7, "d"): + # Special case for weekly frequency. xarray's CFTimeOffsets do not have "W". + return 1, "week" + return out + + # XC def units2pint(value: xr.DataArray | str | units.Quantity) -> pint.Unit: """Return the pint Unit for the DataArray units. @@ -97,28 +158,73 @@ def str2pint(val: str) -> pint.Quantity: return units.Quantity(1, units2pint(val)) -# XC -# def ensure_delta(unit: str) -> str: -# """Return delta units for temperature. - -# For dimensions where delta exist in pint (Temperature), it replaces the temperature unit by delta_degC or -# delta_degF based on the input unit. For other dimensionality, it just gives back the input units. - -# Parameters -# ---------- -# unit : str -# unit to transform in delta (or not) -# """ -# u = units2pint(unit) -# d = 1 * u -# # -# delta_unit = pint2cfunits(d - d) -# # replace kelvin/rankine by delta_degC/F -# if "kelvin" in u._units: -# delta_unit = pint2cfunits(u / units2pint("K") * units2pint("delta_degC")) -# if "degree_Rankine" in u._units: -# delta_unit = pint2cfunits(u / units2pint("°R") * units2pint("delta_degF")) -# return delta_unit +def pint2str(value: units.Quantity | units.Unit) -> str: + """A unit string from a `pint` unit. + + Parameters + ---------- + value : pint.Unit + Input unit. + + Returns + ------- + str + Units + + Notes + ----- + If cf-xarray is installed, the units will be converted to cf units. + """ + if isinstance(value, (pint.Quantity, units.Quantity)): + value = value.units + + # Issue originally introduced in https://github.com/hgrecco/pint/issues/1486 + # Should be resolved in pint v0.24. See: https://github.com/hgrecco/pint/issues/1913 + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + return f"{value:cf}".replace("dimensionless", "") + + +DELTA_ABSOLUTE_TEMP = { + units.delta_degC: units.kelvin, + units.delta_degF: units.rankine, +} + + +def ensure_absolute_temperature(units: str): + """Convert temperature units to their absolute counterpart, assuming they represented a difference (delta). + + Celsius becomes Kelvin, Fahrenheit becomes Rankine. Does nothing for other units. + """ + a = str2pint(units) + # ensure a delta pint unit + a = a - 0 * a + if a.units in DELTA_ABSOLUTE_TEMP: + return pint2str(DELTA_ABSOLUTE_TEMP[a.units]) + return units + + +def ensure_delta(unit: str) -> str: + """Return delta units for temperature. + + For dimensions where delta exist in pint (Temperature), it replaces the temperature unit by delta_degC or + delta_degF based on the input unit. For other dimensionality, it just gives back the input units. + + Parameters + ---------- + unit : str + unit to transform in delta (or not) + """ + u = units2pint(unit) + d = 1 * u + # + delta_unit = pint2str(d - d) + # replace kelvin/rankine by delta_degC/F + if "kelvin" in u._units: + delta_unit = pint2str(u / units2pint("K") * units2pint("delta_degC")) + if "degree_Rankine" in u._units: + delta_unit = pint2str(u / units2pint("°R") * units2pint("delta_degF")) + return delta_unit def extract_units(arg): @@ -126,16 +232,15 @@ def extract_units(arg): if not ( isinstance(arg, (str, xr.DataArray, pint.Unit, units.Unit)) or np.isscalar(arg) ): - print(arg) raise TypeError( f"Argument must be a str, DataArray, or scalar. Got {type(arg)}" ) elif isinstance(arg, xr.DataArray): ustr = None if "units" not in arg.attrs else arg.attrs["units"] elif isinstance(arg, pint.Unit | units.Unit): - ustr = f"{arg:cf}" # XC: from pint2cfunits + ustr = pint2str(arg) # XC: from pint2str elif isinstance(arg, str): - ustr = str2pint(arg).units + ustr = pint2str(str2pint(arg).units) else: # (scalar case) ustr = None return ustr if ustr is None else pint.Quantity(1, ustr).units @@ -230,75 +335,6 @@ def convert_units_to( # noqa: C901 return out -def _fill_args_dict(args, kwargs, args_to_check, func): - """Combine args and kwargs into a dict.""" - args_dict = {} - signature = inspect.signature(func) - for ik, (k, v) in enumerate(signature.parameters.items()): - if ik < len(args): - value = args[ik] - if ik >= len(args): - value = v.default if k not in kwargs else kwargs[k] - args_dict[k] = value - return args_dict - - -def _split_args_kwargs(args, func): - """Assign Keyword only arguments to kwargs.""" - kwargs = {} - signature = inspect.signature(func) - indices_to_pop = [] - for ik, (k, v) in enumerate(signature.parameters.items()): - if v.kind == inspect.Parameter.KEYWORD_ONLY: - indices_to_pop.append(ik) - kwargs[k] = v - indices_to_pop.sort(reverse=True) - for ind in indices_to_pop: - args.pop(ind) - return args, kwargs - - -# TODO: make it work with Dataset for real -# TODO: add a switch to prevent string from being converted to float? -def harmonize_units(args_to_check): - """Check that units are compatible with dimensions, otherwise raise a `ValidationError`.""" - - # if no units are present (DataArray without units attribute or float), then no check is performed - # if units are present, then check is performed - # in mixed cases, an error is raised - def _decorator(func): - @wraps(func) - def _wrapper(*args, **kwargs): - arg_names = inspect.getfullargspec(func).args - args_dict = _fill_args_dict(list(args), kwargs, args_to_check, func) - first_arg_name = args_to_check[0] - first_arg = args_dict[first_arg_name] - for arg_name in args_to_check[1:]: - if isinstance(arg_name, str): - value = args_dict[arg_name] - key = arg_name - if isinstance( - arg_name, dict - ): # support for Dataset, or a dict of thresholds - key, val = list(arg_name.keys())[0], list(arg_name.values())[0] - value = args_dict[key][val] - if value is None: # optional argument, should be ignored - args_to_check.remove(arg_name) - continue - if key not in args_dict: - raise ValueError( - f"Argument '{arg_name}' not found in function arguments." - ) - args_dict[key] = convert_units_to(value, first_arg) - args = list(args_dict.values()) - args, kwargs = _split_args_kwargs(args, kwargs, func) - return func(*args, **kwargs) - - return _wrapper - - return _decorator - - def _add_default_kws(params_dict, params_to_check, func): """Combine args and kwargs into a dict.""" args_dict = {} @@ -361,3 +397,107 @@ def _wrapper(*args, **kwargs): return _wrapper return _decorator + + +def to_agg_units( + out: xr.DataArray, orig: xr.DataArray, op: str, dim: str = "time" +) -> xr.DataArray: + """Set and convert units of an array after an aggregation operation along the sampling dimension (time). + + Parameters + ---------- + out : xr.DataArray + The output array of the aggregation operation, no units operation done yet. + orig : xr.DataArray + The original array before the aggregation operation, + used to infer the sampling units and get the variable units. + op : {'min', 'max', 'mean', 'std', 'var', 'doymin', 'doymax', 'count', 'integral', 'sum'} + The type of aggregation operation performed. "integral" is mathematically equivalent to "sum", + but the units are multiplied by the timestep of the data (requires an inferrable frequency). + dim : str + The time dimension along which the aggregation was performed. + + Returns + ------- + xr.DataArray + + Examples + -------- + Take a daily array of temperature and count number of days above a threshold. + `to_agg_units` will infer the units from the sampling rate along "time", so + we ensure the final units are correct: + + >>> time = xr.cftime_range("2001-01-01", freq="D", periods=365) + >>> tas = xr.DataArray( + ... np.arange(365), + ... dims=("time",), + ... coords={"time": time}, + ... attrs={"units": "degC"}, + ... ) + >>> cond = tas > 100 # Which days are boiling + >>> Ndays = cond.sum("time") # Number of boiling days + >>> Ndays.attrs.get("units") + None + >>> Ndays = to_agg_units(Ndays, tas, op="count") + >>> Ndays.units + 'd' + + Similarly, here we compute the total heating degree-days, but we have weekly data: + + >>> time = xr.cftime_range("2001-01-01", freq="7D", periods=52) + >>> tas = xr.DataArray( + ... np.arange(52) + 10, + ... dims=("time",), + ... coords={"time": time}, + ... ) + >>> dt = (tas - 16).assign_attrs(units="delta_degC") + >>> degdays = dt.clip(0).sum("time") # Integral of temperature above a threshold + >>> degdays = to_agg_units(degdays, dt, op="integral") + >>> degdays.units + 'K week' + + Which we can always convert to the more common "K days": + + >>> degdays = convert_units_to(degdays, "K days") + >>> degdays.units + 'K d' + """ + if op in ["amin", "min", "amax", "max", "mean", "sum"]: + out.attrs["units"] = orig.attrs["units"] + + elif op in ["std"]: + out.attrs["units"] = ensure_absolute_temperature(orig.attrs["units"]) + + elif op in ["var"]: + out.attrs["units"] = pint2str( + str2pint(ensure_absolute_temperature(orig.units)) ** 2 + ) + + elif op in ["doymin", "doymax"]: + out.attrs.update( + units="", is_dayofyear=np.int32(1), calendar=get_calendar(orig) + ) + + elif op in ["count", "integral"]: + m, freq_u_raw = infer_sampling_units(orig[dim]) + orig_u = str2pint(ensure_absolute_temperature(orig.units)) + freq_u = str2pint(freq_u_raw) + out = out * m + + if op == "count": + out.attrs["units"] = freq_u_raw + elif op == "integral": + if "[time]" in orig_u.dimensionality: + # We need to simplify units after multiplication + out_units = (orig_u * freq_u).to_reduced_units() + out = out * out_units.magnitude + out.attrs["units"] = pint2str(out_units) + else: + out.attrs["units"] = pint2str(orig_u * freq_u) + else: + raise ValueError( + f"Unknown aggregation op {op}. " + "Known ops are [min, max, mean, std, var, doymin, doymax, count, integral, sum]." + ) + + return out diff --git a/src/xsdba/utils.py b/src/xsdba/utils.py index b816f80..34dcdfb 100644 --- a/src/xsdba/utils.py +++ b/src/xsdba/utils.py @@ -955,16 +955,6 @@ def rand_rot_matrix( ).astype("float32") -def copy_all_attrs(ds: xr.Dataset | xr.DataArray, ref: xr.Dataset | xr.DataArray): - """Copy all attributes of ds to ref, including attributes of shared coordinates, and variables in the case of Datasets.""" - ds.attrs.update(ref.attrs) - extras = ds.variables if isinstance(ds, xr.Dataset) else ds.coords - others = ref.variables if isinstance(ref, xr.Dataset) else ref.coords - for name, var in extras.items(): - if name in others: - var.attrs.update(ref[name].attrs) - - def _pairwise_spearman(da, dims): """Area-averaged pairwise temporal correlation. @@ -1016,3 +1006,47 @@ def _skipna_correlation(data): "allow_rechunk": True, }, ).rename("correlation") + + +# ADAPT: Maybe this is not the best place +def copy_all_attrs(ds: xr.Dataset | xr.DataArray, ref: xr.Dataset | xr.DataArray): + """Copy all attributes of ds to ref, including attributes of shared coordinates, and variables in the case of Datasets.""" + ds.attrs.update(ref.attrs) + extras = ds.variables if isinstance(ds, xr.Dataset) else ds.coords + others = ref.variables if isinstance(ref, xr.Dataset) else ref.coords + for name, var in extras.items(): + if name in others: + var.attrs.update(ref[name].attrs) + + +# ADAPT: Maybe this is not the best place +def load_module(path: os.PathLike, name: str | None = None): + """Load a python module from a python file, optionally changing its name. + + Examples + -------- + Given a path to a module file (.py): + + .. code-block:: python + + from pathlib import Path + import os + + path = Path("path/to/example.py") + + The two following imports are equivalent, the second uses this method. + + .. code-block:: python + + os.chdir(path.parent) + import example as mod1 + + os.chdir(previous_working_dir) + mod2 = load_module(path) + mod1 == mod2 + """ + path = Path(path) + spec = importlib.util.spec_from_file_location(name or path.stem, path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) # This executes code, effectively loading the module + return mod diff --git a/src/xsdba/xclim_submodules/generic.py b/src/xsdba/xclim_submodules/generic.py new file mode 100644 index 0000000..91697d3 --- /dev/null +++ b/src/xsdba/xclim_submodules/generic.py @@ -0,0 +1,941 @@ +""" +Generic Indices Submodule +========================= + +Helper functions for common generic actions done in the computation of indices. +""" + +from __future__ import annotations + +import warnings +from collections.abc import Sequence +from typing import Callable + +import cftime +import numpy as np +import xarray +import xarray as xr +from xarray.coding.cftime_offsets import _MONTH_ABBREVIATIONS + +from xsdba.calendar import doy_to_days_since, get_calendar, select_time +from xsdba.typing import DayOfYearStr, Quantified, Quantity +from xsdba.units import ( + convert_units_to, + harmonize_units, + pint2str, + str2pint, + to_agg_units, +) + +from . import run_length as rl + +__all__ = [ + "aggregate_between_dates", + "binary_ops", + "compare", + "count_level_crossings", + "count_occurrences", + "cumulative_difference", + "default_freq", + "detrend", + "diurnal_temperature_range", + "domain_count", + "doymax", + "doymin", + "extreme_temperature_range", + "first_day_threshold_reached", + "first_occurrence", + "get_daily_events", + "get_op", + "get_zones", + "interday_diurnal_temperature_range", + "last_occurrence", + "select_resample_op", + "spell_length", + "statistics", + "temperature_sum", + "threshold_count", + "thresholded_statistics", +] + +binary_ops = {">": "gt", "<": "lt", ">=": "ge", "<=": "le", "==": "eq", "!=": "ne"} + + +def select_resample_op( + da: xr.DataArray, op: str, freq: str = "YS", out_units=None, **indexer +) -> xr.DataArray: + """Apply operation over each period that is part of the index selection. + + Parameters + ---------- + da : xr.DataArray + Input data. + op : str {'min', 'max', 'mean', 'std', 'var', 'count', 'sum', 'integral', 'argmax', 'argmin'} or func + Reduce operation. Can either be a DataArray method or a function that can be applied to a DataArray. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + out_units : str, optional + Output units to assign. Only necessary if `op` is function not supported by :py:func:`xclim.core.units.to_agg_units`. + indexer : {dim: indexer, }, optional + Time attribute and values over which to subset the array. For example, use season='DJF' to select winter values, + month=1 to select January, or month=[6,7,8] to select summer months. If not indexer is given, all values are + considered. + + Returns + ------- + xr.DataArray + The maximum value for each period. + """ + da = select_time(da, **indexer) + r = da.resample(time=freq) + if op in _xclim_ops: + op = _xclim_ops[op] + if isinstance(op, str): + out = getattr(r, op.replace("integral", "sum"))(dim="time", keep_attrs=True) + else: + with xr.set_options(keep_attrs=True): + out = r.map(op) + op = op.__name__ + if out_units is not None: + return out.assign_attrs(units=out_units) + return to_agg_units(out, da, op) + + +def select_rolling_resample_op( + da: xr.DataArray, + op: str, + window: int, + window_center: bool = True, + window_op: str = "mean", + freq: str = "YS", + out_units=None, + **indexer, +) -> xr.DataArray: + """Apply operation over each period that is part of the index selection, using a rolling window before the operation. + + Parameters + ---------- + da : xr.DataArray + Input data. + op : str {'min', 'max', 'mean', 'std', 'var', 'count', 'sum', 'integral', 'argmax', 'argmin'} or func + Reduce operation. Can either be a DataArray method or a function that can be applied to a DataArray. + window : int + Size of the rolling window (centered). + window_center : bool + If True, the window is centered on the date. If False, the window is right-aligned. + window_op : str {'min', 'max', 'mean', 'std', 'var', 'count', 'sum', 'integral'} + Operation to apply to the rolling window. Default: 'mean'. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. Applied after the rolling window. + out_units : str, optional + Output units to assign. Only necessary if `op` is function not supported by :py:func:`xclim.core.units.to_agg_units`. + indexer : {dim: indexer, }, optional + Time attribute and values over which to subset the array. For example, use season='DJF' to select winter values, + month=1 to select January, or month=[6,7,8] to select summer months. If not indexer is given, all values are + considered. + + Returns + ------- + xr.DataArray + The array for which the operation has been applied over each period. + """ + rolled = getattr( + da.rolling(time=window, center=window_center), + window_op.replace("integral", "sum"), + )() + rolled = to_agg_units(rolled, da, window_op) + return select_resample_op(rolled, op=op, freq=freq, out_units=out_units, **indexer) + + +def doymax(da: xr.DataArray) -> xr.DataArray: + """Return the day of year of the maximum value.""" + i = da.argmax(dim="time") + out = da.time.dt.dayofyear.isel(time=i, drop=True) + return to_agg_units(out, da, "doymax") + + +def doymin(da: xr.DataArray) -> xr.DataArray: + """Return the day of year of the minimum value.""" + i = da.argmin(dim="time") + out = da.time.dt.dayofyear.isel(time=i, drop=True) + return to_agg_units(out, da, "doymin") + + +_xclim_ops = {"doymin": doymin, "doymax": doymax} + + +def default_freq(**indexer) -> str: + """Return the default frequency.""" + freq = "YS-JAN" + if indexer: + group, value = indexer.popitem() + if group == "season": + month = 12 # The "season" scheme is based on YS-DEC + elif group == "month": + month = np.take(value, 0) + elif group == "doy_bounds": + month = cftime.num2date(value[0] - 1, "days since 2004-01-01").month + elif group == "date_bounds": + month = int(value[0][:2]) + else: + raise ValueError(f"Unknown group `{group}`.") + freq = "YS-" + _MONTH_ABBREVIATIONS[month] + return freq + + +def get_op(op: str, constrain: Sequence[str] | None = None) -> Callable: + """Get python's comparing function according to its name of representation and validate allowed usage. + + Accepted op string are keys and values of xclim.indices.generic.binary_ops. + + Parameters + ---------- + op : str + Operator. + constrain : sequence of str, optional + A tuple of allowed operators. + """ + if op == "gteq": + warnings.warn(f"`{op}` is being renamed `ge` for compatibility.") + op = "ge" + if op == "lteq": + warnings.warn(f"`{op}` is being renamed `le` for compatibility.") + op = "le" + + if op in binary_ops.keys(): + binary_op = binary_ops[op] + elif op in binary_ops.values(): + binary_op = op + else: + raise ValueError(f"Operation `{op}` not recognized.") + + constraints = list() + if isinstance(constrain, (list, tuple, set)): + constraints.extend([binary_ops[c] for c in constrain]) + constraints.extend(constrain) + elif isinstance(constrain, str): + constraints.extend([binary_ops[constrain], constrain]) + + if constrain: + if op not in constraints: + raise ValueError(f"Operation `{op}` not permitted for indice.") + + return xr.core.ops.get_op(binary_op) + + +def compare( + left: xr.DataArray, + op: str, + right: float | int | np.ndarray | xr.DataArray, + constrain: Sequence[str] | None = None, +) -> xr.DataArray: + """Compare a dataArray to a threshold using given operator. + + Parameters + ---------- + left : xr.DataArray + A DatArray being evaluated against `right`. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. e.g. arr > thresh. + right : float, int, np.ndarray, or xr.DataArray + A value or array-like being evaluated against left`. + constrain : sequence of str, optional + Optionally allowed conditions. + + Returns + ------- + xr.DataArray + Boolean mask of the comparison. + """ + return get_op(op, constrain)(left, right) + + +def threshold_count( + da: xr.DataArray, + op: str, + threshold: float | int | xr.DataArray, + freq: str, + constrain: Sequence[str] | None = None, +) -> xr.DataArray: + """Count number of days where value is above or below threshold. + + Parameters + ---------- + da : xr.DataArray + Input data. + op : {">", "<", ">=", "<=", "gt", "lt", "ge", "le"} + Logical operator. e.g. arr > thresh. + threshold : Union[float, int] + Threshold value. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + constrain : sequence of str, optional + Optionally allowed conditions. + + Returns + ------- + xr.DataArray + The number of days meeting the constraints for each period. + """ + if constrain is None: + constrain = (">", "<", ">=", "<=") + + c = compare(da, op, threshold, constrain) * 1 + return c.resample(time=freq).sum(dim="time") + + +def domain_count( + da: xr.DataArray, + low: float | int | xr.DataArray, + high: float | int | xr.DataArray, + freq: str, +) -> xr.DataArray: + """Count number of days where value is within low and high thresholds. + + A value is counted if it is larger than `low`, and smaller or equal to `high`, i.e. in `]low, high]`. + + Parameters + ---------- + da : xr.DataArray + Input data. + low : scalar or DataArray + Minimum threshold value. + high : scalar or DataArray + Maximum threshold value. + freq : str + Resampling frequency defining the periods defined in :ref:`timeseries.resampling`. + + Returns + ------- + xr.DataArray + The number of days where value is within [low, high] for each period. + """ + c = compare(da, ">", low) * compare(da, "<=", high) * 1 + return c.resample(time=freq).sum(dim="time") + + +def get_daily_events( + da: xr.DataArray, + threshold: float | int | xr.DataArray, + op: str, + constrain: Sequence[str] | None = None, +) -> xr.DataArray: + """Return a 0/1 mask when a condition is True or False. + + Parameters + ---------- + da : xr.DataArray + Input data. + threshold : float + Threshold value. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. e.g. arr > thresh. + constrain : sequence of str, optional + Optionally allowed conditions. + + Notes + ----- + The function returns: + + - ``1`` where operator(da, da_value) is ``True`` + - ``0`` where operator(da, da_value) is ``False`` + - ``nan`` where da is ``nan`` + + Returns + ------- + xr.DataArray + """ + events = compare(da, op, threshold, constrain) * 1 + events = events.where(~(np.isnan(da))) + events = events.rename("events") + return events + + +# CF-INDEX-META Indices + + +@harmonize_units(["low_data", "high_data", "threshold"]) +def count_level_crossings( + low_data: xr.DataArray, + high_data: xr.DataArray, + threshold: Quantified, + freq: str, + *, + op_low: str = "<", + op_high: str = ">=", +) -> xr.DataArray: + """Calculate the number of times low_data is below threshold while high_data is above threshold. + + First, the threshold is transformed to the same standard_name and units as the input data, + then the thresholding is performed, and finally, the number of occurrences is counted. + + Parameters + ---------- + low_data : xr.DataArray + Variable that must be under the threshold. + high_data : xr.DataArray + Variable that must be above the threshold. + threshold : Quantified + Threshold. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + op_low : {"<", "<=", "lt", "le"} + Comparison operator for low_data. Default: "<". + op_high : {">", ">=", "gt", "ge"} + Comparison operator for high_data. Default: ">=". + + Returns + ------- + xr.DataArray + """ + # Convert units to low_data + lower = compare(low_data, op_low, threshold, constrain=("<", "<=")) + higher = compare(high_data, op_high, threshold, constrain=(">", ">=")) + + out = (lower & higher).resample(time=freq).sum() + return to_agg_units(out, low_data, "count", dim="time") + + +@harmonize_units(["data", "threshold"]) +def count_occurrences( + data: xr.DataArray, + threshold: Quantified, + freq: str, + op: str, + constrain: Sequence[str] | None = None, +) -> xr.DataArray: + """Calculate the number of times some condition is met. + + First, the threshold is transformed to the same standard_name and units as the input data. + Then the thresholding is performed as condition(data, threshold), + i.e. if condition is `<`, then this counts the number of times `data < threshold`. + Finally, count the number of occurrences when condition is met. + + Parameters + ---------- + data : xr.DataArray + An array. + threshold : Quantified + Threshold. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. e.g. arr > thresh. + constrain : sequence of str, optional + Optionally allowed conditions. + + Returns + ------- + xr.DataArray + """ + cond = compare(data, op, threshold, constrain) + + out = cond.resample(time=freq).sum() + return to_agg_units(out, data, "count", dim="time") + + +@harmonize_units(["data", "threshold"]) +def first_occurrence( + data: xr.DataArray, + threshold: Quantified, + freq: str, + op: str, + constrain: Sequence[str] | None = None, +) -> xr.DataArray: + """Calculate the first time some condition is met. + + First, the threshold is transformed to the same standard_name and units as the input data. + Then the thresholding is performed as condition(data, threshold), i.e. if condition is <, data < threshold. + Finally, locate the first occurrence when condition is met. + + Parameters + ---------- + data : xr.DataArray + Input data. + threshold : Quantified + Threshold. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. e.g. arr > thresh. + constrain : sequence of str, optional + Optionally allowed conditions. + + Returns + ------- + xr.DataArray + """ + cond = compare(data, op, threshold, constrain) + + out = cond.resample(time=freq).map( + rl.first_run, + window=1, + dim="time", + coord="dayofyear", + ) + out.attrs["units"] = "" + return out + + +@harmonize_units(["data", "threshold"]) +def last_occurrence( + data: xr.DataArray, + threshold: Quantified, + freq: str, + op: str, + constrain: Sequence[str] | None = None, +) -> xr.DataArray: + """Calculate the last time some condition is met. + + First, the threshold is transformed to the same standard_name and units as the input data. + Then the thresholding is performed as condition(data, threshold), i.e. if condition is <, data < threshold. + Finally, locate the last occurrence when condition is met. + + Parameters + ---------- + data : xr.DataArray + Input data. + threshold : Quantified + Threshold. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. e.g. arr > thresh. + constrain : sequence of str, optional + Optionally allowed conditions. + + Returns + ------- + xr.DataArray + """ + cond = compare(data, op, threshold, constrain) + + out = cond.resample(time=freq).map( + rl.last_run, + window=1, + dim="time", + coord="dayofyear", + ) + out.attrs["units"] = "" + return out + + +@harmonize_units(["data", "threshold"]) +def spell_length( + data: xr.DataArray, threshold: Quantified, reducer: str, freq: str, op: str +) -> xr.DataArray: + """Calculate statistics on lengths of spells. + + First, the threshold is transformed to the same standard_name and units as the input data. + Then the thresholding is performed as condition(data, threshold), i.e. if condition is <, data < threshold. + Then the spells are determined, and finally the statistics according to the specified reducer are calculated. + + Parameters + ---------- + data : xr.DataArray + Input data. + threshold : Quantified + Threshold. + reducer : {'max', 'min', 'mean', 'sum'} + Reducer. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. e.g. arr > thresh. + + Returns + ------- + xr.DataArray + """ + cond = compare(data, op, threshold) + + out = cond.resample(time=freq).map( + rl.rle_statistics, + reducer=reducer, + window=1, + dim="time", + ) + return to_agg_units(out, data, "count") + + +def statistics(data: xr.DataArray, reducer: str, freq: str) -> xr.DataArray: + """Calculate a simple statistic of the data. + + Parameters + ---------- + data : xr.DataArray + Input data. + reducer : {'max', 'min', 'mean', 'sum'} + Reducer. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + + Returns + ------- + xr.DataArray + """ + out = getattr(data.resample(time=freq), reducer)() + out.attrs["units"] = data.attrs["units"] + return out + + +@harmonize_units(["data", "threshold"]) +def thresholded_statistics( + data: xr.DataArray, + op: str, + threshold: Quantified, + reducer: str, + freq: str, + constrain: Sequence[str] | None = None, +) -> xr.DataArray: + """Calculate a simple statistic of the data for which some condition is met. + + First, the threshold is transformed to the same standard_name and units as the input data. + Then the thresholding is performed as condition(data, threshold), i.e. if condition is <, data < threshold. + Finally, the statistic is calculated for those data values that fulfill the condition. + + Parameters + ---------- + data : xr.DataArray + Input data. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. e.g. arr > thresh. + threshold : Quantified + Threshold. + reducer : {'max', 'min', 'mean', 'sum'} + Reducer. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + constrain : sequence of str, optional + Optionally allowed conditions. Default: None. + + Returns + ------- + xr.DataArray + """ + cond = compare(data, op, threshold, constrain) + + out = getattr(data.where(cond).resample(time=freq), reducer)() + out.attrs["units"] = data.attrs["units"] + return out + + +def aggregate_between_dates( + data: xr.DataArray, + start: xr.DataArray | DayOfYearStr, + end: xr.DataArray | DayOfYearStr, + op: str = "sum", + freq: str | None = None, +) -> xr.DataArray: + """Aggregate the data over a period between start and end dates and apply the operator on the aggregated data. + + Parameters + ---------- + data : xr.DataArray + Data to aggregate between start and end dates. + start : xr.DataArray or DayOfYearStr + Start dates (as day-of-year) for the aggregation periods. + end : xr.DataArray or DayOfYearStr + End (as day-of-year) dates for the aggregation periods. + op : {'min', 'max', 'sum', 'mean', 'std'} + Operator. + freq : str, optional + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. Default: `None`. + + Returns + ------- + xr.DataArray, [dimensionless] + Aggregated data between the start and end dates. If the end date is before the start date, returns np.nan. + If there is no start and/or end date, returns np.nan. + """ + + def _get_days(_bound, _group, _base_time): + """Get bound in number of days since base_time. Bound can be a days_since array or a DayOfYearStr.""" + if isinstance(_bound, str): + b_i = rl.index_of_date(_group.time, _bound, max_idxs=1) + if not b_i.size > 0: + return None + return (_group.time.isel(time=b_i[0]) - _group.time.isel(time=0)).dt.days + if _base_time in _bound.time: + return _bound.sel(time=_base_time) + return None + + if freq is None: + frequencies = [] + for bound in [start, end]: + try: + frequencies.append(xr.infer_freq(bound.time)) + except AttributeError: + frequencies.append(None) + + good_freq = set(frequencies) - {None} + + if len(good_freq) != 1: + raise ValueError( + f"Non-inferrable resampling frequency or inconsistent frequencies. Got start, end = {frequencies}." + " Please consider providing `freq` manually." + ) + freq = good_freq.pop() + + cal = get_calendar(data, dim="time") + + if not isinstance(start, str): + start = start.convert_calendar(cal) + start.attrs["calendar"] = cal + start = doy_to_days_since(start) + if not isinstance(end, str): + end = end.convert_calendar(cal) + end.attrs["calendar"] = cal + end = doy_to_days_since(end) + + out = [] + for base_time, indexes in data.resample(time=freq).groups.items(): + # get group slice + group = data.isel(time=indexes) + + start_d = _get_days(start, group, base_time) + end_d = _get_days(end, group, base_time) + + # convert bounds for this group + if start_d is not None and end_d is not None: + days = (group.time - base_time).dt.days + days[days < 0] = np.nan + + masked = group.where((days >= start_d) & (days <= end_d - 1)) + res = getattr(masked, op)(dim="time", skipna=True) + res = xr.where( + ((start_d > end_d) | (start_d.isnull()) | (end_d.isnull())), np.nan, res + ) + # Re-add the time dimension with the period's base time. + res = res.expand_dims(time=[base_time]) + out.append(res) + else: + # Get an array with the good shape, put nans and add the new time. + res = (group.isel(time=0) * np.nan).expand_dims(time=[base_time]) + out.append(res) + continue + + return xr.concat(out, dim="time") + + +@harmonize_units(["data", "threshold"]) +def cumulative_difference( + data: xr.DataArray, threshold: Quantified, op: str, freq: str | None = None +) -> xr.DataArray: + """Calculate the cumulative difference below/above a given value threshold. + + Parameters + ---------- + data : xr.DataArray + Data for which to determine the cumulative difference. + threshold : Quantified + The value threshold. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le"} + Logical operator. e.g. arr > thresh. + freq : str, optional + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + If `None`, no resampling is performed. Default: `None`. + + Returns + ------- + xr.DataArray + """ + if op in ["<", "<=", "lt", "le"]: + diff = (threshold - data).clip(0) + elif op in [">", ">=", "gt", "ge"]: + diff = (data - threshold).clip(0) + else: + raise NotImplementedError(f"Condition not supported: '{op}'.") + + if freq is not None: + diff = diff.resample(time=freq).sum(dim="time") + + return to_agg_units(diff, data, op="integral") + + +@harmonize_units(["data", "threshold"]) +def first_day_threshold_reached( + data: xr.DataArray, + *, + threshold: Quantified, + op: str, + after_date: DayOfYearStr, + window: int = 1, + freq: str = "YS", + constrain: Sequence[str] | None = None, +) -> xr.DataArray: + r"""First day of values exceeding threshold. + + Returns first day of period where values reach or exceed a threshold over a given number of days, + limited to a starting calendar date. + + Parameters + ---------- + data : xarray.DataArray + Dataset being evaluated. + threshold : str + Threshold on which to base evaluation. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. e.g. arr > thresh. + after_date : str + Date of the year after which to look for the first event. Should have the format '%m-%d'. + window : int + Minimum number of days with values above threshold needed for evaluation. Default: 1. + freq : str + Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. + Default: "YS". + constrain : sequence of str, optional + Optionally allowed conditions. + + Returns + ------- + xarray.DataArray, [dimensionless] + Day of the year when value reaches or exceeds a threshold over a given number of days for the first time. + If there is no such day, returns np.nan. + """ + cond = compare(data, op, threshold, constrain=constrain) + + out: xarray.DataArray = cond.resample(time=freq).map( + rl.first_run_after_date, + window=window, + date=after_date, + dim="time", + coord="dayofyear", + ) + out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(data)) + return out + + +def _get_zone_bins( + zone_min: Quantity, + zone_max: Quantity, + zone_step: Quantity, +): + """Bin boundary values as defined by zone parameters. + + Parameters + ---------- + zone_min : Quantity + Left boundary of the first zone + zone_max : Quantity + Right boundary of the last zone + zone_step: Quantity + Size of zones + + Returns + ------- + xarray.DataArray, [units of `zone_step`] + Array of values corresponding to each zone: [zone_min, zone_min+step, ..., zone_max] + """ + units = pint2str(str2pint(zone_step)) + mn, mx, step = ( + convert_units_to(str2pint(z), units) for z in [zone_min, zone_max, zone_step] + ) + bins = np.arange(mn, mx + step, step) + if (mx - mn) % step != 0: + warnings.warn( + "`zone_max` - `zone_min` is not an integer multiple of `zone_step`. Last zone will be smaller." + ) + bins[-1] = mx + return xr.DataArray(bins, attrs={"units": units}) + + +def get_zones( + da: xr.DataArray, + zone_min: Quantity | None = None, + zone_max: Quantity | None = None, + zone_step: Quantity | None = None, + bins: xr.DataArray | list[Quantity] | None = None, + exclude_boundary_zones: bool = True, + close_last_zone_right_boundary: bool = True, +) -> xr.DataArray: + r"""Divide data into zones and attribute a zone coordinate to each input value. + + Divide values into zones corresponding to bins of width zone_step beginning at zone_min and ending at zone_max. + Bins are inclusive on the left values and exclusive on the right values. + + Parameters + ---------- + da : xarray.DataArray + Input data + zone_min : Quantity | None + Left boundary of the first zone + zone_max : Quantity | None + Right boundary of the last zone + zone_step: Quantity | None + Size of zones + bins : xr.DataArray | list[Quantity] | None + Zones to be used, either as a DataArray with appropriate units or a list of Quantity + exclude_boundary_zones : Bool + Determines whether a zone value is attributed for values in ]`-np.inf`, `zone_min`[ and [`zone_max`, `np.inf`\ [. + close_last_zone_right_boundary : Bool + Determines if the right boundary of the last zone is closed. + + Returns + ------- + xarray.DataArray, [dimensionless] + Zone index for each value in `da`. Zones are returned as an integer range, starting from `0` + """ + # Check compatibility of arguments + zone_params = np.array([zone_min, zone_max, zone_step]) + if bins is None: + if (zone_params == [None] * len(zone_params)).any(): + raise ValueError( + "`bins` is `None` as well as some or all of [`zone_min`, `zone_max`, `zone_step`]. " + "Expected defined parameters in one of these cases." + ) + elif set(zone_params) != {None}: + warnings.warn( + "Expected either `bins` or [`zone_min`, `zone_max`, `zone_step`], got both. " + "`bins` will be used." + ) + + # Get zone bins (if necessary) + bins = bins if bins is not None else _get_zone_bins(zone_min, zone_max, zone_step) + if isinstance(bins, list): + bins = sorted([convert_units_to(b, da) for b in bins]) + else: + bins = convert_units_to(bins, da) + + def _get_zone(_da): + return np.digitize(_da, bins) - 1 + + zones = xr.apply_ufunc(_get_zone, da, dask="parallelized") + + if close_last_zone_right_boundary: + zones = zones.where(da != bins[-1], _get_zone(bins[-2])) + if exclude_boundary_zones: + zones = zones.where( + (zones != _get_zone(bins[0] - 1)) & (zones != _get_zone(bins[-1])) + ) + + return zones + + +def detrend( + ds: xr.DataArray | xr.Dataset, dim="time", deg=1 +) -> xr.DataArray | xr.Dataset: + """Detrend data along a given dimension computing a polynomial trend of a given order. + + Parameters + ---------- + ds : xr.Dataset or xr.DataArray + The data to detrend. If a Dataset, detrending is done on all data variables. + dim : str + Dimension along which to compute the trend. + deg : int + Degree of the polynomial to fit. + + Returns + ------- + xr.Dataset or xr.DataArray + Same as `ds`, but with its trend removed (subtracted). + """ + if isinstance(ds, xr.Dataset): + return ds.map(detrend, keep_attrs=False, dim=dim, deg=deg) + # is a DataArray + # detrend along a single dimension + coeff = ds.polyfit(dim=dim, deg=deg) + trend = xr.polyval(ds[dim], coeff.polyfit_coefficients) + with xr.set_options(keep_attrs=True): + return ds - trend diff --git a/src/xsdba/xclim_submodules/run_length.py b/src/xsdba/xclim_submodules/run_length.py new file mode 100644 index 0000000..e84704c --- /dev/null +++ b/src/xsdba/xclim_submodules/run_length.py @@ -0,0 +1,1538 @@ +""" +Run-Length Algorithms Submodule +=============================== + +Computation of statistics on runs of True values in boolean arrays. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from datetime import datetime +from warnings import warn + +import numpy as np +import xarray as xr +from numba import njit +from xarray.core.utils import get_temp_dimname + +from xsdba.base import uses_dask +from xsdba.options import OPTIONS, RUN_LENGTH_UFUNC +from xsdba.typing import DateStr, DayOfYearStr + +npts_opt = 9000 +""" +Arrays with less than this number of data points per slice will trigger +the use of the ufunc version of run lengths algorithms. +""" +# XC: all copied from xc + + +def use_ufunc( + ufunc_1dim: bool | str, + da: xr.DataArray, + dim: str = "time", + freq: str | None = None, + index: str = "first", +) -> bool: + """Return whether the ufunc version of run length algorithms should be used with this DataArray or not. + + If ufunc_1dim is 'from_context', the parameter is read from xclim's global (or context) options. + If it is 'auto', this returns False for dask-backed array and for arrays with more than :py:const:`npts_opt` + points per slice along `dim`. + + Parameters + ---------- + ufunc_1dim : {'from_context', 'auto', True, False} + The method for handling the ufunc parameters. + da : xr.DataArray + Input array. + dim : str + The dimension along which to find runs. + freq : str + Resampling frequency. + index : {'first', 'last'} + If 'first' (default), the run length is indexed with the first element in the run. + If 'last', with the last element in the run. + + Returns + ------- + bool + If ufunc_1dim is "auto", returns True if the array is on dask or too large. + Otherwise, returns ufunc_1dim. + """ + if ufunc_1dim is True and freq is not None: + raise ValueError( + "Resampling after run length operations is not implemented for 1d method" + ) + + if ufunc_1dim == "from_context": + ufunc_1dim = OPTIONS[RUN_LENGTH_UFUNC] + + if ufunc_1dim == "auto": + ufunc_1dim = not uses_dask(da) and (da.size // da[dim].size) < npts_opt + # If resampling after run length is set up for the computation, the 1d method is not implemented + # Unless ufunc_1dim is specifically set to False (in which case we flag an error above), + # we simply forbid this possibility. + return (index == "first") and (ufunc_1dim) and (freq is None) + + +def resample_and_rl( + da: xr.DataArray, + resample_before_rl: bool, + compute, + *args, + freq: str, + dim: str = "time", + **kwargs, +) -> xr.DataArray: + """Wrap run length algorithms to control if resampling occurs before or after the algorithms. + + Parameters + ---------- + da: xr.DataArray + N-dimensional array (boolean). + resample_before_rl : bool + Determines whether if input arrays of runs `da` should be separated in period before + or after the run length algorithms are applied. + compute + Run length function to apply + args + Positional arguments needed in `compute`. + dim: str + The dimension along which to find runs. + freq : str + Resampling frequency. + kwargs + Keyword arguments needed in `compute`. + + Returns + ------- + xr.DataArray + Output of compute resampled according to frequency {freq}. + """ + if resample_before_rl: + out = da.resample({dim: freq}).map( + compute, args=args, freq=None, dim=dim, **kwargs + ) + else: + out = compute(da, *args, dim=dim, freq=freq, **kwargs) + return out + + +def _cumsum_reset_on_zero( + da: xr.DataArray, + dim: str = "time", + index: str = "last", +) -> xr.DataArray: + """Compute the cumulative sum for each series of numbers separated by zero. + + Parameters + ---------- + da : xr.DataArray + Input array. + dim : str + Dimension name along which the cumulative sum is taken. + index : {'first', 'last'} + If 'first', the largest value of the cumulative sum is indexed with the first element in the run. + If 'last'(default), with the last element in the run. + + Returns + ------- + xr.DataArray + An array with cumulative sums. + """ + if index == "first": + da = da[{dim: slice(None, None, -1)}] + + # Example: da == 100110111 -> cs_s == 100120123 + cs = da.cumsum(dim=dim) # cumulative sum e.g. 111233456 + cs2 = cs.where(da == 0) # keep only numbers at positions of zeroes e.g. N11NN3NNN + cs2[{dim: 0}] = 0 # put a zero in front e.g. 011NN3NNN + cs2 = cs2.ffill(dim=dim) # e.g. 011113333 + out = cs - cs2 + + if index == "first": + out = out[{dim: slice(None, None, -1)}] + + return out + + +# TODO: Check if rle would be more performant with ffill/bfill instead of two times [{dim: slice(None, None, -1)}] +def rle( + da: xr.DataArray, + dim: str = "time", + index: str = "first", +) -> xr.DataArray: + """Generate basic run length function. + + Parameters + ---------- + da : xr.DataArray + Input array. + dim : str + Dimension name. + index : {'first', 'last'} + If 'first' (default), the run length is indexed with the first element in the run. + If 'last', with the last element in the run. + + Returns + ------- + xr.DataArray + Values are 0 where da is False (out of runs). + """ + da = da.astype(int) + + # "first" case: Algorithm is applied on inverted array and output is inverted back + if index == "first": + da = da[{dim: slice(None, None, -1)}] + + # Get cumulative sum for each series of 1, e.g. da == 100110111 -> cs_s == 100120123 + cs_s = _cumsum_reset_on_zero(da, dim) + + # Keep total length of each series (and also keep 0's), e.g. 100120123 -> 100N20NN3 + # Keep numbers with a 0 to the right and also the last number + cs_s = cs_s.where(da.shift({dim: -1}, fill_value=0) == 0) + out = cs_s.where(da == 1, 0) # Reinsert 0's at their original place + + # Inverting back if needed e.g. 100N20NN3 -> 3NN02N001. This is the output of + # `rle` for 111011001 with index == "first" + if index == "first": + out = out[{dim: slice(None, None, -1)}] + + return out + + +def rle_statistics( + da: xr.DataArray, + reducer: str, + window: int, + dim: str = "time", + freq: str | None = None, + ufunc_1dim: str | bool = "from_context", + index: str = "first", +) -> xr.DataArray: + """Return the length of consecutive run of True values, according to a reducing operator. + + Parameters + ---------- + da : xr.DataArray + N-dimensional array (boolean). + reducer : str + Name of the reducing function. + window : int + Minimal length of consecutive runs to be included in the statistics. + dim : str + Dimension along which to calculate consecutive run; Default: 'time'. + freq : str + Resampling frequency. + ufunc_1dim : Union[str, bool] + Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal + usage based on number of data points. Using 1D_ufunc=True is typically more efficient + for DataArray with a small number of grid points. + It can be modified globally through the "run_length_ufunc" global option. + index : {'first', 'last'} + If 'first' (default), the run length is indexed with the first element in the run. + If 'last', with the last element in the run. + + Returns + ------- + xr.DataArray, [int] + Length of runs of True values along dimension, according to the reducing function (float) + If there are no runs (but the data is valid), returns 0. + """ + ufunc_1dim = use_ufunc(ufunc_1dim, da, dim=dim, index=index, freq=freq) + if ufunc_1dim: + rl_stat = statistics_run_ufunc(da, reducer, window, dim) + else: + d = rle(da, dim=dim, index=index) + + def get_rl_stat(d): + rl_stat = getattr(d.where(d >= window), reducer)(dim=dim) + rl_stat = xr.where((d.isnull() | (d < window)).all(dim=dim), 0, rl_stat) + return rl_stat + + if freq is None: + rl_stat = get_rl_stat(d) + else: + rl_stat = d.resample({dim: freq}).map(get_rl_stat) + + return rl_stat + + +def longest_run( + da: xr.DataArray, + dim: str = "time", + freq: str | None = None, + ufunc_1dim: str | bool = "from_context", + index: str = "first", +) -> xr.DataArray: + """Return the length of the longest consecutive run of True values. + + Parameters + ---------- + da : xr.DataArray + N-dimensional array (boolean). + dim : str + Dimension along which to calculate consecutive run; Default: 'time'. + freq : str + Resampling frequency. + ufunc_1dim : Union[str, bool] + Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal + usage based on number of data points. Using 1D_ufunc=True is typically more efficient + for DataArray with a small number of grid points. + It can be modified globally through the "run_length_ufunc" global option. + index : {'first', 'last'} + If 'first', the run length is indexed with the first element in the run. + If 'last', with the last element in the run. + + Returns + ------- + xr.DataArray, [int] + Length of the longest run of True values along dimension (int). + """ + return rle_statistics( + da, + reducer="max", + window=1, + dim=dim, + freq=freq, + ufunc_1dim=ufunc_1dim, + index=index, + ) + + +def windowed_run_events( + da: xr.DataArray, + window: int, + dim: str = "time", + freq: str | None = None, + ufunc_1dim: str | bool = "from_context", + index: str = "first", +) -> xr.DataArray: + """Return the number of runs of a minimum length. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum run length. + When equal to 1, an optimized version of the algorithm is used. + dim : str + Dimension along which to calculate consecutive run (default: 'time'). + freq : str + Resampling frequency. + ufunc_1dim : Union[str, bool] + Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal + usage based on number of data points. Using 1D_ufunc=True is typically more efficient + for DataArray with a small number of grid points. + Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option. + index : {'first', 'last'} + If 'first', the run length is indexed with the first element in the run. + If 'last', with the last element in the run. + + Returns + ------- + xr.DataArray, [int] + Number of distinct runs of a minimum length (int). + """ + ufunc_1dim = use_ufunc(ufunc_1dim, da, dim=dim, index=index, freq=freq) + + if ufunc_1dim: + out = windowed_run_events_ufunc(da, window, dim) + + else: + if window == 1: + shift = 1 * (index == "first") + -1 * (index == "last") + d = xr.where(da.shift({dim: shift}, fill_value=0) == 0, 1, 0) + d = d.where(da == 1, 0) + else: + d = rle(da, dim=dim, index=index) + d = xr.where(d >= window, 1, 0) + if freq is not None: + d = d.resample({dim: freq}) + out = d.sum(dim=dim) + + return out + + +def windowed_run_count( + da: xr.DataArray, + window: int, + dim: str = "time", + freq: str | None = None, + ufunc_1dim: str | bool = "from_context", + index: str = "first", +) -> xr.DataArray: + """Return the number of consecutive true values in array for runs at least as long as given duration. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum run length. + When equal to 1, an optimized version of the algorithm is used. + dim : str + Dimension along which to calculate consecutive run (default: 'time'). + freq : str + Resampling frequency. + ufunc_1dim : Union[str, bool] + Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal + usage based on number of data points. Using 1D_ufunc=True is typically more efficient + for DataArray with a small number of grid points. + Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option. + index : {'first', 'last'} + If 'first', the run length is indexed with the first element in the run. + If 'last', with the last element in the run. + + Returns + ------- + xr.DataArray, [int] + Total number of `True` values part of a consecutive runs of at least `window` long. + """ + ufunc_1dim = use_ufunc(ufunc_1dim, da, dim=dim, index=index, freq=freq) + + if ufunc_1dim: + out = windowed_run_count_ufunc(da, window, dim) + + elif window == 1 and freq is None: + out = da.sum(dim=dim) + + else: + d = rle(da, dim=dim, index=index) + d = d.where(d >= window, 0) + if freq is not None: + d = d.resample({dim: freq}) + out = d.sum(dim=dim) + + return out + + +def _boundary_run( + da: xr.DataArray, + window: int, + dim: str, + freq: str | None, + coord: str | bool | None, + ufunc_1dim: str | bool, + position: str, +) -> xr.DataArray: + """Return the index of the first item of the first or last run of at least a given length. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum duration of consecutive run to accumulate values. + When equal to 1, an optimized version of the algorithm is used. + dim : str + Dimension along which to calculate consecutive run. + freq : str + Resampling frequency. + coord : Optional[str] + If not False, the function returns values along `dim` instead of indexes. + If `dim` has a datetime dtype, `coord` can also be a str of the name of the + DateTimeAccessor object to use (ex: 'dayofyear'). + ufunc_1dim : Union[str, bool] + Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal + usage based on number of data points. Using 1D_ufunc=True is typically more efficient + for DataArray with a small number of grid points. + Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option. + position : {"first", "last"} + Determines if the algorithm finds the "first" or "last" run + + Returns + ------- + xr.DataArray + Index (or coordinate if `coord` is not False) of first item in first (last) valid run. + Returns np.nan if there are no valid runs. + """ + + def coord_transform(out, da): + """Transforms indexes to coordinates if needed, and drops obsolete dim.""" + if coord: + crd = da[dim] + if isinstance(coord, str): + crd = getattr(crd.dt, coord) + + out = lazy_indexing(crd, out) + + if dim in out.coords: + out = out.drop_vars(dim) + return out + + # general method to get indices (or coords) of first run + def find_boundary_run(runs, position): + if position == "last": + runs = runs[{dim: slice(None, None, -1)}] + dmax_ind = runs.argmax(dim=dim) + # If there are no runs, dmax_ind will be 0: We must replace this with NaN + out = dmax_ind.where(dmax_ind != runs.argmin(dim=dim)) + if position == "last": + out = runs[dim].size - out - 1 + runs = runs[{dim: slice(None, None, -1)}] + out = coord_transform(out, runs) + return out + + ufunc_1dim = use_ufunc(ufunc_1dim, da, dim=dim, freq=freq) + + da = da.fillna(0) # We expect a boolean array, but there could be NaNs nonetheless + if window == 1: + if freq is not None: + out = da.resample({dim: freq}).map(find_boundary_run, position=position) + else: + out = find_boundary_run(da, position) + + elif ufunc_1dim: + if position == "last": + da = da[{dim: slice(None, None, -1)}] + out = first_run_ufunc(x=da, window=window, dim=dim) + if position == "last" and not coord: + out = da[dim].size - out - 1 + da = da[{dim: slice(None, None, -1)}] + out = coord_transform(out, da) + + else: + # _cusum_reset_on_zero() is an intermediate step in rle, which is sufficient here + d = _cumsum_reset_on_zero(da, dim=dim, index=position) + d = xr.where(d >= window, 1, 0) + # for "first" run, return "first" element in the run (and conversely for "last" run) + if freq is not None: + out = d.resample({dim: freq}).map(find_boundary_run, position=position) + else: + out = find_boundary_run(d, position) + + return out + + +def first_run( + da: xr.DataArray, + window: int, + dim: str = "time", + freq: str | None = None, + coord: str | bool | None = False, + ufunc_1dim: str | bool = "from_context", +) -> xr.DataArray: + """Return the index of the first item of the first run of at least a given length. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum duration of consecutive run to accumulate values. + When equal to 1, an optimized version of the algorithm is used. + dim : str + Dimension along which to calculate consecutive run (default: 'time'). + freq : str + Resampling frequency. + coord : Optional[str] + If not False, the function returns values along `dim` instead of indexes. + If `dim` has a datetime dtype, `coord` can also be a str of the name of the + DateTimeAccessor object to use (ex: 'dayofyear'). + ufunc_1dim : Union[str, bool] + Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal + usage based on number of data points. Using 1D_ufunc=True is typically more efficient + for DataArray with a small number of grid points. + Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option. + + Returns + ------- + xr.DataArray + Index (or coordinate if `coord` is not False) of first item in first valid run. + Returns np.nan if there are no valid runs. + """ + out = _boundary_run( + da, + window=window, + dim=dim, + freq=freq, + coord=coord, + ufunc_1dim=ufunc_1dim, + position="first", + ) + return out + + +def last_run( + da: xr.DataArray, + window: int, + dim: str = "time", + freq: str | None = None, + coord: str | bool | None = False, + ufunc_1dim: str | bool = "from_context", +) -> xr.DataArray: + """Return the index of the last item of the last run of at least a given length. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum duration of consecutive run to accumulate values. + When equal to 1, an optimized version of the algorithm is used. + dim : str + Dimension along which to calculate consecutive run (default: 'time'). + freq : str + Resampling frequency. + coord : Optional[str] + If not False, the function returns values along `dim` instead of indexes. + If `dim` has a datetime dtype, `coord` can also be a str of the name of the + DateTimeAccessor object to use (ex: 'dayofyear'). + ufunc_1dim : Union[str, bool] + Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal + usage based on number of data points. Using `1D_ufunc=True` is typically more efficient + for a DataArray with a small number of grid points. + Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option. + + Returns + ------- + xr.DataArray + Index (or coordinate if `coord` is not False) of last item in last valid run. + Returns np.nan if there are no valid runs. + """ + out = _boundary_run( + da, + window=window, + dim=dim, + freq=freq, + coord=coord, + ufunc_1dim=ufunc_1dim, + position="last", + ) + return out + + +# TODO: Add window arg +# TODO: Inverse window arg to tolerate holes? +def run_bounds(mask: xr.DataArray, dim: str = "time", coord: bool | str = True): + """Return the start and end dates of boolean runs along a dimension. + + Parameters + ---------- + mask : xr.DataArray + Boolean array. + dim : str + Dimension along which to look for runs. + coord : bool or str + If `True`, return values of the coordinate, if a string, returns values from `dim.dt.`. + If `False`, return indexes. + + Returns + ------- + xr.DataArray + With ``dim`` reduced to "events" and "bounds". The events dim is as long as needed, padded with NaN or NaT. + """ + if uses_dask(mask): + raise NotImplementedError( + "Dask arrays not supported as we can't know the final event number before computing." + ) + + diff = xr.concat( + (mask.isel({dim: [0]}).astype(int), mask.astype(int).diff(dim)), dim + ) + + nstarts = (diff == 1).sum(dim).max().item() + + def _get_indices(arr, *, N): + out = np.full((N,), np.nan, dtype=float) + inds = np.where(arr)[0] + out[: len(inds)] = inds + return out + + starts = xr.apply_ufunc( + _get_indices, + diff == 1, + input_core_dims=[[dim]], + output_core_dims=[["events"]], + kwargs={"N": nstarts}, + vectorize=True, + ) + + ends = xr.apply_ufunc( + _get_indices, + diff == -1, + input_core_dims=[[dim]], + output_core_dims=[["events"]], + kwargs={"N": nstarts}, + vectorize=True, + ) + + if coord: + crd = mask[dim] + if isinstance(coord, str): + crd = getattr(crd.dt, coord) + + starts = lazy_indexing(crd, starts) + ends = lazy_indexing(crd, ends) + return xr.concat((starts, ends), "bounds") + + +def keep_longest_run( + da: xr.DataArray, dim: str = "time", freq: str | None = None +) -> xr.DataArray: + """Keep the longest run along a dimension. + + Parameters + ---------- + da : xr.DataArray + Boolean array. + dim : str + Dimension along which to check for the longest run. + freq : str + Resampling frequency. + + Returns + ------- + xr.DataArray, [bool] + Boolean array similar to da but with only one run, the (first) longest. + """ + # Get run lengths + rls = rle(da, dim) + + def get_out(rls): + out = xr.where( + # Construct an integer array and find the max + rls[dim].copy(data=np.arange(rls[dim].size)) == rls.argmax(dim), + rls + 1, # Add one to the First longest run + rls, + ) + out = out.ffill(dim) == out.max(dim) + return out + + if freq is not None: + out = rls.resample({dim: freq}).map(get_out) + else: + out = get_out(rls) + + return da.copy(data=out.transpose(*da.dims).data) + + +def extract_events( + da_start: xr.DataArray, + window_start: int, + da_stop: xr.DataArray, + window_stop: int, + dim: str = "time", +) -> xr.DataArray: + """Extract events, i.e. runs whose starting and stopping points are defined through run length conditions. + + Parameters + ---------- + da_start : xr.DataArray + Input array where run sequences are searched to define the start points in the main runs + window_start: int, + Number of True (1) values needed to start a run in `da_start` + da_stop : xr.DataArray + Input array where run sequences are searched to define the stop points in the main runs + window_stop: int, + Number of True (1) values needed to start a run in `da_stop` + dim : str + Dimension name. + + Returns + ------- + xr.DataArray + Output array with 1's when in a run sequence and with 0's elsewhere. + + Notes + ----- + A season (as defined in ``season``) could be considered as an event with `window_stop == window_start` and `da_stop == 1 - da_start`, + although it has more constraints on when to start and stop a run through the `date` argument. + """ + da_start = da_start.astype(int).fillna(0) + da_stop = da_stop.astype(int).fillna(0) + + start_runs = _cumsum_reset_on_zero(da_start, dim=dim, index="first") + stop_runs = _cumsum_reset_on_zero(da_stop, dim=dim, index="first") + start_positions = xr.where(start_runs >= window_start, 1, np.NaN) + stop_positions = xr.where(stop_runs >= window_stop, 0, np.NaN) + + # start positions (1) are f-filled until a stop position (0) is met + runs = stop_positions.combine_first(start_positions).ffill(dim=dim).fillna(0) + + return runs + + +def season( + da: xr.DataArray, + window: int, + date: DayOfYearStr | None = None, + dim: str = "time", + coord: str | bool | None = False, +) -> xr.Dataset: + """Calculate the bounds of a season along a dimension. + + A "season" is a run of True values that may include breaks under a given length (`window`). + The start is computed as the first run of `window` True values, then end as the first subsequent run + of `window` False values. If a date is passed, it must be included in the season. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum duration of consecutive values to start and end the season. + date : DayOfYearStr, optional + The date (in MM-DD format) that a run must include to be considered valid. + dim : str + Dimension along which to calculate consecutive run (default: 'time'). + coord : Optional[str] + If not False, the function returns values along `dim` instead of indexes. + If `dim` has a datetime dtype, `coord` can also be a str of the name of the + DateTimeAccessor object to use (ex: 'dayofyear'). + + Returns + ------- + xr.Dataset + "dim" is reduced to "season_bnds" with 2 elements : season start and season end, both indices of da[dim]. + + Notes + ----- + The run can include holes of False or NaN values, so long as they do not exceed the window size. + + If a date is given, the season start and end are forced to be on each side of this date. This means that + even if the "real" season has been over for a long time, this is the date used in the length calculation. + Example : Length of the "warm season", where T > 25°C, with date = 1st August. Let's say the temperature is over + 25 for all June, but July and august have very cold temperatures. Instead of returning 30 days (June), the function + will return 61 days (July + June). + """ + beg = first_run(da, window=window, dim=dim) + # Invert the condition and mask all values after beginning + # we fillna(0) as so to differentiate series with no runs and all-nan series + not_da = (~da).where(da[dim].copy(data=np.arange(da[dim].size)) >= beg.fillna(0)) + + # Mask also values after "date" + mid_idx = index_of_date(da[dim], date, max_idxs=1, default=0) + if mid_idx.size == 0: + # The date is not within the group. Happens at boundaries. + base = da.isel({dim: 0}) # To have the proper shape + beg = xr.full_like(base, np.nan, float).drop_vars(dim) + end = xr.full_like(base, np.nan, float).drop_vars(dim) + length = xr.full_like(base, np.nan, float).drop_vars(dim) + else: + if date is not None: + # If the beginning was after the mid date, both bounds are NaT. + valid_start = beg < mid_idx.squeeze() + else: + valid_start = True + + not_da = not_da.where(da[dim] >= da[dim][mid_idx][0]) + end = first_run( + not_da, + window=window, + dim=dim, + ) + # If there was a beginning but no end, season goes to the end of the array + no_end = beg.notnull() & end.isnull() + + # Length + length = end - beg + + # No end: length is actually until the end of the array, so it is missing 1 + length = xr.where(no_end, da[dim].size - beg, length) + # Where the beginning was before the mid-date, invalid. + length = length.where(valid_start) + # Where there were data points, but no season : put 0 length + length = xr.where(beg.isnull() & end.notnull(), 0, length) + + # No end: end defaults to the last element (this differs from length, but heh) + end = xr.where(no_end, da[dim].size - 1, end) + + # Where the beginning was before the mid-date + beg = beg.where(valid_start) + end = end.where(valid_start) + + if coord: + crd = da[dim] + if isinstance(coord, str): + crd = getattr(crd.dt, coord) + coordstr = coord + else: + coordstr = dim + beg = lazy_indexing(crd, beg) + end = lazy_indexing(crd, end) + else: + coordstr = "index" + + out = xr.Dataset({"start": beg, "end": end, "length": length}) + + out.start.attrs.update( + long_name="Start of the season.", + description=f"First {coordstr} of a run of at least {window} steps respecting the condition.", + ) + out.end.attrs.update( + long_name="End of the season.", + description=f"First {coordstr} of a run of at least {window} " + "steps breaking the condition, starting after `start`.", + ) + out.length.attrs.update( + long_name="Length of the season.", + description="Number of steps of the original series in the season, between 'start' and 'end'.", + ) + return out + + +def season_length( + da: xr.DataArray, + window: int, + date: DayOfYearStr | None = None, + dim: str = "time", +) -> xr.DataArray: + """Return the length of the longest semi-consecutive run of True values (optionally including a given date). + + A "season" is a run of True values that may include breaks under a given length (`window`). + The start is computed as the first run of `window` True values, then end as the first subsequent run + of `window` False values. If a date is passed, it must be included in the season. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum duration of consecutive values to start and end the season. + date : DayOfYearStr, optional + The date (in MM-DD format) that a run must include to be considered valid. + dim : str + Dimension along which to calculate consecutive run (default: 'time'). + + Returns + ------- + xr.DataArray, [int] + Length of the longest run of True values along a given dimension (inclusive of a given date) + without breaks longer than a given length. + + Notes + ----- + The run can include holes of False or NaN values, so long as they do not exceed the window size. + + If a date is given, the season start and end are forced to be on each side of this date. This means that + even if the "real" season has been over for a long time, this is the date used in the length calculation. + Example : Length of the "warm season", where T > 25°C, with date = 1st August. Let's say the temperature is over + 25 for all June, but July and august have very cold temperatures. Instead of returning 30 days (June), the function + will return 61 days (July + June). + """ + seas = season(da, window, date, dim, coord=False) + return seas.length + + +def run_end_after_date( + da: xr.DataArray, + window: int, + date: DayOfYearStr = "07-01", + dim: str = "time", + coord: bool | str | None = "dayofyear", +) -> xr.DataArray: + """Return the index of the first item after the end of a run after a given date. + + The run must begin before the date. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum duration of consecutive run to accumulate values. + date : str + The date after which to look for the end of a run. + dim : str + Dimension along which to calculate consecutive run (default: 'time'). + coord : Optional[Union[bool, str]] + If not False, the function returns values along `dim` instead of indexes. + If `dim` has a datetime dtype, `coord` can also be a str of the name of the + DateTimeAccessor object to use (ex: 'dayofyear'). + + Returns + ------- + xr.DataArray + Index (or coordinate if `coord` is not False) of last item in last valid run. + Returns np.nan if there are no valid runs. + """ + mid_idx = index_of_date(da[dim], date, max_idxs=1, default=0) + if mid_idx.size == 0: # The date is not within the group. Happens at boundaries. + return xr.full_like(da.isel({dim: 0}), np.nan, float).drop_vars(dim) + + end = first_run( + (~da).where(da[dim] >= da[dim][mid_idx][0]), + window=window, + dim=dim, + coord=coord, + ) + beg = first_run(da.where(da[dim] < da[dim][mid_idx][0]), window=window, dim=dim) + + if coord: + last = da[dim][-1] + if isinstance(coord, str): + last = getattr(last.dt, coord) + else: + last = da[dim].size - 1 + + end = xr.where(end.isnull() & beg.notnull(), last, end) + return end.where(beg.notnull()).drop_vars(dim, errors="ignore") + + +def first_run_after_date( + da: xr.DataArray, + window: int, + date: DayOfYearStr | None = "07-01", + dim: str = "time", + coord: bool | str | None = "dayofyear", +) -> xr.DataArray: + """Return the index of the first item of the first run after a given date. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum duration of consecutive run to accumulate values. + date : DayOfYearStr + The date after which to look for the run. + dim : str + Dimension along which to calculate consecutive run (default: 'time'). + coord : Optional[Union[bool, str]] + If not False, the function returns values along `dim` instead of indexes. + If `dim` has a datetime dtype, `coord` can also be a str of the name of the + DateTimeAccessor object to use (ex: 'dayofyear'). + + Returns + ------- + xr.DataArray + Index (or coordinate if `coord` is not False) of first item in the first valid run. + Returns np.nan if there are no valid runs. + """ + mid_idx = index_of_date(da[dim], date, max_idxs=1, default=0) + if mid_idx.size == 0: # The date is not within the group. Happens at boundaries. + return xr.full_like(da.isel({dim: 0}), np.nan, float).drop_vars(dim) + + return first_run( + da.where(da[dim] >= da[dim][mid_idx][0]), + window=window, + dim=dim, + coord=coord, + ) + + +def last_run_before_date( + da: xr.DataArray, + window: int, + date: DayOfYearStr = "07-01", + dim: str = "time", + coord: bool | str | None = "dayofyear", +) -> xr.DataArray: + """Return the index of the last item of the last run before a given date. + + Parameters + ---------- + da : xr.DataArray + Input N-dimensional DataArray (boolean). + window : int + Minimum duration of consecutive run to accumulate values. + date : DayOfYearStr + The date before which to look for the last event. + dim : str + Dimension along which to calculate consecutive run (default: 'time'). + coord : Optional[Union[bool, str]] + If not False, the function returns values along `dim` instead of indexes. + If `dim` has a datetime dtype, `coord` can also be a str of the name of the + DateTimeAccessor object to use (ex: 'dayofyear'). + + Returns + ------- + xr.DataArray + Index (or coordinate if `coord` is not False) of last item in last valid run. + Returns np.nan if there are no valid runs. + """ + mid_idx = index_of_date(da[dim], date, default=-1) + + if mid_idx.size == 0: # The date is not within the group. Happens at boundaries. + return xr.full_like(da.isel({dim: 0}), np.nan, float).drop_vars(dim) + + run = da.where(da[dim] <= da[dim][mid_idx][0]) + return last_run(run, window=window, dim=dim, coord=coord) + + +@njit +def _rle_1d(ia): + y = ia[1:] != ia[:-1] # pairwise unequal (string safe) + i = np.append(np.nonzero(y)[0], ia.size - 1) # must include last element position + rl = np.diff(np.append(-1, i)) # run lengths + pos = np.cumsum(np.append(0, rl))[:-1] # positions + return ia[i], rl, pos + + +def rle_1d( + arr: int | float | bool | Sequence[int | float | bool], +) -> tuple[np.array, np.array, np.array]: + """Return the length, starting position and value of consecutive identical values. + + Parameters + ---------- + arr : Sequence[Union[int, float, bool]] + Array of values to be parsed. + + Returns + ------- + values : np.array + The values taken by arr over each run. + run lengths : np.array + The length of each run. + start position : np.array + The starting index of each run. + + Examples + -------- + >>> from xclim.indices.run_length import rle_1d + >>> a = [1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3] + >>> rle_1d(a) + (array([1, 2, 3]), array([2, 4, 6]), array([0, 2, 6])) + """ + ia = np.asarray(arr) + n = len(ia) + + if n == 0: + warn("run length array empty") + # Returning None makes some other 1d func below fail. + return np.array(np.nan), 0, np.array(np.nan) + return _rle_1d(ia) + + +def first_run_1d(arr: Sequence[int | float], window: int) -> int | np.nan: + """Return the index of the first item of a run of at least a given length. + + Parameters + ---------- + arr : Sequence[Union[int, float]] + Input array. + window : int + Minimum duration of consecutive run to accumulate values. + + Returns + ------- + int or np.nan + Index of first item in first valid run. + Returns np.nan if there are no valid runs. + """ + v, rl, pos = rle_1d(arr) + ind = np.where(v * rl >= window, pos, np.inf).min() + + if np.isinf(ind): + return np.nan + return ind + + +def statistics_run_1d(arr: Sequence[bool], reducer: str, window: int) -> int: + """Return statistics on lengths of run of identical values. + + Parameters + ---------- + arr : Sequence[bool] + Input array (bool) + reducer : {'mean', 'sum', 'min', 'max', 'std'} + Reducing function name. + window : int + Minimal length of runs to be included in the statistics + + Returns + ------- + int + Statistics on length of runs. + """ + v, rl = rle_1d(arr)[:2] + if not np.any(v) or np.all(v * rl < window): + return 0 + func = getattr(np, f"nan{reducer}") + return func(np.where(v * rl >= window, rl, np.NaN)) + + +def windowed_run_count_1d(arr: Sequence[bool], window: int) -> int: + """Return the number of consecutive true values in array for runs at least as long as given duration. + + Parameters + ---------- + arr : Sequence[bool] + Input array (bool). + window : int + Minimum duration of consecutive run to accumulate values. + + Returns + ------- + int + Total number of true values part of a consecutive run at least `window` long. + """ + v, rl = rle_1d(arr)[:2] + return np.where(v * rl >= window, rl, 0).sum() + + +def windowed_run_events_1d(arr: Sequence[bool], window: int) -> xr.DataArray: + """Return the number of runs of a minimum length. + + Parameters + ---------- + arr : Sequence[bool] + Input array (bool). + window : int + Minimum run length. + + Returns + ------- + xr.DataArray, [int] + Number of distinct runs of a minimum length. + """ + v, rl, _ = rle_1d(arr) + return (v * rl >= window).sum() + + +def windowed_run_count_ufunc( + x: xr.DataArray | Sequence[bool], window: int, dim: str +) -> xr.DataArray: + """Dask-parallel version of windowed_run_count_1d, ie: the number of consecutive true values in array for runs at least as long as given duration. + + Parameters + ---------- + x : Sequence[bool] + Input array (bool). + window : int + Minimum duration of consecutive run to accumulate values. + dim : str + Dimension along which to calculate windowed run. + + Returns + ------- + xr.DataArray + A function operating along the time dimension of a dask-array. + """ + return xr.apply_ufunc( + windowed_run_count_1d, + x, + input_core_dims=[[dim]], + vectorize=True, + dask="parallelized", + output_dtypes=[int], + keep_attrs=True, + kwargs={"window": window}, + ) + + +def windowed_run_events_ufunc( + x: xr.DataArray | Sequence[bool], window: int, dim: str +) -> xr.DataArray: + """Dask-parallel version of windowed_run_events_1d, ie: the number of runs at least as long as given duration. + + Parameters + ---------- + x : Sequence[bool] + Input array (bool). + window : int + Minimum run length. + dim : str + Dimension along which to calculate windowed run. + + Returns + ------- + xr.DataArray + A function operating along the time dimension of a dask-array. + """ + return xr.apply_ufunc( + windowed_run_events_1d, + x, + input_core_dims=[[dim]], + vectorize=True, + dask="parallelized", + output_dtypes=[int], + keep_attrs=True, + kwargs={"window": window}, + ) + + +def statistics_run_ufunc( + x: xr.DataArray | Sequence[bool], + reducer: str, + window: int, + dim: str = "time", +) -> xr.DataArray: + """Dask-parallel version of statistics_run_1d, ie: the {reducer} number of consecutive true values in array. + + Parameters + ---------- + x : Sequence[bool] + Input array (bool) + reducer: {'min', 'max', 'mean', 'sum', 'std'} + Reducing function name. + window : int + Minimal length of runs. + dim : str + The dimension along which the runs are found. + + Returns + ------- + xr.DataArray + A function operating along the time dimension of a dask-array. + """ + return xr.apply_ufunc( + statistics_run_1d, + x, + input_core_dims=[[dim]], + kwargs={"reducer": reducer, "window": window}, + vectorize=True, + dask="parallelized", + output_dtypes=[float], + keep_attrs=True, + ) + + +def first_run_ufunc( + x: xr.DataArray | Sequence[bool], + window: int, + dim: str, +) -> xr.DataArray: + """Dask-parallel version of first_run_1d, ie: the first entry in array of consecutive true values. + + Parameters + ---------- + x : Union[xr.DataArray, Sequence[bool]] + Input array (bool). + window : int + Minimum run length. + dim : str + The dimension along which the runs are found. + + Returns + ------- + xr.DataArray + A function operating along the time dimension of a dask-array. + """ + ind = xr.apply_ufunc( + first_run_1d, + x, + input_core_dims=[[dim]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + keep_attrs=True, + kwargs={"window": window}, + ) + + return ind + + +def lazy_indexing( + da: xr.DataArray, index: xr.DataArray, dim: str | None = None +) -> xr.DataArray: + """Get values of `da` at indices `index` in a NaN-aware and lazy manner. + + Parameters + ---------- + da : xr.DataArray + Input array. If not 1D, `dim` must be given and must not appear in index. + index : xr.DataArray + N-d integer indices, if da is not 1D, all dimensions of index must be in da + dim : str, optional + Dimension along which to index, unused if `da` is 1D, should not be present in `index`. + + Returns + ------- + xr.DataArray + Values of `da` at indices `index`. + """ + if da.ndim == 1: + # Case where da is 1D and index is N-D + # Slightly better performance using map_blocks, over an apply_ufunc + def _index_from_1d_array(indices, array): + return array[indices] + + idx_ndim = index.ndim + if idx_ndim == 0: + # The 0-D index case, we add a dummy dimension to help dask + dim = get_temp_dimname(da.dims, "x") + index = index.expand_dims(dim) + # Which indexes to mask. + invalid = index.isnull() + # NaN-indexing doesn't work, so fill with 0 and cast to int + index = index.fillna(0).astype(int) + + # No need for coords, we extract by integer index. + # Renaming with no name to fix bug in xr 2024.01.0 + tmpname = get_temp_dimname(da.dims, "temp") + da2 = xr.DataArray(da.data, dims=(tmpname,), name=None) + # for each chunk of index, take corresponding values from da + out = index.map_blocks(_index_from_1d_array, args=(da2,)).rename(da.name) + # mask where index was NaN. Drop any auxiliary coord, they are already on `out`. + # Chunked aux coord would have the same name on both sides and xarray will want to check if they are equal, which means loading them + # making lazy_indexing not lazy. + out = out.where( + ~invalid.drop_vars( + [crd for crd in invalid.coords if crd not in invalid.dims] + ) + ) + if idx_ndim == 0: + # 0-D case, drop useless coords and dummy dim + out = out.drop_vars(da.dims[0], errors="ignore").squeeze() + return out.drop_vars(dim or da.dims[0], errors="ignore") + + # Case where index.dims is a subset of da.dims. + if dim is None: + diff_dims = set(da.dims) - set(index.dims) + if len(diff_dims) == 0: + raise ValueError( + "da must have at least one dimension more than index for lazy_indexing." + ) + if len(diff_dims) > 1: + raise ValueError( + "If da has more than one dimension more than index, the indexing dim must be given through `dim`" + ) + dim = diff_dims.pop() + + def _index_from_nd_array(array, indices): + return np.take_along_axis(array, indices[..., np.newaxis], axis=-1)[..., 0] + + return xr.apply_ufunc( + _index_from_nd_array, + da, + index.astype(int), + input_core_dims=[[dim], []], + output_core_dims=[[]], + dask="parallelized", + output_dtypes=[da.dtype], + ) + + +def index_of_date( + time: xr.DataArray, + date: DateStr | DayOfYearStr | None, + max_idxs: int | None = None, + default: int = 0, +) -> np.ndarray: + """Get the index of a date in a time array. + + Parameters + ---------- + time : xr.DataArray + An array of datetime values, any calendar. + date : DayOfYearStr or DateStr, optional + A string in the "yyyy-mm-dd" or "mm-dd" format. + If None, returns default. + max_idxs : int, optional + Maximum number of returned indexes. + default : int + Index to return if date is None. + + Raises + ------ + ValueError + If there are most instances of `date` in `time` than `max_idxs`. + + Returns + ------- + numpy.ndarray + 1D array of integers, indexes of `date` in `time`. + """ + if date is None: + return np.array([default]) + try: + date = datetime.strptime(date, "%Y-%m-%d") + year_cond = time.dt.year == date.year + except ValueError: + date = datetime.strptime(date, "%m-%d") + year_cond = True + + idxs = np.where( + year_cond & (time.dt.month == date.month) & (time.dt.day == date.day) + )[0] + if max_idxs is not None and idxs.size > max_idxs: + raise ValueError( + f"More than {max_idxs} instance of date {date} found in the coordinate array." + ) + return idxs + + +def suspicious_run_1d( + arr: np.ndarray, + window: int = 10, + op: str = ">", + thresh: float | None = None, +) -> np.ndarray: + """Return True where the array contains a run of identical values. + + Parameters + ---------- + arr : numpy.ndarray + Array of values to be parsed. + window : int + Minimum run length. + op : {">", ">=", "==", "<", "<=", "eq", "gt", "lt", "gteq", "lteq", "ge", "le"} + Operator for threshold comparison. Defaults to ">". + thresh : float, optional + Threshold compared against which values are checked for identical values. + + Returns + ------- + numpy.ndarray + Whether or not the data points are part of a run of identical values. + """ + v, rl, pos = rle_1d(arr) + sus_runs = rl >= window + if thresh is not None: + if op in {">", "gt"}: + sus_runs = sus_runs & (v > thresh) + elif op in {"<", "lt"}: + sus_runs = sus_runs & (v < thresh) + elif op in {"==", "eq"}: + sus_runs = sus_runs & (v == thresh) + elif op in {"!=", "ne"}: + sus_runs = sus_runs & (v != thresh) + elif op in {">=", "gteq", "ge"}: + sus_runs = sus_runs & (v >= thresh) + elif op in {"<=", "lteq", "le"}: + sus_runs = sus_runs & (v <= thresh) + else: + raise NotImplementedError(f"{op}") + + out = np.zeros_like(arr, dtype=bool) + for st, l in zip(pos[sus_runs], rl[sus_runs]): # noqa: E741 + out[st : st + l] = True # noqa: E741 + return out + + +def suspicious_run( + arr: xr.DataArray, + dim: str = "time", + window: int = 10, + op: str = ">", + thresh: float | None = None, +) -> xr.DataArray: + """Return True where the array contains has runs of identical values, vectorized version. + + In opposition to other run length functions, here the output has the same shape as the input. + + Parameters + ---------- + arr : xr.DataArray + Array of values to be parsed. + dim : str + Dimension along which to check for runs (default: "time"). + window : int + Minimum run length. + op : {">", ">=", "==", "<", "<=", "eq", "gt", "lt", "gteq", "lteq"} + Operator for threshold comparison, defaults to ">". + thresh : float, optional + Threshold above which values are checked for identical values. + + Returns + ------- + xarray.DataArray + """ + return xr.apply_ufunc( + suspicious_run_1d, + arr, + input_core_dims=[[dim]], + output_core_dims=[[dim]], + vectorize=True, + dask="parallelized", + output_dtypes=[bool], + keep_attrs=True, + kwargs=dict(window=window, op=op, thresh=thresh), + ) diff --git a/src/xsdba/xclim_submodules/stats.py b/src/xsdba/xclim_submodules/stats.py new file mode 100644 index 0000000..4a26152 --- /dev/null +++ b/src/xsdba/xclim_submodules/stats.py @@ -0,0 +1,622 @@ +"""Statistic-related functions. See the `frequency_analysis` notebook for examples.""" + +from __future__ import annotations + +import json +import warnings +from collections.abc import Sequence +from typing import Any + +import numpy as np +import scipy.stats +import xarray as xr + +from xsdba.base import uses_dask +from xsdba.formatting import prefix_attrs, unprefix_attrs, update_history +from xsdba.typing import DateStr, Quantified +from xsdba.units import convert_units_to + +from . import generic + +__all__ = [ + "_fit_start", + "dist_method", + "fa", + "fit", + "frequency_analysis", + "get_dist", + "parametric_cdf", + "parametric_quantile", +] + + +# Fit the parameters. +# This would also be the place to impose constraints on the series minimum length if needed. +def _fitfunc_1d(arr, *, dist, nparams, method, **fitkwargs): + """Fit distribution parameters.""" + x = np.ma.masked_invalid(arr).compressed() # pylint: disable=no-member + + # Return NaNs if array is empty. + if len(x) <= 1: + return np.asarray([np.nan] * nparams) + + # Estimate parameters + if method in ["ML", "MLE"]: + args, kwargs = _fit_start(x, dist.name, **fitkwargs) + params = dist.fit(x, *args, method="mle", **kwargs, **fitkwargs) + elif method == "MM": + params = dist.fit(x, method="mm", **fitkwargs) + elif method == "PWM": + params = list(dist.lmom_fit(x).values()) + elif method == "APP": + args, kwargs = _fit_start(x, dist.name, **fitkwargs) + kwargs.setdefault("loc", 0) + params = list(args) + [kwargs["loc"], kwargs["scale"]] + else: + raise NotImplementedError(f"Unknown method `{method}`.") + + params = np.asarray(params) + + # Fill with NaNs if one of the parameters is NaN + if np.isnan(params).any(): + params[:] = np.nan + + return params + + +def fit( + da: xr.DataArray, + dist: str | scipy.stats.rv_continuous = "norm", + method: str = "ML", + dim: str = "time", + **fitkwargs: Any, +) -> xr.DataArray: + r"""Fit an array to a univariate distribution along the time dimension. + + Parameters + ---------- + da : xr.DataArray + Time series to be fitted along the time dimension. + dist : str or rv_continuous distribution object + Name of the univariate distribution, such as beta, expon, genextreme, gamma, gumbel_r, lognorm, norm + (see :py:mod:scipy.stats for full list) or the distribution object itself. + method : {"ML" or "MLE", "MM", "PWM", "APP"} + Fitting method, either maximum likelihood (ML or MLE), method of moments (MM) or approximate method (APP). + Can also be the probability weighted moments (PWM), also called L-Moments, if a compatible `dist` object is passed. + The PWM method is usually more robust to outliers. + dim : str + The dimension upon which to perform the indexing (default: "time"). + \*\*fitkwargs + Other arguments passed directly to :py:func:`_fitstart` and to the distribution's `fit`. + + Returns + ------- + xr.DataArray + An array of fitted distribution parameters. + + Notes + ----- + Coordinates for which all values are NaNs will be dropped before fitting the distribution. If the array still + contains NaNs, the distribution parameters will be returned as NaNs. + """ + method = method.upper() + method_name = { + "ML": "maximum likelihood", + "MM": "method of moments", + "MLE": "maximum likelihood", + "PWM": "probability weighted moments", + "APP": "approximative method", + } + if method not in method_name: + raise ValueError(f"Fitting method not recognized: {method}") + + # Get the distribution + dist = get_dist(dist) + + if method == "PWM" and not hasattr(dist, "lmom_fit"): + raise ValueError( + f"The given distribution {dist} does not implement the PWM fitting method. Please pass an instance from the lmoments3 package." + ) + + shape_params = [] if dist.shapes is None else dist.shapes.split(",") + dist_params = shape_params + ["loc", "scale"] + + data = xr.apply_ufunc( + _fitfunc_1d, + da, + input_core_dims=[[dim]], + output_core_dims=[["dparams"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + keep_attrs=True, + kwargs=dict( + # Don't know how APP should be included, this works for now + dist=dist, + nparams=len(dist_params), + method=method, + **fitkwargs, + ), + dask_gufunc_kwargs={"output_sizes": {"dparams": len(dist_params)}}, + ) + + # Add coordinates for the distribution parameters and transpose to original shape (with dim -> dparams) + dims = [d if d != dim else "dparams" for d in da.dims] + out = data.assign_coords(dparams=dist_params).transpose(*dims) + + out.attrs = prefix_attrs( + da.attrs, ["standard_name", "long_name", "units", "description"], "original_" + ) + attrs = dict( + long_name=f"{dist.name} parameters", + description=f"Parameters of the {dist.name} distribution", + method=method, + estimator=method_name[method].capitalize(), + scipy_dist=dist.name, + units="", + history=update_history( + f"Estimate distribution parameters by {method_name[method]} method along dimension {dim}.", + new_name="fit", + data=da, + ), + ) + out.attrs.update(attrs) + return out + + +def parametric_quantile( + p: xr.DataArray, + q: float | Sequence[float], + dist: str | scipy.stats.rv_continuous | None = None, +) -> xr.DataArray: + """Return the value corresponding to the given distribution parameters and quantile. + + Parameters + ---------- + p : xr.DataArray + Distribution parameters returned by the `fit` function. + The array should have dimension `dparams` storing the distribution parameters, + and attribute `scipy_dist`, storing the name of the distribution. + q : float or Sequence of float + Quantile to compute, which must be between `0` and `1`, inclusive. + dist: str, rv_continuous instance, optional + The distribution name or instance if the `scipy_dist` attribute is not available on `p`. + + Returns + ------- + xarray.DataArray + An array of parametric quantiles estimated from the distribution parameters. + + Notes + ----- + When all quantiles are above 0.5, the `isf` method is used instead of `ppf` because accuracy is sometimes better. + """ + q = np.atleast_1d(q) + + dist = get_dist(dist or p.attrs["scipy_dist"]) + + # Create a lambda function to facilitate passing arguments to dask. There is probably a better way to do this. + if np.all(q > 0.5): + + def func(x): + return dist.isf(1 - q, *x) + + else: + + def func(x): + return dist.ppf(q, *x) + + data = xr.apply_ufunc( + func, + p, + input_core_dims=[["dparams"]], + output_core_dims=[["quantile"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + keep_attrs=True, + dask_gufunc_kwargs={"output_sizes": {"quantile": len(q)}}, + ) + + # Assign quantile coordinates and transpose to preserve original dimension order + dims = [d if d != "dparams" else "quantile" for d in p.dims] + out = data.assign_coords(quantile=q).transpose(*dims) + out.attrs = unprefix_attrs(p.attrs, ["units", "standard_name"], "original_") + + attrs = dict( + long_name=f"{dist.name} quantiles", + description=f"Quantiles estimated by the {dist.name} distribution", + cell_methods="dparams: ppf", + history=update_history( + "Compute parametric quantiles from distribution parameters", + new_name="parametric_quantile", + parameters=p, + ), + ) + out.attrs.update(attrs) + return out + + +def parametric_cdf( + p: xr.DataArray, + v: float | Sequence[float], + dist: str | scipy.stats.rv_continuous | None = None, +) -> xr.DataArray: + """Return the cumulative distribution function corresponding to the given distribution parameters and value. + + Parameters + ---------- + p : xr.DataArray + Distribution parameters returned by the `fit` function. + The array should have dimension `dparams` storing the distribution parameters, + and attribute `scipy_dist`, storing the name of the distribution. + v : float or Sequence of float + Value to compute the CDF. + dist: str, rv_continuous instance, optional + The distribution name or instance is the `scipy_dist` attribute is not available on `p`. + + Returns + ------- + xarray.DataArray + An array of parametric CDF values estimated from the distribution parameters. + """ + v = np.atleast_1d(v) + + dist = get_dist(dist or p.attrs["scipy_dist"]) + + # Create a lambda function to facilitate passing arguments to dask. There is probably a better way to do this. + def func(x): + return dist.cdf(v, *x) + + data = xr.apply_ufunc( + func, + p, + input_core_dims=[["dparams"]], + output_core_dims=[["cdf"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + keep_attrs=True, + dask_gufunc_kwargs={"output_sizes": {"cdf": len(v)}}, + ) + + # Assign quantile coordinates and transpose to preserve original dimension order + dims = [d if d != "dparams" else "cdf" for d in p.dims] + out = data.assign_coords(cdf=v).transpose(*dims) + out.attrs = unprefix_attrs(p.attrs, ["units", "standard_name"], "original_") + + attrs = dict( + long_name=f"{dist.name} cdf", + description=f"CDF estimated by the {dist.name} distribution", + cell_methods="dparams: cdf", + history=update_history( + "Compute parametric cdf from distribution parameters", + new_name="parametric_cdf", + parameters=p, + ), + ) + out.attrs.update(attrs) + return out + + +def fa( + da: xr.DataArray, + t: int | Sequence, + dist: str | scipy.stats.rv_continuous = "norm", + mode: str = "max", + method: str = "ML", +) -> xr.DataArray: + """Return the value corresponding to the given return period. + + Parameters + ---------- + da : xr.DataArray + Maximized/minimized input data with a `time` dimension. + t : int or Sequence of int + Return period. The period depends on the resolution of the input data. If the input array's resolution is + yearly, then the return period is in years. + dist : str or rv_continuous instance + Name of the univariate distribution, such as: + `beta`, `expon`, `genextreme`, `gamma`, `gumbel_r`, `lognorm`, `norm` + Or the distribution instance itself. + mode : {'min', 'max} + Whether we are looking for a probability of exceedance (max) or a probability of non-exceedance (min). + method : {"ML", "MLE", "MOM", "PWM", "APP"} + Fitting method, either maximum likelihood (ML or MLE), method of moments (MOM) or approximate method (APP). + Also accepts probability weighted moments (PWM), also called L-Moments, if `dist` is an instance from the lmoments3 library. + The PWM method is usually more robust to outliers. + + Returns + ------- + xarray.DataArray + An array of values with a 1/t probability of exceedance (if mode=='max'). + + See Also + -------- + scipy.stats : For descriptions of univariate distribution types. + """ + # Fit the parameters of the distribution + p = fit(da, dist, method=method) + t = np.atleast_1d(t) + + if mode in ["max", "high"]: + q = 1 - 1.0 / t + + elif mode in ["min", "low"]: + q = 1.0 / t + + else: + raise ValueError(f"Mode `{mode}` should be either 'max' or 'min'.") + + # Compute the quantiles + out = ( + parametric_quantile(p, q, dist) + .rename({"quantile": "return_period"}) + .assign_coords(return_period=t) + ) + out.attrs["mode"] = mode + return out + + +def frequency_analysis( + da: xr.DataArray, + mode: str, + t: int | Sequence[int], + dist: str | scipy.stats.rv_continuous, + window: int = 1, + freq: str | None = None, + method: str = "ML", + **indexer: int | float | str, +) -> xr.DataArray: + r"""Return the value corresponding to a return period. + + Parameters + ---------- + da : xarray.DataArray + Input data. + mode : {'min', 'max'} + Whether we are looking for a probability of exceedance (high) or a probability of non-exceedance (low). + t : int or sequence + Return period. The period depends on the resolution of the input data. If the input array's resolution is + yearly, then the return period is in years. + dist : str or rv_continuous + Name of the univariate distribution, e.g. `beta`, `expon`, `genextreme`, `gamma`, `gumbel_r`, `lognorm`, `norm`. + Or an instance of the distribution. + window : int + Averaging window length (days). + freq : str, optional + Resampling frequency. If None, the frequency is assumed to be 'YS' unless the indexer is season='DJF', + in which case `freq` would be set to `YS-DEC`. + method : {"ML" or "MLE", "MOM", "PWM", "APP"} + Fitting method, either maximum likelihood (ML or MLE), method of moments (MOM) or approximate method (APP). + Also accepts probability weighted moments (PWM), also called L-Moments, if `dist` is an instance from the lmoments3 library. + The PWM method is usually more robust to outliers. + \*\*indexer + Time attribute and values over which to subset the array. For example, use season='DJF' to select winter values, + month=1 to select January, or month=[6,7,8] to select summer months. If indexer is not provided, all values are + considered. + + Returns + ------- + xarray.DataArray + An array of values with a 1/t probability of exceedance or non-exceedance when mode is high or low respectively. + + See Also + -------- + scipy.stats : For descriptions of univariate distribution types. + """ + # Apply rolling average + attrs = da.attrs.copy() + if window > 1: + da = da.rolling(time=window).mean(skipna=False) + da.attrs.update(attrs) + + # Assign default resampling frequency if not provided + freq = freq or generic.default_freq(**indexer) + + # Extract the time series of min or max over the period + sel = generic.select_resample_op(da, op=mode, freq=freq, **indexer) + + if uses_dask(sel): + sel = sel.chunk({"time": -1}) + # Frequency analysis + return fa(sel, t, dist=dist, mode=mode, method=method) + + +def get_dist(dist: str | scipy.stats.rv_continuous): + """Return a distribution object from `scipy.stats`.""" + if isinstance(dist, scipy.stats.rv_continuous): + return dist + + dc = getattr(scipy.stats, dist, None) + if dc is None: + e = f"Statistical distribution `{dist}` is not found in scipy.stats." + raise ValueError(e) + return dc + + +def _fit_start(x, dist: str, **fitkwargs: Any) -> tuple[tuple, dict]: + r"""Return initial values for distribution parameters. + + Providing the ML fit method initial values can help the optimizer find the global optimum. + + Parameters + ---------- + x : array-like + Input data. + dist : str + Name of the univariate distribution, e.g. `beta`, `expon`, `genextreme`, `gamma`, `gumbel_r`, `lognorm`, `norm`. + (see :py:mod:scipy.stats). Only `genextreme` and `weibull_exp` distributions are supported. + \*\*fitkwargs + Kwargs passed to fit. + + Returns + ------- + tuple, dict + + References + ---------- + :cite:cts:`coles_introduction_2001,cohen_parameter_2019, thom_1958, cooke_1979, muralidhar_1992` + + """ + x = np.asarray(x) + m = x.mean() + v = x.var() + + if dist == "genextreme": + s = np.sqrt(6 * v) / np.pi + return (0.1,), {"loc": m - 0.57722 * s, "scale": s} + + if dist == "genpareto" and "floc" in fitkwargs: + # Taken from julia' Extremes. Case for when "mu/loc" is known. + t = fitkwargs["floc"] + if not np.isclose(t, 0): + m = (x - t).mean() + v = (x - t).var() + + c = 0.5 * (1 - m**2 / v) + scale = (1 - c) * m + return (c,), {"scale": scale} + + if dist in "weibull_min": + s = x.std() + loc = x.min() - 0.01 * s + chat = np.pi / np.sqrt(6) / (np.log(x - loc)).std() + scale = ((x - loc) ** chat).mean() ** (1 / chat) + return (chat,), {"loc": loc, "scale": scale} + + if dist in ["gamma"]: + if "floc" in fitkwargs: + loc0 = fitkwargs["floc"] + else: + xs = sorted(x) + x1, x2, xn = xs[0], xs[1], xs[-1] + # muralidhar_1992 would suggest the following, but it seems more unstable + # using cooke_1979 for now + # n = len(x) + # cv = x.std() / x.mean() + # p = (0.48265 + 0.32967 * cv) * n ** (-0.2984 * cv) + # xp = xs[int(p/100*n)] + xp = x2 + loc0 = (x1 * xn - xp**2) / (x1 + xn - 2 * xp) + loc0 = loc0 if loc0 < x1 else (0.9999 * x1 if x1 > 0 else 1.0001 * x1) + x_pos = x - loc0 + x_pos = x_pos[x_pos > 0] + m = x_pos.mean() + log_of_mean = np.log(m) + mean_of_logs = np.log(x_pos).mean() + A = log_of_mean - mean_of_logs + a0 = (1 + np.sqrt(1 + 4 * A / 3)) / (4 * A) + scale0 = m / a0 + kwargs = {"scale": scale0, "loc": loc0} + return (a0,), kwargs + + if dist in ["fisk"]: + if "floc" in fitkwargs: + loc0 = fitkwargs["floc"] + else: + xs = sorted(x) + x1, x2, xn = xs[0], xs[1], xs[-1] + loc0 = (x1 * xn - x2**2) / (x1 + xn - 2 * x2) + loc0 = loc0 if loc0 < x1 else (0.9999 * x1 if x1 > 0 else 1.0001 * x1) + x_pos = x - loc0 + x_pos = x_pos[x_pos > 0] + # method of moments: + # LHS is computed analytically with the two-parameters log-logistic distribution + # and depends on alpha,beta + # RHS is from the sample + # = m + # / ^2 = m2/m**2 + # solving these equations yields + m = x_pos.mean() + m2 = (x_pos**2).mean() + scale0 = 2 * m**3 / (m2 + m**2) + c0 = np.pi * m / np.sqrt(3) / np.sqrt(m2 - m**2) + kwargs = {"scale": scale0, "loc": loc0} + return (c0,), kwargs + return (), {} + + +def _dist_method_1D( # noqa: N802 + *args, dist: str | scipy.stats.rv_continuous, function: str, **kwargs: Any +) -> xr.DataArray: + r"""Statistical function for given argument on given distribution initialized with params. + + See :py:ref:`scipy:scipy.stats.rv_continuous` for all available functions and their arguments. + Every method where `"*args"` are the distribution parameters can be wrapped. + + Parameters + ---------- + \*args + The arguments for the requested scipy function. + dist : str + The scipy name of the distribution. + function : str + The name of the function to call. + \*\*kwargs + Other parameters to pass to the function call. + + Returns + ------- + array_like + """ + dist = get_dist(dist) + return getattr(dist, function)(*args, **kwargs) + + +def dist_method( + function: str, + fit_params: xr.DataArray, + arg: xr.DataArray | None = None, + dist: str | scipy.stats.rv_continuous | None = None, + **kwargs: Any, +) -> xr.DataArray: + r"""Vectorized statistical function for given argument on given distribution initialized with params. + + Methods where `"*args"` are the distribution parameters can be wrapped, except those that reduce dimensions ( + e.g. `nnlf`) or create new dimensions (eg: 'rvs' with size != 1, 'stats' with more than one moment, 'interval', + 'support'). + + Parameters + ---------- + function : str + The name of the function to call. + fit_params : xr.DataArray + Distribution parameters are along `dparams`, in the same order as given by :py:func:`fit`. + arg : array_like, optional + The first argument for the requested function if different from `fit_params`. + dist : str pr rv_continuous, optional + The distribution name or instance. Defaults to the `scipy_dist` attribute or `fit_params`. + \*\*kwargs + Other parameters to pass to the function call. + + Returns + ------- + array_like + Same shape as arg. + + See Also + -------- + scipy:scipy.stats.rv_continuous : for all available functions and their arguments. + """ + # Typically the data to be transformed + arg = [arg] if arg is not None else [] + if function == "nnlf": + raise ValueError( + "This method is not supported because it reduces the dimensionality of the data." + ) + + # We don't need to set `input_core_dims` because we're explicitly splitting the parameters here. + args = arg + [fit_params.sel(dparams=dp) for dp in fit_params.dparams.values] + + return xr.apply_ufunc( + _dist_method_1D, + *args, + kwargs={ + "dist": dist or fit_params.attrs["scipy_dist"], + "function": function, + **kwargs, + }, + output_dtypes=[float], + dask="parallelized", + ) diff --git a/tests/test_properties.py b/tests/test_properties.py new file mode 100644 index 0000000..5b693bf --- /dev/null +++ b/tests/test_properties.py @@ -0,0 +1,577 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from xarray import set_options + +from xsdba import properties +from xsdba.units import convert_units_to + + +class TestProperties: + def test_mean(self, open_dataset): + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1980"), location="Vancouver") + .pr + ).load() + + out_year = properties.mean(sim) + np.testing.assert_array_almost_equal(out_year.values, [3.0016028e-05]) + + out_season = properties.mean(sim, group="time.season") + np.testing.assert_array_almost_equal( + out_season.values, + [4.6115547e-05, 1.7220482e-05, 2.8805329e-05, 2.825359e-05], + ) + + assert out_season.long_name.startswith("Mean") + + def test_var(self, open_dataset): + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1980"), location="Vancouver") + .pr + ).load() + + out_year = properties.var(sim) + np.testing.assert_array_almost_equal(out_year.values, [2.5884779e-09]) + + out_season = properties.var(sim, group="time.season") + np.testing.assert_array_almost_equal( + out_season.values, + [3.9270796e-09, 1.2538864e-09, 1.9057025e-09, 2.8776632e-09], + ) + assert out_season.long_name.startswith("Variance") + assert out_season.units == "kg2 m-4 s-2" + + def test_std(self, open_dataset): + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1980"), location="Vancouver") + .pr + ).load() + + out_year = properties.std(sim) + np.testing.assert_array_almost_equal(out_year.values, [5.08770208398345e-05]) + + out_season = properties.std(sim, group="time.season") + np.testing.assert_array_almost_equal( + out_season.values, + [6.2666411e-05, 3.5410259e-05, 4.3654352e-05, 5.3643853e-05], + ) + assert out_season.long_name.startswith("Standard deviation") + assert out_season.units == "kg m-2 s-1" + + def test_skewness(self, open_dataset): + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1980"), location="Vancouver") + .pr + ).load() + + out_year = properties.skewness(sim) + np.testing.assert_array_almost_equal(out_year.values, [2.8497460898513745]) + + out_season = properties.skewness(sim, group="time.season") + np.testing.assert_array_almost_equal( + out_season.values, + [ + 2.036650744163691, + 3.7909534745807147, + 2.416590445325826, + 3.3521301798559566, + ], + ) + assert out_season.long_name.startswith("Skewness") + assert out_season.units == "" + + def test_quantile(self, open_dataset): + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1980"), location="Vancouver") + .pr + ).load() + + out_year = properties.quantile(sim, q=0.2) + np.testing.assert_array_almost_equal(out_year.values, [2.8109431013945154e-07]) + + out_season = properties.quantile(sim, group="time.season", q=0.2) + np.testing.assert_array_almost_equal( + out_season.values, + [ + 1.5171653330980917e-06, + 9.822543773907455e-08, + 1.8135805248675763e-07, + 4.135342521749408e-07, + ], + ) + assert out_season.long_name.startswith("Quantile 0.2") + + # TODO: test theshold_count? it's the same a test_spell_length_distribution + def test_spell_length_distribution(self, open_dataset): + ds = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1952"), location="Vancouver") + .load() + ) + + # test pr, with amount method + sim = ds.pr + kws = {"op": "<", "group": "time.month", "thresh": "1.157e-05 kg/m/m/s"} + outd = { + stat: properties.spell_length_distribution(da=sim, **kws, stat=stat) + .sel(month=1) + .values + for stat in ["mean", "max", "min"] + } + np.testing.assert_array_almost_equal( + [outd[k] for k in ["mean", "max", "min"]], [2.44127, 10, 1] + ) + + # test tasmax, with quantile method + simt = ds.tasmax + kws = {"thresh": 0.9, "op": ">=", "method": "quantile", "group": "time.month"} + outd = { + stat: properties.spell_length_distribution(da=simt, **kws, stat=stat).sel( + month=6 + ) + for stat in ["mean", "max", "min"] + } + np.testing.assert_array_almost_equal( + [outd[k].values for k in ["mean", "max", "min"]], [3.0, 6, 1] + ) + + # test varia + with pytest.raises( + ValueError, + match="percentile is not a valid method. Choose 'amount' or 'quantile'.", + ): + properties.spell_length_distribution(simt, method="percentile") + + assert ( + outd["mean"].long_name + == "Average of spell length distribution when the variable is >= the quantile 0.9 for 1 consecutive day(s)." + ) + + def test_spell_length_distribution_mixed_stat(self, open_dataset): + + time = pd.date_range("2000-01-01", periods=2 * 365, freq="D") + tas = xr.DataArray( + np.array([0] * 365 + [40] * 365), + dims=("time"), + coords={"time": time}, + attrs={"units": "degC"}, + ) + + kws_sum = dict( + thresh="30 degC", op=">=", stat="sum", stat_resample="sum", group="time" + ) + out_sum = properties.spell_length_distribution(tas, **kws_sum).values + kws_mixed = dict( + thresh="30 degC", op=">=", stat="mean", stat_resample="sum", group="time" + ) + out_mixed = properties.spell_length_distribution(tas, **kws_mixed).values + + assert out_sum == 365 + assert out_mixed == 182.5 + + @pytest.mark.parametrize( + "window,expected_amount,expected_quantile", + [ + (1, [2.333333, 4, 1], [3, 6, 1]), + (3, [1.333333, 4, 0], [2, 6, 0]), + ], + ) + def test_bivariate_spell_length_distribution( + self, open_dataset, window, expected_amount, expected_quantile + ): + ds = ( + open_dataset("sdba/CanESM2_1950-2100.nc").sel( + time=slice("1950", "1952"), location="Vancouver" + ) + ).load() + tx = ds.tasmax + with set_options(keep_attrs=True): + tn = tx - 5 + + # test with amount method + kws = { + "thresh1": "0 degC", + "thresh2": "0 degC", + "op1": ">", + "op2": "<=", + "group": "time.month", + "window": window, + } + outd = { + stat: properties.bivariate_spell_length_distribution( + da1=tx, da2=tn, **kws, stat=stat + ) + .sel(month=1) + .values + for stat in ["mean", "max", "min"] + } + np.testing.assert_array_almost_equal( + [outd[k] for k in ["mean", "max", "min"]], expected_amount + ) + + # test with quantile method + kws = { + "thresh1": 0.9, + "thresh2": 0.9, + "op1": ">", + "op2": ">", + "method1": "quantile", + "method2": "quantile", + "group": "time.month", + "window": window, + } + outd = { + stat: properties.bivariate_spell_length_distribution( + da1=tx, da2=tn, **kws, stat=stat + ) + .sel(month=6) + .values + for stat in ["mean", "max", "min"] + } + np.testing.assert_array_almost_equal( + [outd[k] for k in ["mean", "max", "min"]], expected_quantile + ) + + def test_acf(self, open_dataset): + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1952"), location="Vancouver") + .pr + ).load() + + out = properties.acf(sim, lag=1, group="time.month").sel(month=1) + np.testing.assert_array_almost_equal(out.values, [0.11242357313756905]) + + # FIXME + # with pytest.raises(ValueError, match="Grouping period year is not allowed for"): + # properties.acf(sim, group="time") + + assert out.long_name.startswith("Lag-1 autocorrelation") + assert out.units == "" + + def test_annual_cycle(self, open_dataset): + simt = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1952"), location="Vancouver") + .tasmax + ).load() + + amp = properties.annual_cycle_amplitude(simt) + relamp = properties.relative_annual_cycle_amplitude(simt) + phase = properties.annual_cycle_phase(simt) + + np.testing.assert_allclose( + [amp.values, relamp.values, phase.values], + [16.74645996, 5.802083, 167], + rtol=1e-6, + ) + # FIXME + # with pytest.raises( + # ValueError, + # match="Grouping period season is not allowed for property", + # ): + # properties.annual_cycle_amplitude(simt, group="time.season") + + # with pytest.raises( + # ValueError, + # match="Grouping period month is not allowed for property", + # ): + # properties.annual_cycle_phase(simt, group="time.month") + + assert amp.long_name.startswith("Absolute amplitude of the annual cycle") + assert phase.long_name.startswith("Phase of the annual cycle") + assert amp.units == "delta_degC" + assert relamp.units == "%" + assert phase.units == "" + + def test_annual_range(self, open_dataset): + simt = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1952"), location="Vancouver") + .tasmax + ).load() + + # Initial annual cycle was this with window = 1 + amp = properties.mean_annual_range(simt, window=1) + relamp = properties.mean_annual_relative_range(simt, window=1) + phase = properties.mean_annual_phase(simt, window=1) + + np.testing.assert_allclose( + [amp.values, relamp.values, phase.values], + [34.039806, 11.793684020675501, 165.33333333333334], + ) + + amp = properties.mean_annual_range(simt) + relamp = properties.mean_annual_relative_range(simt) + phase = properties.mean_annual_phase(simt) + + np.testing.assert_array_almost_equal( + [amp.values, relamp.values, phase.values], + [18.715261, 6.480101, 181.6666667], + ) + # FIXME + # with pytest.raises( + # ValueError, + # match="Grouping period season is not allowed for property", + # ): + # properties.mean_annual_range(simt, group="time.season") + + # with pytest.raises( + # ValueError, + # match="Grouping period month is not allowed for property", + # ): + # properties.mean_annual_phase(simt, group="time.month") + + assert amp.long_name.startswith("Average annual absolute amplitude") + assert phase.long_name.startswith("Average annual phase") + assert amp.units == "delta_degC" + assert relamp.units == "%" + assert phase.units == "" + + def test_corr_btw_var(self, open_dataset): + simt = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1952"), location="Vancouver") + .tasmax + ).load() + + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1952"), location="Vancouver") + .pr + ).load() + + pc = properties.corr_btw_var(simt, sim, corr_type="Pearson") + pp = properties.corr_btw_var( + simt, sim, corr_type="Pearson", output="pvalue" + ).values + sc = properties.corr_btw_var(simt, sim).values + sp = properties.corr_btw_var(simt, sim, output="pvalue").values + sc_jan = ( + properties.corr_btw_var(simt, sim, group="time.month").sel(month=1).values + ) + sim[0] = np.nan + pc_nan = properties.corr_btw_var(sim, simt, corr_type="Pearson").values + + np.testing.assert_array_almost_equal( + [pc.values, pp, sc, sp, sc_jan, pc_nan], + [ + -0.20849051347480407, + 3.2160438749049577e-12, + -0.3449358561881698, + 5.97619379511559e-32, + 0.28329503745038936, + -0.2090292, + ], + ) + assert pc.long_name == "Pearson correlation coefficient" + assert pc.units == "" + + with pytest.raises( + ValueError, + match="pear is not a valid type. Choose 'Pearson' or 'Spearman'.", + ): + properties.corr_btw_var(sim, simt, group="time", corr_type="pear") + + def test_relative_frequency(self, open_dataset): + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1952"), location="Vancouver") + .pr + ).load() + + test = properties.relative_frequency(sim, thresh="2.8925e-04 kg/m^2/s", op=">=") + testjan = ( + properties.relative_frequency( + sim, thresh="2.8925e-04 kg/m^2/s", op=">=", group="time.month" + ) + .sel(month=1) + .values + ) + np.testing.assert_array_almost_equal( + [test.values, testjan], [0.0045662100456621, 0.010752688172043012] + ) + assert test.long_name == "Relative frequency of values >= 2.8925e-04 kg/m^2/s." + assert test.units == "" + + def test_transition(self, open_dataset): + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1952"), location="Vancouver") + .pr + ).load() + + test = properties.transition_probability( + da=sim, initial_op="<", final_op=">=", thresh="1.157e-05 kg/m^2/s" + ) + + np.testing.assert_array_almost_equal([test.values], [0.14076782449725778]) + assert ( + test.long_name + == "Transition probability of values < 1.157e-05 kg/m^2/s to values >= 1.157e-05 kg/m^2/s." + ) + assert test.units == "" + + def test_trend(self, open_dataset): + simt = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "1952"), location="Vancouver") + .tasmax + ).load() + + slope = properties.trend(simt).values + intercept = properties.trend(simt, output="intercept").values + rvalue = properties.trend(simt, output="rvalue").values + pvalue = properties.trend(simt, output="pvalue").values + stderr = properties.trend(simt, output="stderr").values + intercept_stderr = properties.trend(simt, output="intercept_stderr").values + + np.testing.assert_array_almost_equal( + [slope, intercept, rvalue, pvalue, stderr, intercept_stderr], + [ + -0.133711111111111, + 288.762132222222222, + -0.9706433333333333, + 0.1546344444444444, + 0.033135555555555, + 0.042776666666666, + ], + 4, + ) + + slope = properties.trend(simt, group="time.month").sel(month=1) + intercept = ( + properties.trend(simt, output="intercept", group="time.month") + .sel(month=1) + .values + ) + rvalue = ( + properties.trend(simt, output="rvalue", group="time.month") + .sel(month=1) + .values + ) + pvalue = ( + properties.trend(simt, output="pvalue", group="time.month") + .sel(month=1) + .values + ) + stderr = ( + properties.trend(simt, output="stderr", group="time.month") + .sel(month=1) + .values + ) + intercept_stderr = ( + properties.trend(simt, output="intercept_stderr", group="time.month") + .sel(month=1) + .values + ) + + np.testing.assert_array_almost_equal( + [slope.values, intercept, rvalue, pvalue, stderr, intercept_stderr], + [ + 0.8254511111111111, + 281.76353222222222, + 0.576843333333333, + 0.6085644444444444, + 1.1689105555555555, + 1.509056666666666, + ], + 4, + ) + + assert slope.long_name.startswith("Slope of the interannual linear trend") + assert slope.units == "K/year" + + def test_return_value(self, open_dataset): + simt = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1950", "2010"), location="Vancouver") + .tasmax + ).load() + + out_y = properties.return_value(simt) + + out_djf = ( + properties.return_value(simt, op="min", group="time.season") + .sel(season="DJF") + .values + ) + + np.testing.assert_array_almost_equal( + [out_y.values, out_djf], [313.154, 278.072], 3 + ) + assert out_y.long_name.startswith("20-year maximal return level") + + @pytest.mark.slow + def test_spatial_correlogram(self, open_dataset): + # This also tests sdba.utils._pairwise_spearman and sdba.nbutils._pairwise_haversine_and_bins + # Test 1, does it work with 1D data? + sim = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .sel(time=slice("1981", "2010")) + .tasmax + ).load() + + out = properties.spatial_correlogram(sim, dims=["location"], bins=3) + np.testing.assert_allclose(out, [-1, np.nan, 0], atol=1e-6) + + # Test 2, not very exhaustive, this is more of a detect-if-we-break-it test. + sim = open_dataset("NRCANdaily/nrcan_canada_daily_tasmax_1990.nc").tasmax + out = properties.spatial_correlogram( + sim.isel(lon=slice(0, 50)), dims=["lon", "lat"], bins=20 + ) + np.testing.assert_allclose( + out[:5], + [0.95099902, 0.83028772, 0.66874473, 0.48893958, 0.30915054], + ) + np.testing.assert_allclose( + out.distance[:5], + [26.543199, 67.716227, 108.889254, 150.062282, 191.23531], + rtol=5e-07, + ) + + @pytest.mark.slow + def test_decorrelation_length(self, open_dataset): + sim = ( + open_dataset("NRCANdaily/nrcan_canada_daily_tasmax_1990.nc") + .tasmax.isel(lon=slice(0, 5), lat=slice(0, 1)) + .load() + ) + + out = properties.decorrelation_length( + sim, dims=["lat", "lon"], bins=10, radius=30 + ) + np.testing.assert_allclose( + out[0], + [4.5, 4.5, 4.5, 4.5, 10.5], + ) + + # ADAPT? The plan was not to allow mm/d -> kg m-2 s-1 in xsdba + # def test_get_measure(self, open_dataset): + # sim = ( + # open_dataset("sdba/CanESM2_1950-2100.nc") + # .sel(time=slice("1981", "2010"), location="Vancouver") + # .pr + # ).load() + + # ref = ( + # open_dataset("sdba/ahccd_1950-2013.nc") + # .sel(time=slice("1981", "2010"), location="Vancouver") + # .pr + # ).load() + + # sim = convert_units_to(sim, ref) + # sim_var = properties.var(sim) + # ref_var = properties.var(ref) + + # meas = properties.var.get_measure()(sim_var, ref_var) + # np.testing.assert_allclose(meas, [0.408327], rtol=1e-3)