diff --git a/pyproject.toml b/pyproject.toml
index 108a79d..455fdf2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -309,7 +309,10 @@ ignore = [
"N806",
"PTH123",
"S310",
- "PERF401" # don't force list comprehensions
+ "PERF401", # don't force list comprehensions
+ "PERF203", # allow try/except in loop
+ "E501", # line too long
+ "W505" # doc line too long
]
preview = true
select = [
diff --git a/src/xsdba/__init__.py b/src/xsdba/__init__.py
index 9554388..20483b4 100644
--- a/src/xsdba/__init__.py
+++ b/src/xsdba/__init__.py
@@ -20,7 +20,16 @@
from __future__ import annotations
-from . import adjustment, base, detrending, processing, testing, units, utils
+from . import (
+ adjustment,
+ base,
+ detrending,
+ processing,
+ properties,
+ testing,
+ units,
+ utils,
+)
# , adjustment
# from . import adjustment, base, detrending, measures, processing, properties, utils
diff --git a/src/xsdba/base.py b/src/xsdba/base.py
index fce999a..293876f 100644
--- a/src/xsdba/base.py
+++ b/src/xsdba/base.py
@@ -8,7 +8,6 @@
import datetime as pydt
import itertools
from collections.abc import Sequence
-from enum import IntEnum
from inspect import _empty, signature
from typing import Any, Callable, NewType, TypeVar
@@ -19,19 +18,10 @@
import pandas as pd
import xarray as xr
from boltons.funcutils import wraps
-from pint import Quantity
from xsdba.options import OPTIONS, SDBA_ENCODE_CF
-# XC:
-#: Type annotation for strings representing full dates (YYYY-MM-DD), may include time.
-DateStr = NewType("DateStr", str)
-
-#: Type annotation for strings representing dates without a year (MM-DD).
-DayOfYearStr = NewType("DayOfYearStr", str)
-
-#: Type annotation for thresholds and other not-exactly-a-variable quantities
-Quantified = TypeVar("Quantified", xr.DataArray, str, Quantity)
+from .typing import InputKind
# ## Base class for the sdba module
@@ -116,112 +106,6 @@ def set_dataset(self, ds: xr.Dataset) -> None:
self.ds.attrs[self._attribute] = jsonpickle.encode(self)
-# XC
-
-
-class InputKind(IntEnum):
- """Constants for input parameter kinds.
-
- For use by external parses to determine what kind of data the indicator expects.
- On the creation of an indicator, the appropriate constant is stored in
- :py:attr:`xclim.core.indicator.Indicator.parameters`. The integer value is what gets stored in the output
- of :py:meth:`xclim.core.indicator.Indicator.json`.
-
- For developers : for each constant, the docstring specifies the annotation a parameter of an indice function
- should use in order to be picked up by the indicator constructor. Notice that we are using the annotation format
- as described in `PEP 604 `_, i.e. with '|' indicating a union and without import
- objects from `typing`.
- """
-
- VARIABLE = 0
- """A data variable (DataArray or variable name).
-
- Annotation : ``xr.DataArray``.
- """
- OPTIONAL_VARIABLE = 1
- """An optional data variable (DataArray or variable name).
-
- Annotation : ``xr.DataArray | None``. The default should be None.
- """
- QUANTIFIED = 2
- """A quantity with units, either as a string (scalar), a pint.Quantity (scalar) or a DataArray (with units set).
-
- Annotation : ``xclim.core.utils.Quantified`` and an entry in the :py:func:`xclim.core.units.declare_units`
- decorator. "Quantified" translates to ``str | xr.DataArray | pint.util.Quantity``.
- """
- FREQ_STR = 3
- """A string representing an "offset alias", as defined by pandas.
-
- See the Pandas documentation on :ref:`timeseries.offset_aliases` for a list of valid aliases.
-
- Annotation : ``str`` + ``freq`` as the parameter name.
- """
- NUMBER = 4
- """A number.
-
- Annotation : ``int``, ``float`` and unions thereof, potentially optional.
- """
- STRING = 5
- """A simple string.
-
- Annotation : ``str`` or ``str | None``. In most cases, this kind of parameter makes sense
- with choices indicated in the docstring's version of the annotation with curly braces.
- See :ref:`notebooks/extendxclim:Defining new indices`.
- """
- DAY_OF_YEAR = 6
- """A date, but without a year, in the MM-DD format.
-
- Annotation : :py:obj:`xclim.core.utils.DayOfYearStr` (may be optional).
- """
- DATE = 7
- """A date in the YYYY-MM-DD format, may include a time.
-
- Annotation : :py:obj:`xclim.core.utils.DateStr` (may be optional).
- """
- NUMBER_SEQUENCE = 8
- """A sequence of numbers
-
- Annotation : ``Sequence[int]``, ``Sequence[float]`` and unions thereof, may include single ``int`` and ``float``,
- may be optional.
- """
- BOOL = 9
- """A boolean flag.
-
- Annotation : ``bool``, may be optional.
- """
- DICT = 10
- """A dictionary.
-
- Annotation : ``dict`` or ``dict | None``, may be optional.
- """
- KWARGS = 50
- """A mapping from argument name to value.
-
- Developers : maps the ``**kwargs``. Please use as little as possible.
- """
- DATASET = 70
- """An xarray dataset.
-
- Developers : as indices only accept DataArrays, this should only be added on the indicator's constructor.
- """
- OTHER_PARAMETER = 99
- """An object that fits None of the previous kinds.
-
- Developers : This is the fallback kind, it will raise an error in xclim's unit tests if used.
- """
-
-
-# XC
-def copy_all_attrs(ds: xr.Dataset | xr.DataArray, ref: xr.Dataset | xr.DataArray):
- """Copy all attributes of ds to ref, including attributes of shared coordinates, and variables in the case of Datasets."""
- ds.attrs.update(ref.attrs)
- extras = ds.variables if isinstance(ds, xr.Dataset) else ds.coords
- others = ref.variables if isinstance(ref, xr.Dataset) else ref.coords
- for name, var in extras.items():
- if name in others:
- var.attrs.update(ref[name].attrs)
-
-
# XC put here to avoid circular import
def uses_dask(*das: xr.DataArray | xr.Dataset) -> bool:
r"""Evaluate whether dask is installed and array is loaded as a dask array.
@@ -1020,3 +904,65 @@ def _apply_on_group(dsblock, **kwargs):
return wrapper
return _decorator
+
+
+def infer_kind_from_parameter(param) -> InputKind:
+ """Return the appropriate InputKind constant from an ``inspect.Parameter`` object.
+
+ Parameters
+ ----------
+ param : Parameter
+
+ Notes
+ -----
+ The correspondence between parameters and kinds is documented in :py:class:`xclim.core.utils.InputKind`.
+ """
+ if param.annotation is not _empty:
+ annot = set(
+ param.annotation.replace("xarray.", "").replace("xr.", "").split(" | ")
+ )
+ else:
+ annot = {"no_annotation"}
+
+ if "DataArray" in annot and "None" not in annot and param.default is not None:
+ return InputKind.VARIABLE
+
+ annot = annot - {"None"}
+
+ if "DataArray" in annot:
+ return InputKind.OPTIONAL_VARIABLE
+
+ if param.name == "freq":
+ return InputKind.FREQ_STR
+
+ if param.kind == param.VAR_KEYWORD:
+ return InputKind.KWARGS
+
+ if annot == {"Quantified"}:
+ return InputKind.QUANTIFIED
+
+ if "DayOfYearStr" in annot:
+ return InputKind.DAY_OF_YEAR
+
+ if annot.issubset({"int", "float"}):
+ return InputKind.NUMBER
+
+ if annot.issubset({"int", "float", "Sequence[int]", "Sequence[float]"}):
+ return InputKind.NUMBER_SEQUENCE
+
+ if annot.issuperset({"str"}):
+ return InputKind.STRING
+
+ if annot == {"DateStr"}:
+ return InputKind.DATE
+
+ if annot == {"bool"}:
+ return InputKind.BOOL
+
+ if annot == {"dict"}:
+ return InputKind.DICT
+
+ if annot == {"Dataset"}:
+ return InputKind.DATASET
+
+ return InputKind.OTHER_PARAMETER
diff --git a/src/xsdba/calendar.py b/src/xsdba/calendar.py
new file mode 100644
index 0000000..e32cdcd
--- /dev/null
+++ b/src/xsdba/calendar.py
@@ -0,0 +1,1663 @@
+"""
+Calendar Handling Utilities
+===========================
+
+Helper function to handle dates, times and different calendars with xarray.
+"""
+
+from __future__ import annotations
+
+import datetime as pydt
+from collections.abc import Sequence
+from typing import Any, TypeVar
+from warnings import warn
+
+import cftime
+import numpy as np
+import pandas as pd
+import xarray as xr
+from xarray.coding.cftime_offsets import to_cftime_datetime
+from xarray.coding.cftimeindex import CFTimeIndex
+from xarray.core import dtypes
+from xarray.core.resample import DataArrayResample, DatasetResample
+
+from .base import uses_dask
+from .formatting import update_xsdba_history
+from .typing import DayOfYearStr
+
+__all__ = [
+ "DayOfYearStr",
+ "adjust_doy_calendar",
+ "build_climatology_bounds",
+ "climatological_mean_doy",
+ "common_calendar",
+ "compare_offsets",
+ "construct_offset",
+ "convert_calendar",
+ "convert_doy",
+ "date_range",
+ "date_range_like",
+ "datetime_to_decimal_year",
+ "days_in_year",
+ "days_since_to_doy",
+ "doy_from_string",
+ "doy_to_days_since",
+ "ensure_cftime_array",
+ "get_calendar",
+ "interp_calendar",
+ "is_offset_divisor",
+ "max_doy",
+ "parse_offset",
+ "percentile_doy",
+ "resample_doy",
+ "select_time",
+ "stack_periods",
+ "time_bnds",
+ "uniform_calendars",
+ "unstack_periods",
+ "within_bnds_doy",
+]
+
+# Maximum day of year in each calendar.
+max_doy = {
+ "standard": 366,
+ "gregorian": 366,
+ "proleptic_gregorian": 366,
+ "julian": 366,
+ "noleap": 365,
+ "365_day": 365,
+ "all_leap": 366,
+ "366_day": 366,
+ "360_day": 360,
+}
+
+# Some xclim.core.utils functions made accessible here for backwards compatibility reasons.
+datetime_classes = cftime._cftime.DATE_TYPES
+
+# Names of calendars that have the same number of days for all years
+uniform_calendars = ("noleap", "all_leap", "365_day", "366_day", "360_day")
+
+
+DataType = TypeVar("DataType", xr.DataArray, xr.Dataset)
+
+
+def _get_usecf_and_warn(calendar: str, xcfunc: str, xrfunc: str):
+ if calendar == "default":
+ calendar = "standard"
+ use_cftime = False
+ msg = " and use use_cftime=False instead of calendar='default' to get numpy objects."
+ else:
+ use_cftime = None
+ msg = ""
+ warn(
+ f"`xclim` function {xcfunc} is deprecated in favour of {xrfunc} and will be removed in v0.51.0. Please adjust your script{msg}.",
+ FutureWarning,
+ )
+ return calendar, use_cftime
+
+
+def days_in_year(year: int, calendar: str = "proleptic_gregorian") -> int:
+ """Deprecated : use :py:func:`xarray.coding.calendar_ops._days_in_year` instead. Passing use_cftime=False instead of calendar='default'.
+
+ Return the number of days in the input year according to the input calendar.
+ """
+ calendar, usecf = _get_usecf_and_warn(
+ calendar, "days_in_year", "xarray.coding.calendar_ops._days_in_year"
+ )
+ return xr.coding.calendar_ops._days_in_year(year, calendar, use_cftime=usecf)
+
+
+def doy_from_string(doy: DayOfYearStr, year: int, calendar: str) -> int:
+ """Return the day-of-year corresponding to a "MM-DD" string for a given year and calendar."""
+ MM, DD = doy.split("-")
+ return datetime_classes[calendar](year, int(MM), int(DD)).timetuple().tm_yday
+
+
+def date_range(*args, **kwargs) -> pd.DatetimeIndex | CFTimeIndex:
+ """Deprecated : use :py:func:`xarray.date_range` instead. Passing use_cftime=False instead of calendar='default'.
+
+ Wrap a Pandas date_range object.
+
+ Uses pd.date_range (if calendar == 'default') or xr.cftime_range (otherwise).
+ """
+ calendar, usecf = _get_usecf_and_warn(
+ kwargs.pop("calendar", "default"), "date_range", "xarray.date_range"
+ )
+ return xr.date_range(*args, calendar=calendar, use_cftime=usecf, **kwargs)
+
+
+def get_calendar(obj: Any, dim: str = "time") -> str:
+ """Return the calendar of an object.
+
+ Parameters
+ ----------
+ obj : Any
+ An object defining some date.
+ If `obj` is an array/dataset with a datetime coordinate, use `dim` to specify its name.
+ Values must have either a datetime64 dtype or a cftime dtype.
+ `obj` can also be a python datetime.datetime, a cftime object or a pandas Timestamp
+ or an iterable of those, in which case the calendar is inferred from the first value.
+ dim : str
+ Name of the coordinate to check (if `obj` is a DataArray or Dataset).
+
+ Raises
+ ------
+ ValueError
+ If no calendar could be inferred.
+
+ Returns
+ -------
+ str
+ The Climate and Forecasting (CF) calendar name.
+ Will always return "standard" instead of "gregorian", following CF conventions 1.9.
+ """
+ if isinstance(obj, (xr.DataArray, xr.Dataset)):
+ return obj[dim].dt.calendar
+ elif isinstance(obj, xr.CFTimeIndex):
+ obj = obj.values[0]
+ else:
+ obj = np.take(obj, 0)
+ # Take zeroth element, overcome cases when arrays or lists are passed.
+ if isinstance(obj, pydt.datetime): # Also covers pandas Timestamp
+ return "standard"
+ if isinstance(obj, cftime.datetime):
+ if obj.calendar == "gregorian":
+ return "standard"
+ return obj.calendar
+
+ raise ValueError(f"Calendar could not be inferred from object of type {type(obj)}.")
+
+
+def common_calendar(calendars: Sequence[str], join="outer") -> str:
+ """Return a calendar common to all calendars from a list.
+
+ Uses the hierarchy: 360_day < noleap < standard < all_leap.
+ Returns "default" only if all calendars are "default."
+
+ Parameters
+ ----------
+ calendars: Sequence of string
+ List of calendar names.
+ join : {'inner', 'outer'}
+ The criterion for the common calendar.
+
+ - 'outer': the common calendar is the smallest calendar (in number of days by year) that will include all the
+ dates of the other calendars.
+ When converting the data to this calendar, no timeseries will lose elements, but some
+ might be missing (gaps or NaNs in the series).
+ - 'inner': the common calendar is the smallest calendar of the list.
+ When converting the data to this calendar, no timeseries will have missing elements (no gaps or NaNs),
+ but some might be dropped.
+
+ Examples
+ --------
+ >>> common_calendar(["360_day", "noleap", "default"], join="outer")
+ 'standard'
+ >>> common_calendar(["360_day", "noleap", "default"], join="inner")
+ '360_day'
+ """
+ if all(cal == "default" for cal in calendars):
+ return "default"
+
+ trans = {
+ "proleptic_gregorian": "standard",
+ "gregorian": "standard",
+ "default": "standard",
+ "366_day": "all_leap",
+ "365_day": "noleap",
+ "julian": "standard",
+ }
+ ranks = {"360_day": 0, "noleap": 1, "standard": 2, "all_leap": 3}
+ calendars = sorted([trans.get(cal, cal) for cal in calendars], key=ranks.get)
+
+ if join == "outer":
+ return calendars[-1]
+ if join == "inner":
+ return calendars[0]
+ raise NotImplementedError(f"Unknown join criterion `{join}`.")
+
+
+def _convert_doy_date(doy: int, year: int, src, tgt):
+ fracpart = doy - int(doy)
+ date = src(year, 1, 1) + pydt.timedelta(days=int(doy - 1))
+ try:
+ same_date = tgt(date.year, date.month, date.day)
+ except ValueError:
+ return np.nan
+ else:
+ if tgt is pydt.datetime:
+ return float(same_date.timetuple().tm_yday) + fracpart
+ return float(same_date.dayofyr) + fracpart
+
+
+def convert_doy(
+ source: xr.DataArray | xr.Dataset,
+ target_cal: str,
+ source_cal: str | None = None,
+ align_on: str = "year",
+ missing: Any = np.nan,
+ dim: str = "time",
+) -> xr.DataArray:
+ """Convert the calendar of day of year (doy) data.
+
+ Parameters
+ ----------
+ source : xr.DataArray or xr.Dataset
+ Day of year data (range [1, 366], max depending on the calendar).
+ If a Dataset, the function is mapped to each variables with attribute `is_day_of_year == 1`.
+ target_cal : str
+ Name of the calendar to convert to.
+ source_cal : str, optional
+ Calendar the doys are in. If not given, uses the "calendar" attribute of `source` or,
+ if absent, the calendar of its `dim` axis.
+ align_on : {'date', 'year'}
+ If 'year' (default), the doy is seen as a "percentage" of the year and is simply rescaled unto the new doy range.
+ This always result in floating point data, changing the decimal part of the value.
+ if 'date', the doy is seen as a specific date. See notes. This never changes the decimal part of the value.
+ missing : Any
+ If `align_on` is "date" and the new doy doesn't exist in the new calendar, this value is used.
+ dim : str
+ Name of the temporal dimension.
+ """
+ if isinstance(source, xr.Dataset):
+ return source.map(
+ lambda da: (
+ da
+ if da.attrs.get("is_dayofyear") != 1
+ else convert_doy(
+ da,
+ target_cal,
+ source_cal=source_cal,
+ align_on=align_on,
+ missing=missing,
+ dim=dim,
+ )
+ )
+ )
+
+ source_cal = source_cal or source.attrs.get("calendar", get_calendar(source[dim]))
+ is_calyear = xr.infer_freq(source[dim]) in ("YS-JAN", "Y-DEC", "YE-DEC")
+
+ if is_calyear: # Fast path
+ year_of_the_doy = source[dim].dt.year
+ else: # Doy might refer to a date from the year after the timestamp.
+ year_of_the_doy = source[dim].dt.year + 1 * (source < source[dim].dt.dayofyear)
+
+ if align_on == "year":
+ if source_cal in ["noleap", "all_leap", "360_day"]:
+ max_doy_src = max_doy[source_cal]
+ else:
+ max_doy_src = xr.apply_ufunc(
+ xr.coding.calendar_ops._days_in_year,
+ year_of_the_doy,
+ vectorize=True,
+ dask="parallelized",
+ kwargs={"calendar": source_cal},
+ )
+ if target_cal in ["noleap", "all_leap", "360_day"]:
+ max_doy_tgt = max_doy[target_cal]
+ else:
+ max_doy_tgt = xr.apply_ufunc(
+ xr.coding.calendar_ops._days_in_year,
+ year_of_the_doy,
+ vectorize=True,
+ dask="parallelized",
+ kwargs={"calendar": target_cal},
+ )
+ new_doy = source.copy(data=source * max_doy_tgt / max_doy_src)
+ elif align_on == "date":
+ new_doy = xr.apply_ufunc(
+ _convert_doy_date,
+ source,
+ year_of_the_doy,
+ vectorize=True,
+ dask="parallelized",
+ kwargs={
+ "src": datetime_classes[source_cal],
+ "tgt": datetime_classes[target_cal],
+ },
+ )
+ else:
+ raise NotImplementedError('"align_on" must be one of "date" or "year".')
+ return new_doy.assign_attrs(is_dayofyear=np.int32(1), calendar=target_cal)
+
+
+def convert_calendar(
+ source: xr.DataArray | xr.Dataset,
+ target: xr.DataArray | str,
+ align_on: str | None = None,
+ missing: Any | None = None,
+ doy: bool | str = False,
+ dim: str = "time",
+) -> DataType:
+ """Deprecated : use :py:meth:`xarray.Dataset.convert_calendar` or :py:meth:`xarray.DataArray.convert_calendar`
+ or :py:func:`xarray.coding.calendar_ops.convert_calendar` instead. Passing use_cftime=False instead of calendar='default'.
+
+ Convert a DataArray/Dataset to another calendar using the specified method.
+ """
+ if isinstance(target, xr.DataArray):
+ raise NotImplementedError(
+ "In `xclim` v0.50.0, `convert_calendar` is a direct copy of `xarray.coding.calendar_ops.convert_calendar`. "
+ "To retrieve the previous behaviour with target as a DataArray, convert the source first then reindex to the target."
+ )
+ if doy is not False:
+ raise NotImplementedError(
+ "In `xclim` v0.50.0, `convert_calendar` is a direct copy of `xarray.coding.calendar_ops.convert_calendar`. "
+ "To retrieve the previous behaviour of doy=True, do convert_doy(obj, target_cal).convert_cal(target_cal)."
+ )
+ target, _usecf = _get_usecf_and_warn(
+ target,
+ "convert_calendar",
+ "xarray.coding.calendar_ops.convert_calendar or obj.convert_calendar",
+ )
+ return xr.coding.calendar_ops.convert_calendar(
+ source, target, dim=dim, align_on=align_on, missing=missing
+ )
+
+
+def interp_calendar(
+ source: xr.DataArray | xr.Dataset,
+ target: xr.DataArray,
+ dim: str = "time",
+) -> xr.DataArray | xr.Dataset:
+ """Deprecated : use :py:func:`xarray.coding.calendar_ops.interp_calendar` instead.
+
+ Interpolates a DataArray/Dataset to another calendar based on decimal year measure.
+ """
+ _, _ = _get_usecf_and_warn(
+ "standard", "interp_calendar", "xarray.coding.calendar_ops.interp_calendar"
+ )
+ return xr.coding.calendar_ops.interp_calendar(source, target, dim=dim)
+
+
+def ensure_cftime_array(time: Sequence) -> np.ndarray | Sequence[cftime.datetime]:
+ """Convert an input 1D array to a numpy array of cftime objects.
+
+ Python's datetime are converted to cftime.DatetimeGregorian ("standard" calendar).
+
+ Parameters
+ ----------
+ time : sequence
+ A 1D array of datetime-like objects.
+
+ Returns
+ -------
+ np.ndarray
+
+ Raises
+ ------
+ ValueError: When unable to cast the input.
+ """
+ if isinstance(time, xr.DataArray):
+ time = time.indexes["time"]
+ elif isinstance(time, np.ndarray):
+ time = pd.DatetimeIndex(time)
+ if isinstance(time, xr.CFTimeIndex):
+ return time.values
+ if isinstance(time[0], cftime.datetime):
+ return time
+ if isinstance(time[0], pydt.datetime):
+ return np.array(
+ [cftime.DatetimeGregorian(*ele.timetuple()[:6]) for ele in time]
+ )
+ raise ValueError("Unable to cast array to cftime dtype")
+
+
+def datetime_to_decimal_year(times: xr.DataArray, calendar: str = "") -> xr.DataArray:
+ """Deprecated : use :py:func:`xarray.coding.calendar_ops_datetime_to_decimal_year` instead.
+
+ Convert a datetime xr.DataArray to decimal years according to its calendar or the given one.
+ """
+ _, _ = _get_usecf_and_warn(
+ "standard",
+ "datetime_to_decimal_year",
+ "xarray.coding.calendar_ops._datetime_to_decimal_year",
+ )
+ return xr.coding.calendar_ops._datetime_to_decimal_year(
+ times, dim="time", calendar=calendar
+ )
+
+
+@update_xsdba_history
+def percentile_doy(
+ arr: xr.DataArray,
+ window: int = 5,
+ per: float | Sequence[float] = 10.0,
+ alpha: float = 1.0 / 3.0,
+ beta: float = 1.0 / 3.0,
+ copy: bool = True,
+) -> xr.DataArray:
+ """Percentile value for each day of the year.
+
+ Return the climatological percentile over a moving window around each day of the year. Different quantile estimators
+ can be used by specifying `alpha` and `beta` according to specifications given by :cite:t:`hyndman_sample_1996`.
+ The default definition corresponds to method 8, which meets multiple desirable statistical properties for sample
+ quantiles. Note that `numpy.percentile` corresponds to method 7, with alpha and beta set to 1.
+
+ Parameters
+ ----------
+ arr : xr.DataArray
+ Input data, a daily frequency (or coarser) is required.
+ window : int
+ Number of time-steps around each day of the year to include in the calculation.
+ per : float or sequence of floats
+ Percentile(s) between [0, 100]
+ alpha : float
+ Plotting position parameter.
+ beta : float
+ Plotting position parameter.
+ copy : bool
+ If True (default) the input array will be deep-copied. It's a necessary step
+ to keep the data integrity, but it can be costly.
+ If False, no copy is made of the input array. It will be mutated and rendered
+ unusable but performances may significantly improve.
+ Put this flag to False only if you understand the consequences.
+
+ Returns
+ -------
+ xr.DataArray
+ The percentiles indexed by the day of the year.
+ For calendars with 366 days, percentiles of doys 1-365 are interpolated to the 1-366 range.
+
+ References
+ ----------
+ :cite:cts:`hyndman_sample_1996`
+ """
+ from .utils import calc_perc # pylint: disable=import-outside-toplevel
+
+ # Ensure arr sampling frequency is daily or coarser
+ # but cowardly escape the non-inferrable case.
+ if compare_offsets(xr.infer_freq(arr.time) or "D", "<", "D"):
+ raise ValueError("input data should have daily or coarser frequency")
+
+ rr = arr.rolling(min_periods=1, center=True, time=window).construct("window")
+
+ crd = xr.Coordinates.from_pandas_multiindex(
+ pd.MultiIndex.from_arrays(
+ (rr.time.dt.year.values, rr.time.dt.dayofyear.values),
+ names=("year", "dayofyear"),
+ ),
+ "time",
+ )
+ rr = rr.drop_vars("time").assign_coords(crd)
+ rrr = rr.unstack("time").stack(stack_dim=("year", "window"))
+
+ if rrr.chunks is not None and len(rrr.chunks[rrr.get_axis_num("stack_dim")]) > 1:
+ # Preserve chunk size
+ time_chunks_count = len(arr.chunks[arr.get_axis_num("time")])
+ doy_chunk_size = np.ceil(len(rrr.dayofyear) / (window * time_chunks_count))
+ rrr = rrr.chunk(dict(stack_dim=-1, dayofyear=doy_chunk_size))
+
+ if np.isscalar(per):
+ per = [per]
+
+ p = xr.apply_ufunc(
+ calc_perc,
+ rrr,
+ input_core_dims=[["stack_dim"]],
+ output_core_dims=[["percentiles"]],
+ keep_attrs=True,
+ kwargs=dict(percentiles=per, alpha=alpha, beta=beta, copy=copy),
+ dask="parallelized",
+ output_dtypes=[rrr.dtype],
+ dask_gufunc_kwargs=dict(output_sizes={"percentiles": len(per)}),
+ )
+ p = p.assign_coords(percentiles=xr.DataArray(per, dims=("percentiles",)))
+
+ # The percentile for the 366th day has a sample size of 1/4 of the other days.
+ # To have the same sample size, we interpolate the percentile from 1-365 doy range to 1-366
+ if p.dayofyear.max() == 366:
+ p = adjust_doy_calendar(p.sel(dayofyear=(p.dayofyear < 366)), arr)
+
+ p.attrs.update(arr.attrs.copy())
+
+ # Saving percentile attributes
+ p.attrs["climatology_bounds"] = build_climatology_bounds(arr)
+ p.attrs["window"] = window
+ p.attrs["alpha"] = alpha
+ p.attrs["beta"] = beta
+ return p.rename("per")
+
+
+def build_climatology_bounds(da: xr.DataArray) -> list[str]:
+ """Build the climatology_bounds property with the start and end dates of input data.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ The input data.
+ Must have a time dimension.
+ """
+ n = len(da.time)
+ return da.time[0 :: n - 1].dt.strftime("%Y-%m-%d").values.tolist()
+
+
+def compare_offsets(freqA: str, op: str, freqB: str) -> bool:
+ """Compare offsets string based on their approximate length, according to a given operator.
+
+ Offset are compared based on their length approximated for a period starting
+ after 1970-01-01 00:00:00. If the offsets are from the same category (same first letter),
+ only the multiplier prefix is compared (QS-DEC == QS-JAN, MS < 2MS).
+ "Business" offsets are not implemented.
+
+ Parameters
+ ----------
+ freqA : str
+ RHS Date offset string ('YS', '1D', 'QS-DEC', ...)
+ op : {'<', '<=', '==', '>', '>=', '!='}
+ Operator to use.
+ freqB : str
+ LHS Date offset string ('YS', '1D', 'QS-DEC', ...)
+
+ Returns
+ -------
+ bool
+ freqA op freqB
+ """
+ from ..indices.generic import get_op # pylint: disable=import-outside-toplevel
+
+ # Get multiplier and base frequency
+ t_a, b_a, _, _ = parse_offset(freqA)
+ t_b, b_b, _, _ = parse_offset(freqB)
+
+ if b_a != b_b:
+ # Different base freq, compare length of first period after beginning of time.
+ t = pd.date_range("1970-01-01T00:00:00.000", periods=2, freq=freqA)
+ t_a = (t[1] - t[0]).total_seconds()
+ t = pd.date_range("1970-01-01T00:00:00.000", periods=2, freq=freqB)
+ t_b = (t[1] - t[0]).total_seconds()
+ # else Same base freq, compare multiplier only.
+
+ return get_op(op)(t_a, t_b)
+
+
+def parse_offset(freq: str) -> tuple[int, str, bool, str | None]:
+ """Parse an offset string.
+
+ Parse a frequency offset and, if needed, convert to cftime-compatible components.
+
+ Parameters
+ ----------
+ freq : str
+ Frequency offset.
+
+ Returns
+ -------
+ multiplier : int
+ Multiplier of the base frequency. "[n]W" is always replaced with "[7n]D",
+ as xarray doesn't support "W" for cftime indexes.
+ offset_base : str
+ Base frequency.
+ is_start_anchored : bool
+ Whether coordinates of this frequency should correspond to the beginning of the period (`True`)
+ or its end (`False`). Can only be False when base is Y, Q or M; in other words, xclim assumes frequencies finer
+ than monthly are all start-anchored.
+ anchor : str, optional
+ Anchor date for bases Y or Q. As xarray doesn't support "W",
+ neither does xclim (anchor information is lost when given).
+
+ """
+ # Useful to raise on invalid freqs, convert Y to A and get default anchor (A, Q)
+ offset = pd.tseries.frequencies.to_offset(freq)
+ base, *anchor = offset.name.split("-")
+ anchor = anchor[0] if len(anchor) > 0 else None
+ start = ("S" in base) or (base[0] not in "AYQM")
+ if base.endswith("S") or base.endswith("E"):
+ base = base[:-1]
+ mult = offset.n
+ if base == "W":
+ mult = 7 * mult
+ base = "D"
+ anchor = None
+ return mult, base, start, anchor
+
+
+def construct_offset(mult: int, base: str, start_anchored: bool, anchor: str | None):
+ """Reconstruct an offset string from its parts.
+
+ Parameters
+ ----------
+ mult : int
+ The period multiplier (>= 1).
+ base : str
+ The base period string (one char).
+ start_anchored : bool
+ If True and base in [Y, Q, M], adds the "S" flag, False add "E".
+ anchor : str, optional
+ The month anchor of the offset. Defaults to JAN for bases YS and QS and to DEC for bases YE and QE.
+
+ Returns
+ -------
+ str
+ An offset string, conformant to pandas-like naming conventions.
+
+ Notes
+ -----
+ This provides the mirror opposite functionality of :py:func:`parse_offset`.
+ """
+ start = ("S" if start_anchored else "E") if base in "YAQM" else ""
+ if anchor is None and base in "AQY":
+ anchor = "JAN" if start_anchored else "DEC"
+ return (
+ f"{mult if mult > 1 else ''}{base}{start}{'-' if anchor else ''}{anchor or ''}"
+ )
+
+
+def is_offset_divisor(divisor: str, offset: str):
+ """Check that divisor is a divisor of offset.
+
+ A frequency is a "divisor" of another if a whole number of periods of the
+ former fit within a single period of the latter.
+
+ Parameters
+ ----------
+ divisor : str
+ The divisor frequency.
+ offset: str
+ The large frequency.
+
+ Returns
+ -------
+ bool
+
+ Examples
+ --------
+ >>> is_offset_divisor("QS-Jan", "YS")
+ True
+ >>> is_offset_divisor("QS-DEC", "YS-JUL")
+ False
+ >>> is_offset_divisor("D", "M")
+ True
+ """
+ if compare_offsets(divisor, ">", offset):
+ return False
+ # Reconstruct offsets anchored at the start of the period
+ # to have comparable quantities, also get "offset" objects
+ mA, bA, _sA, aA = parse_offset(divisor)
+ offAs = pd.tseries.frequencies.to_offset(construct_offset(mA, bA, True, aA))
+
+ mB, bB, _sB, aB = parse_offset(offset)
+ offBs = pd.tseries.frequencies.to_offset(construct_offset(mB, bB, True, aB))
+ tB = pd.date_range("1970-01-01T00:00:00", freq=offBs, periods=13)
+
+ if bA in ["W", "D", "h", "min", "s", "ms", "us", "ms"] or bB in [
+ "W",
+ "D",
+ "h",
+ "min",
+ "s",
+ "ms",
+ "us",
+ "ms",
+ ]:
+ # Simple length comparison is sufficient for submonthly freqs
+ # In case one of bA or bB is > W, we test many to be sure.
+ tA = pd.date_range("1970-01-01T00:00:00", freq=offAs, periods=13)
+ return np.all(
+ (np.diff(tB)[:, np.newaxis] / np.diff(tA)[np.newaxis, :]) % 1 == 0
+ )
+
+ # else, we test alignment with some real dates
+ # If both fall on offAs, then is means divisor is aligned with offset at those dates
+ # if N=13 is True, then it is always True
+ # As divisor <= offset, this means divisor is a "divisor" of offset.
+ return all(offAs.is_on_offset(d) for d in tB)
+
+
+def _interpolate_doy_calendar(
+ source: xr.DataArray, doy_max: int, doy_min: int = 1
+) -> xr.DataArray:
+ """Interpolate from one set of dayofyear range to another.
+
+ Interpolate an array defined over a `dayofyear` range (say 1 to 360) to another `dayofyear` range (say 1
+ to 365).
+
+ Parameters
+ ----------
+ source : xr.DataArray
+ Array with `dayofyear` coordinates.
+ doy_max : int
+ The largest day of the year allowed by calendar.
+ doy_min : int
+ The smallest day of the year in the output.
+ This parameter is necessary when the target time series does not span over a full year (e.g. JJA season).
+ Default is 1.
+
+ Returns
+ -------
+ xr.DataArray
+ Interpolated source array over coordinates spanning the target `dayofyear` range.
+ """
+ if "dayofyear" not in source.coords.keys():
+ raise AttributeError("Source should have `dayofyear` coordinates.")
+
+ # Interpolate to fill na values
+ da = source
+ if uses_dask(source):
+ # interpolate_na cannot run on chunked dayofyear.
+ da = source.chunk(dict(dayofyear=-1))
+ filled_na = da.interpolate_na(dim="dayofyear")
+
+ # Interpolate to target dayofyear range
+ filled_na.coords["dayofyear"] = np.linspace(
+ start=doy_min, stop=doy_max, num=len(filled_na.coords["dayofyear"])
+ )
+
+ return filled_na.interp(dayofyear=range(doy_min, doy_max + 1))
+
+
+def adjust_doy_calendar(
+ source: xr.DataArray, target: xr.DataArray | xr.Dataset
+) -> xr.DataArray:
+ """Interpolate from one set of dayofyear range to another calendar.
+
+ Interpolate an array defined over a `dayofyear` range (say 1 to 360) to another `dayofyear` range (say 1 to 365).
+
+ Parameters
+ ----------
+ source : xr.DataArray
+ Array with `dayofyear` coordinate.
+ target : xr.DataArray or xr.Dataset
+ Array with `time` coordinate.
+
+ Returns
+ -------
+ xr.DataArray
+ Interpolated source array over coordinates spanning the target `dayofyear` range.
+ """
+ max_target_doy = int(target.time.dt.dayofyear.max())
+ min_target_doy = int(target.time.dt.dayofyear.min())
+
+ def has_same_calendar():
+ # case of full year (doys between 1 and 360|365|366)
+ return source.dayofyear.max() == max_doy[get_calendar(target)]
+
+ def has_similar_doys():
+ # case of partial year (e.g. JJA, doys between 152|153 and 243|244)
+ return (
+ source.dayofyear.min == min_target_doy
+ and source.dayofyear.max == max_target_doy
+ )
+
+ if has_same_calendar() or has_similar_doys():
+ return source
+ return _interpolate_doy_calendar(source, max_target_doy, min_target_doy)
+
+
+def resample_doy(doy: xr.DataArray, arr: xr.DataArray | xr.Dataset) -> xr.DataArray:
+ """Create a temporal DataArray where each day takes the value defined by the day-of-year.
+
+ Parameters
+ ----------
+ doy : xr.DataArray
+ Array with `dayofyear` coordinate.
+ arr : xr.DataArray or xr.Dataset
+ Array with `time` coordinate.
+
+ Returns
+ -------
+ xr.DataArray
+ An array with the same dimensions as `doy`, except for `dayofyear`, which is
+ replaced by the `time` dimension of `arr`. Values are filled according to the
+ day of year value in `doy`.
+ """
+ if "dayofyear" not in doy.coords:
+ raise AttributeError("Source should have `dayofyear` coordinates.")
+
+ # Adjust calendar
+ adoy = adjust_doy_calendar(doy, arr)
+
+ out = adoy.rename(dayofyear="time").reindex(time=arr.time.dt.dayofyear)
+ out["time"] = arr.time
+
+ return out
+
+
+def time_bnds( # noqa: C901
+ time: (
+ xr.DataArray
+ | xr.Dataset
+ | CFTimeIndex
+ | pd.DatetimeIndex
+ | DataArrayResample
+ | DatasetResample
+ ),
+ freq: str | None = None,
+ precision: str | None = None,
+):
+ """Find the time bounds for a datetime index.
+
+ As we are using datetime indices to stand in for period indices, assumptions regarding the period
+ are made based on the given freq.
+
+ Parameters
+ ----------
+ time : DataArray, Dataset, CFTimeIndex, DatetimeIndex, DataArrayResample or DatasetResample
+ Object which contains a time index as a proxy representation for a period index.
+ freq : str, optional
+ String specifying the frequency/offset such as 'MS', '2D', or '3min'
+ If not given, it is inferred from the time index, which means that index must
+ have at least three elements.
+ precision : str, optional
+ A timedelta representation that :py:class:`pandas.Timedelta` understands.
+ The time bounds will be correct up to that precision. If not given,
+ 1 ms ("1U") is used for CFtime indexes and 1 ns ("1N") for numpy datetime64 indexes.
+
+ Returns
+ -------
+ DataArray
+ The time bounds: start and end times of the periods inferred from the time index and a frequency.
+ It has the original time index along it's `time` coordinate and a new `bnds` coordinate.
+ The dtype and calendar of the array are the same as the index.
+
+ Notes
+ -----
+ xclim assumes that indexes for greater-than-day frequencies are "floored" down to a daily resolution.
+ For example, the coordinate "2000-01-31 00:00:00" with a "ME" frequency is assumed to mean a period
+ going from "2000-01-01 00:00:00" to "2000-01-31 23:59:59.999999".
+
+ Similarly, it assumes that daily and finer frequencies yield indexes pointing to the period's start.
+ So "2000-01-31 00:00:00" with a "3h" frequency, means a period going from "2000-01-31 00:00:00" to
+ "2000-01-31 02:59:59.999999".
+ """
+ if isinstance(time, (xr.DataArray, xr.Dataset)):
+ time = time.indexes[time.name]
+ elif isinstance(time, (DataArrayResample, DatasetResample)):
+ for grouper in time.groupers:
+ if "time" in grouper.dims:
+ datetime = grouper.unique_coord.data
+ freq = freq or grouper.grouper.freq
+ if datetime.dtype == "O":
+ time = xr.CFTimeIndex(datetime)
+ else:
+ time = pd.DatetimeIndex(datetime)
+ break
+
+ else:
+ raise ValueError(
+ 'Got object resampled along another dimension than "time".'
+ )
+
+ if freq is None and hasattr(time, "freq"):
+ freq = time.freq
+ if freq is None:
+ freq = xr.infer_freq(time)
+ elif hasattr(freq, "freqstr"):
+ # When freq is a Offset
+ freq = freq.freqstr
+
+ freq_base, freq_is_start = parse_offset(freq)[1:3]
+
+ # Normalizing without using `.normalize` because cftime doesn't have it
+ floor = {"hour": 0, "minute": 0, "second": 0, "microsecond": 0, "nanosecond": 0}
+ if freq_base in ["h", "min", "s", "ms", "us", "ns"]:
+ floor.pop("hour")
+ if freq_base in ["min", "s", "ms", "us", "ns"]:
+ floor.pop("minute")
+ if freq_base in ["s", "ms", "us", "ns"]:
+ floor.pop("second")
+ if freq_base in ["us", "ns"]:
+ floor.pop("microsecond")
+ if freq_base == "ns":
+ floor.pop("nanosecond")
+
+ if isinstance(time, xr.CFTimeIndex):
+ period = xr.coding.cftime_offsets.to_offset(freq)
+ is_on_offset = period.onOffset
+ eps = pd.Timedelta(precision or "1us").to_pytimedelta()
+ day = pd.Timedelta("1D").to_pytimedelta()
+ floor.pop("nanosecond") # unsupported by cftime
+ else:
+ period = pd.tseries.frequencies.to_offset(freq)
+ is_on_offset = period.is_on_offset
+ eps = pd.Timedelta(precision or "1ns")
+ day = pd.Timedelta("1D")
+
+ def shift_time(t):
+ if not is_on_offset(t):
+ if freq_is_start:
+ t = period.rollback(t)
+ else:
+ t = period.rollforward(t)
+ return t.replace(**floor)
+
+ time_real = list(map(shift_time, time))
+
+ cls = time.__class__
+ if freq_is_start:
+ tbnds = [cls(time_real), cls([t + period - eps for t in time_real])]
+ else:
+ tbnds = [
+ cls([t - period + day for t in time_real]),
+ cls([t + day - eps for t in time_real]),
+ ]
+ return xr.DataArray(
+ tbnds, dims=("bnds", "time"), coords={"time": time}, name="time_bnds"
+ ).transpose()
+
+
+def climatological_mean_doy(
+ arr: xr.DataArray, window: int = 5
+) -> tuple[xr.DataArray, xr.DataArray]:
+ """Calculate the climatological mean and standard deviation for each day of the year.
+
+ Parameters
+ ----------
+ arr : xarray.DataArray
+ Input array.
+ window : int
+ Window size in days.
+
+ Returns
+ -------
+ xarray.DataArray, xarray.DataArray
+ Mean and standard deviation.
+ """
+ rr = arr.rolling(min_periods=1, center=True, time=window).construct("window")
+
+ # Create empty percentile array
+ g = rr.groupby("time.dayofyear")
+
+ m = g.mean(["time", "window"])
+ s = g.std(["time", "window"])
+
+ return m, s
+
+
+def within_bnds_doy(
+ arr: xr.DataArray, *, low: xr.DataArray, high: xr.DataArray
+) -> xr.DataArray:
+ """Return whether array values are within bounds for each day of the year.
+
+ Parameters
+ ----------
+ arr : xarray.DataArray
+ Input array.
+ low : xarray.DataArray
+ Low bound with dayofyear coordinate.
+ high : xarray.DataArray
+ High bound with dayofyear coordinate.
+
+ Returns
+ -------
+ xarray.DataArray
+ """
+ low = resample_doy(low, arr)
+ high = resample_doy(high, arr)
+ return (low < arr) * (arr < high)
+
+
+def _doy_days_since_doys(
+ base: xr.DataArray, start: DayOfYearStr | None = None
+) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
+ """Calculate dayofyear to days since, or the inverse.
+
+ Parameters
+ ----------
+ base : xr.DataArray
+ 1D time coordinate.
+ start : DayOfYearStr, optional
+ A date to compute the offset relative to. If note given, start_doy is the same as base_doy.
+
+ Returns
+ -------
+ base_doy : xr.DataArray
+ Day of year for each element in base.
+ start_doy : xr.DataArray
+ Day of year of the "start" date.
+ The year used is the one the start date would take as a doy for the corresponding base element.
+ doy_max : xr.DataArray
+ Number of days (maximum doy) for the year of each value in base.
+ """
+ calendar = get_calendar(base)
+
+ base_doy = base.dt.dayofyear
+
+ doy_max = xr.apply_ufunc(
+ xr.coding.calendar_ops._days_in_year,
+ base.dt.year,
+ vectorize=True,
+ kwargs={"calendar": calendar},
+ )
+
+ if start is not None:
+ mm, dd = map(int, start.split("-"))
+ starts = xr.apply_ufunc(
+ lambda y: datetime_classes[calendar](y, mm, dd),
+ base.dt.year,
+ vectorize=True,
+ )
+ start_doy = starts.dt.dayofyear
+ start_doy = start_doy.where(start_doy >= base_doy, start_doy + doy_max)
+ else:
+ start_doy = base_doy
+
+ return base_doy, start_doy, doy_max
+
+
+def doy_to_days_since(
+ da: xr.DataArray,
+ start: DayOfYearStr | None = None,
+ calendar: str | None = None,
+) -> xr.DataArray:
+ """Convert day-of-year data to days since a given date.
+
+ This is useful for computing meaningful statistics on doy data.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Array of "day-of-year", usually int dtype, must have a `time` dimension.
+ Sampling frequency should be finer or similar to yearly and coarser than daily.
+ start : date of year str, optional
+ A date in "MM-DD" format, the base day of the new array. If None (default), the `time` axis is used.
+ Passing `start` only makes sense if `da` has a yearly sampling frequency.
+ calendar : str, optional
+ The calendar to use when computing the new interval.
+ If None (default), the calendar attribute of the data or of its `time` axis is used.
+ All time coordinates of `da` must exist in this calendar.
+ No check is done to ensure doy values exist in this calendar.
+
+ Returns
+ -------
+ xr.DataArray
+ Same shape as `da`, int dtype, day-of-year data translated to a number of days since a given date.
+ If start is not None, there might be negative values.
+
+ Notes
+ -----
+ The time coordinates of `da` are considered as the START of the period. For example, a doy value of
+ 350 with a timestamp of '2020-12-31' is understood as '2021-12-16' (the 350th day of 2021).
+ Passing `start=None`, will use the time coordinate as the base, so in this case the converted value
+ will be 350 "days since time coordinate".
+
+ Examples
+ --------
+ >>> from xarray import DataArray
+ >>> time = date_range("2020-07-01", "2021-07-01", freq="AS-JUL")
+ >>> # July 8th 2020 and Jan 2nd 2022
+ >>> da = DataArray([190, 2], dims=("time",), coords={"time": time})
+ >>> # Convert to days since Oct. 2nd, of the data's year.
+ >>> doy_to_days_since(da, start="10-02").values
+ array([-86, 92])
+ """
+ base_calendar = get_calendar(da)
+ calendar = calendar or da.attrs.get("calendar", base_calendar)
+ dac = da.convert_calendar(calendar)
+
+ base_doy, start_doy, doy_max = _doy_days_since_doys(dac.time, start)
+
+ # 2cases:
+ # val is a day in the same year as its index : da - offset
+ # val is a day in the next year : da + doy_max - offset
+ out = xr.where(dac > base_doy, dac, dac + doy_max) - start_doy
+ out.attrs.update(da.attrs)
+ if start is not None:
+ out.attrs.update(units=f"days after {start}")
+ else:
+ starts = np.unique(out.time.dt.strftime("%m-%d"))
+ if len(starts) == 1:
+ out.attrs.update(units=f"days after {starts[0]}")
+ else:
+ out.attrs.update(units="days after time coordinate")
+
+ out.attrs.pop("is_dayofyear", None)
+ out.attrs.update(calendar=calendar)
+ return out.convert_calendar(base_calendar).rename(da.name)
+
+
+def days_since_to_doy(
+ da: xr.DataArray,
+ start: DayOfYearStr | None = None,
+ calendar: str | None = None,
+) -> xr.DataArray:
+ """Reverse the conversion made by :py:func:`doy_to_days_since`.
+
+ Converts data given in days since a specific date to day-of-year.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ The result of :py:func:`doy_to_days_since`.
+ start : DateOfYearStr, optional
+ `da` is considered as days since that start date (in the year of the time index).
+ If None (default), it is read from the attributes.
+ calendar : str, optional
+ Calendar the "days since" were computed in.
+ If None (default), it is read from the attributes.
+
+ Returns
+ -------
+ xr.DataArray
+ Same shape as `da`, values as `day of year`.
+
+ Examples
+ --------
+ >>> from xarray import DataArray
+ >>> time = date_range("2020-07-01", "2021-07-01", freq="AS-JUL")
+ >>> da = DataArray(
+ ... [-86, 92],
+ ... dims=("time",),
+ ... coords={"time": time},
+ ... attrs={"units": "days since 10-02"},
+ ... )
+ >>> days_since_to_doy(da).values
+ array([190, 2])
+ """
+ if start is None:
+ unitstr = da.attrs.get("units", " time coordinate").split(" ", maxsplit=2)[-1]
+ if unitstr != "time coordinate":
+ start = unitstr
+
+ base_calendar = get_calendar(da)
+ calendar = calendar or da.attrs.get("calendar", base_calendar)
+
+ dac = da.convert_calendar(calendar)
+
+ _, start_doy, doy_max = _doy_days_since_doys(dac.time, start)
+
+ # 2cases:
+ # val is a day in the same year as its index : da + offset
+ # val is a day in the next year : da + offset - doy_max
+ out = dac + start_doy
+ out = xr.where(out > doy_max, out - doy_max, out)
+
+ out.attrs.update(
+ {k: v for k, v in da.attrs.items() if k not in ["units", "calendar"]}
+ )
+ out.attrs.update(calendar=calendar, is_dayofyear=1)
+ return out.convert_calendar(base_calendar).rename(da.name)
+
+
+def date_range_like(source: xr.DataArray, calendar: str) -> xr.DataArray:
+ """Deprecated : use :py:func:`xarray.date_range_like` instead. Passing use_cftime=False instead of calendar='default'.
+
+ Generate a datetime array with the same frequency, start and end as another one, but in a different calendar.
+ """
+ calendar, usecf = _get_usecf_and_warn(
+ calendar, "date_range_like", "xarray.date_range_like"
+ )
+ return xr.coding.calendar_ops.date_range_like(
+ source=source, calendar=calendar, use_cftime=usecf
+ )
+
+
+def select_time(
+ da: xr.DataArray | xr.Dataset,
+ drop: bool = False,
+ season: str | Sequence[str] | None = None,
+ month: int | Sequence[int] | None = None,
+ doy_bounds: tuple[int, int] | None = None,
+ date_bounds: tuple[str, str] | None = None,
+ include_bounds: bool | tuple[bool, bool] = True,
+) -> DataType:
+ """Select entries according to a time period.
+
+ This conveniently improves xarray's :py:meth:`xarray.DataArray.where` and
+ :py:meth:`xarray.DataArray.sel` with fancier ways of indexing over time elements.
+ In addition to the data `da` and argument `drop`, only one of `season`, `month`,
+ `doy_bounds` or `date_bounds` may be passed.
+
+ Parameters
+ ----------
+ da : xr.DataArray or xr.Dataset
+ Input data.
+ drop : bool
+ Whether to drop elements outside the period of interest or to simply mask them (default).
+ season : string or sequence of strings, optional
+ One or more of 'DJF', 'MAM', 'JJA' and 'SON'.
+ month : integer or sequence of integers, optional
+ Sequence of month numbers (January = 1 ... December = 12)
+ doy_bounds : 2-tuple of integers, optional
+ The bounds as (start, end) of the period of interest expressed in day-of-year, integers going from
+ 1 (January 1st) to 365 or 366 (December 31st).
+ If calendar awareness is needed, consider using ``date_bounds`` instead.
+ date_bounds : 2-tuple of strings, optional
+ The bounds as (start, end) of the period of interest expressed as dates in the month-day (%m-%d) format.
+ include_bounds : bool or 2-tuple of booleans
+ Whether the bounds of `doy_bounds` or `date_bounds` should be inclusive or not.
+ Either one value for both or a tuple. Default is True, meaning bounds are inclusive.
+
+ Returns
+ -------
+ xr.DataArray or xr.Dataset
+ Selected input values. If ``drop=False``, this has the same length as ``da`` (along dimension 'time'),
+ but with masked (NaN) values outside the period of interest.
+
+ Examples
+ --------
+ Keep only the values of fall and spring.
+
+ >>> ds = open_dataset("ERA5/daily_surface_cancities_1990-1993.nc")
+ >>> ds.time.size
+ 1461
+ >>> out = select_time(ds, drop=True, season=["MAM", "SON"])
+ >>> out.time.size
+ 732
+
+ Or all values between two dates (included).
+
+ >>> out = select_time(ds, drop=True, date_bounds=("02-29", "03-02"))
+ >>> out.time.values
+ array(['1990-03-01T00:00:00.000000000', '1990-03-02T00:00:00.000000000',
+ '1991-03-01T00:00:00.000000000', '1991-03-02T00:00:00.000000000',
+ '1992-02-29T00:00:00.000000000', '1992-03-01T00:00:00.000000000',
+ '1992-03-02T00:00:00.000000000', '1993-03-01T00:00:00.000000000',
+ '1993-03-02T00:00:00.000000000'], dtype='datetime64[ns]')
+ """
+ N = sum(arg is not None for arg in [season, month, doy_bounds, date_bounds])
+ if N > 1:
+ raise ValueError(f"Only one method of indexing may be given, got {N}.")
+
+ if N == 0:
+ return da
+
+ def _get_doys(_start, _end, _inclusive):
+ if _start <= _end:
+ _doys = np.arange(_start, _end + 1)
+ else:
+ _doys = np.concatenate((np.arange(_start, 367), np.arange(0, _end + 1)))
+ if not _inclusive[0]:
+ _doys = _doys[1:]
+ if not _inclusive[1]:
+ _doys = _doys[:-1]
+ return _doys
+
+ if isinstance(include_bounds, bool):
+ include_bounds = (include_bounds, include_bounds)
+
+ if season is not None:
+ if isinstance(season, str):
+ season = [season]
+ mask = da.time.dt.season.isin(season)
+
+ elif month is not None:
+ if isinstance(month, int):
+ month = [month]
+ mask = da.time.dt.month.isin(month)
+
+ elif doy_bounds is not None:
+ mask = da.time.dt.dayofyear.isin(_get_doys(*doy_bounds, include_bounds))
+
+ elif date_bounds is not None:
+ # This one is a bit trickier.
+ start, end = date_bounds
+ time = da.time
+ calendar = get_calendar(time)
+ if calendar not in uniform_calendars:
+ # For non-uniform calendars, we can't simply convert dates to doys
+ # conversion to all_leap is safe for all non-uniform calendar as it doesn't remove any date.
+ time = time.convert_calendar("all_leap")
+ # values of time are the _old_ calendar
+ # and the new calendar is in the coordinate
+ calendar = "all_leap"
+
+ # Get doy of date, this is now safe because the calendar is uniform.
+ doys = _get_doys(
+ to_cftime_datetime(f"2000-{start}", calendar).dayofyr,
+ to_cftime_datetime(f"2000-{end}", calendar).dayofyr,
+ include_bounds,
+ )
+ mask = time.time.dt.dayofyear.isin(doys)
+ # Needed if we converted calendar, this puts back the correct coord
+ mask["time"] = da.time
+
+ else:
+ raise ValueError(
+ "Must provide either `season`, `month`, `doy_bounds` or `date_bounds`."
+ )
+
+ return da.where(mask, drop=drop)
+
+
+def _month_is_first_period_month(time, freq):
+ """Returns True if the given time is from the first month of freq."""
+ if isinstance(time, cftime.datetime):
+ frq_monthly = xr.coding.cftime_offsets.to_offset("MS")
+ frq = xr.coding.cftime_offsets.to_offset(freq)
+ if frq_monthly.onOffset(time):
+ return frq.onOffset(time)
+ return frq.onOffset(frq_monthly.rollback(time))
+ # Pandas
+ time = pd.Timestamp(time)
+ frq_monthly = pd.tseries.frequencies.to_offset("MS")
+ frq = pd.tseries.frequencies.to_offset(freq)
+ if frq_monthly.is_on_offset(time):
+ return frq.is_on_offset(time)
+ return frq.is_on_offset(frq_monthly.rollback(time))
+
+
+def stack_periods(
+ da: xr.Dataset | xr.DataArray,
+ window: int = 30,
+ stride: int | None = None,
+ min_length: int | None = None,
+ freq: str = "YS",
+ dim: str = "period",
+ start: str = "1970-01-01",
+ align_days: bool = True,
+ pad_value=dtypes.NA,
+):
+ """Construct a multi-period array.
+
+ Stack different equal-length periods of `da` into a new 'period' dimension.
+
+ This is similar to ``da.rolling(time=window).construct(dim, stride=stride)``, but adapted for arguments
+ in terms of a base temporal frequency that might be non-uniform (years, months, etc.).
+ It is reversible for some cases (see `stride`).
+ A rolling-construct method will be much more performant for uniform periods (days, weeks).
+
+ Parameters
+ ----------
+ da : xr.Dataset or xr.DataArray
+ An xarray object with a `time` dimension.
+ Must have a uniform timestep length.
+ Output might be strange if this does not use a uniform calendar (noleap, 360_day, all_leap).
+ window : int
+ The length of the moving window as a multiple of ``freq``.
+ stride : int, optional
+ At which interval to take the windows, as a multiple of ``freq``.
+ For the operation to be reversible with :py:func:`unstack_periods`, it must divide `window` into an odd number of parts.
+ Default is `window` (no overlap between periods).
+ min_length : int, optional
+ Windows shorter than this are not included in the output.
+ Given as a multiple of ``freq``. Default is ``window`` (every window must be complete).
+ Similar to the ``min_periods`` argument of ``da.rolling``.
+ If ``freq`` is annual or quarterly and ``min_length == ``window``, the first period is considered complete
+ if the first timestep is in the first month of the period.
+ freq : str
+ Units of ``window``, ``stride`` and ``min_length``, as a frequency string.
+ Must be larger or equal to the data's sampling frequency.
+ Note that this function offers an easier interface for non-uniform period (like years or months)
+ but is much slower than a rolling-construct method.
+ dim : str
+ The new dimension name.
+ start : str
+ The `start` argument passed to :py:func:`xarray.date_range` to generate the new placeholder
+ time coordinate.
+ align_days : bool
+ When True (default), an error is raised if the output would have unaligned days across periods.
+ If `freq = 'YS'`, day-of-year alignment is checked and if `freq` is "MS" or "QS", we check day-in-month.
+ Only uniform-calendar will pass the test for `freq='YS'`.
+ For other frequencies, only the `360_day` calendar will work.
+ This check is ignored if the sampling rate of the data is coarser than "D".
+ pad_value : Any
+ When some periods are shorter than others, this value is used to pad them at the end.
+ Passed directly as argument ``fill_value`` to :py:func:`xarray.concat`,
+ the default is the same as on that function.
+
+ Return
+ ------
+ xr.DataArray
+ A DataArray with a new `period` dimension and a `time` dimension with the length of the longest window.
+ The new time coordinate has the same frequency as the input data but is generated using
+ :py:func:`xarray.date_range` with the given `start` value.
+ That coordinate is the same for all periods, depending on the choice of ``window`` and ``freq``, it might make sense.
+ But for unequal periods or non-uniform calendars, it will certainly not.
+ If ``stride`` is a divisor of ``window``, the correct timeseries can be reconstructed with :py:func:`unstack_periods`.
+ The coordinate of `period` is the first timestep of each window.
+ """
+ from xsdba.units import ( # Import in function to avoid cyclical imports; ensure_cf_units,
+ infer_sampling_units,
+ )
+
+ stride = stride or window
+ min_length = min_length or window
+ if stride > window:
+ raise ValueError(
+ f"Stride must be less than or equal to window. Got {stride} > {window}."
+ )
+
+ srcfreq = xr.infer_freq(da.time)
+ cal = da.time.dt.calendar
+ use_cftime = da.time.dtype == "O"
+
+ if (
+ compare_offsets(srcfreq, "<=", "D")
+ and align_days
+ and (
+ (freq.startswith(("Y", "A")) and cal not in uniform_calendars)
+ or (freq.startswith(("Q", "M")) and window > 1 and cal != "360_day")
+ )
+ ):
+ if freq.startswith(("Y", "A")):
+ u = "year"
+ else:
+ u = "month"
+ raise ValueError(
+ f"Stacking {window}{freq} periods will result in unaligned day-of-{u}. "
+ f"Consider converting the calendar of your data to one with uniform {u} lengths, "
+ "or pass `align_days=False` to disable this check."
+ )
+
+ # Convert integer inputs to freq strings
+ mult, *args = parse_offset(freq)
+ win_frq = construct_offset(mult * window, *args)
+ strd_frq = construct_offset(mult * stride, *args)
+ minl_frq = construct_offset(mult * min_length, *args)
+
+ # The same time coord as da, but with one extra element.
+ # This way, the last window's last index is not returned as None by xarray's grouper.
+ time2 = xr.DataArray(
+ xr.date_range(
+ da.time[0].item(),
+ freq=srcfreq,
+ calendar=cal,
+ periods=da.time.size + 1,
+ use_cftime=use_cftime,
+ ),
+ dims=("time",),
+ name="time",
+ )
+
+ periods = []
+ # longest = 0
+ # Iterate over strides, but recompute the full window for each stride start
+ for strd_slc in da.resample(time=strd_frq).groups.values():
+ win_resamp = time2.isel(time=slice(strd_slc.start, None)).resample(time=win_frq)
+ # Get slice for first group
+ win_slc = win_resamp._group_indices[0]
+ if min_length < window:
+ # If we ask for a min_length period instead is it complete ?
+ min_resamp = time2.isel(time=slice(strd_slc.start, None)).resample(
+ time=minl_frq
+ )
+ min_slc = min_resamp._group_indices[0]
+ open_ended = min_slc.stop is None
+ else:
+ # The end of the group slice is None if no outside-group value was found after the last element
+ # As we added an extra step to time2, we avoid the case where a group ends exactly on the last element of ds
+ open_ended = win_slc.stop is None
+ if open_ended:
+ # Too short, we got to the end
+ break
+ if (
+ strd_slc.start == 0
+ and parse_offset(freq)[1] in "YAQ"
+ and min_length == window
+ and not _month_is_first_period_month(da.time[0].item(), freq)
+ ):
+ # For annual or quarterly frequencies (which can be anchor-based),
+ # if the first time is not in the first month of the first period,
+ # then the first period is incomplete but by a fractional amount.
+ continue
+ periods.append(
+ slice(
+ strd_slc.start + win_slc.start,
+ (
+ (strd_slc.start + win_slc.stop)
+ if win_slc.stop is not None
+ else da.time.size
+ ),
+ )
+ )
+
+ # Make coordinates
+ lengths = xr.DataArray(
+ [slc.stop - slc.start for slc in periods],
+ dims=(dim,),
+ attrs={"long_name": "Length of each period"},
+ )
+ longest = lengths.max().item()
+ # Length as a pint-ready array : with proper units, but values are not usable as indexes anymore
+ m, u = infer_sampling_units(da)
+ lengths = lengths * m
+ # ADAPT: cf-agnostic
+ # lengths.attrs["units"] = ensure_cf_units(u)
+
+ # Start points for each period and remember parameters for unstacking
+ starts = xr.DataArray(
+ [da.time[slc.start].item() for slc in periods],
+ dims=(dim,),
+ attrs={
+ "long_name": "Start of the period",
+ # Save parameters so that we can unstack.
+ "window": window,
+ "stride": stride,
+ "freq": freq,
+ "unequal_lengths": int(len(np.unique(lengths)) > 1),
+ },
+ )
+ # The "fake" axis that all periods share
+ fake_time = xr.date_range(
+ start, periods=longest, freq=srcfreq, calendar=cal, use_cftime=use_cftime
+ )
+ # Slice and concat along new dim. We drop the index and add a new one so that xarray can concat them together.
+ out = xr.concat(
+ [
+ da.isel(time=slc)
+ .drop_vars("time")
+ .assign_coords(time=np.arange(slc.stop - slc.start))
+ for slc in periods
+ ],
+ dim,
+ join="outer",
+ fill_value=pad_value,
+ )
+ out = out.assign_coords(
+ time=(("time",), fake_time, da.time.attrs.copy()),
+ **{f"{dim}_length": lengths, dim: starts},
+ )
+ out.time.attrs.update(long_name="Placeholder time axis")
+ return out
+
+
+def unstack_periods(da: xr.DataArray | xr.Dataset, dim: str = "period"):
+ """Unstack an array constructed with :py:func:`stack_periods`.
+
+ Can only work with periods stacked with a ``stride`` that divides ``window`` in an odd number of sections.
+ When ``stride`` is smaller than ``window``, only the center-most stride of each window is kept,
+ except for the beginning and end which are taken from the first and last windows.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ As constructed by :py:func:`stack_periods`, attributes of the period coordinates must have been preserved.
+ dim : str
+ The period dimension name.
+
+ Notes
+ -----
+ The following table shows which strides are included (``o``) in the unstacked output.
+
+ In this example, ``stride`` was a fifth of ``window`` and ``min_length`` was four (4) times ``stride``.
+ The row index ``i`` the period index in the stacked dataset,
+ columns are the stride-long section of the original timeseries.
+
+ .. table:: Unstacking example with ``stride < window``.
+
+ === === === === === === === ===
+ i 0 1 2 3 4 5 6
+ === === === === === === === ===
+ 3 x x o o
+ 2 x x o x x
+ 1 x x o x x
+ 0 o o o x x
+ === === === === === === === ===
+ """
+ from xclim.core.units import infer_sampling_units
+
+ try:
+ starts = da[dim]
+ window = starts.attrs["window"]
+ stride = starts.attrs["stride"]
+ freq = starts.attrs["freq"]
+ unequal_lengths = bool(starts.attrs["unequal_lengths"])
+ except (AttributeError, KeyError) as err:
+ raise ValueError(
+ f"`unstack_periods` can't find the window, stride and freq attributes on the {dim} coordinates."
+ ) from err
+
+ if unequal_lengths:
+ try:
+ lengths = da[f"{dim}_length"]
+ except KeyError as err:
+ raise ValueError(
+ f"`unstack_periods` can't find the `{dim}_length` coordinate."
+ ) from err
+ # Get length as number of points
+ m, _ = infer_sampling_units(da.time)
+ lengths = lengths // m
+ else:
+ # It is acceptable to lose "{dim}_length" if they were all equal
+ lengths = xr.DataArray([da.time.size] * da[dim].size, dims=(dim,))
+
+ # Convert from the fake axis to the real one
+ time_as_delta = da.time - da.time[0]
+ if da.time.dtype == "O":
+ # cftime can't add with np.timedelta64 (restriction comes from numpy which refuses to add O with m8)
+ time_as_delta = pd.TimedeltaIndex(
+ time_as_delta
+ ).to_pytimedelta() # this array is O, numpy complies
+ else:
+ # Xarray will return int when iterating over datetime values, this returns timestamps
+ starts = pd.DatetimeIndex(starts)
+
+ def _reconstruct_time(_time_as_delta, _start):
+ times = _time_as_delta + _start
+ return xr.DataArray(times, dims=("time",), coords={"time": times}, name="time")
+
+ # Easy case:
+ if window == stride:
+ # just concat them all
+ periods = []
+ for i, (start, length) in enumerate(zip(starts.values, lengths.values)):
+ real_time = _reconstruct_time(time_as_delta, start)
+ periods.append(
+ da.isel(**{dim: i}, drop=True)
+ .isel(time=slice(0, length))
+ .assign_coords(time=real_time.isel(time=slice(0, length)))
+ )
+ return xr.concat(periods, "time")
+
+ # Difficult and ambiguous case
+ if (window / stride) % 2 != 1:
+ raise NotImplementedError(
+ "`unstack_periods` can't work with strides that do not divide the window into an odd number of parts."
+ f"Got {window} / {stride} which is not an odd integer."
+ )
+
+ # Non-ambiguous overlapping case
+ Nwin = window // stride
+ mid = (Nwin - 1) // 2 # index of the center window
+
+ mult, *args = parse_offset(freq)
+ strd_frq = construct_offset(mult * stride, *args)
+
+ periods = []
+ for i, (start, length) in enumerate(zip(starts.values, lengths.values)):
+ real_time = _reconstruct_time(time_as_delta, start)
+ slices = real_time.resample(time=strd_frq)._group_indices
+ if i == 0:
+ slc = slice(slices[0].start, min(slices[mid].stop, length))
+ elif i == da.period.size - 1:
+ slc = slice(slices[mid].start, min(slices[Nwin - 1].stop or length, length))
+ else:
+ slc = slice(slices[mid].start, min(slices[mid].stop, length))
+ periods.append(
+ da.isel(**{dim: i}, drop=True)
+ .isel(time=slc)
+ .assign_coords(time=real_time.isel(time=slc))
+ )
+
+ return xr.concat(periods, "time")
diff --git a/src/xsdba/datachecks.py b/src/xsdba/datachecks.py
new file mode 100644
index 0000000..9269046
--- /dev/null
+++ b/src/xsdba/datachecks.py
@@ -0,0 +1,123 @@
+"""
+Data Checks
+===========
+
+Utilities designed to check the validity of data inputs.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import xarray as xr
+
+from .calendar import compare_offsets, parse_offset
+from .logging import ValidationError
+from .options import datacheck
+
+
+@datacheck
+def check_freq(var: xr.DataArray, freq: str | Sequence[str], strict: bool = True):
+ """Raise an error if not series has not the expected temporal frequency or is not monotonically increasing.
+
+ Parameters
+ ----------
+ var : xr.DataArray
+ Input array.
+ freq : str or sequence of str
+ The expected temporal frequencies, using Pandas frequency terminology ({'Y', 'M', 'D', 'h', 'min', 's', 'ms', 'us'})
+ and multiples thereof. To test strictly for 'W', pass '7D' with `strict=True`.
+ This ignores the start/end flag and the anchor (ex: 'YS-JUL' will validate against 'Y').
+ strict : bool
+ Whether multiples of the frequencies are considered invalid or not. With `strict` set to False, a '3h' series
+ will not raise an error if freq is set to 'h'.
+
+ Raises
+ ------
+ ValidationError
+ - If the frequency of `var` is not inferrable.
+ - If the frequency of `var` does not match the requested `freq`.
+ """
+ if isinstance(freq, str):
+ freq = [freq]
+ exp_base = [parse_offset(frq)[1] for frq in freq]
+ v_freq = xr.infer_freq(var.time)
+ if v_freq is None:
+ raise ValidationError(
+ "Unable to infer the frequency of the time series. "
+ "To mute this, set xclim's option data_validation='log'."
+ )
+ v_base = parse_offset(v_freq)[1]
+ if v_base not in exp_base or (
+ strict and all(compare_offsets(v_freq, "!=", frq) for frq in freq)
+ ):
+ raise ValidationError(
+ f"Frequency of time series not {'strictly' if strict else ''} in {freq}. "
+ "To mute this, set xclim's option data_validation='log'."
+ )
+
+
+def check_daily(var: xr.DataArray):
+ """Raise an error if not series has a frequency other that daily, or is not monotonically increasing.
+
+ Notes
+ -----
+ This does not check for gaps in series.
+ """
+ return check_freq(var, "D")
+
+
+@datacheck
+def check_common_time(inputs: Sequence[xr.DataArray]):
+ """Raise an error if the list of inputs doesn't have a single common frequency.
+
+ Raises
+ ------
+ ValidationError
+ - if the frequency of any input can't be inferred
+ - if inputs have different frequencies
+ - if inputs have a daily or hourly frequency, but they are not given at the same time of day.
+
+ Parameters
+ ----------
+ inputs : Sequence of xr.DataArray
+ Input arrays.
+ """
+ # Check all have the same freq
+ freqs = [xr.infer_freq(da.time) for da in inputs]
+ if None in freqs:
+ raise ValidationError(
+ "Unable to infer the frequency of the time series. "
+ "To mute this, set xclim's option data_validation='log'."
+ )
+ if len(set(freqs)) != 1:
+ raise ValidationError(
+ f"Inputs have different frequencies. Got : {freqs}."
+ "To mute this, set xclim's option data_validation='log'."
+ )
+
+ # Check if anchor is the same
+ freq = freqs[0]
+ base = parse_offset(freq)[1]
+ fmt = {"h": ":%M", "D": "%H:%M"}
+ if base in fmt:
+ outs = {da.indexes["time"][0].strftime(fmt[base]) for da in inputs}
+ if len(outs) > 1:
+ raise ValidationError(
+ f"All inputs have the same frequency ({freq}), but they are not anchored on the same minutes (got {outs}). "
+ f"xarray's alignment would silently fail. You can try to fix this with `da.resample('{freq}').mean()`."
+ "To mute this, set xclim's option data_validation='log'."
+ )
+
+
+def is_percentile_dataarray(source: xr.DataArray) -> bool:
+ """Evaluate whether a DataArray is a Percentile.
+
+ A percentile dataarray must have climatology_bounds attributes and either a
+ quantile or percentiles coordinate, the window is not mandatory.
+ """
+ return (
+ isinstance(source, xr.DataArray)
+ and source.attrs.get("climatology_bounds", None) is not None
+ and ("quantile" in source.coords or "percentiles" in source.coords)
+ )
diff --git a/src/xsdba/formatting.py b/src/xsdba/formatting.py
index de5bbfb..77c1c08 100644
--- a/src/xsdba/formatting.py
+++ b/src/xsdba/formatting.py
@@ -7,11 +7,342 @@
import datetime as dt
import itertools
-from inspect import signature
+import re
+import string
+import warnings
+from ast import literal_eval
+from collections.abc import Sequence
+from fnmatch import fnmatch
+from inspect import _empty, signature
+from typing import Any, Callable
import xarray as xr
from boltons.funcutils import wraps
+from .typing import KIND_ANNOTATION, InputKind
+
+
+class AttrFormatter(string.Formatter):
+ """A formatter for frequently used attribute values.
+
+ See the doc of format_field() for more details.
+ """
+
+ def __init__(
+ self,
+ mapping: dict[str, Sequence[str]],
+ modifiers: Sequence[str],
+ ) -> None:
+ """Initialize the formatter.
+
+ Parameters
+ ----------
+ mapping : dict[str, Sequence[str]]
+ A mapping from values to their possible variations.
+ modifiers : Sequence[str]
+ The list of modifiers, must be the as long as the longest value of `mapping`.
+ Cannot include reserved modifier 'r'.
+ """
+ super().__init__()
+ if "r" in modifiers:
+ raise ValueError("Modifier 'r' is reserved for default raw formatting.")
+ self.modifiers = modifiers
+ self.mapping = mapping
+
+ def format(self, format_string: str, /, *args: Any, **kwargs: Any) -> str:
+ r"""Format a string.
+
+ Parameters
+ ----------
+ format_string: str
+ \*args: Any
+ \*\*kwargs: Any
+
+ Returns
+ -------
+ str
+ """
+ # ADAPT: THIS IS VERY CLIMATE, WILL BE REMOVED
+ # for k, v in DEFAULT_FORMAT_PARAMS.items():
+ # if k not in kwargs:
+ # kwargs.update({k: v})
+ return super().format(format_string, *args, **kwargs)
+
+ def format_field(self, value, format_spec):
+ """Format a value given a formatting spec.
+
+ If `format_spec` is in this Formatter's modifiers, the corresponding variation
+ of value is given. If `format_spec` is 'r' (raw), the value is returned unmodified.
+ If `format_spec` is not specified but `value` is in the mapping, the first variation is returned.
+
+ Examples
+ --------
+ Let's say the string "The dog is {adj1}, the goose is {adj2}" is to be translated
+ to French and that we know that possible values of `adj` are `nice` and `evil`.
+ In French, the genre of the noun changes the adjective (cat = chat is masculine,
+ and goose = oie is feminine) so we initialize the formatter as:
+
+ >>> fmt = AttrFormatter(
+ ... {
+ ... "nice": ["beau", "belle"],
+ ... "evil": ["méchant", "méchante"],
+ ... "smart": ["intelligent", "intelligente"],
+ ... },
+ ... ["m", "f"],
+ ... )
+ >>> fmt.format(
+ ... "Le chien est {adj1:m}, l'oie est {adj2:f}, le gecko est {adj3:r}",
+ ... adj1="nice",
+ ... adj2="evil",
+ ... adj3="smart",
+ ... )
+ "Le chien est beau, l'oie est méchante, le gecko est smart"
+
+ The base values may be given using unix shell-like patterns:
+
+ >>> fmt = AttrFormatter(
+ ... {"YS-*": ["annuel", "annuelle"], "MS": ["mensuel", "mensuelle"]},
+ ... ["m", "f"],
+ ... )
+ >>> fmt.format(
+ ... "La moyenne {freq:f} est faite sur un échantillon {src_timestep:m}",
+ ... freq="YS-JUL",
+ ... src_timestep="MS",
+ ... )
+ 'La moyenne annuelle est faite sur un échantillon mensuel'
+ """
+ baseval = self._match_value(value)
+ if baseval is None: # Not something we know how to translate
+ if format_spec in self.modifiers + [
+ "r"
+ ]: # Woops, however a known format spec was asked
+ warnings.warn(
+ f"Requested formatting `{format_spec}` for unknown string `{value}`."
+ )
+ format_spec = ""
+ return super().format_field(value, format_spec)
+ # Thus, known value
+
+ if not format_spec: # (None or '') No modifiers, return first
+ return self.mapping[baseval][0]
+
+ if format_spec == "r": # Raw modifier
+ return super().format_field(value, "")
+
+ if format_spec in self.modifiers: # Known modifier
+ if len(self.mapping[baseval]) == 1: # But unmodifiable entry
+ return self.mapping[baseval][0]
+ # Known modifier, modifiable entry
+ return self.mapping[baseval][self.modifiers.index(format_spec)]
+ # Known value but unknown modifier, must be a built-in one, only works for the default val...
+ return super().format_field(self.mapping[baseval][0], format_spec)
+
+ def _match_value(self, value):
+ if isinstance(value, str):
+ for mapval in self.mapping.keys():
+ if fnmatch(value, mapval):
+ return mapval
+ return None
+
+
+# Tag mappings between keyword arguments and long-form text.
+default_formatter = AttrFormatter(
+ {
+ # Arguments to "freq"
+ "D": ["daily", "days"],
+ "YS": ["annual", "years"],
+ "YS-*": ["annual", "years"],
+ "MS": ["monthly", "months"],
+ "QS-*": ["seasonal", "seasons"],
+ # Arguments to "indexer"
+ "DJF": ["winter"],
+ "MAM": ["spring"],
+ "JJA": ["summer"],
+ "SON": ["fall"],
+ "norm": ["Normal"],
+ "m1": ["january"],
+ "m2": ["february"],
+ "m3": ["march"],
+ "m4": ["april"],
+ "m5": ["may"],
+ "m6": ["june"],
+ "m7": ["july"],
+ "m8": ["august"],
+ "m9": ["september"],
+ "m10": ["october"],
+ "m11": ["november"],
+ "m12": ["december"],
+ # Arguments to "op / reducer / stat" (for example for generic.stats)
+ "integral": ["integrated", "integral"],
+ "count": ["count"],
+ "doymin": ["day of minimum"],
+ "doymax": ["day of maximum"],
+ "mean": ["average"],
+ "max": ["maximal", "maximum"],
+ "min": ["minimal", "minimum"],
+ "sum": ["total", "sum"],
+ "std": ["standard deviation"],
+ "var": ["variance"],
+ "absamp": ["absolute amplitude"],
+ "relamp": ["relative amplitude"],
+ # For when we are formatting indicator classes with empty options
+ "": [""],
+ },
+ ["adj", "noun"],
+)
+
+
+def parse_doc(doc: str) -> dict[str, str]:
+ """Crude regex parsing reading an indice docstring and extracting information needed in indicator construction.
+
+ The appropriate docstring syntax is detailed in :ref:`notebooks/extendxclim:Defining new indices`.
+
+ Parameters
+ ----------
+ doc : str
+ The docstring of an indice function.
+
+ Returns
+ -------
+ dict
+ A dictionary with all parsed sections.
+ """
+ if doc is None:
+ return {}
+
+ out = {}
+
+ sections = re.split(r"(\w+\s?\w+)\n\s+-{3,50}", doc) # obj.__doc__.split('\n\n')
+ intro = sections.pop(0)
+ if intro:
+ intro_content = list(map(str.strip, intro.strip().split("\n\n")))
+ if len(intro_content) == 1:
+ out["title"] = intro_content[0]
+ elif len(intro_content) >= 2:
+ out["title"], abstract = intro_content[:2]
+ out["abstract"] = " ".join(map(str.strip, abstract.splitlines()))
+
+ for i in range(0, len(sections), 2):
+ header, content = sections[i : i + 2]
+
+ if header in ["Notes", "References"]:
+ out[header.lower()] = content.replace("\n ", "\n").strip()
+ elif header == "Parameters":
+ out["parameters"] = _parse_parameters(content)
+ elif header == "Returns":
+ rets = _parse_returns(content)
+ if rets:
+ meta = list(rets.values())[0]
+ if "long_name" in meta:
+ out["long_name"] = meta["long_name"]
+ return out
+
+
+def _parse_parameters(section):
+ """Parse the 'parameters' section of a docstring into a dictionary.
+
+ Works by mapping the parameter name to its description and, potentially, to its set of choices.
+ The type annotation are not parsed, except for fixed sets of values (listed as "{'a', 'b', 'c'}").
+ The annotation parsing only accepts strings, numbers, `None` and `nan` (to represent `numpy.nan`).
+ """
+ curr_key = None
+ params = {}
+ for line in section.split("\n"):
+ if line.startswith(" " * 6): # description
+ s = " " if params[curr_key]["description"] else ""
+ params[curr_key]["description"] += s + line.strip()
+ elif line.startswith(" " * 4) and ":" in line: # param title
+ name, annot = line.split(":", maxsplit=1)
+ curr_key = name.strip()
+ params[curr_key] = {"description": ""}
+ match = re.search(r".*(\{.*\}).*", annot)
+ if match:
+ try:
+ choices = literal_eval(match.groups()[0])
+ params[curr_key]["choices"] = choices
+ except ValueError: # noqa: S110
+ # If the literal_eval fails, we just ignore the choices.
+ pass
+ return params
+
+
+def _parse_returns(section):
+ """Parse the returns section of a docstring into a dictionary mapping the parameter name to its description."""
+ curr_key = None
+ params = {}
+ for line in section.split("\n"):
+ if line.strip():
+ if line.startswith(" " * 6): # long_name
+ s = " " if params[curr_key]["long_name"] else ""
+ params[curr_key]["long_name"] += s + line.strip()
+ elif line.startswith(" " * 4): # param title
+ annot, *name = reversed(line.split(":", maxsplit=1))
+ if name:
+ curr_key = name[0].strip()
+ else:
+ curr_key = None
+ params[curr_key] = {"long_name": ""}
+ annot, *unit = annot.split(",", maxsplit=1)
+ if unit:
+ params[curr_key]["units"] = unit[0].strip()
+ return params
+
+
+# XC
+def prefix_attrs(source: dict, keys: Sequence, prefix: str) -> dict:
+ """Rename some keys of a dictionary by adding a prefix.
+
+ Parameters
+ ----------
+ source : dict
+ Source dictionary, for example data attributes.
+ keys : sequence
+ Names of keys to prefix.
+ prefix : str
+ Prefix to prepend to keys.
+
+ Returns
+ -------
+ dict
+ Dictionary of attributes with some keys prefixed.
+ """
+ out = {}
+ for key, val in source.items():
+ if key in keys:
+ out[f"{prefix}{key}"] = val
+ else:
+ out[key] = val
+ return out
+
+
+# XC
+def unprefix_attrs(source: dict, keys: Sequence, prefix: str) -> dict:
+ """Remove prefix from keys in a dictionary.
+
+ Parameters
+ ----------
+ source : dict
+ Source dictionary, for example data attributes.
+ keys : sequence
+ Names of original keys for which prefix should be removed.
+ prefix : str
+ Prefix to remove from keys.
+
+ Returns
+ -------
+ dict
+ Dictionary of attributes whose keys were prefixed, with prefix removed.
+ """
+ out = {}
+ n = len(prefix)
+ for key, val in source.items():
+ k = key[n:]
+ if (k in keys) and key.startswith(prefix):
+ out[k] = val
+ elif key not in out:
+ out[key] = val
+ return out
+
# XC
def merge_attributes(
@@ -205,3 +536,163 @@ def gen_call_string(
elements.append(rep)
return f"{funcname}({', '.join(elements)})"
+
+
+# XC
+def _gen_parameters_section(
+ parameters: dict[str, dict[str, Any]], allowed_periods: list[str] | None = None
+) -> str:
+ """Generate the "parameters" section of the indicator docstring.
+
+ Parameters
+ ----------
+ parameters : dict
+ Parameters dictionary (`Ind.parameters`).
+ allowed_periods : list of str, optional
+ Restrict parameters to specific periods. Default: None.
+
+ Returns
+ -------
+ str
+ """
+ section = "Parameters\n----------\n"
+ for name, param in parameters.items():
+ desc_str = param.description
+ if param.kind == InputKind.FREQ_STR:
+ desc_str += (
+ " See https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset"
+ "-aliases for available options."
+ )
+ if allowed_periods is not None:
+ desc_str += (
+ f" Restricted to frequencies equivalent to one of {allowed_periods}"
+ )
+ if param.kind == InputKind.VARIABLE:
+ defstr = f"Default : `ds.{param.default}`. "
+ elif param.kind == InputKind.OPTIONAL_VARIABLE:
+ defstr = ""
+ elif param.default is not _empty:
+ defstr = f"Default : {param.default}. "
+ else:
+ defstr = "Required. "
+ if "choices" in param:
+ annotstr = str(param.choices)
+ else:
+ annotstr = KIND_ANNOTATION[param.kind]
+ if "units" in param and param.units is not None:
+ unitstr = f"[Required units : {param.units}]"
+ else:
+ unitstr = ""
+ section += f"{name} {': ' if annotstr else ''}{annotstr}\n {desc_str}\n {defstr}{unitstr}\n"
+ return section
+
+
+def _gen_returns_section(cf_attrs: Sequence[dict[str, Any]]) -> str:
+ """Generate the "Returns" section of an indicator's docstring.
+
+ Parameters
+ ----------
+ cf_attrs : Sequence[Dict[str, Any]]
+ The list of attributes, usually Indicator.cf_attrs.
+
+ Returns
+ -------
+ str
+ """
+ section = "Returns\n-------\n"
+ for attrs in cf_attrs:
+ if not section.endswith("\n"):
+ section += "\n"
+ section += f"{attrs['var_name']} : DataArray\n"
+ section += f" {attrs.get('long_name', '')}"
+ if "standard_name" in attrs:
+ section += f" ({attrs['standard_name']})"
+ if "units" in attrs:
+ section += f" [{attrs['units']}]"
+ added_section = ""
+ for key, attr in attrs.items():
+ if key not in ["long_name", "standard_name", "units", "var_name"]:
+ if callable(attr):
+ attr = ""
+ added_section += f" **{key}**: {attr};"
+ if added_section:
+ section = f"{section}, with additional attributes:{added_section[:-1]}"
+ section += "\n"
+ return section
+
+
+def generate_indicator_docstring(ind) -> str:
+ """Generate an indicator's docstring from keywords.
+
+ Parameters
+ ----------
+ ind
+ Indicator instance
+
+ Returns
+ -------
+ str
+ """
+ header = f"{ind.title} (realm: {ind.realm})\n\n{ind.abstract}\n"
+
+ special = ""
+
+ if hasattr(ind, "missing"): # Only ResamplingIndicators
+ special += f'This indicator will check for missing values according to the method "{ind.missing}".\n'
+ if hasattr(ind.compute, "__module__"):
+ special += f"Based on indice :py:func:`~{ind.compute.__module__}.{ind.compute.__name__}`.\n"
+ if ind.injected_parameters:
+ special += "With injected parameters: "
+ special += ", ".join(
+ [f"{k}={v}" for k, v in ind.injected_parameters.items()]
+ )
+ special += ".\n"
+ if ind.keywords:
+ special += f"Keywords : {ind.keywords}.\n"
+
+ parameters = _gen_parameters_section(
+ ind.parameters, getattr(ind, "allowed_periods", None)
+ )
+
+ returns = _gen_returns_section(ind.cf_attrs)
+
+ extras = ""
+ for section in ["notes", "references"]:
+ if getattr(ind, section):
+ extras += f"{section.capitalize()}\n{'-' * len(section)}\n{getattr(ind, section)}\n\n"
+
+ doc = f"{header}\n{special}\n{parameters}\n{returns}\n{extras}"
+ return doc
+
+
+def get_percentile_metadata(data: xr.DataArray, prefix: str) -> dict[str, str]:
+ """Get the metadata related to percentiles from the given DataArray as a dictionary.
+
+ Parameters
+ ----------
+ data : xr.DataArray
+ Must be a percentile DataArray, this means the necessary metadata
+ must be available in its attributes and coordinates.
+ prefix : str
+ The prefix to be used in the metadata key.
+ Usually this takes the form of "tasmin_per" or equivalent.
+
+ Returns
+ -------
+ dict
+ A mapping of the configuration used to compute these percentiles.
+ """
+ # handle case where da was created with `quantile()` method
+ if "quantile" in data.coords:
+ percs = data.coords["quantile"].values * 100
+ elif "percentiles" in data.coords:
+ percs = data.coords["percentiles"].values
+ else:
+ percs = ""
+ clim_bounds = data.attrs.get("climatology_bounds", "")
+
+ return {
+ f"{prefix}_thresh": percs,
+ f"{prefix}_window": data.attrs.get("window", ""),
+ f"{prefix}_period": clim_bounds,
+ }
diff --git a/src/xsdba/locales.py b/src/xsdba/locales.py
new file mode 100644
index 0000000..1eba691
--- /dev/null
+++ b/src/xsdba/locales.py
@@ -0,0 +1,331 @@
+"""
+Internationalization
+====================
+
+This module defines methods and object to help the internationalization of metadata for
+climate indicators computed by xclim. Go to :ref:`notebooks/customize:Adding translated metadata` to see
+how to use this feature.
+
+All the methods and objects in this module use localization data given in JSON files.
+These files are expected to be defined as in this example for French:
+
+.. code-block::
+
+ {
+ "attrs_mapping": {
+ "modifiers": ["", "f", "mpl", "fpl"],
+ "YS": ["annuel", "annuelle", "annuels", "annuelles"],
+ "YS-*": ["annuel", "annuelle", "annuels", "annuelles"],
+ # ... and so on for other frequent parameters translation...
+ },
+ "DTRVAR": {
+ "long_name": "Variabilité de l'amplitude de la température diurne",
+ "description": "Variabilité {freq:f} de l'amplitude de la température diurne (définie comme la moyenne de la variation journalière de l'amplitude de température sur une période donnée)",
+ "title": "Variation quotidienne absolue moyenne de l'amplitude de la température diurne",
+ "comment": "",
+ "abstract": "La valeur absolue de la moyenne de l'amplitude de la température diurne.",
+ },
+ # ... and so on for other indicators...
+ }
+
+Indicators are named by subclass identifier, the same as in the indicator registry (`xclim.core.indicators.registry`),
+but which can differ from the callable name. In this case, the indicator is called through
+`atmos.daily_temperature_range_variability`, but its identifier is `DTRVAR`.
+Use the `ind.__class__.__name__` accessor to get its registry name.
+
+Here, the usual parameter passed to the formatting of "description" is "freq" and is usually translated from "YS"
+to "annual". However, in French and in this sentence, the feminine form should be used, so the "f" modifier is added
+by the translator so that the formatting function knows which translation to use. Acceptable entries for the mappings
+are limited to what is already defined in `xclim.core.indicators.utils.default_formatter`.
+
+For user-provided internationalization dictionaries, only the "attrs_mapping" and its "modifiers" key are mandatory,
+all other entries (translations of frequent parameters and all indicator entries) are optional.
+For xclim-provided translations (for now only French), all indicators must have en entry and the "attrs_mapping"
+entries must match exactly the default formatter.
+Those default translations are found in the `xclim/locales` folder.
+"""
+
+from __future__ import annotations
+
+import json
+import warnings
+from collections.abc import Sequence
+from copy import deepcopy
+from pathlib import Path
+
+from .formatting import AttrFormatter, default_formatter
+
+TRANSLATABLE_ATTRS = [
+ "long_name",
+ "description",
+ "comment",
+ "title",
+ "abstract",
+ "keywords",
+]
+"""
+List of attributes to consider translatable when generating locale dictionaries.
+"""
+
+_LOCALES = {}
+
+
+def list_locales():
+ """List of loaded locales. Includes all loaded locales, no matter how complete the translations are."""
+ return list(_LOCALES.keys())
+
+
+def _valid_locales(locales):
+ """Check if the locales are valid."""
+ if isinstance(locales, str):
+ return True
+ return all(
+ [
+ # A locale is valid if it is a string from the list
+ (isinstance(locale, str) and locale in _LOCALES)
+ or (
+ # Or if it is a tuple of a string and either a file or a dict.
+ not isinstance(locale, str)
+ and isinstance(locale[0], str)
+ and (isinstance(locale[1], dict) or Path(locale[1]).is_file())
+ )
+ for locale in locales
+ ]
+ )
+
+
+def get_local_dict(locale: str | Sequence[str] | tuple[str, dict]) -> tuple[str, dict]:
+ """Return all translated metadata for a given locale.
+
+ Parameters
+ ----------
+ locale: str or sequence of str
+ IETF language tag or a tuple of the language tag and a translation dict, or a tuple of the language
+ tag and a path to a json file defining translation of attributes.
+
+ Raises
+ ------
+ UnavailableLocaleError
+ If the given locale is not available.
+
+ Returns
+ -------
+ str
+ The best fitting locale string
+ dict
+ The available translations in this locale.
+ """
+ _valid_locales([locale])
+
+ if isinstance(locale, str):
+ if locale not in _LOCALES:
+ raise UnavailableLocaleError(locale)
+
+ return locale, deepcopy(_LOCALES[locale])
+
+ if isinstance(locale[1], dict):
+ trans = locale[1]
+ else:
+ # Thus, a string pointing to a json file
+ trans = read_locale_file(locale[1])
+
+ if locale[0] in _LOCALES:
+ loaded_trans = deepcopy(_LOCALES[locale[0]])
+ # Passed translations have priority
+ loaded_trans.update(trans)
+ trans = loaded_trans
+ return locale[0], trans
+
+
+def get_local_attrs(
+ indicator: str | Sequence[str],
+ *locales: str | Sequence[str] | tuple[str, dict],
+ names: Sequence[str] | None = None,
+ append_locale_name: bool = True,
+) -> dict:
+ """Get all attributes of an indicator in the requested locales.
+
+ Parameters
+ ----------
+ indicator : str or sequence of strings
+ Indicator's class name, usually the same as in `xc.core.indicator.registry`.
+ If multiple names are passed, the attrs from each indicator are merged,
+ with the highest priority set to the first name.
+ locales : str or tuple of str
+ IETF language tag or a tuple of the language tag and a translation dict, or a tuple of the language tag
+ and a path to a json file defining translation of attributes.
+ names : sequence of str, optional
+ If given, only returns translations of attributes in this list.
+ append_locale_name : bool
+ If True (default), append the language tag (as "{attr_name}_{locale}") to the returned attributes.
+
+ Raises
+ ------
+ ValueError
+ If `append_locale_name` is False and multiple `locales` are requested.
+
+ Returns
+ -------
+ dict
+ All CF attributes available for given indicator and locales.
+ Warns and returns an empty dict if none were available.
+ """
+ if isinstance(indicator, str):
+ indicator = [indicator]
+
+ if not append_locale_name and len(locales) > 1:
+ raise ValueError(
+ "`append_locale_name` cannot be False if multiple locales are requested."
+ )
+
+ attrs = {}
+ for locale in locales:
+ loc_name, loc_dict = get_local_dict(locale)
+ loc_name = f"_{loc_name}" if append_locale_name else ""
+ local_attrs = loc_dict.get(indicator[-1], {})
+ for other_ind in indicator[-2::-1]:
+ local_attrs.update(loc_dict.get(other_ind, {}))
+ if not local_attrs:
+ warnings.warn(
+ f"Attributes of indicator {', '.join(indicator)} in language {locale} "
+ "were requested, but none were found."
+ )
+ else:
+ for name in TRANSLATABLE_ATTRS:
+ if (names is None or name in names) and name in local_attrs:
+ attrs[f"{name}{loc_name}"] = local_attrs[name]
+ return attrs
+
+
+def get_local_formatter(
+ locale: str | Sequence[str] | tuple[str, dict]
+) -> AttrFormatter:
+ """Return an AttrFormatter instance for the given locale.
+
+ Parameters
+ ----------
+ locale: str or tuple of str
+ IETF language tag or a tuple of the language tag and a translation dict, or a tuple of the language tag
+ and a path to a json file defining translation of attributes.
+ """
+ _, loc_dict = get_local_dict(locale)
+ if "attrs_mapping" in loc_dict:
+ attrs_mapping = loc_dict["attrs_mapping"].copy()
+ mods = attrs_mapping.pop("modifiers")
+ return AttrFormatter(attrs_mapping, mods)
+
+ warnings.warn(
+ "No `attrs_mapping` entry found for locale {loc_name}, using default (english) formatter."
+ )
+ return default_formatter
+
+
+class UnavailableLocaleError(ValueError):
+ """Error raised when a locale is requested but doesn't exist."""
+
+ def __init__(self, locale):
+ super().__init__(
+ f"Locale {locale} not available. Use `xclim.core.locales.list_locales()` to see available languages."
+ )
+
+
+def read_locale_file(
+ filename, module: str | None = None, encoding: str = "UTF8"
+) -> dict[str, dict]:
+ """Read a locale file (.json) and return its dictionary.
+
+ Parameters
+ ----------
+ filename : PathLike
+ The file to read.
+ module : str, optional
+ If module is a string, this module name is added to all identifiers translated in this file.
+ Defaults to None, and no module name is added (as if the indicator was an official xclim indicator).
+ encoding : str
+ The encoding to use when reading the file.
+ Defaults to UTF-8, overriding python's default mechanism which is machine dependent.
+ """
+ locdict: dict[str, dict]
+ with open(filename, encoding=encoding) as f:
+ locdict = json.load(f)
+
+ if module is not None:
+ locdict = {
+ (k if k == "attrs_mapping" else f"{module}.{k}"): v
+ for k, v in locdict.items()
+ }
+ return locdict
+
+
+def load_locale(locdata: str | Path | dict[str, dict], locale: str):
+ """Load translations from a json file into xclim.
+
+ Parameters
+ ----------
+ locdata : str or Path or dictionary
+ Either a loaded locale dictionary or a path to a json file.
+ locale : str
+ The locale name (IETF tag).
+ """
+ if isinstance(locdata, (str, Path)):
+ filename = Path(locdata)
+ locdata = read_locale_file(filename)
+
+ if locale in _LOCALES:
+ _LOCALES[locale].update(locdata)
+ else:
+ _LOCALES[locale] = locdata
+
+
+def generate_local_dict(locale: str, init_english: bool = False) -> dict:
+ """Generate a dictionary with keys for each indicator and translatable attributes.
+
+ Parameters
+ ----------
+ locale : str
+ Locale in the IETF format
+ init_english : bool
+ If True, fills the initial dictionary with the english versions of the attributes.
+ Defaults to False.
+ """
+ from ..core.indicator import registry # pylint: disable=import-outside-toplevel
+
+ if locale in _LOCALES:
+ _, attrs = get_local_dict(locale)
+ for ind_name in attrs.copy().keys():
+ if ind_name != "attrs_mapping" and ind_name not in registry:
+ attrs.pop(ind_name)
+ else:
+ attrs = {}
+
+ attrs_mapping = attrs.setdefault("attrs_mapping", {})
+ attrs_mapping.setdefault("modifiers", [""])
+ for key, value in default_formatter.mapping.items():
+ attrs_mapping.setdefault(key, [value[0]])
+
+ eng_attr = ""
+ for ind_name, indicator in registry.items():
+ ind_attrs = attrs.setdefault(ind_name, {})
+ for translatable_attr in set(TRANSLATABLE_ATTRS).difference(
+ set(indicator._cf_names)
+ ):
+ if init_english:
+ eng_attr = getattr(indicator, translatable_attr)
+ if not isinstance(eng_attr, str):
+ eng_attr = ""
+ ind_attrs.setdefault(f"{translatable_attr}", eng_attr)
+
+ for cf_attrs in indicator.cf_attrs:
+ # In the case of single output, put var attrs in main dict
+ if len(indicator.cf_attrs) > 1:
+ ind_attrs = attrs.setdefault(f"{ind_name}.{cf_attrs['var_name']}", {})
+
+ for translatable_attr in set(TRANSLATABLE_ATTRS).intersection(
+ set(indicator._cf_names)
+ ):
+ if init_english:
+ eng_attr = cf_attrs.get(translatable_attr)
+ if not isinstance(eng_attr, str):
+ eng_attr = ""
+ ind_attrs.setdefault(f"{translatable_attr}", eng_attr)
+ return attrs
diff --git a/src/xsdba/logging.py b/src/xsdba/logging.py
index 79ee33c..14ae2c3 100644
--- a/src/xsdba/logging.py
+++ b/src/xsdba/logging.py
@@ -54,3 +54,7 @@ def raise_warn_or_log(
warnings.warn(message, stacklevel=stacklevel + 1)
else: # mode == "raise"
raise err from err_type(message)
+
+
+class MissingVariableError(ValueError):
+ """Error raised when a dataset is passed to an indicator but one of the needed variable is missing."""
diff --git a/src/xsdba/options.py b/src/xsdba/options.py
index ef48f11..cd34814 100644
--- a/src/xsdba/options.py
+++ b/src/xsdba/options.py
@@ -11,14 +11,15 @@
from boltons.funcutils import wraps
-# from .locales import _valid_locales # from XC, not reproduced for now
+from .locales import _valid_locales
from .logging import ValidationError, raise_warn_or_log
-# METADATA_LOCALES = "metadata_locales"
+METADATA_LOCALES = "metadata_locales"
DATA_VALIDATION = "data_validation"
CF_COMPLIANCE = "cf_compliance"
CHECK_MISSING = "check_missing"
MISSING_OPTIONS = "missing_options"
+RUN_LENGTH_UFUNC = "run_length_ufunc"
SDBA_EXTRA_OUTPUT = "sdba_extra_output"
SDBA_ENCODE_CF = "sdba_encode_cf"
KEEP_ATTRS = "keep_attrs"
@@ -27,11 +28,12 @@
MISSING_METHODS: dict[str, Callable] = {}
OPTIONS = {
- # METADATA_LOCALES: [],
+ METADATA_LOCALES: [],
DATA_VALIDATION: "raise",
CF_COMPLIANCE: "warn",
CHECK_MISSING: "any",
MISSING_OPTIONS: {},
+ RUN_LENGTH_UFUNC: "auto",
SDBA_EXTRA_OUTPUT: False,
SDBA_ENCODE_CF: False,
KEEP_ATTRS: "xarray",
@@ -39,6 +41,7 @@
}
_LOUDNESS_OPTIONS = frozenset(["log", "warn", "raise"])
+_RUN_LENGTH_UFUNC_OPTIONS = frozenset(["auto", True, False])
_KEEP_ATTRS_OPTIONS = frozenset(["xarray", True, False])
@@ -57,11 +60,12 @@ def _valid_missing_options(mopts):
_VALIDATORS = {
- # METADATA_LOCALES: _valid_locales,
+ METADATA_LOCALES: _valid_locales,
DATA_VALIDATION: _LOUDNESS_OPTIONS.__contains__,
CF_COMPLIANCE: _LOUDNESS_OPTIONS.__contains__,
CHECK_MISSING: lambda meth: meth != "from_context" and meth in MISSING_METHODS,
MISSING_OPTIONS: _valid_missing_options,
+ RUN_LENGTH_UFUNC: _RUN_LENGTH_UFUNC_OPTIONS.__contains__,
SDBA_EXTRA_OUTPUT: lambda opt: isinstance(opt, bool),
SDBA_ENCODE_CF: lambda opt: isinstance(opt, bool),
KEEP_ATTRS: _KEEP_ATTRS_OPTIONS.__contains__,
@@ -74,16 +78,16 @@ def _set_missing_options(mopts):
OPTIONS[MISSING_OPTIONS][meth].update(opts)
-# def _set_metadata_locales(locales):
-# if isinstance(locales, str):
-# OPTIONS[METADATA_LOCALES] = [locales]
-# else:
-# OPTIONS[METADATA_LOCALES] = locales
+def _set_metadata_locales(locales):
+ if isinstance(locales, str):
+ OPTIONS[METADATA_LOCALES] = [locales]
+ else:
+ OPTIONS[METADATA_LOCALES] = locales
_SETTERS = {
MISSING_OPTIONS: _set_missing_options,
- # METADATA_LOCALES: _set_metadata_locales,
+ METADATA_LOCALES: _set_metadata_locales,
}
diff --git a/src/xsdba/processing.py b/src/xsdba/processing.py
index 992b198..4e37cb3 100644
--- a/src/xsdba/processing.py
+++ b/src/xsdba/processing.py
@@ -14,7 +14,7 @@
import xarray as xr
from xarray.core.utils import get_temp_dimname
-from xsdba.base import get_calendar, max_doy, parse_offset, uses_dask
+from xsdba.calendar import get_calendar, max_doy, parse_offset, uses_dask
from xsdba.formatting import update_xsdba_history
from ._processing import _adapt_freq, _normalize, _reordering
@@ -388,7 +388,7 @@ def escore(
tgt: xr.DataArray,
sim: xr.DataArray,
dims: Sequence[str] = ("variables", "time"),
- N: int = 0, # noqa
+ N: int = 0,
scale: bool = False,
) -> xr.DataArray:
r"""Energy score, or energy dissimilarity metric, based on :cite:t:`sdba-szekely_testing_2004` and :cite:t:`sdba-cannon_multivariate_2018`.
diff --git a/src/xsdba/properties.py b/src/xsdba/properties.py
new file mode 100644
index 0000000..648cc51
--- /dev/null
+++ b/src/xsdba/properties.py
@@ -0,0 +1,1577 @@
+# pylint: disable=missing-kwoa
+"""
+Properties Submodule
+====================
+SDBA diagnostic tests are made up of statistical properties and measures. Properties are calculated on both simulation
+and reference datasets. They collapse the time dimension to one value.
+
+This framework for the diagnostic tests was inspired by the `VALUE `_ project.
+Statistical Properties is the xclim term for 'indices' in the VALUE project.
+
+"""
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import numpy as np
+import xarray as xr
+import xclim as xc
+from scipy import stats
+from statsmodels.tsa import stattools
+
+import xsdba.xclim_submodules.run_length as rl
+from xsdba.indicator import Indicator, base_registry
+from xsdba.units import (
+ convert_units_to,
+ ensure_delta,
+ pint2str,
+ to_agg_units,
+ units2pint,
+)
+from xsdba.utils import uses_dask
+from xsdba.xclim_submodules.generic import compare, select_resample_op
+from xsdba.xclim_submodules.stats import fit, parametric_quantile
+
+from .base import Grouper, map_groups, parse_group, parse_offset
+from .nbutils import _pairwise_haversine_and_bins
+from .utils import _pairwise_spearman, copy_all_attrs
+
+
+class StatisticalProperty(Indicator):
+ """Base indicator class for statistical properties used for validating bias-adjusted outputs.
+
+ Statistical properties reduce the time dimension, sometimes adding a grouping dimension
+ according to the passed value of `group` (e.g.: group='time.month' means the loss of the
+ time dimension and the addition of a month one).
+
+ Statistical properties are generally unit-generic. To use those indicator in a workflow, it
+ is recommended to wrap them with a virtual submodule, creating one specific indicator for
+ each variable input (or at least for each possible dimensionality).
+
+ Statistical properties may restrict the sampling frequency of the input, they usually take in a
+ single variable (named "da" in unit-generic instances).
+
+ """
+
+ aspect = None
+ """The aspect the statistical property studies: marginal, temporal, multivariate or spatial."""
+
+ measure = "xclim.sdba.measures.BIAS"
+ """The default measure to use when comparing the properties of two datasets.
+ This gives the registry id. See :py:meth:`get_measure`."""
+
+ allowed_groups = None
+ """A list of allowed groupings. A subset of dayofyear, week, month, season or group.
+ The latter stands for no temporal grouping."""
+
+ realm = "generic"
+
+ @classmethod
+ def _ensure_correct_parameters(cls, parameters):
+ if "group" not in parameters:
+ raise ValueError(
+ f"{cls.__name__} require a 'group' argument, use the base Indicator"
+ " class if your computation doesn't perform any regrouping."
+ )
+ return super()._ensure_correct_parameters(parameters)
+
+ def _preprocess_and_checks(self, das, params):
+ """Check if group is allowed."""
+ # Convert grouping and check if allowed:
+ if isinstance(params["group"], str):
+ params["group"] = Grouper(params["group"])
+
+ if self.allowed_groups is not None:
+ if params["group"].prop not in self.allowed_groups:
+ raise ValueError(
+ f"Grouping period {params['group'].prop_name} is not allowed for property "
+ f"{self.identifier} (needs something in "
+ f"{map(lambda g: '.' + g.replace('group', ''), self.allowed_groups)})."
+ )
+
+ return das, params
+
+ def _postprocess(self, outs, das, params):
+ """Squeeze `group` dim if needed."""
+ outs = super()._postprocess(outs, das, params)
+
+ for i in range(len(outs)):
+ if "group" in outs[i].dims:
+ outs[i] = outs[i].squeeze("group", drop=True)
+
+ return outs
+
+ def get_measure(self):
+ """Get the statistical measure indicator that is best used with this statistical property."""
+ from xclim.core.indicator import registry
+
+ return registry[self.measure].get_instance()
+
+
+base_registry["StatisticalProperty"] = StatisticalProperty
+
+
+@parse_group
+def _mean(da: xr.DataArray, *, group: str | Grouper = "time") -> xr.DataArray:
+ """Mean.
+
+ Mean over all years at the time resolution.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. If 'time.month', the temporal average is performed separately for each month.
+
+ Returns
+ -------
+ xr.DataArray, [same as input]
+ Mean of the variable.
+ """
+ units = da.units
+ if group.prop != "group":
+ da = da.groupby(group.name)
+ out = da.mean(dim=group.dim)
+ return out.assign_attrs(units=units)
+
+
+mean = StatisticalProperty(
+ identifier="mean",
+ aspect="marginal",
+ cell_methods="time: mean",
+ compute=_mean,
+)
+
+
+@parse_group
+def _var(da: xr.DataArray, *, group: str | Grouper = "time") -> xr.DataArray:
+ """Variance.
+
+ Variance of the variable over all years at the time resolution.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. If 'time.month', the variance is performed separately for each month.
+
+ Returns
+ -------
+ xr.DataArray, [square of the input units]
+ Variance of the variable.
+ """
+ units = da.units
+ if group.prop != "group":
+ da = da.groupby(group.name)
+ out = da.var(dim=group.dim)
+ u2 = units2pint(units) ** 2
+ out.attrs["units"] = pint2str(u2)
+ return out
+
+
+var = StatisticalProperty(
+ identifier="var",
+ aspect="marginal",
+ cell_methods="time: var",
+ compute=_var,
+ measure="xclim.sdba.measures.RATIO",
+)
+
+
+@parse_group
+def _std(da: xr.DataArray, *, group: str | Grouper = "time") -> xr.DataArray:
+ """Standard Deviation.
+
+ Standard deviation of the variable over all years at the time resolution.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. If 'time.month', the standard deviation is performed separately for each month.
+
+ Returns
+ -------
+ xr.DataArray,
+ Standard deviation of the variable.
+ """
+ units = da.units
+ if group.prop != "group":
+ da = da.groupby(group.name)
+ out = da.std(dim=group.dim)
+ out.attrs["units"] = units
+ return out
+
+
+std = StatisticalProperty(
+ identifier="std",
+ aspect="marginal",
+ cell_methods="time: std",
+ compute=_std,
+ measure="xclim.sdba.measures.RATIO",
+)
+
+
+@parse_group
+def _skewness(da: xr.DataArray, *, group: str | Grouper = "time") -> xr.DataArray:
+ """Skewness.
+
+ Skewness of the distribution of the variable over all years at the time resolution.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. If 'time.month', the skewness is performed separately for each month.
+
+ Returns
+ -------
+ xr.DataArray, [dimensionless]
+ Skewness of the variable.
+
+ See Also
+ --------
+ scipy.stats.skew
+ """
+ if group.prop != "group":
+ da = da.groupby(group.name)
+ out = xr.apply_ufunc(
+ stats.skew,
+ da,
+ input_core_dims=[[group.dim]],
+ vectorize=True,
+ dask="parallelized",
+ )
+ out.attrs["units"] = ""
+ return out
+
+
+skewness = StatisticalProperty(
+ identifier="skewness", aspect="marginal", compute=_skewness, units=""
+)
+
+
+@parse_group
+def _quantile(
+ da: xr.DataArray, *, q: float = 0.98, group: str | Grouper = "time"
+) -> xr.DataArray:
+ """Quantile.
+
+ Returns the quantile q of the distribution of the variable over all years at the time resolution.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ q : float
+ Quantile to be calculated. Should be between 0 and 1.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. If 'time.month', the quantile is computed separately for each month.
+
+ Returns
+ -------
+ xr.DataArray, [same as input]
+ Quantile {q} of the variable.
+ """
+ units = da.units
+ if group.prop != "group":
+ da = da.groupby(group.name)
+ out = da.quantile(q, dim=group.dim, keep_attrs=True).drop_vars("quantile")
+ return out.assign_attrs(units=units)
+
+
+quantile = StatisticalProperty(
+ identifier="quantile", aspect="marginal", compute=_quantile
+)
+
+
+def _spell_length_distribution(
+ da: xr.DataArray,
+ *,
+ method: str = "amount",
+ op: str = ">=",
+ thresh: str = "1 mm d-1",
+ window: int = 1,
+ stat: str = "mean",
+ stat_resample: str | None = None,
+ group: str | Grouper = "time",
+ resample_before_rl: bool = True,
+) -> xr.DataArray:
+ """Spell length distribution.
+
+ Statistic of spell length distribution when the variable respects a condition (defined by an operation, a method and
+ a threshold).
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ method: {'amount', 'quantile'}
+ Method to choose the threshold.
+ 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}.
+ 'quantile': The threshold is calculated as the quantile {thresh} of the distribution.
+ op : {">", "<", ">=", "<="}
+ Operation to verify the condition for a spell.
+ The condition for a spell is variable {op} threshold.
+ thresh : str or float
+ Threshold on which to evaluate the condition to have a spell.
+ String with units if the method is "amount".
+ Float of the quantile if the method is "quantile".
+ window : int
+ Number of consecutive days respecting the constraint in order to begin a spell.
+ Default is 1, which is equivalent to `_threshold_count`
+ stat : {'mean', 'sum', 'max','min'}
+ Statistics to apply to the remaining time dimension after resampling (e.g. Jan 1980-2010)
+ stat_resample : {'mean', 'sum', 'max','min'}, optional
+ Statistics to apply to the resampled input at the {group} (e.g. 1-31 Jan 1980).
+ If `None`, the same method as `stat` will be used.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. If 'time.month', the spell lengths are computed separately for each month.
+ resample_before_rl : bool
+ Determines if the resampling should take place before or after the run
+ length encoding (or a similar algorithm) is applied to runs.
+
+ Returns
+ -------
+ xr.DataArray, [units of the sampling frequency]
+ {stat} of spell length distribution when the variable is {op} the {method} {thresh} for {window} consecutive day(s).
+ """
+ group = group if isinstance(group, Grouper) else Grouper(group)
+
+ ops = {">": np.greater, "<": np.less, ">=": np.greater_equal, "<=": np.less_equal}
+
+ @map_groups(out=[Grouper.PROP], main_only=True)
+ def _spell_stats(
+ ds,
+ *,
+ dim,
+ method,
+ thresh,
+ window,
+ op,
+ freq,
+ resample_before_rl,
+ stat,
+ stat_resample,
+ ):
+ # PB: This prevents an import error in the distributed dask scheduler, but I don't know why.
+ import xarray.core.resample_cftime # noqa: F401, pylint: disable=unused-import
+
+ da = ds.data
+ mask = ~(da.isel({dim: 0}).isnull()).drop_vars(
+ dim
+ ) # mask of the ocean with NaNs
+ if method == "quantile":
+ thresh = da.quantile(thresh, dim=dim).drop_vars("quantile")
+
+ cond = op(da, thresh)
+ out = rl.resample_and_rl(
+ cond,
+ resample_before_rl,
+ rl.rle_statistics,
+ reducer=stat_resample,
+ window=window,
+ dim=dim,
+ freq=freq,
+ )
+ out = getattr(out, stat)(dim=dim)
+ out = out.where(mask)
+ return out.rename("out").to_dataset()
+
+ # threshold is an amount that will be converted to the right units
+ if method == "amount":
+ thresh = convert_units_to(thresh, da) # , context="infer")
+ elif method != "quantile":
+ raise ValueError(
+ f"{method} is not a valid method. Choose 'amount' or 'quantile'."
+ )
+
+ out = _spell_stats(
+ da.rename("data").to_dataset(),
+ group=group,
+ method=method,
+ thresh=thresh,
+ window=window,
+ op=ops[op],
+ freq=group.freq,
+ resample_before_rl=resample_before_rl,
+ stat=stat,
+ stat_resample=stat_resample or stat,
+ ).out
+ return to_agg_units(out, da, op="count")
+
+
+spell_length_distribution = StatisticalProperty(
+ identifier="spell_length_distribution",
+ aspect="temporal",
+ compute=_spell_length_distribution,
+)
+
+
+@parse_group
+def _threshold_count(
+ da: xr.DataArray,
+ *,
+ method: str = "amount",
+ op: str = ">=",
+ thresh: str = "1 mm d-1",
+ stat: str = "mean",
+ stat_resample: str | None = None,
+ group: str | Grouper = "time",
+) -> xr.DataArray:
+ r"""Correlation between two variables.
+
+ Spearman or Pearson correlation coefficient between two variables at the time resolution.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ method : {'amount', 'quantile'}
+ Method to choose the threshold.
+ 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}.
+ 'quantile': The threshold is calculated as the quantile {thresh} of the distribution.
+ op : {">", "<", ">=", "<="}
+ Operation to verify the condition for a spell.
+ The condition for a spell is variable {op} threshold.
+ thresh : str or float
+ Threshold on which to evaluate the condition to have a spell.
+ String with units if the method is "amount".
+ Float of the quantile if the method is "quantile".
+ stat : {'mean', 'sum', 'max','min'}
+ Statistics to apply to the remaining time dimension after resampling (e.g. Jan 1980-2010)
+ stat_resample : {'mean', 'sum', 'max','min'}, optional
+ Statistics to apply to the resampled input at the {group} (e.g. 1-31 Jan 1980). If `None`, the same method as `stat` will be used.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. For 'time.month', the correlation would be calculated on each month separately,
+ but with all the years together.
+
+ Returns
+ -------
+ xr.DataArray, [dimensionless]
+ {stat} number of days when the variable is {op} the {method} {thresh}.
+
+ Notes
+ -----
+ This corresponds to ``xclim.sdba.properties._spell_length_distribution`` with `window=1`.
+ """
+ return _spell_length_distribution(
+ da,
+ method=method,
+ op=op,
+ thresh=thresh,
+ stat=stat,
+ stat_resample=stat_resample,
+ group=group,
+ window=1,
+ )
+
+
+threshold_count = StatisticalProperty(
+ identifier="threshold_count", aspect="temporal", compute=_threshold_count
+)
+
+
+@parse_group
+def _acf(
+ da: xr.DataArray, *, lag: int = 1, group: str | Grouper = "time.season"
+) -> xr.DataArray:
+ """Autocorrelation.
+
+ Autocorrelation with a lag over a time resolution and averaged over all years.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ lag : int
+ Lag.
+ group : {'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. If 'time.month', the autocorrelation is calculated over each month separately for all years.
+ Then, the autocorrelation for all Jan/Feb/... is averaged over all years, giving 12 outputs for each grid point.
+
+ Returns
+ -------
+ xr.DataArray, [dimensionless]
+ Lag-{lag} autocorrelation of the variable over a {group.prop} and averaged over all years.
+
+ See Also
+ --------
+ statsmodels.tsa.stattools.acf
+
+ References
+ ----------
+ :cite:cts:`alavoine_distinct_2022`
+ """
+
+ def acf_last(x, nlags):
+ """Statsmodels acf calculates acf for lag 0 to nlags, this return only the last one."""
+ # As we resample + group, timeseries are quite short and fft=False seems more performant
+ out_last = stattools.acf(x, nlags=nlags, fft=False)
+ return out_last[-1]
+
+ @map_groups(out=[Grouper.PROP], main_only=True)
+ def __acf(ds, *, dim, lag, freq):
+ out = xr.apply_ufunc(
+ acf_last,
+ ds.data.resample({dim: freq}),
+ input_core_dims=[[dim]],
+ vectorize=True,
+ kwargs={"nlags": lag},
+ )
+ out = out.mean("__resample_dim__")
+ return out.rename("out").to_dataset()
+
+ out = __acf(
+ da.rename("data").to_dataset(), group=group, lag=lag, freq=group.freq
+ ).out
+ out.attrs["units"] = ""
+ return out
+
+
+acf = StatisticalProperty(
+ identifier="acf",
+ aspect="temporal",
+ allowed_groups=["season", "month"],
+ compute=_acf,
+)
+
+
+# group was kept even though "time" is the only acceptable arg to keep the signature similar to other properties
+# @parse_group doesn't work well here because of `window`
+def _annual_cycle(
+ da: xr.DataArray,
+ *,
+ stat: str = "absamp",
+ window: int = 31,
+ group: str | Grouper = "time",
+) -> xr.DataArray:
+ r"""Annual cycle statistics.
+
+ A daily climatology is calculated and optionally smoothed with a (circular) moving average.
+ The requested statistic is returned.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ stat : {'absamp','relamp', 'phase', 'min', 'max', 'asymmetry'}
+ - 'absamp' is the peak-to-peak amplitude. (max - min). In the same units as the input.
+ - 'relamp' is a relative percentage. 100 * (max - min) / mean (Recommended for precipitation). Dimensionless.
+ - 'phase' is the day of year of the maximum.
+ - 'max' is the maximum. Same units as the input.
+ - 'min' is the minimum. Same units as the input.
+ - 'asymmetry' is the length of the period going from the minimum to the maximum. In years between 0 and 1.
+ window : int
+ Size of the window for the moving average filtering. Deactivate this feature by passing window = 1.
+
+ Returns
+ -------
+ xr.DataArray, [same units as input or dimensionless or time]
+ {stat} of the annual cycle.
+ """
+ group = group if isinstance(group, Grouper) else Grouper(group)
+ units = da.units
+
+ ac = da.groupby("time.dayofyear").mean()
+ if window > 1: # smooth the cycle
+ # We want the rolling mean to be circular. There's no built-in method to do this in xarray,
+ # we'll pad the array and extract the meaningful part.
+ ac = (
+ ac.pad(dayofyear=(window // 2), mode="wrap")
+ .rolling(dayofyear=window, center=True)
+ .mean()
+ .isel(dayofyear=slice(window // 2, -(window // 2)))
+ )
+ # TODO: In April 2024, use a match-case.
+ if stat == "absamp":
+ out = ac.max("dayofyear") - ac.min("dayofyear")
+ out.attrs["units"] = xc.core.units.ensure_delta(units)
+ elif stat == "relamp":
+ out = (ac.max("dayofyear") - ac.min("dayofyear")) * 100 / ac.mean("dayofyear")
+ out.attrs["units"] = "%"
+ elif stat == "phase":
+ out = ac.idxmax("dayofyear")
+ out.attrs.update(units="", is_dayofyear=np.int32(1))
+ elif stat == "min":
+ out = ac.min("dayofyear")
+ out.attrs["units"] = units
+ elif stat == "max":
+ out = ac.max("dayofyear")
+ out.attrs["units"] = units
+ elif stat == "asymmetry":
+ out = (ac.idxmax("dayofyear") - ac.idxmin("dayofyear")) % 365 / 365
+ out.attrs["units"] = "yr"
+ else:
+ raise NotImplementedError(f"{stat} is not a valid annual cycle statistic.")
+ return out
+
+
+annual_cycle_amplitude = StatisticalProperty(
+ identifier="annual_cycle_amplitude",
+ aspect="temporal",
+ compute=_annual_cycle,
+ parameters={"stat": "absamp"},
+ allowed_groups=["group"],
+ cell_methods="time: mean time: range",
+)
+
+relative_annual_cycle_amplitude = StatisticalProperty(
+ identifier="relative_annual_cycle_amplitude",
+ aspect="temporal",
+ compute=_annual_cycle,
+ units="%",
+ parameters={"stat": "relamp"},
+ allowed_groups=["group"],
+ cell_methods="time: mean time: range",
+ measure="xclim.sdba.measures.RATIO",
+)
+
+annual_cycle_phase = StatisticalProperty(
+ identifier="annual_cycle_phase",
+ aspect="temporal",
+ units="",
+ compute=_annual_cycle,
+ parameters={"stat": "phase"},
+ cell_methods="time: range",
+ allowed_groups=["group"],
+ measure="xclim.sdba.measures.CIRCULAR_BIAS",
+)
+
+annual_cycle_asymmetry = StatisticalProperty(
+ identifier="annual_cycle_asymmetry",
+ aspect="temporal",
+ compute=_annual_cycle,
+ parameters={"stat": "asymmetry"},
+ allowed_groups=["group"],
+ units="yr",
+)
+
+annual_cycle_minimum = StatisticalProperty(
+ identifier="annual_cycle_minimum",
+ aspect="temporal",
+ units="",
+ compute=_annual_cycle,
+ parameters={"stat": "min"},
+ cell_methods="time: mean time: min",
+ allowed_groups=["group"],
+)
+
+annual_cycle_maximum = StatisticalProperty(
+ identifier="annual_cycle_maximum",
+ aspect="temporal",
+ compute=_annual_cycle,
+ parameters={"stat": "max"},
+ cell_methods="time: mean time: max",
+ allowed_groups=["group"],
+)
+
+
+# @parse_group
+def _annual_statistic(
+ da: xr.DataArray,
+ *,
+ stat: str = "absamp",
+ window: int = 31,
+ group: str | Grouper = "time",
+):
+ """Annual range statistics.
+
+ Compute a statistic on each year of data and return the interannual average. This is similar
+ to the annual cycle, but with the statistic and average operations inverted.
+
+ Parameters
+ ----------
+ da: xr.DataArray
+ Data.
+ stat : {'absamp', 'relamp', 'phase'}
+ The statistic to return.
+ window : int
+ Size of the window for the moving average filtering. Deactivate this feature by passing window = 1.
+
+ Returns
+ -------
+ xr.DataArray, [same units as input or dimensionless]
+ Average annual {stat}.
+ """
+ units = da.units
+
+ if window > 1:
+ da = da.rolling(time=window, center=True).mean()
+
+ yrs = da.resample(time="YS")
+
+ if stat == "absamp":
+ out = yrs.max() - yrs.min()
+ out.attrs["units"] = ensure_delta(units)
+ elif stat == "relamp":
+ out = (yrs.max() - yrs.min()) * 100 / yrs.mean()
+ out.attrs["units"] = "%"
+ elif stat == "phase":
+ out = yrs.map(xr.DataArray.idxmax).dt.dayofyear
+ out.attrs.update(units="", is_dayofyear=np.int32(1))
+ else:
+ raise NotImplementedError(f"{stat} is not a valid annual cycle statistic.")
+ return out.mean("time", keep_attrs=True)
+
+
+mean_annual_range = StatisticalProperty(
+ identifier="mean_annual_range",
+ aspect="temporal",
+ compute=_annual_statistic,
+ parameters={"stat": "absamp"},
+ allowed_groups=["group"],
+)
+
+mean_annual_relative_range = StatisticalProperty(
+ identifier="mean_annual_relative_range",
+ aspect="temporal",
+ compute=_annual_statistic,
+ parameters={"stat": "relamp"},
+ allowed_groups=["group"],
+ units="%",
+ measure="xclim.sdba.measures.RATIO",
+)
+
+mean_annual_phase = StatisticalProperty(
+ identifier="mean_annual_phase",
+ aspect="temporal",
+ compute=_annual_statistic,
+ parameters={"stat": "phase"},
+ allowed_groups=["group"],
+ units="",
+ measure="xclim.sdba.measures.CIRCULAR_BIAS",
+)
+
+
+@parse_group
+def _corr_btw_var(
+ da1: xr.DataArray,
+ da2: xr.DataArray,
+ *,
+ corr_type: str = "Spearman",
+ group: str | Grouper = "time",
+ output: str = "correlation",
+) -> xr.DataArray:
+ r"""Correlation between two variables.
+
+ Spearman or Pearson correlation coefficient between two variables at the time resolution.
+
+ Parameters
+ ----------
+ da1 : xr.DataArray
+ First variable on which to calculate the diagnostic.
+ da2 : xr.DataArray
+ Second variable on which to calculate the diagnostic.
+ corr_type: {'Pearson','Spearman'}
+ Type of correlation to calculate.
+ output: {'correlation', 'pvalue'}
+ Whether to return the correlation coefficient or the p-value.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. For 'time.month', the correlation would be calculated on each month separately,
+ but with all the years together.
+
+ Returns
+ -------
+ xr.DataArray, [dimensionless]
+ {corr_type} correlation coefficient
+ """
+ if corr_type.lower() not in {"pearson", "spearman"}:
+ raise ValueError(
+ f"{corr_type} is not a valid type. Choose 'Pearson' or 'Spearman'."
+ )
+
+ index = {"correlation": 0, "pvalue": 1}[output]
+
+ def _first_output_1d(a, b, index, corr_type):
+ """Only keep the correlation (first output) from the scipy function."""
+ # for points in the water with NaNs
+ if np.isnan(a).all():
+ return np.nan
+ aok = ~np.isnan(a)
+ bok = ~np.isnan(b)
+ if corr_type == "Pearson":
+ return stats.pearsonr(a[aok & bok], b[aok & bok])[index]
+ return stats.spearmanr(a[aok & bok], b[aok & bok])[index]
+
+ @map_groups(out=[Grouper.PROP], main_only=True)
+ def _first_output(ds, *, dim, index, corr_type):
+ out = xr.apply_ufunc(
+ _first_output_1d,
+ ds.a,
+ ds.b,
+ input_core_dims=[[dim], [dim]],
+ vectorize=True,
+ dask="parallelized",
+ kwargs={"index": index, "corr_type": corr_type},
+ )
+ return out.rename("out").to_dataset()
+
+ out = _first_output(
+ xr.Dataset({"a": da1, "b": da2}), group=group, index=index, corr_type=corr_type
+ ).out
+ out.attrs["units"] = ""
+ return out
+
+
+corr_btw_var = StatisticalProperty(
+ identifier="corr_btw_var", aspect="multivariate", compute=_corr_btw_var
+)
+
+
+def _bivariate_spell_length_distribution(
+ da1: xr.DataArray,
+ da2: xr.DataArray,
+ *,
+ method1: str = "amount",
+ method2: str = "amount",
+ op1: str = ">=",
+ op2: str = ">=",
+ thresh1: str = "1 mm d-1",
+ thresh2: str = "1 mm d-1",
+ window: int = 1,
+ stat: str = "mean",
+ stat_resample: str | None = None,
+ group: str | Grouper = "time",
+ resample_before_rl: bool = True,
+) -> xr.DataArray:
+ """Spell length distribution with bivariate condition.
+
+ Statistic of spell length distribution when two variables respect individual conditions (defined by an operation, a method,
+ and a threshold).
+
+ Parameters
+ ----------
+ da1 : xr.DataArray
+ First variable on which to calculate the diagnostic.
+ da2 : xr.DataArray
+ Second variable on which to calculate the diagnostic.
+ method1 : {'amount', 'quantile'}
+ Method to choose the threshold.
+ 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}.
+ 'quantile': The threshold is calculated as the quantile {thresh} of the distribution.
+ method2 : {'amount', 'quantile'}
+ Method to choose the threshold.
+ 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}.
+ 'quantile': The threshold is calculated as the quantile {thresh} of the distribution.
+ op1 : {">", "<", ">=", "<="}
+ Operation to verify the condition for a spell.
+ The condition for a spell is variable {op1} threshold.
+ op2 : {">", "<", ">=", "<="}
+ Operation to verify the condition for a spell.
+ The condition for a spell is variable {op2} threshold.
+ thresh1 : str or float
+ Threshold on which to evaluate the condition to have a spell.
+ String with units if the method is "amount".
+ Float of the quantile if the method is "quantile".
+ thresh2 : str or float
+ Threshold on which to evaluate the condition to have a spell.
+ String with units if the method is "amount".
+ Float of the quantile if the method is "quantile".
+ window : int
+ Number of consecutive days respecting the constraint in order to begin a spell.
+ Default is 1, which is equivalent to `_bivariate_threshold_count`
+ stat : {'mean', 'sum', 'max','min'}
+ Statistics to apply to the remaining time dimension after resampling (e.g. Jan 1980-2010)
+ stat_resample : {'mean', 'sum', 'max','min'}, optional
+ Statistics to apply to the resampled input at the {group} (e.g. 1-31 Jan 1980). If `None`, the same method as `stat` will be used.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. If 'time.month', the spell lengths are computed separately for each month.
+ resample_before_rl : bool
+ Determines if the resampling should take place before or after the run
+ length encoding (or a similar algorithm) is applied to runs.
+
+ Returns
+ -------
+ xr.DataArray, [units of the sampling frequency]
+ {stat} of spell length distribution when the first variable is {op1} the {method1} {thresh1}
+ and the second variable is {op2} the {method2} {thresh2} for {window} consecutive day(s).
+ """
+ group = group if isinstance(group, Grouper) else Grouper(group)
+ ops = {
+ ">": np.greater,
+ "<": np.less,
+ ">=": np.greater_equal,
+ "<=": np.less_equal,
+ }
+
+ @map_groups(out=[Grouper.PROP], main_only=True)
+ def _bivariate_spell_stats(
+ ds,
+ *,
+ dim,
+ methods,
+ threshs,
+ opss,
+ freq,
+ window,
+ resample_before_rl,
+ stat,
+ stat_resample,
+ ):
+ # PB: This prevents an import error in the distributed dask scheduler, but I don't know why.
+ import xarray.core.resample_cftime # noqa: F401, pylint: disable=unused-import
+
+ conds = []
+ masks = []
+ for da, thresh, op, method in zip([ds.da1, ds.da2], threshs, opss, methods):
+ masks.append(
+ ~(da.isel({dim: 0}).isnull()).drop_vars(dim)
+ ) # mask of the ocean with NaNs
+ if method == "quantile":
+ thresh = da.quantile(thresh, dim=dim).drop_vars("quantile")
+ conds.append(op(da, thresh))
+ mask = masks[0] & masks[1]
+ cond = conds[0] & conds[1]
+ out = rl.resample_and_rl(
+ cond,
+ resample_before_rl,
+ rl.rle_statistics,
+ reducer=stat_resample,
+ window=window,
+ dim=dim,
+ freq=freq,
+ )
+ out = getattr(out, stat)(dim=dim)
+ out = out.where(mask)
+ return out.rename("out").to_dataset()
+
+ # threshold is an amount that will be converted to the right units
+ methods = [method1, method2]
+ threshs = [thresh1, thresh2]
+ for i, da in enumerate([da1, da2]):
+ if methods[i] == "amount":
+ # ADAPT: will this work?
+ threshs[i] = convert_units_to(threshs[i], da) # , context="infer")
+ elif methods[i] != "quantile":
+ raise ValueError(
+ f"{methods[i]} is not a valid method. Choose 'amount' or 'quantile'."
+ )
+
+ out = _bivariate_spell_stats(
+ xr.Dataset({"da1": da1, "da2": da2}),
+ group=group,
+ threshs=threshs,
+ methods=methods,
+ opss=[ops[op1], ops[op2]],
+ window=window,
+ freq=group.freq,
+ resample_before_rl=resample_before_rl,
+ stat=stat,
+ stat_resample=stat_resample or stat,
+ ).out
+ return to_agg_units(out, da1, op="count")
+
+
+bivariate_spell_length_distribution = StatisticalProperty(
+ identifier="bivariate_spell_length_distribution",
+ aspect="temporal",
+ compute=_bivariate_spell_length_distribution,
+)
+
+
+@parse_group
+def _bivariate_threshold_count(
+ da1: xr.DataArray,
+ da2: xr.DataArray,
+ *,
+ method1: str = "amount",
+ method2: str = "amount",
+ op1: str = ">=",
+ op2: str = ">=",
+ thresh1: str = "1 mm d-1",
+ thresh2: str = "1 mm d-1",
+ stat: str = "mean",
+ stat_resample: str | None = None,
+ group: str | Grouper = "time",
+) -> xr.DataArray:
+ """Count the number of time steps where two variables respect given conditions.
+
+ Statistic of number of time steps when two variables respect individual conditions (defined by an operation, a method,
+ and a threshold).
+
+ Parameters
+ ----------
+ da1 : xr.DataArray
+ First variable on which to calculate the diagnostic.
+ da2 : xr.DataArray
+ Second variable on which to calculate the diagnostic.
+ method1 : {'amount', 'quantile'}
+ Method to choose the threshold.
+ 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}.
+ 'quantile': The threshold is calculated as the quantile {thresh} of the distribution.
+ method2 : {'amount', 'quantile'}
+ Method to choose the threshold.
+ 'amount': The threshold is directly the quantity in {thresh}. It needs to have the same units as {da}.
+ 'quantile': The threshold is calculated as the quantile {thresh} of the distribution.
+ op1 : {">", "<", ">=", "<="}
+ Operation to verify the condition for a spell.
+ The condition for a spell is variable {op} threshold.
+ op2 : {">", "<", ">=", "<="}
+ Operation to verify the condition for a spell.
+ The condition for a spell is variable {op} threshold.
+ thresh1 : str or float
+ Threshold on which to evaluate the condition to have a spell.
+ String with units if the method is "amount".
+ Float of the quantile if the method is "quantile".
+ thresh2 : str or float
+ Threshold on which to evaluate the condition to have a spell.
+ String with units if the method is "amount".
+ Float of the quantile if the method is "quantile".
+ stat : {'mean', 'sum', 'max','min'}
+ Statistics to apply to the remaining time dimension after resampling (e.g. Jan 1980-2010)
+ stat_resample : {'mean', 'sum', 'max','min'}, optional
+ Statistics to apply to the resampled input at the {group} (e.g. 1-31 Jan 1980).
+ If `None`, the same method as `stat` will be used.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output.
+ e.g. For 'time.month', the correlation would be calculated on each month separately,
+ but with all the years together.
+
+ Returns
+ -------
+ xr.DataArray, [dimensionless]
+ {stat} number of days when the first variable is {op1} the {method1} {thresh1}
+ and the second variable is {op2} the {method2} {thresh2} for {window} consecutive day(s).
+
+ Notes
+ -----
+ This corresponds to ``xclim.sdba.properties._bivariate_spell_length_distribution`` with `window=1`.
+ """
+ return _bivariate_spell_length_distribution(
+ da1,
+ da2,
+ method1=method1,
+ method2=method2,
+ op1=op1,
+ op2=op2,
+ thresh1=thresh1,
+ thresh2=thresh2,
+ window=1,
+ stat=stat,
+ stat_resample=stat_resample,
+ group=group,
+ )
+
+
+bivariate_threshold_count = StatisticalProperty(
+ identifier="bivariate_threshold_count",
+ aspect="multivariate",
+ compute=_bivariate_threshold_count,
+)
+
+
+@parse_group
+def _relative_frequency(
+ da: xr.DataArray,
+ *,
+ op: str = ">=",
+ thresh: str = "1 mm d-1",
+ group: str | Grouper = "time",
+) -> xr.DataArray:
+ """Relative Frequency.
+
+ Relative Frequency of days with variable respecting a condition (defined by an operation and a threshold) at the
+ time resolution. The relative frequency is the number of days that satisfy the condition divided by the total number
+ of days.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ op : {">", "<", ">=", "<="}
+ Operation to verify the condition.
+ The condition is variable {op} threshold.
+ thresh : str
+ Threshold on which to evaluate the condition.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping on the output.
+ e.g. For 'time.month', the relative frequency would be calculated on each month, with all years included.
+
+ Returns
+ -------
+ xr.DataArray, [dimensionless]
+ Relative frequency of values {op} {thresh}.
+ """
+ # mask of the ocean with NaNs
+ mask = ~(da.isel({group.dim: 0}).isnull()).drop_vars(group.dim)
+ ops: dict[str, np.ufunc] = {
+ ">": np.greater,
+ "<": np.less,
+ ">=": np.greater_equal,
+ "<=": np.less_equal,
+ }
+ t = convert_units_to(thresh, da) # , context="infer")
+ length = da.sizes[group.dim]
+ cond = ops[op](da, t)
+ if group.prop != "group": # change the time resolution if necessary
+ cond = cond.groupby(group.name)
+ # length of the groupBy groups
+ length = np.array([len(v) for k, v in cond.groups.items()])
+ for _ in range(da.ndim - 1): # add empty dimension(s) to match input
+ length = np.expand_dims(length, axis=-1)
+ # count days with the condition and divide by total nb of days
+ out = cond.sum(dim=group.dim, skipna=False) / length
+ out = out.where(mask, np.nan)
+ out.attrs["units"] = ""
+ return out
+
+
+relative_frequency = StatisticalProperty(
+ identifier="relative_frequency", aspect="temporal", compute=_relative_frequency
+)
+
+
+@parse_group
+def _transition_probability(
+ da: xr.DataArray,
+ *,
+ initial_op: str = ">=",
+ final_op: str = ">=",
+ thresh: str = "1 mm d-1",
+ group: str | Grouper = "time",
+) -> xr.DataArray:
+ """Transition probability.
+
+ Probability of transition from the initial state to the final state. The states are
+ booleans comparing the value of the day to the threshold with the operator.
+
+ The transition occurs when consecutive days are both in the given states.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ initial_op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Operation to verify the condition for the initial state.
+ The condition is variable {op} threshold.
+ final_op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Operation to verify the condition for the final state.
+ The condition is variable {op} threshold.
+ thresh : str
+ Threshold on which to evaluate the condition.
+ group : {"time", "time.season", "time.month"}
+ Grouping on the output.
+ e.g. For "time.month", the transition probability would be calculated on each month, with all years included.
+
+ Returns
+ -------
+ xr.DataArray, [dimensionless]
+ Transition probability of values {initial_op} {thresh} to values {final_op} {thresh}.
+ """
+ # mask of the ocean with NaNs
+ mask = ~(da.isel({group.dim: 0}).isnull()).drop_vars(group.dim)
+
+ today = da.isel(time=slice(0, -1))
+ tomorrow = da.shift(time=-1).isel(time=slice(0, -1))
+
+ t = convert_units_to(thresh, da) # , context="infer")
+ cond = compare(today, initial_op, t) * compare(tomorrow, final_op, t)
+ out = group.apply("mean", cond)
+ out = out.where(mask, np.nan)
+ out.attrs["units"] = ""
+ return out
+
+
+transition_probability = StatisticalProperty(
+ identifier="transition_probability",
+ aspect="temporal",
+ compute=_transition_probability,
+)
+
+
+@parse_group
+def _trend(
+ da: xr.DataArray,
+ *,
+ group: str | Grouper = "time",
+ output: str = "slope",
+) -> xr.DataArray:
+ """Linear Trend.
+
+ The data is averaged over each time resolution and the inter-annual trend is returned.
+ This function will rechunk along the grouping dimension.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ output : {'slope', 'intercept', 'rvalue', 'pvalue', 'stderr', 'intercept_stderr'}
+ The attributes of the linear regression to return, as defined in scipy.stats.linregress:
+ 'slope' is the slope of the regression line.
+ 'intercept' is the intercept of the regression line.
+ 'rvalue' is The Pearson correlation coefficient.
+ The square of rvalue is equal to the coefficient of determination.
+ 'pvalue' is the p-value for a hypothesis test whose null hypothesis is that the slope is zero,
+ using Wald Test with t-distribution of the test statistic.
+ 'stderr' is the standard error of the estimated slope (gradient), under the assumption of residual normality.
+ 'intercept_stderr' is the standard error of the estimated intercept, under the assumption of residual normality.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping on the output.
+
+ Returns
+ -------
+ xr.DataArray, [units of input per year or dimensionless]
+ {output} of the interannual linear trend.
+
+ See Also
+ --------
+ scipy.stats.linregress
+
+ numpy.polyfit
+ """
+ units = da.units
+
+ da = da.resample({group.dim: group.freq}) # separate all the {group}
+ da_mean = da.mean(dim=group.dim) # avg over all {group}
+ if uses_dask(da_mean):
+ da_mean = da_mean.chunk({group.dim: -1})
+ if group.prop != "group":
+ da_mean = da_mean.groupby(group.name) # group all month/season together
+
+ def modified_lr(
+ x,
+ ): # modify linregress to fit into apply_ufunc and only return slope
+ return getattr(stats.linregress(list(range(len(x))), x), output)
+
+ out = xr.apply_ufunc(
+ modified_lr,
+ da_mean,
+ input_core_dims=[[group.dim]],
+ vectorize=True,
+ dask="parallelized",
+ )
+ out.attrs["units"] = f"{units}/year"
+ return out
+
+
+trend = StatisticalProperty(identifier="trend", aspect="temporal", compute=_trend)
+
+
+@parse_group
+def _return_value(
+ da: xr.DataArray,
+ *,
+ period: int = 20,
+ op: str = "max",
+ method: str = "ML",
+ group: str | Grouper = "time",
+) -> xr.DataArray:
+ r"""Return value.
+
+ Return the value corresponding to a return period. On average, the return value will be exceeded
+ (or not exceed for op='min') every return period (e.g. 20 years). The return value is computed by first extracting
+ the variable annual maxima/minima, fitting a statistical distribution to the maxima/minima,
+ then estimating the percentile associated with the return period (eg. 95th percentile (1/20) for 20 years)
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Variable on which to calculate the diagnostic.
+ period : int
+ Return period. Number of years over which to check if the value is exceeded (or not for op='min').
+ op : {'max','min'}
+ Whether we are looking for a probability of exceedance ('max', right side of the distribution)
+ or a probability of non-exceedance (min, left side of the distribution).
+ method : {"ML", "PWM"}
+ Fitting method, either maximum likelihood (ML) or probability weighted moments (PWM), also called L-Moments.
+ The PWM method is usually more robust to outliers.
+ group : {'time', 'time.season', 'time.month'}
+ Grouping of the output. A distribution of the extremes is done for each group.
+
+ Returns
+ -------
+ xr.DataArray, [same as input]
+ {period}-{group.prop_name} {op} return level of the variable.
+ """
+
+ @map_groups(out=[Grouper.PROP], main_only=True)
+ def frequency_analysis_method(ds, *, dim, method):
+ sub = select_resample_op(ds.x, op=op)
+ params = fit(sub, dist="genextreme", method=method)
+ out = parametric_quantile(params, q=1 - 1.0 / period)
+ return out.isel(quantile=0, drop=True).rename("out").to_dataset()
+
+ out = frequency_analysis_method(
+ da.rename("x").to_dataset(), method=method, group=group
+ ).out
+ return out.assign_attrs(units=da.units)
+
+
+return_value = StatisticalProperty(
+ identifier="return_value", aspect="temporal", compute=_return_value
+)
+
+
+@parse_group
+def _spatial_correlogram(
+ da: xr.DataArray,
+ *,
+ dims: Sequence[str] | None = None,
+ bins: int = 100,
+ group: str = "time",
+ method: int = 1,
+):
+ """Spatial correlogram.
+
+ Compute the pairwise spatial correlations (Spearman) and averages them based on the pairwise distances.
+ This collapses the spatial and temporal dimensions and returns a distance bins dimension.
+ Needs coordinates for longitude and latitude. This property is heavy to compute, and it will
+ need to create a NxN array in memory (outside of dask), where N is the number of spatial points.
+ There are shortcuts for all-nan time-slices or spatial points, but scipy's nan-omitting algorithm
+ is extremely slow, so the presence of any lone NaN will increase the computation time. Based on an idea
+ from :cite:p:`francois_multivariate_2020`.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Data.
+ dims : sequence of strings, optional
+ Name of the spatial dimensions. Once these are stacked, the longitude and latitude coordinates must be 1D.
+ bins : int
+ Same as argument `bins` from :py:meth:`xarray.DataArray.groupby_bins`.
+ If given as a scalar, the equal-width bin limits are generated here
+ (instead of letting xarray do it) to improve performance.
+ group : str
+ Useless for now.
+
+ Returns
+ -------
+ xr.DataArray, [dimensionless]
+ Inter-site correlogram as a function of distance.
+ """
+ if dims is None:
+ dims = [d for d in da.dims if d != "time"]
+
+ corr = _pairwise_spearman(da, dims)
+ dists, mn, mx = _pairwise_haversine_and_bins(
+ corr.cf["longitude"].values, corr.cf["latitude"].values
+ )
+ dists = xr.DataArray(dists, dims=corr.dims, coords=corr.coords, name="distance")
+ if np.isscalar(bins):
+ bins = np.linspace(mn * 0.9999, mx * 1.0001, bins + 1)
+ if uses_dask(corr):
+ dists = dists.chunk()
+
+ w = np.diff(bins)
+ centers = xr.DataArray(
+ bins[:-1] + w / 2,
+ dims=("distance_bins",),
+ attrs={
+ "units": "km",
+ "long_name": f"Centers of the intersite distance bins (width of {w[0]:.3f} km)",
+ },
+ )
+
+ dists = dists.where(corr.notnull())
+
+ def _bin_corr(corr, distance):
+ """Bin and mean."""
+ return stats.binned_statistic(
+ distance.flatten(), corr.flatten(), statistic="mean", bins=bins
+ ).statistic
+
+ # (_spatial, _spatial2) -> (_spatial, distance_bins)
+ binned = xr.apply_ufunc(
+ _bin_corr,
+ corr,
+ dists,
+ input_core_dims=[["_spatial", "_spatial2"], ["_spatial", "_spatial2"]],
+ output_core_dims=[["distance_bins"]],
+ dask="parallelized",
+ vectorize=True,
+ output_dtypes=[float],
+ dask_gufunc_kwargs={
+ "allow_rechunk": True,
+ "output_sizes": {"distance_bins": bins},
+ },
+ )
+ binned = (
+ binned.assign_coords(distance_bins=centers)
+ .rename(distance_bins="distance")
+ .assign_attrs(units="")
+ .rename("corr")
+ )
+ return binned
+
+
+spatial_correlogram = StatisticalProperty(
+ identifier="spatial_correlogram",
+ aspect="spatial",
+ compute=_spatial_correlogram,
+ allowed_groups=["group"],
+)
+
+
+def _decorrelation_length(
+ da: xr.DataArray,
+ *,
+ radius: int | float = 300,
+ thresh: float = 0.50,
+ dims: Sequence[str] | None = None,
+ bins: int = 100,
+ group: xr.Coordinate | str | None = "time", # FIXME: this needs to be clarified
+):
+ """Decorrelation length.
+
+ Distance from a grid cell where the correlation with its neighbours goes below the threshold.
+ A correlogram is calculated for each grid cell following the method from
+ ``xclim.sdba.properties.spatial_correlogram``. Then, we find the first bin closest to the correlation threshold.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Data.
+ radius : float
+ Radius (in km) defining the region where correlations will be calculated between a point and its neighbours.
+ thresh : float
+ Threshold correlation defining decorrelation.
+ The decorrelation length is defined as the center of the distance bin that has a correlation closest
+ to this threshold.
+ dims : sequence of strings
+ Name of the spatial dimensions. Once these are stacked, the longitude and latitude coordinates must be 1D.
+ bins : int
+ Same as argument `bins` from :py:meth:`scipy.stats.binned_statistic`.
+ If given as a scalar, the equal-width bin limits from 0 to radius are generated here
+ (instead of letting scipy do it) to improve performance.
+ group : xarray.Coordinate or str, optional
+ Useless for now.
+
+ Returns
+ -------
+ xr.DataArray, [km]
+ Decorrelation length.
+
+ Notes
+ -----
+ Calculating this property requires a lot of memory. It will not work with large datasets.
+ """
+ if dims is None and group is not None:
+ dims = [d for d in da.dims if d != group.dim]
+
+ corr = _pairwise_spearman(da, dims)
+
+ dists, _, _ = _pairwise_haversine_and_bins(
+ corr.cf["longitude"].values, corr.cf["latitude"].values, transpose=True
+ )
+
+ dists = xr.DataArray(dists, dims=corr.dims, coords=corr.coords, name="distance")
+
+ trans_dists = xr.DataArray(
+ dists.T, dims=corr.dims, coords=corr.coords, name="distance"
+ )
+
+ if np.isscalar(bins):
+ bin_array = np.linspace(0, radius, bins + 1)
+ elif isinstance(bins, np.ndarray):
+ bin_array = bins
+ else:
+ raise ValueError("bins must be a scalar or a numpy array.")
+
+ if uses_dask(corr):
+ dists = dists.chunk()
+ trans_dists = trans_dists.chunk()
+
+ w = np.diff(bin_array)
+ centers = xr.DataArray(
+ bin_array[:-1] + w / 2,
+ dims=("distance_bins",),
+ attrs={
+ "units": "km",
+ "long_name": f"Centers of the intersite distance bins (width of {w[0]:.3f} km)",
+ },
+ )
+ ds = xr.Dataset({"corr": corr, "distance": dists, "distance2": trans_dists})
+
+ # only keep points inside the radius
+ ds = ds.where(ds.distance < radius)
+ ds = ds.where(ds.distance2 < radius)
+
+ def _bin_corr(_corr, _distance):
+ """Bin and mean."""
+ mask_nan = ~np.isnan(_corr)
+ binned_corr = stats.binned_statistic(
+ _distance[mask_nan], _corr[mask_nan], statistic="mean", bins=bin_array
+ )
+ stat = binned_corr.statistic
+ return stat
+
+ # (_spatial, _spatial2) -> (_spatial, distance_bins)
+ binned = (
+ xr.apply_ufunc(
+ _bin_corr,
+ ds.corr,
+ ds.distance,
+ input_core_dims=[["_spatial2"], ["_spatial2"]],
+ output_core_dims=[["distance_bins"]],
+ dask="parallelized",
+ vectorize=True,
+ output_dtypes=[float],
+ dask_gufunc_kwargs={
+ "allow_rechunk": True,
+ "output_sizes": {"distance_bins": len(bin_array)},
+ },
+ )
+ .rename("corr")
+ .to_dataset()
+ )
+
+ binned = (
+ binned.assign_coords(distance_bins=centers)
+ .rename(distance_bins="distance")
+ .assign_attrs(units="")
+ )
+
+ closest = abs(binned.corr - thresh).idxmin(dim="distance")
+ binned["decorrelation_length"] = closest
+
+ # get back to 2d lat and lon
+ # if 'lat' in dims and 'lon' in dims:
+ if len(dims) > 1:
+ binned = binned.set_index({"_spatial": dims})
+ out = binned.decorrelation_length.unstack()
+ else:
+ out = binned.swap_dims({"_spatial": dims[0]}).decorrelation_length
+
+ copy_all_attrs(out, da)
+
+ out.attrs["units"] = "km"
+ return out
+
+
+decorrelation_length = StatisticalProperty(
+ identifier="decorrelation_length",
+ aspect="spatial",
+ compute=_decorrelation_length,
+ allowed_groups=["group"],
+)
+
+
+def first_eof():
+ """EOF Statistical Property (function removed).
+
+ Warnings
+ --------
+ Due to a licensing issue, eofs-based functionality has been permanently removed.
+ Please excuse the inconvenience.
+ For more information, see: https://github.com/Ouranosinc/xclim/issues/1620
+ """
+ raise RuntimeError(
+ "Due to a licensing issue, eofs-based functionality has been permanently removed. "
+ "Please excuse the inconvenience. "
+ "For more information, see: https://github.com/Ouranosinc/xclim/issues/1620"
+ )
diff --git a/src/xsdba/typing.py b/src/xsdba/typing.py
new file mode 100644
index 0000000..ac96ad2
--- /dev/null
+++ b/src/xsdba/typing.py
@@ -0,0 +1,133 @@
+"""# noqa: SS01
+Typing Utilities
+===================================
+"""
+
+from __future__ import annotations
+
+from enum import IntEnum
+from typing import NewType, TypeVar
+
+import xarray as xr
+from pint import Quantity
+
+# XC:
+#: Type annotation for strings representing full dates (YYYY-MM-DD), may include time.
+DateStr = NewType("DateStr", str)
+
+#: Type annotation for strings representing dates without a year (MM-DD).
+DayOfYearStr = NewType("DayOfYearStr", str)
+
+#: Type annotation for thresholds and other not-exactly-a-variable quantities
+Quantified = TypeVar("Quantified", xr.DataArray, str, Quantity)
+
+
+# XC
+class InputKind(IntEnum):
+ """Constants for input parameter kinds.
+
+ For use by external parses to determine what kind of data the indicator expects.
+ On the creation of an indicator, the appropriate constant is stored in
+ :py:attr:`xclim.core.indicator.Indicator.parameters`. The integer value is what gets stored in the output
+ of :py:meth:`xclim.core.indicator.Indicator.json`.
+
+ For developers : for each constant, the docstring specifies the annotation a parameter of an indice function
+ should use in order to be picked up by the indicator constructor. Notice that we are using the annotation format
+ as described in `PEP 604 `_, i.e. with '|' indicating a union and without import
+ objects from `typing`.
+ """
+
+ VARIABLE = 0
+ """A data variable (DataArray or variable name).
+
+ Annotation : ``xr.DataArray``.
+ """
+ OPTIONAL_VARIABLE = 1
+ """An optional data variable (DataArray or variable name).
+
+ Annotation : ``xr.DataArray | None``. The default should be None.
+ """
+ QUANTIFIED = 2
+ """A quantity with units, either as a string (scalar), a pint.Quantity (scalar) or a DataArray (with units set).
+
+ Annotation : ``xclim.core.utils.Quantified`` and an entry in the :py:func:`xclim.core.units.declare_units`
+ decorator. "Quantified" translates to ``str | xr.DataArray | pint.util.Quantity``.
+ """
+ FREQ_STR = 3
+ """A string representing an "offset alias", as defined by pandas.
+
+ See the Pandas documentation on :ref:`timeseries.offset_aliases` for a list of valid aliases.
+
+ Annotation : ``str`` + ``freq`` as the parameter name.
+ """
+ NUMBER = 4
+ """A number.
+
+ Annotation : ``int``, ``float`` and unions thereof, potentially optional.
+ """
+ STRING = 5
+ """A simple string.
+
+ Annotation : ``str`` or ``str | None``. In most cases, this kind of parameter makes sense
+ with choices indicated in the docstring's version of the annotation with curly braces.
+ See :ref:`notebooks/extendxclim:Defining new indices`.
+ """
+ DAY_OF_YEAR = 6
+ """A date, but without a year, in the MM-DD format.
+
+ Annotation : :py:obj:`xclim.core.utils.DayOfYearStr` (may be optional).
+ """
+ DATE = 7
+ """A date in the YYYY-MM-DD format, may include a time.
+
+ Annotation : :py:obj:`xclim.core.utils.DateStr` (may be optional).
+ """
+ NUMBER_SEQUENCE = 8
+ """A sequence of numbers
+
+ Annotation : ``Sequence[int]``, ``Sequence[float]`` and unions thereof, may include single ``int`` and ``float``,
+ may be optional.
+ """
+ BOOL = 9
+ """A boolean flag.
+
+ Annotation : ``bool``, may be optional.
+ """
+ DICT = 10
+ """A dictionary.
+
+ Annotation : ``dict`` or ``dict | None``, may be optional.
+ """
+ KWARGS = 50
+ """A mapping from argument name to value.
+
+ Developers : maps the ``**kwargs``. Please use as little as possible.
+ """
+ DATASET = 70
+ """An xarray dataset.
+
+ Developers : as indices only accept DataArrays, this should only be added on the indicator's constructor.
+ """
+ OTHER_PARAMETER = 99
+ """An object that fits None of the previous kinds.
+
+ Developers : This is the fallback kind, it will raise an error in xclim's unit tests if used.
+ """
+
+
+KIND_ANNOTATION = {
+ InputKind.VARIABLE: "str or DataArray",
+ InputKind.OPTIONAL_VARIABLE: "str or DataArray, optional",
+ InputKind.QUANTIFIED: "quantity (string or DataArray, with units)",
+ InputKind.FREQ_STR: "offset alias (string)",
+ InputKind.NUMBER: "number",
+ InputKind.NUMBER_SEQUENCE: "number or sequence of numbers",
+ InputKind.STRING: "str",
+ InputKind.DAY_OF_YEAR: "date (string, MM-DD)",
+ InputKind.DATE: "date (string, YYYY-MM-DD)",
+ InputKind.BOOL: "boolean",
+ InputKind.DICT: "dict",
+ InputKind.DATASET: "Dataset, optional",
+ InputKind.KWARGS: "",
+ InputKind.OTHER_PARAMETER: "Any",
+}
diff --git a/src/xsdba/units.py b/src/xsdba/units.py
index 634d9a0..484182f 100644
--- a/src/xsdba/units.py
+++ b/src/xsdba/units.py
@@ -19,14 +19,75 @@
except ImportError: # noqa: S110
# cf-xarray is not installed, this will not be used
pass
+import warnings
+
import numpy as np
import xarray as xr
-from .base import Quantified, copy_all_attrs
+from .calendar import parse_offset
+from .typing import Quantified
+from .utils import copy_all_attrs
units = pint.get_application_registry()
+FREQ_UNITS = {
+ "D": "d",
+ "W": "week",
+}
+"""
+Resampling frequency units for :py:func:`xclim.core.units.infer_sampling_units`.
+
+Mapping from offset base to CF-compliant unit. Only constant-length frequencies are included.
+"""
+
+
+def infer_sampling_units(
+ da: xr.DataArray,
+ deffreq: str | None = "D",
+ dim: str = "time",
+) -> tuple[int, str]:
+ """Infer a multiplier and the units corresponding to one sampling period.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ A DataArray from which to take coordinate `dim`.
+ deffreq : str, optional
+ If no frequency is inferred from `da[dim]`, take this one.
+ dim : str
+ Dimension from which to infer the frequency.
+
+ Raises
+ ------
+ ValueError
+ If the frequency has no exact corresponding units.
+
+ Returns
+ -------
+ int
+ The magnitude (number of base periods per period)
+ str
+ Units as a string, understandable by pint.
+ """
+ dimmed = getattr(da, dim)
+ freq = xr.infer_freq(dimmed)
+ if freq is None:
+ freq = deffreq
+
+ multi, base, _, _ = parse_offset(freq)
+ try:
+ out = multi, FREQ_UNITS.get(base, base)
+ except KeyError as err:
+ raise ValueError(
+ f"Sampling frequency {freq} has no corresponding units."
+ ) from err
+ if out == (7, "d"):
+ # Special case for weekly frequency. xarray's CFTimeOffsets do not have "W".
+ return 1, "week"
+ return out
+
+
# XC
def units2pint(value: xr.DataArray | str | units.Quantity) -> pint.Unit:
"""Return the pint Unit for the DataArray units.
@@ -97,28 +158,73 @@ def str2pint(val: str) -> pint.Quantity:
return units.Quantity(1, units2pint(val))
-# XC
-# def ensure_delta(unit: str) -> str:
-# """Return delta units for temperature.
-
-# For dimensions where delta exist in pint (Temperature), it replaces the temperature unit by delta_degC or
-# delta_degF based on the input unit. For other dimensionality, it just gives back the input units.
-
-# Parameters
-# ----------
-# unit : str
-# unit to transform in delta (or not)
-# """
-# u = units2pint(unit)
-# d = 1 * u
-# #
-# delta_unit = pint2cfunits(d - d)
-# # replace kelvin/rankine by delta_degC/F
-# if "kelvin" in u._units:
-# delta_unit = pint2cfunits(u / units2pint("K") * units2pint("delta_degC"))
-# if "degree_Rankine" in u._units:
-# delta_unit = pint2cfunits(u / units2pint("°R") * units2pint("delta_degF"))
-# return delta_unit
+def pint2str(value: units.Quantity | units.Unit) -> str:
+ """A unit string from a `pint` unit.
+
+ Parameters
+ ----------
+ value : pint.Unit
+ Input unit.
+
+ Returns
+ -------
+ str
+ Units
+
+ Notes
+ -----
+ If cf-xarray is installed, the units will be converted to cf units.
+ """
+ if isinstance(value, (pint.Quantity, units.Quantity)):
+ value = value.units
+
+ # Issue originally introduced in https://github.com/hgrecco/pint/issues/1486
+ # Should be resolved in pint v0.24. See: https://github.com/hgrecco/pint/issues/1913
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=DeprecationWarning)
+ return f"{value:cf}".replace("dimensionless", "")
+
+
+DELTA_ABSOLUTE_TEMP = {
+ units.delta_degC: units.kelvin,
+ units.delta_degF: units.rankine,
+}
+
+
+def ensure_absolute_temperature(units: str):
+ """Convert temperature units to their absolute counterpart, assuming they represented a difference (delta).
+
+ Celsius becomes Kelvin, Fahrenheit becomes Rankine. Does nothing for other units.
+ """
+ a = str2pint(units)
+ # ensure a delta pint unit
+ a = a - 0 * a
+ if a.units in DELTA_ABSOLUTE_TEMP:
+ return pint2str(DELTA_ABSOLUTE_TEMP[a.units])
+ return units
+
+
+def ensure_delta(unit: str) -> str:
+ """Return delta units for temperature.
+
+ For dimensions where delta exist in pint (Temperature), it replaces the temperature unit by delta_degC or
+ delta_degF based on the input unit. For other dimensionality, it just gives back the input units.
+
+ Parameters
+ ----------
+ unit : str
+ unit to transform in delta (or not)
+ """
+ u = units2pint(unit)
+ d = 1 * u
+ #
+ delta_unit = pint2str(d - d)
+ # replace kelvin/rankine by delta_degC/F
+ if "kelvin" in u._units:
+ delta_unit = pint2str(u / units2pint("K") * units2pint("delta_degC"))
+ if "degree_Rankine" in u._units:
+ delta_unit = pint2str(u / units2pint("°R") * units2pint("delta_degF"))
+ return delta_unit
def extract_units(arg):
@@ -126,16 +232,15 @@ def extract_units(arg):
if not (
isinstance(arg, (str, xr.DataArray, pint.Unit, units.Unit)) or np.isscalar(arg)
):
- print(arg)
raise TypeError(
f"Argument must be a str, DataArray, or scalar. Got {type(arg)}"
)
elif isinstance(arg, xr.DataArray):
ustr = None if "units" not in arg.attrs else arg.attrs["units"]
elif isinstance(arg, pint.Unit | units.Unit):
- ustr = f"{arg:cf}" # XC: from pint2cfunits
+ ustr = pint2str(arg) # XC: from pint2str
elif isinstance(arg, str):
- ustr = str2pint(arg).units
+ ustr = pint2str(str2pint(arg).units)
else: # (scalar case)
ustr = None
return ustr if ustr is None else pint.Quantity(1, ustr).units
@@ -230,75 +335,6 @@ def convert_units_to( # noqa: C901
return out
-def _fill_args_dict(args, kwargs, args_to_check, func):
- """Combine args and kwargs into a dict."""
- args_dict = {}
- signature = inspect.signature(func)
- for ik, (k, v) in enumerate(signature.parameters.items()):
- if ik < len(args):
- value = args[ik]
- if ik >= len(args):
- value = v.default if k not in kwargs else kwargs[k]
- args_dict[k] = value
- return args_dict
-
-
-def _split_args_kwargs(args, func):
- """Assign Keyword only arguments to kwargs."""
- kwargs = {}
- signature = inspect.signature(func)
- indices_to_pop = []
- for ik, (k, v) in enumerate(signature.parameters.items()):
- if v.kind == inspect.Parameter.KEYWORD_ONLY:
- indices_to_pop.append(ik)
- kwargs[k] = v
- indices_to_pop.sort(reverse=True)
- for ind in indices_to_pop:
- args.pop(ind)
- return args, kwargs
-
-
-# TODO: make it work with Dataset for real
-# TODO: add a switch to prevent string from being converted to float?
-def harmonize_units(args_to_check):
- """Check that units are compatible with dimensions, otherwise raise a `ValidationError`."""
-
- # if no units are present (DataArray without units attribute or float), then no check is performed
- # if units are present, then check is performed
- # in mixed cases, an error is raised
- def _decorator(func):
- @wraps(func)
- def _wrapper(*args, **kwargs):
- arg_names = inspect.getfullargspec(func).args
- args_dict = _fill_args_dict(list(args), kwargs, args_to_check, func)
- first_arg_name = args_to_check[0]
- first_arg = args_dict[first_arg_name]
- for arg_name in args_to_check[1:]:
- if isinstance(arg_name, str):
- value = args_dict[arg_name]
- key = arg_name
- if isinstance(
- arg_name, dict
- ): # support for Dataset, or a dict of thresholds
- key, val = list(arg_name.keys())[0], list(arg_name.values())[0]
- value = args_dict[key][val]
- if value is None: # optional argument, should be ignored
- args_to_check.remove(arg_name)
- continue
- if key not in args_dict:
- raise ValueError(
- f"Argument '{arg_name}' not found in function arguments."
- )
- args_dict[key] = convert_units_to(value, first_arg)
- args = list(args_dict.values())
- args, kwargs = _split_args_kwargs(args, kwargs, func)
- return func(*args, **kwargs)
-
- return _wrapper
-
- return _decorator
-
-
def _add_default_kws(params_dict, params_to_check, func):
"""Combine args and kwargs into a dict."""
args_dict = {}
@@ -361,3 +397,107 @@ def _wrapper(*args, **kwargs):
return _wrapper
return _decorator
+
+
+def to_agg_units(
+ out: xr.DataArray, orig: xr.DataArray, op: str, dim: str = "time"
+) -> xr.DataArray:
+ """Set and convert units of an array after an aggregation operation along the sampling dimension (time).
+
+ Parameters
+ ----------
+ out : xr.DataArray
+ The output array of the aggregation operation, no units operation done yet.
+ orig : xr.DataArray
+ The original array before the aggregation operation,
+ used to infer the sampling units and get the variable units.
+ op : {'min', 'max', 'mean', 'std', 'var', 'doymin', 'doymax', 'count', 'integral', 'sum'}
+ The type of aggregation operation performed. "integral" is mathematically equivalent to "sum",
+ but the units are multiplied by the timestep of the data (requires an inferrable frequency).
+ dim : str
+ The time dimension along which the aggregation was performed.
+
+ Returns
+ -------
+ xr.DataArray
+
+ Examples
+ --------
+ Take a daily array of temperature and count number of days above a threshold.
+ `to_agg_units` will infer the units from the sampling rate along "time", so
+ we ensure the final units are correct:
+
+ >>> time = xr.cftime_range("2001-01-01", freq="D", periods=365)
+ >>> tas = xr.DataArray(
+ ... np.arange(365),
+ ... dims=("time",),
+ ... coords={"time": time},
+ ... attrs={"units": "degC"},
+ ... )
+ >>> cond = tas > 100 # Which days are boiling
+ >>> Ndays = cond.sum("time") # Number of boiling days
+ >>> Ndays.attrs.get("units")
+ None
+ >>> Ndays = to_agg_units(Ndays, tas, op="count")
+ >>> Ndays.units
+ 'd'
+
+ Similarly, here we compute the total heating degree-days, but we have weekly data:
+
+ >>> time = xr.cftime_range("2001-01-01", freq="7D", periods=52)
+ >>> tas = xr.DataArray(
+ ... np.arange(52) + 10,
+ ... dims=("time",),
+ ... coords={"time": time},
+ ... )
+ >>> dt = (tas - 16).assign_attrs(units="delta_degC")
+ >>> degdays = dt.clip(0).sum("time") # Integral of temperature above a threshold
+ >>> degdays = to_agg_units(degdays, dt, op="integral")
+ >>> degdays.units
+ 'K week'
+
+ Which we can always convert to the more common "K days":
+
+ >>> degdays = convert_units_to(degdays, "K days")
+ >>> degdays.units
+ 'K d'
+ """
+ if op in ["amin", "min", "amax", "max", "mean", "sum"]:
+ out.attrs["units"] = orig.attrs["units"]
+
+ elif op in ["std"]:
+ out.attrs["units"] = ensure_absolute_temperature(orig.attrs["units"])
+
+ elif op in ["var"]:
+ out.attrs["units"] = pint2str(
+ str2pint(ensure_absolute_temperature(orig.units)) ** 2
+ )
+
+ elif op in ["doymin", "doymax"]:
+ out.attrs.update(
+ units="", is_dayofyear=np.int32(1), calendar=get_calendar(orig)
+ )
+
+ elif op in ["count", "integral"]:
+ m, freq_u_raw = infer_sampling_units(orig[dim])
+ orig_u = str2pint(ensure_absolute_temperature(orig.units))
+ freq_u = str2pint(freq_u_raw)
+ out = out * m
+
+ if op == "count":
+ out.attrs["units"] = freq_u_raw
+ elif op == "integral":
+ if "[time]" in orig_u.dimensionality:
+ # We need to simplify units after multiplication
+ out_units = (orig_u * freq_u).to_reduced_units()
+ out = out * out_units.magnitude
+ out.attrs["units"] = pint2str(out_units)
+ else:
+ out.attrs["units"] = pint2str(orig_u * freq_u)
+ else:
+ raise ValueError(
+ f"Unknown aggregation op {op}. "
+ "Known ops are [min, max, mean, std, var, doymin, doymax, count, integral, sum]."
+ )
+
+ return out
diff --git a/src/xsdba/utils.py b/src/xsdba/utils.py
index b816f80..34dcdfb 100644
--- a/src/xsdba/utils.py
+++ b/src/xsdba/utils.py
@@ -955,16 +955,6 @@ def rand_rot_matrix(
).astype("float32")
-def copy_all_attrs(ds: xr.Dataset | xr.DataArray, ref: xr.Dataset | xr.DataArray):
- """Copy all attributes of ds to ref, including attributes of shared coordinates, and variables in the case of Datasets."""
- ds.attrs.update(ref.attrs)
- extras = ds.variables if isinstance(ds, xr.Dataset) else ds.coords
- others = ref.variables if isinstance(ref, xr.Dataset) else ref.coords
- for name, var in extras.items():
- if name in others:
- var.attrs.update(ref[name].attrs)
-
-
def _pairwise_spearman(da, dims):
"""Area-averaged pairwise temporal correlation.
@@ -1016,3 +1006,47 @@ def _skipna_correlation(data):
"allow_rechunk": True,
},
).rename("correlation")
+
+
+# ADAPT: Maybe this is not the best place
+def copy_all_attrs(ds: xr.Dataset | xr.DataArray, ref: xr.Dataset | xr.DataArray):
+ """Copy all attributes of ds to ref, including attributes of shared coordinates, and variables in the case of Datasets."""
+ ds.attrs.update(ref.attrs)
+ extras = ds.variables if isinstance(ds, xr.Dataset) else ds.coords
+ others = ref.variables if isinstance(ref, xr.Dataset) else ref.coords
+ for name, var in extras.items():
+ if name in others:
+ var.attrs.update(ref[name].attrs)
+
+
+# ADAPT: Maybe this is not the best place
+def load_module(path: os.PathLike, name: str | None = None):
+ """Load a python module from a python file, optionally changing its name.
+
+ Examples
+ --------
+ Given a path to a module file (.py):
+
+ .. code-block:: python
+
+ from pathlib import Path
+ import os
+
+ path = Path("path/to/example.py")
+
+ The two following imports are equivalent, the second uses this method.
+
+ .. code-block:: python
+
+ os.chdir(path.parent)
+ import example as mod1
+
+ os.chdir(previous_working_dir)
+ mod2 = load_module(path)
+ mod1 == mod2
+ """
+ path = Path(path)
+ spec = importlib.util.spec_from_file_location(name or path.stem, path)
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod) # This executes code, effectively loading the module
+ return mod
diff --git a/src/xsdba/xclim_submodules/generic.py b/src/xsdba/xclim_submodules/generic.py
new file mode 100644
index 0000000..91697d3
--- /dev/null
+++ b/src/xsdba/xclim_submodules/generic.py
@@ -0,0 +1,941 @@
+"""
+Generic Indices Submodule
+=========================
+
+Helper functions for common generic actions done in the computation of indices.
+"""
+
+from __future__ import annotations
+
+import warnings
+from collections.abc import Sequence
+from typing import Callable
+
+import cftime
+import numpy as np
+import xarray
+import xarray as xr
+from xarray.coding.cftime_offsets import _MONTH_ABBREVIATIONS
+
+from xsdba.calendar import doy_to_days_since, get_calendar, select_time
+from xsdba.typing import DayOfYearStr, Quantified, Quantity
+from xsdba.units import (
+ convert_units_to,
+ harmonize_units,
+ pint2str,
+ str2pint,
+ to_agg_units,
+)
+
+from . import run_length as rl
+
+__all__ = [
+ "aggregate_between_dates",
+ "binary_ops",
+ "compare",
+ "count_level_crossings",
+ "count_occurrences",
+ "cumulative_difference",
+ "default_freq",
+ "detrend",
+ "diurnal_temperature_range",
+ "domain_count",
+ "doymax",
+ "doymin",
+ "extreme_temperature_range",
+ "first_day_threshold_reached",
+ "first_occurrence",
+ "get_daily_events",
+ "get_op",
+ "get_zones",
+ "interday_diurnal_temperature_range",
+ "last_occurrence",
+ "select_resample_op",
+ "spell_length",
+ "statistics",
+ "temperature_sum",
+ "threshold_count",
+ "thresholded_statistics",
+]
+
+binary_ops = {">": "gt", "<": "lt", ">=": "ge", "<=": "le", "==": "eq", "!=": "ne"}
+
+
+def select_resample_op(
+ da: xr.DataArray, op: str, freq: str = "YS", out_units=None, **indexer
+) -> xr.DataArray:
+ """Apply operation over each period that is part of the index selection.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input data.
+ op : str {'min', 'max', 'mean', 'std', 'var', 'count', 'sum', 'integral', 'argmax', 'argmin'} or func
+ Reduce operation. Can either be a DataArray method or a function that can be applied to a DataArray.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ out_units : str, optional
+ Output units to assign. Only necessary if `op` is function not supported by :py:func:`xclim.core.units.to_agg_units`.
+ indexer : {dim: indexer, }, optional
+ Time attribute and values over which to subset the array. For example, use season='DJF' to select winter values,
+ month=1 to select January, or month=[6,7,8] to select summer months. If not indexer is given, all values are
+ considered.
+
+ Returns
+ -------
+ xr.DataArray
+ The maximum value for each period.
+ """
+ da = select_time(da, **indexer)
+ r = da.resample(time=freq)
+ if op in _xclim_ops:
+ op = _xclim_ops[op]
+ if isinstance(op, str):
+ out = getattr(r, op.replace("integral", "sum"))(dim="time", keep_attrs=True)
+ else:
+ with xr.set_options(keep_attrs=True):
+ out = r.map(op)
+ op = op.__name__
+ if out_units is not None:
+ return out.assign_attrs(units=out_units)
+ return to_agg_units(out, da, op)
+
+
+def select_rolling_resample_op(
+ da: xr.DataArray,
+ op: str,
+ window: int,
+ window_center: bool = True,
+ window_op: str = "mean",
+ freq: str = "YS",
+ out_units=None,
+ **indexer,
+) -> xr.DataArray:
+ """Apply operation over each period that is part of the index selection, using a rolling window before the operation.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input data.
+ op : str {'min', 'max', 'mean', 'std', 'var', 'count', 'sum', 'integral', 'argmax', 'argmin'} or func
+ Reduce operation. Can either be a DataArray method or a function that can be applied to a DataArray.
+ window : int
+ Size of the rolling window (centered).
+ window_center : bool
+ If True, the window is centered on the date. If False, the window is right-aligned.
+ window_op : str {'min', 'max', 'mean', 'std', 'var', 'count', 'sum', 'integral'}
+ Operation to apply to the rolling window. Default: 'mean'.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. Applied after the rolling window.
+ out_units : str, optional
+ Output units to assign. Only necessary if `op` is function not supported by :py:func:`xclim.core.units.to_agg_units`.
+ indexer : {dim: indexer, }, optional
+ Time attribute and values over which to subset the array. For example, use season='DJF' to select winter values,
+ month=1 to select January, or month=[6,7,8] to select summer months. If not indexer is given, all values are
+ considered.
+
+ Returns
+ -------
+ xr.DataArray
+ The array for which the operation has been applied over each period.
+ """
+ rolled = getattr(
+ da.rolling(time=window, center=window_center),
+ window_op.replace("integral", "sum"),
+ )()
+ rolled = to_agg_units(rolled, da, window_op)
+ return select_resample_op(rolled, op=op, freq=freq, out_units=out_units, **indexer)
+
+
+def doymax(da: xr.DataArray) -> xr.DataArray:
+ """Return the day of year of the maximum value."""
+ i = da.argmax(dim="time")
+ out = da.time.dt.dayofyear.isel(time=i, drop=True)
+ return to_agg_units(out, da, "doymax")
+
+
+def doymin(da: xr.DataArray) -> xr.DataArray:
+ """Return the day of year of the minimum value."""
+ i = da.argmin(dim="time")
+ out = da.time.dt.dayofyear.isel(time=i, drop=True)
+ return to_agg_units(out, da, "doymin")
+
+
+_xclim_ops = {"doymin": doymin, "doymax": doymax}
+
+
+def default_freq(**indexer) -> str:
+ """Return the default frequency."""
+ freq = "YS-JAN"
+ if indexer:
+ group, value = indexer.popitem()
+ if group == "season":
+ month = 12 # The "season" scheme is based on YS-DEC
+ elif group == "month":
+ month = np.take(value, 0)
+ elif group == "doy_bounds":
+ month = cftime.num2date(value[0] - 1, "days since 2004-01-01").month
+ elif group == "date_bounds":
+ month = int(value[0][:2])
+ else:
+ raise ValueError(f"Unknown group `{group}`.")
+ freq = "YS-" + _MONTH_ABBREVIATIONS[month]
+ return freq
+
+
+def get_op(op: str, constrain: Sequence[str] | None = None) -> Callable:
+ """Get python's comparing function according to its name of representation and validate allowed usage.
+
+ Accepted op string are keys and values of xclim.indices.generic.binary_ops.
+
+ Parameters
+ ----------
+ op : str
+ Operator.
+ constrain : sequence of str, optional
+ A tuple of allowed operators.
+ """
+ if op == "gteq":
+ warnings.warn(f"`{op}` is being renamed `ge` for compatibility.")
+ op = "ge"
+ if op == "lteq":
+ warnings.warn(f"`{op}` is being renamed `le` for compatibility.")
+ op = "le"
+
+ if op in binary_ops.keys():
+ binary_op = binary_ops[op]
+ elif op in binary_ops.values():
+ binary_op = op
+ else:
+ raise ValueError(f"Operation `{op}` not recognized.")
+
+ constraints = list()
+ if isinstance(constrain, (list, tuple, set)):
+ constraints.extend([binary_ops[c] for c in constrain])
+ constraints.extend(constrain)
+ elif isinstance(constrain, str):
+ constraints.extend([binary_ops[constrain], constrain])
+
+ if constrain:
+ if op not in constraints:
+ raise ValueError(f"Operation `{op}` not permitted for indice.")
+
+ return xr.core.ops.get_op(binary_op)
+
+
+def compare(
+ left: xr.DataArray,
+ op: str,
+ right: float | int | np.ndarray | xr.DataArray,
+ constrain: Sequence[str] | None = None,
+) -> xr.DataArray:
+ """Compare a dataArray to a threshold using given operator.
+
+ Parameters
+ ----------
+ left : xr.DataArray
+ A DatArray being evaluated against `right`.
+ op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Logical operator. e.g. arr > thresh.
+ right : float, int, np.ndarray, or xr.DataArray
+ A value or array-like being evaluated against left`.
+ constrain : sequence of str, optional
+ Optionally allowed conditions.
+
+ Returns
+ -------
+ xr.DataArray
+ Boolean mask of the comparison.
+ """
+ return get_op(op, constrain)(left, right)
+
+
+def threshold_count(
+ da: xr.DataArray,
+ op: str,
+ threshold: float | int | xr.DataArray,
+ freq: str,
+ constrain: Sequence[str] | None = None,
+) -> xr.DataArray:
+ """Count number of days where value is above or below threshold.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input data.
+ op : {">", "<", ">=", "<=", "gt", "lt", "ge", "le"}
+ Logical operator. e.g. arr > thresh.
+ threshold : Union[float, int]
+ Threshold value.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ constrain : sequence of str, optional
+ Optionally allowed conditions.
+
+ Returns
+ -------
+ xr.DataArray
+ The number of days meeting the constraints for each period.
+ """
+ if constrain is None:
+ constrain = (">", "<", ">=", "<=")
+
+ c = compare(da, op, threshold, constrain) * 1
+ return c.resample(time=freq).sum(dim="time")
+
+
+def domain_count(
+ da: xr.DataArray,
+ low: float | int | xr.DataArray,
+ high: float | int | xr.DataArray,
+ freq: str,
+) -> xr.DataArray:
+ """Count number of days where value is within low and high thresholds.
+
+ A value is counted if it is larger than `low`, and smaller or equal to `high`, i.e. in `]low, high]`.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input data.
+ low : scalar or DataArray
+ Minimum threshold value.
+ high : scalar or DataArray
+ Maximum threshold value.
+ freq : str
+ Resampling frequency defining the periods defined in :ref:`timeseries.resampling`.
+
+ Returns
+ -------
+ xr.DataArray
+ The number of days where value is within [low, high] for each period.
+ """
+ c = compare(da, ">", low) * compare(da, "<=", high) * 1
+ return c.resample(time=freq).sum(dim="time")
+
+
+def get_daily_events(
+ da: xr.DataArray,
+ threshold: float | int | xr.DataArray,
+ op: str,
+ constrain: Sequence[str] | None = None,
+) -> xr.DataArray:
+ """Return a 0/1 mask when a condition is True or False.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input data.
+ threshold : float
+ Threshold value.
+ op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Logical operator. e.g. arr > thresh.
+ constrain : sequence of str, optional
+ Optionally allowed conditions.
+
+ Notes
+ -----
+ The function returns:
+
+ - ``1`` where operator(da, da_value) is ``True``
+ - ``0`` where operator(da, da_value) is ``False``
+ - ``nan`` where da is ``nan``
+
+ Returns
+ -------
+ xr.DataArray
+ """
+ events = compare(da, op, threshold, constrain) * 1
+ events = events.where(~(np.isnan(da)))
+ events = events.rename("events")
+ return events
+
+
+# CF-INDEX-META Indices
+
+
+@harmonize_units(["low_data", "high_data", "threshold"])
+def count_level_crossings(
+ low_data: xr.DataArray,
+ high_data: xr.DataArray,
+ threshold: Quantified,
+ freq: str,
+ *,
+ op_low: str = "<",
+ op_high: str = ">=",
+) -> xr.DataArray:
+ """Calculate the number of times low_data is below threshold while high_data is above threshold.
+
+ First, the threshold is transformed to the same standard_name and units as the input data,
+ then the thresholding is performed, and finally, the number of occurrences is counted.
+
+ Parameters
+ ----------
+ low_data : xr.DataArray
+ Variable that must be under the threshold.
+ high_data : xr.DataArray
+ Variable that must be above the threshold.
+ threshold : Quantified
+ Threshold.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ op_low : {"<", "<=", "lt", "le"}
+ Comparison operator for low_data. Default: "<".
+ op_high : {">", ">=", "gt", "ge"}
+ Comparison operator for high_data. Default: ">=".
+
+ Returns
+ -------
+ xr.DataArray
+ """
+ # Convert units to low_data
+ lower = compare(low_data, op_low, threshold, constrain=("<", "<="))
+ higher = compare(high_data, op_high, threshold, constrain=(">", ">="))
+
+ out = (lower & higher).resample(time=freq).sum()
+ return to_agg_units(out, low_data, "count", dim="time")
+
+
+@harmonize_units(["data", "threshold"])
+def count_occurrences(
+ data: xr.DataArray,
+ threshold: Quantified,
+ freq: str,
+ op: str,
+ constrain: Sequence[str] | None = None,
+) -> xr.DataArray:
+ """Calculate the number of times some condition is met.
+
+ First, the threshold is transformed to the same standard_name and units as the input data.
+ Then the thresholding is performed as condition(data, threshold),
+ i.e. if condition is `<`, then this counts the number of times `data < threshold`.
+ Finally, count the number of occurrences when condition is met.
+
+ Parameters
+ ----------
+ data : xr.DataArray
+ An array.
+ threshold : Quantified
+ Threshold.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Logical operator. e.g. arr > thresh.
+ constrain : sequence of str, optional
+ Optionally allowed conditions.
+
+ Returns
+ -------
+ xr.DataArray
+ """
+ cond = compare(data, op, threshold, constrain)
+
+ out = cond.resample(time=freq).sum()
+ return to_agg_units(out, data, "count", dim="time")
+
+
+@harmonize_units(["data", "threshold"])
+def first_occurrence(
+ data: xr.DataArray,
+ threshold: Quantified,
+ freq: str,
+ op: str,
+ constrain: Sequence[str] | None = None,
+) -> xr.DataArray:
+ """Calculate the first time some condition is met.
+
+ First, the threshold is transformed to the same standard_name and units as the input data.
+ Then the thresholding is performed as condition(data, threshold), i.e. if condition is <, data < threshold.
+ Finally, locate the first occurrence when condition is met.
+
+ Parameters
+ ----------
+ data : xr.DataArray
+ Input data.
+ threshold : Quantified
+ Threshold.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Logical operator. e.g. arr > thresh.
+ constrain : sequence of str, optional
+ Optionally allowed conditions.
+
+ Returns
+ -------
+ xr.DataArray
+ """
+ cond = compare(data, op, threshold, constrain)
+
+ out = cond.resample(time=freq).map(
+ rl.first_run,
+ window=1,
+ dim="time",
+ coord="dayofyear",
+ )
+ out.attrs["units"] = ""
+ return out
+
+
+@harmonize_units(["data", "threshold"])
+def last_occurrence(
+ data: xr.DataArray,
+ threshold: Quantified,
+ freq: str,
+ op: str,
+ constrain: Sequence[str] | None = None,
+) -> xr.DataArray:
+ """Calculate the last time some condition is met.
+
+ First, the threshold is transformed to the same standard_name and units as the input data.
+ Then the thresholding is performed as condition(data, threshold), i.e. if condition is <, data < threshold.
+ Finally, locate the last occurrence when condition is met.
+
+ Parameters
+ ----------
+ data : xr.DataArray
+ Input data.
+ threshold : Quantified
+ Threshold.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Logical operator. e.g. arr > thresh.
+ constrain : sequence of str, optional
+ Optionally allowed conditions.
+
+ Returns
+ -------
+ xr.DataArray
+ """
+ cond = compare(data, op, threshold, constrain)
+
+ out = cond.resample(time=freq).map(
+ rl.last_run,
+ window=1,
+ dim="time",
+ coord="dayofyear",
+ )
+ out.attrs["units"] = ""
+ return out
+
+
+@harmonize_units(["data", "threshold"])
+def spell_length(
+ data: xr.DataArray, threshold: Quantified, reducer: str, freq: str, op: str
+) -> xr.DataArray:
+ """Calculate statistics on lengths of spells.
+
+ First, the threshold is transformed to the same standard_name and units as the input data.
+ Then the thresholding is performed as condition(data, threshold), i.e. if condition is <, data < threshold.
+ Then the spells are determined, and finally the statistics according to the specified reducer are calculated.
+
+ Parameters
+ ----------
+ data : xr.DataArray
+ Input data.
+ threshold : Quantified
+ Threshold.
+ reducer : {'max', 'min', 'mean', 'sum'}
+ Reducer.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Logical operator. e.g. arr > thresh.
+
+ Returns
+ -------
+ xr.DataArray
+ """
+ cond = compare(data, op, threshold)
+
+ out = cond.resample(time=freq).map(
+ rl.rle_statistics,
+ reducer=reducer,
+ window=1,
+ dim="time",
+ )
+ return to_agg_units(out, data, "count")
+
+
+def statistics(data: xr.DataArray, reducer: str, freq: str) -> xr.DataArray:
+ """Calculate a simple statistic of the data.
+
+ Parameters
+ ----------
+ data : xr.DataArray
+ Input data.
+ reducer : {'max', 'min', 'mean', 'sum'}
+ Reducer.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+
+ Returns
+ -------
+ xr.DataArray
+ """
+ out = getattr(data.resample(time=freq), reducer)()
+ out.attrs["units"] = data.attrs["units"]
+ return out
+
+
+@harmonize_units(["data", "threshold"])
+def thresholded_statistics(
+ data: xr.DataArray,
+ op: str,
+ threshold: Quantified,
+ reducer: str,
+ freq: str,
+ constrain: Sequence[str] | None = None,
+) -> xr.DataArray:
+ """Calculate a simple statistic of the data for which some condition is met.
+
+ First, the threshold is transformed to the same standard_name and units as the input data.
+ Then the thresholding is performed as condition(data, threshold), i.e. if condition is <, data < threshold.
+ Finally, the statistic is calculated for those data values that fulfill the condition.
+
+ Parameters
+ ----------
+ data : xr.DataArray
+ Input data.
+ op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Logical operator. e.g. arr > thresh.
+ threshold : Quantified
+ Threshold.
+ reducer : {'max', 'min', 'mean', 'sum'}
+ Reducer.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ constrain : sequence of str, optional
+ Optionally allowed conditions. Default: None.
+
+ Returns
+ -------
+ xr.DataArray
+ """
+ cond = compare(data, op, threshold, constrain)
+
+ out = getattr(data.where(cond).resample(time=freq), reducer)()
+ out.attrs["units"] = data.attrs["units"]
+ return out
+
+
+def aggregate_between_dates(
+ data: xr.DataArray,
+ start: xr.DataArray | DayOfYearStr,
+ end: xr.DataArray | DayOfYearStr,
+ op: str = "sum",
+ freq: str | None = None,
+) -> xr.DataArray:
+ """Aggregate the data over a period between start and end dates and apply the operator on the aggregated data.
+
+ Parameters
+ ----------
+ data : xr.DataArray
+ Data to aggregate between start and end dates.
+ start : xr.DataArray or DayOfYearStr
+ Start dates (as day-of-year) for the aggregation periods.
+ end : xr.DataArray or DayOfYearStr
+ End (as day-of-year) dates for the aggregation periods.
+ op : {'min', 'max', 'sum', 'mean', 'std'}
+ Operator.
+ freq : str, optional
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`. Default: `None`.
+
+ Returns
+ -------
+ xr.DataArray, [dimensionless]
+ Aggregated data between the start and end dates. If the end date is before the start date, returns np.nan.
+ If there is no start and/or end date, returns np.nan.
+ """
+
+ def _get_days(_bound, _group, _base_time):
+ """Get bound in number of days since base_time. Bound can be a days_since array or a DayOfYearStr."""
+ if isinstance(_bound, str):
+ b_i = rl.index_of_date(_group.time, _bound, max_idxs=1)
+ if not b_i.size > 0:
+ return None
+ return (_group.time.isel(time=b_i[0]) - _group.time.isel(time=0)).dt.days
+ if _base_time in _bound.time:
+ return _bound.sel(time=_base_time)
+ return None
+
+ if freq is None:
+ frequencies = []
+ for bound in [start, end]:
+ try:
+ frequencies.append(xr.infer_freq(bound.time))
+ except AttributeError:
+ frequencies.append(None)
+
+ good_freq = set(frequencies) - {None}
+
+ if len(good_freq) != 1:
+ raise ValueError(
+ f"Non-inferrable resampling frequency or inconsistent frequencies. Got start, end = {frequencies}."
+ " Please consider providing `freq` manually."
+ )
+ freq = good_freq.pop()
+
+ cal = get_calendar(data, dim="time")
+
+ if not isinstance(start, str):
+ start = start.convert_calendar(cal)
+ start.attrs["calendar"] = cal
+ start = doy_to_days_since(start)
+ if not isinstance(end, str):
+ end = end.convert_calendar(cal)
+ end.attrs["calendar"] = cal
+ end = doy_to_days_since(end)
+
+ out = []
+ for base_time, indexes in data.resample(time=freq).groups.items():
+ # get group slice
+ group = data.isel(time=indexes)
+
+ start_d = _get_days(start, group, base_time)
+ end_d = _get_days(end, group, base_time)
+
+ # convert bounds for this group
+ if start_d is not None and end_d is not None:
+ days = (group.time - base_time).dt.days
+ days[days < 0] = np.nan
+
+ masked = group.where((days >= start_d) & (days <= end_d - 1))
+ res = getattr(masked, op)(dim="time", skipna=True)
+ res = xr.where(
+ ((start_d > end_d) | (start_d.isnull()) | (end_d.isnull())), np.nan, res
+ )
+ # Re-add the time dimension with the period's base time.
+ res = res.expand_dims(time=[base_time])
+ out.append(res)
+ else:
+ # Get an array with the good shape, put nans and add the new time.
+ res = (group.isel(time=0) * np.nan).expand_dims(time=[base_time])
+ out.append(res)
+ continue
+
+ return xr.concat(out, dim="time")
+
+
+@harmonize_units(["data", "threshold"])
+def cumulative_difference(
+ data: xr.DataArray, threshold: Quantified, op: str, freq: str | None = None
+) -> xr.DataArray:
+ """Calculate the cumulative difference below/above a given value threshold.
+
+ Parameters
+ ----------
+ data : xr.DataArray
+ Data for which to determine the cumulative difference.
+ threshold : Quantified
+ The value threshold.
+ op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le"}
+ Logical operator. e.g. arr > thresh.
+ freq : str, optional
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ If `None`, no resampling is performed. Default: `None`.
+
+ Returns
+ -------
+ xr.DataArray
+ """
+ if op in ["<", "<=", "lt", "le"]:
+ diff = (threshold - data).clip(0)
+ elif op in [">", ">=", "gt", "ge"]:
+ diff = (data - threshold).clip(0)
+ else:
+ raise NotImplementedError(f"Condition not supported: '{op}'.")
+
+ if freq is not None:
+ diff = diff.resample(time=freq).sum(dim="time")
+
+ return to_agg_units(diff, data, op="integral")
+
+
+@harmonize_units(["data", "threshold"])
+def first_day_threshold_reached(
+ data: xr.DataArray,
+ *,
+ threshold: Quantified,
+ op: str,
+ after_date: DayOfYearStr,
+ window: int = 1,
+ freq: str = "YS",
+ constrain: Sequence[str] | None = None,
+) -> xr.DataArray:
+ r"""First day of values exceeding threshold.
+
+ Returns first day of period where values reach or exceed a threshold over a given number of days,
+ limited to a starting calendar date.
+
+ Parameters
+ ----------
+ data : xarray.DataArray
+ Dataset being evaluated.
+ threshold : str
+ Threshold on which to base evaluation.
+ op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}
+ Logical operator. e.g. arr > thresh.
+ after_date : str
+ Date of the year after which to look for the first event. Should have the format '%m-%d'.
+ window : int
+ Minimum number of days with values above threshold needed for evaluation. Default: 1.
+ freq : str
+ Resampling frequency defining the periods as defined in :ref:`timeseries.resampling`.
+ Default: "YS".
+ constrain : sequence of str, optional
+ Optionally allowed conditions.
+
+ Returns
+ -------
+ xarray.DataArray, [dimensionless]
+ Day of the year when value reaches or exceeds a threshold over a given number of days for the first time.
+ If there is no such day, returns np.nan.
+ """
+ cond = compare(data, op, threshold, constrain=constrain)
+
+ out: xarray.DataArray = cond.resample(time=freq).map(
+ rl.first_run_after_date,
+ window=window,
+ date=after_date,
+ dim="time",
+ coord="dayofyear",
+ )
+ out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(data))
+ return out
+
+
+def _get_zone_bins(
+ zone_min: Quantity,
+ zone_max: Quantity,
+ zone_step: Quantity,
+):
+ """Bin boundary values as defined by zone parameters.
+
+ Parameters
+ ----------
+ zone_min : Quantity
+ Left boundary of the first zone
+ zone_max : Quantity
+ Right boundary of the last zone
+ zone_step: Quantity
+ Size of zones
+
+ Returns
+ -------
+ xarray.DataArray, [units of `zone_step`]
+ Array of values corresponding to each zone: [zone_min, zone_min+step, ..., zone_max]
+ """
+ units = pint2str(str2pint(zone_step))
+ mn, mx, step = (
+ convert_units_to(str2pint(z), units) for z in [zone_min, zone_max, zone_step]
+ )
+ bins = np.arange(mn, mx + step, step)
+ if (mx - mn) % step != 0:
+ warnings.warn(
+ "`zone_max` - `zone_min` is not an integer multiple of `zone_step`. Last zone will be smaller."
+ )
+ bins[-1] = mx
+ return xr.DataArray(bins, attrs={"units": units})
+
+
+def get_zones(
+ da: xr.DataArray,
+ zone_min: Quantity | None = None,
+ zone_max: Quantity | None = None,
+ zone_step: Quantity | None = None,
+ bins: xr.DataArray | list[Quantity] | None = None,
+ exclude_boundary_zones: bool = True,
+ close_last_zone_right_boundary: bool = True,
+) -> xr.DataArray:
+ r"""Divide data into zones and attribute a zone coordinate to each input value.
+
+ Divide values into zones corresponding to bins of width zone_step beginning at zone_min and ending at zone_max.
+ Bins are inclusive on the left values and exclusive on the right values.
+
+ Parameters
+ ----------
+ da : xarray.DataArray
+ Input data
+ zone_min : Quantity | None
+ Left boundary of the first zone
+ zone_max : Quantity | None
+ Right boundary of the last zone
+ zone_step: Quantity | None
+ Size of zones
+ bins : xr.DataArray | list[Quantity] | None
+ Zones to be used, either as a DataArray with appropriate units or a list of Quantity
+ exclude_boundary_zones : Bool
+ Determines whether a zone value is attributed for values in ]`-np.inf`, `zone_min`[ and [`zone_max`, `np.inf`\ [.
+ close_last_zone_right_boundary : Bool
+ Determines if the right boundary of the last zone is closed.
+
+ Returns
+ -------
+ xarray.DataArray, [dimensionless]
+ Zone index for each value in `da`. Zones are returned as an integer range, starting from `0`
+ """
+ # Check compatibility of arguments
+ zone_params = np.array([zone_min, zone_max, zone_step])
+ if bins is None:
+ if (zone_params == [None] * len(zone_params)).any():
+ raise ValueError(
+ "`bins` is `None` as well as some or all of [`zone_min`, `zone_max`, `zone_step`]. "
+ "Expected defined parameters in one of these cases."
+ )
+ elif set(zone_params) != {None}:
+ warnings.warn(
+ "Expected either `bins` or [`zone_min`, `zone_max`, `zone_step`], got both. "
+ "`bins` will be used."
+ )
+
+ # Get zone bins (if necessary)
+ bins = bins if bins is not None else _get_zone_bins(zone_min, zone_max, zone_step)
+ if isinstance(bins, list):
+ bins = sorted([convert_units_to(b, da) for b in bins])
+ else:
+ bins = convert_units_to(bins, da)
+
+ def _get_zone(_da):
+ return np.digitize(_da, bins) - 1
+
+ zones = xr.apply_ufunc(_get_zone, da, dask="parallelized")
+
+ if close_last_zone_right_boundary:
+ zones = zones.where(da != bins[-1], _get_zone(bins[-2]))
+ if exclude_boundary_zones:
+ zones = zones.where(
+ (zones != _get_zone(bins[0] - 1)) & (zones != _get_zone(bins[-1]))
+ )
+
+ return zones
+
+
+def detrend(
+ ds: xr.DataArray | xr.Dataset, dim="time", deg=1
+) -> xr.DataArray | xr.Dataset:
+ """Detrend data along a given dimension computing a polynomial trend of a given order.
+
+ Parameters
+ ----------
+ ds : xr.Dataset or xr.DataArray
+ The data to detrend. If a Dataset, detrending is done on all data variables.
+ dim : str
+ Dimension along which to compute the trend.
+ deg : int
+ Degree of the polynomial to fit.
+
+ Returns
+ -------
+ xr.Dataset or xr.DataArray
+ Same as `ds`, but with its trend removed (subtracted).
+ """
+ if isinstance(ds, xr.Dataset):
+ return ds.map(detrend, keep_attrs=False, dim=dim, deg=deg)
+ # is a DataArray
+ # detrend along a single dimension
+ coeff = ds.polyfit(dim=dim, deg=deg)
+ trend = xr.polyval(ds[dim], coeff.polyfit_coefficients)
+ with xr.set_options(keep_attrs=True):
+ return ds - trend
diff --git a/src/xsdba/xclim_submodules/run_length.py b/src/xsdba/xclim_submodules/run_length.py
new file mode 100644
index 0000000..e84704c
--- /dev/null
+++ b/src/xsdba/xclim_submodules/run_length.py
@@ -0,0 +1,1538 @@
+"""
+Run-Length Algorithms Submodule
+===============================
+
+Computation of statistics on runs of True values in boolean arrays.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from datetime import datetime
+from warnings import warn
+
+import numpy as np
+import xarray as xr
+from numba import njit
+from xarray.core.utils import get_temp_dimname
+
+from xsdba.base import uses_dask
+from xsdba.options import OPTIONS, RUN_LENGTH_UFUNC
+from xsdba.typing import DateStr, DayOfYearStr
+
+npts_opt = 9000
+"""
+Arrays with less than this number of data points per slice will trigger
+the use of the ufunc version of run lengths algorithms.
+"""
+# XC: all copied from xc
+
+
+def use_ufunc(
+ ufunc_1dim: bool | str,
+ da: xr.DataArray,
+ dim: str = "time",
+ freq: str | None = None,
+ index: str = "first",
+) -> bool:
+ """Return whether the ufunc version of run length algorithms should be used with this DataArray or not.
+
+ If ufunc_1dim is 'from_context', the parameter is read from xclim's global (or context) options.
+ If it is 'auto', this returns False for dask-backed array and for arrays with more than :py:const:`npts_opt`
+ points per slice along `dim`.
+
+ Parameters
+ ----------
+ ufunc_1dim : {'from_context', 'auto', True, False}
+ The method for handling the ufunc parameters.
+ da : xr.DataArray
+ Input array.
+ dim : str
+ The dimension along which to find runs.
+ freq : str
+ Resampling frequency.
+ index : {'first', 'last'}
+ If 'first' (default), the run length is indexed with the first element in the run.
+ If 'last', with the last element in the run.
+
+ Returns
+ -------
+ bool
+ If ufunc_1dim is "auto", returns True if the array is on dask or too large.
+ Otherwise, returns ufunc_1dim.
+ """
+ if ufunc_1dim is True and freq is not None:
+ raise ValueError(
+ "Resampling after run length operations is not implemented for 1d method"
+ )
+
+ if ufunc_1dim == "from_context":
+ ufunc_1dim = OPTIONS[RUN_LENGTH_UFUNC]
+
+ if ufunc_1dim == "auto":
+ ufunc_1dim = not uses_dask(da) and (da.size // da[dim].size) < npts_opt
+ # If resampling after run length is set up for the computation, the 1d method is not implemented
+ # Unless ufunc_1dim is specifically set to False (in which case we flag an error above),
+ # we simply forbid this possibility.
+ return (index == "first") and (ufunc_1dim) and (freq is None)
+
+
+def resample_and_rl(
+ da: xr.DataArray,
+ resample_before_rl: bool,
+ compute,
+ *args,
+ freq: str,
+ dim: str = "time",
+ **kwargs,
+) -> xr.DataArray:
+ """Wrap run length algorithms to control if resampling occurs before or after the algorithms.
+
+ Parameters
+ ----------
+ da: xr.DataArray
+ N-dimensional array (boolean).
+ resample_before_rl : bool
+ Determines whether if input arrays of runs `da` should be separated in period before
+ or after the run length algorithms are applied.
+ compute
+ Run length function to apply
+ args
+ Positional arguments needed in `compute`.
+ dim: str
+ The dimension along which to find runs.
+ freq : str
+ Resampling frequency.
+ kwargs
+ Keyword arguments needed in `compute`.
+
+ Returns
+ -------
+ xr.DataArray
+ Output of compute resampled according to frequency {freq}.
+ """
+ if resample_before_rl:
+ out = da.resample({dim: freq}).map(
+ compute, args=args, freq=None, dim=dim, **kwargs
+ )
+ else:
+ out = compute(da, *args, dim=dim, freq=freq, **kwargs)
+ return out
+
+
+def _cumsum_reset_on_zero(
+ da: xr.DataArray,
+ dim: str = "time",
+ index: str = "last",
+) -> xr.DataArray:
+ """Compute the cumulative sum for each series of numbers separated by zero.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input array.
+ dim : str
+ Dimension name along which the cumulative sum is taken.
+ index : {'first', 'last'}
+ If 'first', the largest value of the cumulative sum is indexed with the first element in the run.
+ If 'last'(default), with the last element in the run.
+
+ Returns
+ -------
+ xr.DataArray
+ An array with cumulative sums.
+ """
+ if index == "first":
+ da = da[{dim: slice(None, None, -1)}]
+
+ # Example: da == 100110111 -> cs_s == 100120123
+ cs = da.cumsum(dim=dim) # cumulative sum e.g. 111233456
+ cs2 = cs.where(da == 0) # keep only numbers at positions of zeroes e.g. N11NN3NNN
+ cs2[{dim: 0}] = 0 # put a zero in front e.g. 011NN3NNN
+ cs2 = cs2.ffill(dim=dim) # e.g. 011113333
+ out = cs - cs2
+
+ if index == "first":
+ out = out[{dim: slice(None, None, -1)}]
+
+ return out
+
+
+# TODO: Check if rle would be more performant with ffill/bfill instead of two times [{dim: slice(None, None, -1)}]
+def rle(
+ da: xr.DataArray,
+ dim: str = "time",
+ index: str = "first",
+) -> xr.DataArray:
+ """Generate basic run length function.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input array.
+ dim : str
+ Dimension name.
+ index : {'first', 'last'}
+ If 'first' (default), the run length is indexed with the first element in the run.
+ If 'last', with the last element in the run.
+
+ Returns
+ -------
+ xr.DataArray
+ Values are 0 where da is False (out of runs).
+ """
+ da = da.astype(int)
+
+ # "first" case: Algorithm is applied on inverted array and output is inverted back
+ if index == "first":
+ da = da[{dim: slice(None, None, -1)}]
+
+ # Get cumulative sum for each series of 1, e.g. da == 100110111 -> cs_s == 100120123
+ cs_s = _cumsum_reset_on_zero(da, dim)
+
+ # Keep total length of each series (and also keep 0's), e.g. 100120123 -> 100N20NN3
+ # Keep numbers with a 0 to the right and also the last number
+ cs_s = cs_s.where(da.shift({dim: -1}, fill_value=0) == 0)
+ out = cs_s.where(da == 1, 0) # Reinsert 0's at their original place
+
+ # Inverting back if needed e.g. 100N20NN3 -> 3NN02N001. This is the output of
+ # `rle` for 111011001 with index == "first"
+ if index == "first":
+ out = out[{dim: slice(None, None, -1)}]
+
+ return out
+
+
+def rle_statistics(
+ da: xr.DataArray,
+ reducer: str,
+ window: int,
+ dim: str = "time",
+ freq: str | None = None,
+ ufunc_1dim: str | bool = "from_context",
+ index: str = "first",
+) -> xr.DataArray:
+ """Return the length of consecutive run of True values, according to a reducing operator.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ N-dimensional array (boolean).
+ reducer : str
+ Name of the reducing function.
+ window : int
+ Minimal length of consecutive runs to be included in the statistics.
+ dim : str
+ Dimension along which to calculate consecutive run; Default: 'time'.
+ freq : str
+ Resampling frequency.
+ ufunc_1dim : Union[str, bool]
+ Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal
+ usage based on number of data points. Using 1D_ufunc=True is typically more efficient
+ for DataArray with a small number of grid points.
+ It can be modified globally through the "run_length_ufunc" global option.
+ index : {'first', 'last'}
+ If 'first' (default), the run length is indexed with the first element in the run.
+ If 'last', with the last element in the run.
+
+ Returns
+ -------
+ xr.DataArray, [int]
+ Length of runs of True values along dimension, according to the reducing function (float)
+ If there are no runs (but the data is valid), returns 0.
+ """
+ ufunc_1dim = use_ufunc(ufunc_1dim, da, dim=dim, index=index, freq=freq)
+ if ufunc_1dim:
+ rl_stat = statistics_run_ufunc(da, reducer, window, dim)
+ else:
+ d = rle(da, dim=dim, index=index)
+
+ def get_rl_stat(d):
+ rl_stat = getattr(d.where(d >= window), reducer)(dim=dim)
+ rl_stat = xr.where((d.isnull() | (d < window)).all(dim=dim), 0, rl_stat)
+ return rl_stat
+
+ if freq is None:
+ rl_stat = get_rl_stat(d)
+ else:
+ rl_stat = d.resample({dim: freq}).map(get_rl_stat)
+
+ return rl_stat
+
+
+def longest_run(
+ da: xr.DataArray,
+ dim: str = "time",
+ freq: str | None = None,
+ ufunc_1dim: str | bool = "from_context",
+ index: str = "first",
+) -> xr.DataArray:
+ """Return the length of the longest consecutive run of True values.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ N-dimensional array (boolean).
+ dim : str
+ Dimension along which to calculate consecutive run; Default: 'time'.
+ freq : str
+ Resampling frequency.
+ ufunc_1dim : Union[str, bool]
+ Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal
+ usage based on number of data points. Using 1D_ufunc=True is typically more efficient
+ for DataArray with a small number of grid points.
+ It can be modified globally through the "run_length_ufunc" global option.
+ index : {'first', 'last'}
+ If 'first', the run length is indexed with the first element in the run.
+ If 'last', with the last element in the run.
+
+ Returns
+ -------
+ xr.DataArray, [int]
+ Length of the longest run of True values along dimension (int).
+ """
+ return rle_statistics(
+ da,
+ reducer="max",
+ window=1,
+ dim=dim,
+ freq=freq,
+ ufunc_1dim=ufunc_1dim,
+ index=index,
+ )
+
+
+def windowed_run_events(
+ da: xr.DataArray,
+ window: int,
+ dim: str = "time",
+ freq: str | None = None,
+ ufunc_1dim: str | bool = "from_context",
+ index: str = "first",
+) -> xr.DataArray:
+ """Return the number of runs of a minimum length.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum run length.
+ When equal to 1, an optimized version of the algorithm is used.
+ dim : str
+ Dimension along which to calculate consecutive run (default: 'time').
+ freq : str
+ Resampling frequency.
+ ufunc_1dim : Union[str, bool]
+ Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal
+ usage based on number of data points. Using 1D_ufunc=True is typically more efficient
+ for DataArray with a small number of grid points.
+ Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option.
+ index : {'first', 'last'}
+ If 'first', the run length is indexed with the first element in the run.
+ If 'last', with the last element in the run.
+
+ Returns
+ -------
+ xr.DataArray, [int]
+ Number of distinct runs of a minimum length (int).
+ """
+ ufunc_1dim = use_ufunc(ufunc_1dim, da, dim=dim, index=index, freq=freq)
+
+ if ufunc_1dim:
+ out = windowed_run_events_ufunc(da, window, dim)
+
+ else:
+ if window == 1:
+ shift = 1 * (index == "first") + -1 * (index == "last")
+ d = xr.where(da.shift({dim: shift}, fill_value=0) == 0, 1, 0)
+ d = d.where(da == 1, 0)
+ else:
+ d = rle(da, dim=dim, index=index)
+ d = xr.where(d >= window, 1, 0)
+ if freq is not None:
+ d = d.resample({dim: freq})
+ out = d.sum(dim=dim)
+
+ return out
+
+
+def windowed_run_count(
+ da: xr.DataArray,
+ window: int,
+ dim: str = "time",
+ freq: str | None = None,
+ ufunc_1dim: str | bool = "from_context",
+ index: str = "first",
+) -> xr.DataArray:
+ """Return the number of consecutive true values in array for runs at least as long as given duration.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum run length.
+ When equal to 1, an optimized version of the algorithm is used.
+ dim : str
+ Dimension along which to calculate consecutive run (default: 'time').
+ freq : str
+ Resampling frequency.
+ ufunc_1dim : Union[str, bool]
+ Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal
+ usage based on number of data points. Using 1D_ufunc=True is typically more efficient
+ for DataArray with a small number of grid points.
+ Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option.
+ index : {'first', 'last'}
+ If 'first', the run length is indexed with the first element in the run.
+ If 'last', with the last element in the run.
+
+ Returns
+ -------
+ xr.DataArray, [int]
+ Total number of `True` values part of a consecutive runs of at least `window` long.
+ """
+ ufunc_1dim = use_ufunc(ufunc_1dim, da, dim=dim, index=index, freq=freq)
+
+ if ufunc_1dim:
+ out = windowed_run_count_ufunc(da, window, dim)
+
+ elif window == 1 and freq is None:
+ out = da.sum(dim=dim)
+
+ else:
+ d = rle(da, dim=dim, index=index)
+ d = d.where(d >= window, 0)
+ if freq is not None:
+ d = d.resample({dim: freq})
+ out = d.sum(dim=dim)
+
+ return out
+
+
+def _boundary_run(
+ da: xr.DataArray,
+ window: int,
+ dim: str,
+ freq: str | None,
+ coord: str | bool | None,
+ ufunc_1dim: str | bool,
+ position: str,
+) -> xr.DataArray:
+ """Return the index of the first item of the first or last run of at least a given length.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum duration of consecutive run to accumulate values.
+ When equal to 1, an optimized version of the algorithm is used.
+ dim : str
+ Dimension along which to calculate consecutive run.
+ freq : str
+ Resampling frequency.
+ coord : Optional[str]
+ If not False, the function returns values along `dim` instead of indexes.
+ If `dim` has a datetime dtype, `coord` can also be a str of the name of the
+ DateTimeAccessor object to use (ex: 'dayofyear').
+ ufunc_1dim : Union[str, bool]
+ Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal
+ usage based on number of data points. Using 1D_ufunc=True is typically more efficient
+ for DataArray with a small number of grid points.
+ Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option.
+ position : {"first", "last"}
+ Determines if the algorithm finds the "first" or "last" run
+
+ Returns
+ -------
+ xr.DataArray
+ Index (or coordinate if `coord` is not False) of first item in first (last) valid run.
+ Returns np.nan if there are no valid runs.
+ """
+
+ def coord_transform(out, da):
+ """Transforms indexes to coordinates if needed, and drops obsolete dim."""
+ if coord:
+ crd = da[dim]
+ if isinstance(coord, str):
+ crd = getattr(crd.dt, coord)
+
+ out = lazy_indexing(crd, out)
+
+ if dim in out.coords:
+ out = out.drop_vars(dim)
+ return out
+
+ # general method to get indices (or coords) of first run
+ def find_boundary_run(runs, position):
+ if position == "last":
+ runs = runs[{dim: slice(None, None, -1)}]
+ dmax_ind = runs.argmax(dim=dim)
+ # If there are no runs, dmax_ind will be 0: We must replace this with NaN
+ out = dmax_ind.where(dmax_ind != runs.argmin(dim=dim))
+ if position == "last":
+ out = runs[dim].size - out - 1
+ runs = runs[{dim: slice(None, None, -1)}]
+ out = coord_transform(out, runs)
+ return out
+
+ ufunc_1dim = use_ufunc(ufunc_1dim, da, dim=dim, freq=freq)
+
+ da = da.fillna(0) # We expect a boolean array, but there could be NaNs nonetheless
+ if window == 1:
+ if freq is not None:
+ out = da.resample({dim: freq}).map(find_boundary_run, position=position)
+ else:
+ out = find_boundary_run(da, position)
+
+ elif ufunc_1dim:
+ if position == "last":
+ da = da[{dim: slice(None, None, -1)}]
+ out = first_run_ufunc(x=da, window=window, dim=dim)
+ if position == "last" and not coord:
+ out = da[dim].size - out - 1
+ da = da[{dim: slice(None, None, -1)}]
+ out = coord_transform(out, da)
+
+ else:
+ # _cusum_reset_on_zero() is an intermediate step in rle, which is sufficient here
+ d = _cumsum_reset_on_zero(da, dim=dim, index=position)
+ d = xr.where(d >= window, 1, 0)
+ # for "first" run, return "first" element in the run (and conversely for "last" run)
+ if freq is not None:
+ out = d.resample({dim: freq}).map(find_boundary_run, position=position)
+ else:
+ out = find_boundary_run(d, position)
+
+ return out
+
+
+def first_run(
+ da: xr.DataArray,
+ window: int,
+ dim: str = "time",
+ freq: str | None = None,
+ coord: str | bool | None = False,
+ ufunc_1dim: str | bool = "from_context",
+) -> xr.DataArray:
+ """Return the index of the first item of the first run of at least a given length.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum duration of consecutive run to accumulate values.
+ When equal to 1, an optimized version of the algorithm is used.
+ dim : str
+ Dimension along which to calculate consecutive run (default: 'time').
+ freq : str
+ Resampling frequency.
+ coord : Optional[str]
+ If not False, the function returns values along `dim` instead of indexes.
+ If `dim` has a datetime dtype, `coord` can also be a str of the name of the
+ DateTimeAccessor object to use (ex: 'dayofyear').
+ ufunc_1dim : Union[str, bool]
+ Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal
+ usage based on number of data points. Using 1D_ufunc=True is typically more efficient
+ for DataArray with a small number of grid points.
+ Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option.
+
+ Returns
+ -------
+ xr.DataArray
+ Index (or coordinate if `coord` is not False) of first item in first valid run.
+ Returns np.nan if there are no valid runs.
+ """
+ out = _boundary_run(
+ da,
+ window=window,
+ dim=dim,
+ freq=freq,
+ coord=coord,
+ ufunc_1dim=ufunc_1dim,
+ position="first",
+ )
+ return out
+
+
+def last_run(
+ da: xr.DataArray,
+ window: int,
+ dim: str = "time",
+ freq: str | None = None,
+ coord: str | bool | None = False,
+ ufunc_1dim: str | bool = "from_context",
+) -> xr.DataArray:
+ """Return the index of the last item of the last run of at least a given length.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum duration of consecutive run to accumulate values.
+ When equal to 1, an optimized version of the algorithm is used.
+ dim : str
+ Dimension along which to calculate consecutive run (default: 'time').
+ freq : str
+ Resampling frequency.
+ coord : Optional[str]
+ If not False, the function returns values along `dim` instead of indexes.
+ If `dim` has a datetime dtype, `coord` can also be a str of the name of the
+ DateTimeAccessor object to use (ex: 'dayofyear').
+ ufunc_1dim : Union[str, bool]
+ Use the 1d 'ufunc' version of this function : default (auto) will attempt to select optimal
+ usage based on number of data points. Using `1D_ufunc=True` is typically more efficient
+ for a DataArray with a small number of grid points.
+ Ignored when `window=1`. It can be modified globally through the "run_length_ufunc" global option.
+
+ Returns
+ -------
+ xr.DataArray
+ Index (or coordinate if `coord` is not False) of last item in last valid run.
+ Returns np.nan if there are no valid runs.
+ """
+ out = _boundary_run(
+ da,
+ window=window,
+ dim=dim,
+ freq=freq,
+ coord=coord,
+ ufunc_1dim=ufunc_1dim,
+ position="last",
+ )
+ return out
+
+
+# TODO: Add window arg
+# TODO: Inverse window arg to tolerate holes?
+def run_bounds(mask: xr.DataArray, dim: str = "time", coord: bool | str = True):
+ """Return the start and end dates of boolean runs along a dimension.
+
+ Parameters
+ ----------
+ mask : xr.DataArray
+ Boolean array.
+ dim : str
+ Dimension along which to look for runs.
+ coord : bool or str
+ If `True`, return values of the coordinate, if a string, returns values from `dim.dt.`.
+ If `False`, return indexes.
+
+ Returns
+ -------
+ xr.DataArray
+ With ``dim`` reduced to "events" and "bounds". The events dim is as long as needed, padded with NaN or NaT.
+ """
+ if uses_dask(mask):
+ raise NotImplementedError(
+ "Dask arrays not supported as we can't know the final event number before computing."
+ )
+
+ diff = xr.concat(
+ (mask.isel({dim: [0]}).astype(int), mask.astype(int).diff(dim)), dim
+ )
+
+ nstarts = (diff == 1).sum(dim).max().item()
+
+ def _get_indices(arr, *, N):
+ out = np.full((N,), np.nan, dtype=float)
+ inds = np.where(arr)[0]
+ out[: len(inds)] = inds
+ return out
+
+ starts = xr.apply_ufunc(
+ _get_indices,
+ diff == 1,
+ input_core_dims=[[dim]],
+ output_core_dims=[["events"]],
+ kwargs={"N": nstarts},
+ vectorize=True,
+ )
+
+ ends = xr.apply_ufunc(
+ _get_indices,
+ diff == -1,
+ input_core_dims=[[dim]],
+ output_core_dims=[["events"]],
+ kwargs={"N": nstarts},
+ vectorize=True,
+ )
+
+ if coord:
+ crd = mask[dim]
+ if isinstance(coord, str):
+ crd = getattr(crd.dt, coord)
+
+ starts = lazy_indexing(crd, starts)
+ ends = lazy_indexing(crd, ends)
+ return xr.concat((starts, ends), "bounds")
+
+
+def keep_longest_run(
+ da: xr.DataArray, dim: str = "time", freq: str | None = None
+) -> xr.DataArray:
+ """Keep the longest run along a dimension.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Boolean array.
+ dim : str
+ Dimension along which to check for the longest run.
+ freq : str
+ Resampling frequency.
+
+ Returns
+ -------
+ xr.DataArray, [bool]
+ Boolean array similar to da but with only one run, the (first) longest.
+ """
+ # Get run lengths
+ rls = rle(da, dim)
+
+ def get_out(rls):
+ out = xr.where(
+ # Construct an integer array and find the max
+ rls[dim].copy(data=np.arange(rls[dim].size)) == rls.argmax(dim),
+ rls + 1, # Add one to the First longest run
+ rls,
+ )
+ out = out.ffill(dim) == out.max(dim)
+ return out
+
+ if freq is not None:
+ out = rls.resample({dim: freq}).map(get_out)
+ else:
+ out = get_out(rls)
+
+ return da.copy(data=out.transpose(*da.dims).data)
+
+
+def extract_events(
+ da_start: xr.DataArray,
+ window_start: int,
+ da_stop: xr.DataArray,
+ window_stop: int,
+ dim: str = "time",
+) -> xr.DataArray:
+ """Extract events, i.e. runs whose starting and stopping points are defined through run length conditions.
+
+ Parameters
+ ----------
+ da_start : xr.DataArray
+ Input array where run sequences are searched to define the start points in the main runs
+ window_start: int,
+ Number of True (1) values needed to start a run in `da_start`
+ da_stop : xr.DataArray
+ Input array where run sequences are searched to define the stop points in the main runs
+ window_stop: int,
+ Number of True (1) values needed to start a run in `da_stop`
+ dim : str
+ Dimension name.
+
+ Returns
+ -------
+ xr.DataArray
+ Output array with 1's when in a run sequence and with 0's elsewhere.
+
+ Notes
+ -----
+ A season (as defined in ``season``) could be considered as an event with `window_stop == window_start` and `da_stop == 1 - da_start`,
+ although it has more constraints on when to start and stop a run through the `date` argument.
+ """
+ da_start = da_start.astype(int).fillna(0)
+ da_stop = da_stop.astype(int).fillna(0)
+
+ start_runs = _cumsum_reset_on_zero(da_start, dim=dim, index="first")
+ stop_runs = _cumsum_reset_on_zero(da_stop, dim=dim, index="first")
+ start_positions = xr.where(start_runs >= window_start, 1, np.NaN)
+ stop_positions = xr.where(stop_runs >= window_stop, 0, np.NaN)
+
+ # start positions (1) are f-filled until a stop position (0) is met
+ runs = stop_positions.combine_first(start_positions).ffill(dim=dim).fillna(0)
+
+ return runs
+
+
+def season(
+ da: xr.DataArray,
+ window: int,
+ date: DayOfYearStr | None = None,
+ dim: str = "time",
+ coord: str | bool | None = False,
+) -> xr.Dataset:
+ """Calculate the bounds of a season along a dimension.
+
+ A "season" is a run of True values that may include breaks under a given length (`window`).
+ The start is computed as the first run of `window` True values, then end as the first subsequent run
+ of `window` False values. If a date is passed, it must be included in the season.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum duration of consecutive values to start and end the season.
+ date : DayOfYearStr, optional
+ The date (in MM-DD format) that a run must include to be considered valid.
+ dim : str
+ Dimension along which to calculate consecutive run (default: 'time').
+ coord : Optional[str]
+ If not False, the function returns values along `dim` instead of indexes.
+ If `dim` has a datetime dtype, `coord` can also be a str of the name of the
+ DateTimeAccessor object to use (ex: 'dayofyear').
+
+ Returns
+ -------
+ xr.Dataset
+ "dim" is reduced to "season_bnds" with 2 elements : season start and season end, both indices of da[dim].
+
+ Notes
+ -----
+ The run can include holes of False or NaN values, so long as they do not exceed the window size.
+
+ If a date is given, the season start and end are forced to be on each side of this date. This means that
+ even if the "real" season has been over for a long time, this is the date used in the length calculation.
+ Example : Length of the "warm season", where T > 25°C, with date = 1st August. Let's say the temperature is over
+ 25 for all June, but July and august have very cold temperatures. Instead of returning 30 days (June), the function
+ will return 61 days (July + June).
+ """
+ beg = first_run(da, window=window, dim=dim)
+ # Invert the condition and mask all values after beginning
+ # we fillna(0) as so to differentiate series with no runs and all-nan series
+ not_da = (~da).where(da[dim].copy(data=np.arange(da[dim].size)) >= beg.fillna(0))
+
+ # Mask also values after "date"
+ mid_idx = index_of_date(da[dim], date, max_idxs=1, default=0)
+ if mid_idx.size == 0:
+ # The date is not within the group. Happens at boundaries.
+ base = da.isel({dim: 0}) # To have the proper shape
+ beg = xr.full_like(base, np.nan, float).drop_vars(dim)
+ end = xr.full_like(base, np.nan, float).drop_vars(dim)
+ length = xr.full_like(base, np.nan, float).drop_vars(dim)
+ else:
+ if date is not None:
+ # If the beginning was after the mid date, both bounds are NaT.
+ valid_start = beg < mid_idx.squeeze()
+ else:
+ valid_start = True
+
+ not_da = not_da.where(da[dim] >= da[dim][mid_idx][0])
+ end = first_run(
+ not_da,
+ window=window,
+ dim=dim,
+ )
+ # If there was a beginning but no end, season goes to the end of the array
+ no_end = beg.notnull() & end.isnull()
+
+ # Length
+ length = end - beg
+
+ # No end: length is actually until the end of the array, so it is missing 1
+ length = xr.where(no_end, da[dim].size - beg, length)
+ # Where the beginning was before the mid-date, invalid.
+ length = length.where(valid_start)
+ # Where there were data points, but no season : put 0 length
+ length = xr.where(beg.isnull() & end.notnull(), 0, length)
+
+ # No end: end defaults to the last element (this differs from length, but heh)
+ end = xr.where(no_end, da[dim].size - 1, end)
+
+ # Where the beginning was before the mid-date
+ beg = beg.where(valid_start)
+ end = end.where(valid_start)
+
+ if coord:
+ crd = da[dim]
+ if isinstance(coord, str):
+ crd = getattr(crd.dt, coord)
+ coordstr = coord
+ else:
+ coordstr = dim
+ beg = lazy_indexing(crd, beg)
+ end = lazy_indexing(crd, end)
+ else:
+ coordstr = "index"
+
+ out = xr.Dataset({"start": beg, "end": end, "length": length})
+
+ out.start.attrs.update(
+ long_name="Start of the season.",
+ description=f"First {coordstr} of a run of at least {window} steps respecting the condition.",
+ )
+ out.end.attrs.update(
+ long_name="End of the season.",
+ description=f"First {coordstr} of a run of at least {window} "
+ "steps breaking the condition, starting after `start`.",
+ )
+ out.length.attrs.update(
+ long_name="Length of the season.",
+ description="Number of steps of the original series in the season, between 'start' and 'end'.",
+ )
+ return out
+
+
+def season_length(
+ da: xr.DataArray,
+ window: int,
+ date: DayOfYearStr | None = None,
+ dim: str = "time",
+) -> xr.DataArray:
+ """Return the length of the longest semi-consecutive run of True values (optionally including a given date).
+
+ A "season" is a run of True values that may include breaks under a given length (`window`).
+ The start is computed as the first run of `window` True values, then end as the first subsequent run
+ of `window` False values. If a date is passed, it must be included in the season.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum duration of consecutive values to start and end the season.
+ date : DayOfYearStr, optional
+ The date (in MM-DD format) that a run must include to be considered valid.
+ dim : str
+ Dimension along which to calculate consecutive run (default: 'time').
+
+ Returns
+ -------
+ xr.DataArray, [int]
+ Length of the longest run of True values along a given dimension (inclusive of a given date)
+ without breaks longer than a given length.
+
+ Notes
+ -----
+ The run can include holes of False or NaN values, so long as they do not exceed the window size.
+
+ If a date is given, the season start and end are forced to be on each side of this date. This means that
+ even if the "real" season has been over for a long time, this is the date used in the length calculation.
+ Example : Length of the "warm season", where T > 25°C, with date = 1st August. Let's say the temperature is over
+ 25 for all June, but July and august have very cold temperatures. Instead of returning 30 days (June), the function
+ will return 61 days (July + June).
+ """
+ seas = season(da, window, date, dim, coord=False)
+ return seas.length
+
+
+def run_end_after_date(
+ da: xr.DataArray,
+ window: int,
+ date: DayOfYearStr = "07-01",
+ dim: str = "time",
+ coord: bool | str | None = "dayofyear",
+) -> xr.DataArray:
+ """Return the index of the first item after the end of a run after a given date.
+
+ The run must begin before the date.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum duration of consecutive run to accumulate values.
+ date : str
+ The date after which to look for the end of a run.
+ dim : str
+ Dimension along which to calculate consecutive run (default: 'time').
+ coord : Optional[Union[bool, str]]
+ If not False, the function returns values along `dim` instead of indexes.
+ If `dim` has a datetime dtype, `coord` can also be a str of the name of the
+ DateTimeAccessor object to use (ex: 'dayofyear').
+
+ Returns
+ -------
+ xr.DataArray
+ Index (or coordinate if `coord` is not False) of last item in last valid run.
+ Returns np.nan if there are no valid runs.
+ """
+ mid_idx = index_of_date(da[dim], date, max_idxs=1, default=0)
+ if mid_idx.size == 0: # The date is not within the group. Happens at boundaries.
+ return xr.full_like(da.isel({dim: 0}), np.nan, float).drop_vars(dim)
+
+ end = first_run(
+ (~da).where(da[dim] >= da[dim][mid_idx][0]),
+ window=window,
+ dim=dim,
+ coord=coord,
+ )
+ beg = first_run(da.where(da[dim] < da[dim][mid_idx][0]), window=window, dim=dim)
+
+ if coord:
+ last = da[dim][-1]
+ if isinstance(coord, str):
+ last = getattr(last.dt, coord)
+ else:
+ last = da[dim].size - 1
+
+ end = xr.where(end.isnull() & beg.notnull(), last, end)
+ return end.where(beg.notnull()).drop_vars(dim, errors="ignore")
+
+
+def first_run_after_date(
+ da: xr.DataArray,
+ window: int,
+ date: DayOfYearStr | None = "07-01",
+ dim: str = "time",
+ coord: bool | str | None = "dayofyear",
+) -> xr.DataArray:
+ """Return the index of the first item of the first run after a given date.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum duration of consecutive run to accumulate values.
+ date : DayOfYearStr
+ The date after which to look for the run.
+ dim : str
+ Dimension along which to calculate consecutive run (default: 'time').
+ coord : Optional[Union[bool, str]]
+ If not False, the function returns values along `dim` instead of indexes.
+ If `dim` has a datetime dtype, `coord` can also be a str of the name of the
+ DateTimeAccessor object to use (ex: 'dayofyear').
+
+ Returns
+ -------
+ xr.DataArray
+ Index (or coordinate if `coord` is not False) of first item in the first valid run.
+ Returns np.nan if there are no valid runs.
+ """
+ mid_idx = index_of_date(da[dim], date, max_idxs=1, default=0)
+ if mid_idx.size == 0: # The date is not within the group. Happens at boundaries.
+ return xr.full_like(da.isel({dim: 0}), np.nan, float).drop_vars(dim)
+
+ return first_run(
+ da.where(da[dim] >= da[dim][mid_idx][0]),
+ window=window,
+ dim=dim,
+ coord=coord,
+ )
+
+
+def last_run_before_date(
+ da: xr.DataArray,
+ window: int,
+ date: DayOfYearStr = "07-01",
+ dim: str = "time",
+ coord: bool | str | None = "dayofyear",
+) -> xr.DataArray:
+ """Return the index of the last item of the last run before a given date.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input N-dimensional DataArray (boolean).
+ window : int
+ Minimum duration of consecutive run to accumulate values.
+ date : DayOfYearStr
+ The date before which to look for the last event.
+ dim : str
+ Dimension along which to calculate consecutive run (default: 'time').
+ coord : Optional[Union[bool, str]]
+ If not False, the function returns values along `dim` instead of indexes.
+ If `dim` has a datetime dtype, `coord` can also be a str of the name of the
+ DateTimeAccessor object to use (ex: 'dayofyear').
+
+ Returns
+ -------
+ xr.DataArray
+ Index (or coordinate if `coord` is not False) of last item in last valid run.
+ Returns np.nan if there are no valid runs.
+ """
+ mid_idx = index_of_date(da[dim], date, default=-1)
+
+ if mid_idx.size == 0: # The date is not within the group. Happens at boundaries.
+ return xr.full_like(da.isel({dim: 0}), np.nan, float).drop_vars(dim)
+
+ run = da.where(da[dim] <= da[dim][mid_idx][0])
+ return last_run(run, window=window, dim=dim, coord=coord)
+
+
+@njit
+def _rle_1d(ia):
+ y = ia[1:] != ia[:-1] # pairwise unequal (string safe)
+ i = np.append(np.nonzero(y)[0], ia.size - 1) # must include last element position
+ rl = np.diff(np.append(-1, i)) # run lengths
+ pos = np.cumsum(np.append(0, rl))[:-1] # positions
+ return ia[i], rl, pos
+
+
+def rle_1d(
+ arr: int | float | bool | Sequence[int | float | bool],
+) -> tuple[np.array, np.array, np.array]:
+ """Return the length, starting position and value of consecutive identical values.
+
+ Parameters
+ ----------
+ arr : Sequence[Union[int, float, bool]]
+ Array of values to be parsed.
+
+ Returns
+ -------
+ values : np.array
+ The values taken by arr over each run.
+ run lengths : np.array
+ The length of each run.
+ start position : np.array
+ The starting index of each run.
+
+ Examples
+ --------
+ >>> from xclim.indices.run_length import rle_1d
+ >>> a = [1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]
+ >>> rle_1d(a)
+ (array([1, 2, 3]), array([2, 4, 6]), array([0, 2, 6]))
+ """
+ ia = np.asarray(arr)
+ n = len(ia)
+
+ if n == 0:
+ warn("run length array empty")
+ # Returning None makes some other 1d func below fail.
+ return np.array(np.nan), 0, np.array(np.nan)
+ return _rle_1d(ia)
+
+
+def first_run_1d(arr: Sequence[int | float], window: int) -> int | np.nan:
+ """Return the index of the first item of a run of at least a given length.
+
+ Parameters
+ ----------
+ arr : Sequence[Union[int, float]]
+ Input array.
+ window : int
+ Minimum duration of consecutive run to accumulate values.
+
+ Returns
+ -------
+ int or np.nan
+ Index of first item in first valid run.
+ Returns np.nan if there are no valid runs.
+ """
+ v, rl, pos = rle_1d(arr)
+ ind = np.where(v * rl >= window, pos, np.inf).min()
+
+ if np.isinf(ind):
+ return np.nan
+ return ind
+
+
+def statistics_run_1d(arr: Sequence[bool], reducer: str, window: int) -> int:
+ """Return statistics on lengths of run of identical values.
+
+ Parameters
+ ----------
+ arr : Sequence[bool]
+ Input array (bool)
+ reducer : {'mean', 'sum', 'min', 'max', 'std'}
+ Reducing function name.
+ window : int
+ Minimal length of runs to be included in the statistics
+
+ Returns
+ -------
+ int
+ Statistics on length of runs.
+ """
+ v, rl = rle_1d(arr)[:2]
+ if not np.any(v) or np.all(v * rl < window):
+ return 0
+ func = getattr(np, f"nan{reducer}")
+ return func(np.where(v * rl >= window, rl, np.NaN))
+
+
+def windowed_run_count_1d(arr: Sequence[bool], window: int) -> int:
+ """Return the number of consecutive true values in array for runs at least as long as given duration.
+
+ Parameters
+ ----------
+ arr : Sequence[bool]
+ Input array (bool).
+ window : int
+ Minimum duration of consecutive run to accumulate values.
+
+ Returns
+ -------
+ int
+ Total number of true values part of a consecutive run at least `window` long.
+ """
+ v, rl = rle_1d(arr)[:2]
+ return np.where(v * rl >= window, rl, 0).sum()
+
+
+def windowed_run_events_1d(arr: Sequence[bool], window: int) -> xr.DataArray:
+ """Return the number of runs of a minimum length.
+
+ Parameters
+ ----------
+ arr : Sequence[bool]
+ Input array (bool).
+ window : int
+ Minimum run length.
+
+ Returns
+ -------
+ xr.DataArray, [int]
+ Number of distinct runs of a minimum length.
+ """
+ v, rl, _ = rle_1d(arr)
+ return (v * rl >= window).sum()
+
+
+def windowed_run_count_ufunc(
+ x: xr.DataArray | Sequence[bool], window: int, dim: str
+) -> xr.DataArray:
+ """Dask-parallel version of windowed_run_count_1d, ie: the number of consecutive true values in array for runs at least as long as given duration.
+
+ Parameters
+ ----------
+ x : Sequence[bool]
+ Input array (bool).
+ window : int
+ Minimum duration of consecutive run to accumulate values.
+ dim : str
+ Dimension along which to calculate windowed run.
+
+ Returns
+ -------
+ xr.DataArray
+ A function operating along the time dimension of a dask-array.
+ """
+ return xr.apply_ufunc(
+ windowed_run_count_1d,
+ x,
+ input_core_dims=[[dim]],
+ vectorize=True,
+ dask="parallelized",
+ output_dtypes=[int],
+ keep_attrs=True,
+ kwargs={"window": window},
+ )
+
+
+def windowed_run_events_ufunc(
+ x: xr.DataArray | Sequence[bool], window: int, dim: str
+) -> xr.DataArray:
+ """Dask-parallel version of windowed_run_events_1d, ie: the number of runs at least as long as given duration.
+
+ Parameters
+ ----------
+ x : Sequence[bool]
+ Input array (bool).
+ window : int
+ Minimum run length.
+ dim : str
+ Dimension along which to calculate windowed run.
+
+ Returns
+ -------
+ xr.DataArray
+ A function operating along the time dimension of a dask-array.
+ """
+ return xr.apply_ufunc(
+ windowed_run_events_1d,
+ x,
+ input_core_dims=[[dim]],
+ vectorize=True,
+ dask="parallelized",
+ output_dtypes=[int],
+ keep_attrs=True,
+ kwargs={"window": window},
+ )
+
+
+def statistics_run_ufunc(
+ x: xr.DataArray | Sequence[bool],
+ reducer: str,
+ window: int,
+ dim: str = "time",
+) -> xr.DataArray:
+ """Dask-parallel version of statistics_run_1d, ie: the {reducer} number of consecutive true values in array.
+
+ Parameters
+ ----------
+ x : Sequence[bool]
+ Input array (bool)
+ reducer: {'min', 'max', 'mean', 'sum', 'std'}
+ Reducing function name.
+ window : int
+ Minimal length of runs.
+ dim : str
+ The dimension along which the runs are found.
+
+ Returns
+ -------
+ xr.DataArray
+ A function operating along the time dimension of a dask-array.
+ """
+ return xr.apply_ufunc(
+ statistics_run_1d,
+ x,
+ input_core_dims=[[dim]],
+ kwargs={"reducer": reducer, "window": window},
+ vectorize=True,
+ dask="parallelized",
+ output_dtypes=[float],
+ keep_attrs=True,
+ )
+
+
+def first_run_ufunc(
+ x: xr.DataArray | Sequence[bool],
+ window: int,
+ dim: str,
+) -> xr.DataArray:
+ """Dask-parallel version of first_run_1d, ie: the first entry in array of consecutive true values.
+
+ Parameters
+ ----------
+ x : Union[xr.DataArray, Sequence[bool]]
+ Input array (bool).
+ window : int
+ Minimum run length.
+ dim : str
+ The dimension along which the runs are found.
+
+ Returns
+ -------
+ xr.DataArray
+ A function operating along the time dimension of a dask-array.
+ """
+ ind = xr.apply_ufunc(
+ first_run_1d,
+ x,
+ input_core_dims=[[dim]],
+ vectorize=True,
+ dask="parallelized",
+ output_dtypes=[float],
+ keep_attrs=True,
+ kwargs={"window": window},
+ )
+
+ return ind
+
+
+def lazy_indexing(
+ da: xr.DataArray, index: xr.DataArray, dim: str | None = None
+) -> xr.DataArray:
+ """Get values of `da` at indices `index` in a NaN-aware and lazy manner.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Input array. If not 1D, `dim` must be given and must not appear in index.
+ index : xr.DataArray
+ N-d integer indices, if da is not 1D, all dimensions of index must be in da
+ dim : str, optional
+ Dimension along which to index, unused if `da` is 1D, should not be present in `index`.
+
+ Returns
+ -------
+ xr.DataArray
+ Values of `da` at indices `index`.
+ """
+ if da.ndim == 1:
+ # Case where da is 1D and index is N-D
+ # Slightly better performance using map_blocks, over an apply_ufunc
+ def _index_from_1d_array(indices, array):
+ return array[indices]
+
+ idx_ndim = index.ndim
+ if idx_ndim == 0:
+ # The 0-D index case, we add a dummy dimension to help dask
+ dim = get_temp_dimname(da.dims, "x")
+ index = index.expand_dims(dim)
+ # Which indexes to mask.
+ invalid = index.isnull()
+ # NaN-indexing doesn't work, so fill with 0 and cast to int
+ index = index.fillna(0).astype(int)
+
+ # No need for coords, we extract by integer index.
+ # Renaming with no name to fix bug in xr 2024.01.0
+ tmpname = get_temp_dimname(da.dims, "temp")
+ da2 = xr.DataArray(da.data, dims=(tmpname,), name=None)
+ # for each chunk of index, take corresponding values from da
+ out = index.map_blocks(_index_from_1d_array, args=(da2,)).rename(da.name)
+ # mask where index was NaN. Drop any auxiliary coord, they are already on `out`.
+ # Chunked aux coord would have the same name on both sides and xarray will want to check if they are equal, which means loading them
+ # making lazy_indexing not lazy.
+ out = out.where(
+ ~invalid.drop_vars(
+ [crd for crd in invalid.coords if crd not in invalid.dims]
+ )
+ )
+ if idx_ndim == 0:
+ # 0-D case, drop useless coords and dummy dim
+ out = out.drop_vars(da.dims[0], errors="ignore").squeeze()
+ return out.drop_vars(dim or da.dims[0], errors="ignore")
+
+ # Case where index.dims is a subset of da.dims.
+ if dim is None:
+ diff_dims = set(da.dims) - set(index.dims)
+ if len(diff_dims) == 0:
+ raise ValueError(
+ "da must have at least one dimension more than index for lazy_indexing."
+ )
+ if len(diff_dims) > 1:
+ raise ValueError(
+ "If da has more than one dimension more than index, the indexing dim must be given through `dim`"
+ )
+ dim = diff_dims.pop()
+
+ def _index_from_nd_array(array, indices):
+ return np.take_along_axis(array, indices[..., np.newaxis], axis=-1)[..., 0]
+
+ return xr.apply_ufunc(
+ _index_from_nd_array,
+ da,
+ index.astype(int),
+ input_core_dims=[[dim], []],
+ output_core_dims=[[]],
+ dask="parallelized",
+ output_dtypes=[da.dtype],
+ )
+
+
+def index_of_date(
+ time: xr.DataArray,
+ date: DateStr | DayOfYearStr | None,
+ max_idxs: int | None = None,
+ default: int = 0,
+) -> np.ndarray:
+ """Get the index of a date in a time array.
+
+ Parameters
+ ----------
+ time : xr.DataArray
+ An array of datetime values, any calendar.
+ date : DayOfYearStr or DateStr, optional
+ A string in the "yyyy-mm-dd" or "mm-dd" format.
+ If None, returns default.
+ max_idxs : int, optional
+ Maximum number of returned indexes.
+ default : int
+ Index to return if date is None.
+
+ Raises
+ ------
+ ValueError
+ If there are most instances of `date` in `time` than `max_idxs`.
+
+ Returns
+ -------
+ numpy.ndarray
+ 1D array of integers, indexes of `date` in `time`.
+ """
+ if date is None:
+ return np.array([default])
+ try:
+ date = datetime.strptime(date, "%Y-%m-%d")
+ year_cond = time.dt.year == date.year
+ except ValueError:
+ date = datetime.strptime(date, "%m-%d")
+ year_cond = True
+
+ idxs = np.where(
+ year_cond & (time.dt.month == date.month) & (time.dt.day == date.day)
+ )[0]
+ if max_idxs is not None and idxs.size > max_idxs:
+ raise ValueError(
+ f"More than {max_idxs} instance of date {date} found in the coordinate array."
+ )
+ return idxs
+
+
+def suspicious_run_1d(
+ arr: np.ndarray,
+ window: int = 10,
+ op: str = ">",
+ thresh: float | None = None,
+) -> np.ndarray:
+ """Return True where the array contains a run of identical values.
+
+ Parameters
+ ----------
+ arr : numpy.ndarray
+ Array of values to be parsed.
+ window : int
+ Minimum run length.
+ op : {">", ">=", "==", "<", "<=", "eq", "gt", "lt", "gteq", "lteq", "ge", "le"}
+ Operator for threshold comparison. Defaults to ">".
+ thresh : float, optional
+ Threshold compared against which values are checked for identical values.
+
+ Returns
+ -------
+ numpy.ndarray
+ Whether or not the data points are part of a run of identical values.
+ """
+ v, rl, pos = rle_1d(arr)
+ sus_runs = rl >= window
+ if thresh is not None:
+ if op in {">", "gt"}:
+ sus_runs = sus_runs & (v > thresh)
+ elif op in {"<", "lt"}:
+ sus_runs = sus_runs & (v < thresh)
+ elif op in {"==", "eq"}:
+ sus_runs = sus_runs & (v == thresh)
+ elif op in {"!=", "ne"}:
+ sus_runs = sus_runs & (v != thresh)
+ elif op in {">=", "gteq", "ge"}:
+ sus_runs = sus_runs & (v >= thresh)
+ elif op in {"<=", "lteq", "le"}:
+ sus_runs = sus_runs & (v <= thresh)
+ else:
+ raise NotImplementedError(f"{op}")
+
+ out = np.zeros_like(arr, dtype=bool)
+ for st, l in zip(pos[sus_runs], rl[sus_runs]): # noqa: E741
+ out[st : st + l] = True # noqa: E741
+ return out
+
+
+def suspicious_run(
+ arr: xr.DataArray,
+ dim: str = "time",
+ window: int = 10,
+ op: str = ">",
+ thresh: float | None = None,
+) -> xr.DataArray:
+ """Return True where the array contains has runs of identical values, vectorized version.
+
+ In opposition to other run length functions, here the output has the same shape as the input.
+
+ Parameters
+ ----------
+ arr : xr.DataArray
+ Array of values to be parsed.
+ dim : str
+ Dimension along which to check for runs (default: "time").
+ window : int
+ Minimum run length.
+ op : {">", ">=", "==", "<", "<=", "eq", "gt", "lt", "gteq", "lteq"}
+ Operator for threshold comparison, defaults to ">".
+ thresh : float, optional
+ Threshold above which values are checked for identical values.
+
+ Returns
+ -------
+ xarray.DataArray
+ """
+ return xr.apply_ufunc(
+ suspicious_run_1d,
+ arr,
+ input_core_dims=[[dim]],
+ output_core_dims=[[dim]],
+ vectorize=True,
+ dask="parallelized",
+ output_dtypes=[bool],
+ keep_attrs=True,
+ kwargs=dict(window=window, op=op, thresh=thresh),
+ )
diff --git a/src/xsdba/xclim_submodules/stats.py b/src/xsdba/xclim_submodules/stats.py
new file mode 100644
index 0000000..4a26152
--- /dev/null
+++ b/src/xsdba/xclim_submodules/stats.py
@@ -0,0 +1,622 @@
+"""Statistic-related functions. See the `frequency_analysis` notebook for examples."""
+
+from __future__ import annotations
+
+import json
+import warnings
+from collections.abc import Sequence
+from typing import Any
+
+import numpy as np
+import scipy.stats
+import xarray as xr
+
+from xsdba.base import uses_dask
+from xsdba.formatting import prefix_attrs, unprefix_attrs, update_history
+from xsdba.typing import DateStr, Quantified
+from xsdba.units import convert_units_to
+
+from . import generic
+
+__all__ = [
+ "_fit_start",
+ "dist_method",
+ "fa",
+ "fit",
+ "frequency_analysis",
+ "get_dist",
+ "parametric_cdf",
+ "parametric_quantile",
+]
+
+
+# Fit the parameters.
+# This would also be the place to impose constraints on the series minimum length if needed.
+def _fitfunc_1d(arr, *, dist, nparams, method, **fitkwargs):
+ """Fit distribution parameters."""
+ x = np.ma.masked_invalid(arr).compressed() # pylint: disable=no-member
+
+ # Return NaNs if array is empty.
+ if len(x) <= 1:
+ return np.asarray([np.nan] * nparams)
+
+ # Estimate parameters
+ if method in ["ML", "MLE"]:
+ args, kwargs = _fit_start(x, dist.name, **fitkwargs)
+ params = dist.fit(x, *args, method="mle", **kwargs, **fitkwargs)
+ elif method == "MM":
+ params = dist.fit(x, method="mm", **fitkwargs)
+ elif method == "PWM":
+ params = list(dist.lmom_fit(x).values())
+ elif method == "APP":
+ args, kwargs = _fit_start(x, dist.name, **fitkwargs)
+ kwargs.setdefault("loc", 0)
+ params = list(args) + [kwargs["loc"], kwargs["scale"]]
+ else:
+ raise NotImplementedError(f"Unknown method `{method}`.")
+
+ params = np.asarray(params)
+
+ # Fill with NaNs if one of the parameters is NaN
+ if np.isnan(params).any():
+ params[:] = np.nan
+
+ return params
+
+
+def fit(
+ da: xr.DataArray,
+ dist: str | scipy.stats.rv_continuous = "norm",
+ method: str = "ML",
+ dim: str = "time",
+ **fitkwargs: Any,
+) -> xr.DataArray:
+ r"""Fit an array to a univariate distribution along the time dimension.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Time series to be fitted along the time dimension.
+ dist : str or rv_continuous distribution object
+ Name of the univariate distribution, such as beta, expon, genextreme, gamma, gumbel_r, lognorm, norm
+ (see :py:mod:scipy.stats for full list) or the distribution object itself.
+ method : {"ML" or "MLE", "MM", "PWM", "APP"}
+ Fitting method, either maximum likelihood (ML or MLE), method of moments (MM) or approximate method (APP).
+ Can also be the probability weighted moments (PWM), also called L-Moments, if a compatible `dist` object is passed.
+ The PWM method is usually more robust to outliers.
+ dim : str
+ The dimension upon which to perform the indexing (default: "time").
+ \*\*fitkwargs
+ Other arguments passed directly to :py:func:`_fitstart` and to the distribution's `fit`.
+
+ Returns
+ -------
+ xr.DataArray
+ An array of fitted distribution parameters.
+
+ Notes
+ -----
+ Coordinates for which all values are NaNs will be dropped before fitting the distribution. If the array still
+ contains NaNs, the distribution parameters will be returned as NaNs.
+ """
+ method = method.upper()
+ method_name = {
+ "ML": "maximum likelihood",
+ "MM": "method of moments",
+ "MLE": "maximum likelihood",
+ "PWM": "probability weighted moments",
+ "APP": "approximative method",
+ }
+ if method not in method_name:
+ raise ValueError(f"Fitting method not recognized: {method}")
+
+ # Get the distribution
+ dist = get_dist(dist)
+
+ if method == "PWM" and not hasattr(dist, "lmom_fit"):
+ raise ValueError(
+ f"The given distribution {dist} does not implement the PWM fitting method. Please pass an instance from the lmoments3 package."
+ )
+
+ shape_params = [] if dist.shapes is None else dist.shapes.split(",")
+ dist_params = shape_params + ["loc", "scale"]
+
+ data = xr.apply_ufunc(
+ _fitfunc_1d,
+ da,
+ input_core_dims=[[dim]],
+ output_core_dims=[["dparams"]],
+ vectorize=True,
+ dask="parallelized",
+ output_dtypes=[float],
+ keep_attrs=True,
+ kwargs=dict(
+ # Don't know how APP should be included, this works for now
+ dist=dist,
+ nparams=len(dist_params),
+ method=method,
+ **fitkwargs,
+ ),
+ dask_gufunc_kwargs={"output_sizes": {"dparams": len(dist_params)}},
+ )
+
+ # Add coordinates for the distribution parameters and transpose to original shape (with dim -> dparams)
+ dims = [d if d != dim else "dparams" for d in da.dims]
+ out = data.assign_coords(dparams=dist_params).transpose(*dims)
+
+ out.attrs = prefix_attrs(
+ da.attrs, ["standard_name", "long_name", "units", "description"], "original_"
+ )
+ attrs = dict(
+ long_name=f"{dist.name} parameters",
+ description=f"Parameters of the {dist.name} distribution",
+ method=method,
+ estimator=method_name[method].capitalize(),
+ scipy_dist=dist.name,
+ units="",
+ history=update_history(
+ f"Estimate distribution parameters by {method_name[method]} method along dimension {dim}.",
+ new_name="fit",
+ data=da,
+ ),
+ )
+ out.attrs.update(attrs)
+ return out
+
+
+def parametric_quantile(
+ p: xr.DataArray,
+ q: float | Sequence[float],
+ dist: str | scipy.stats.rv_continuous | None = None,
+) -> xr.DataArray:
+ """Return the value corresponding to the given distribution parameters and quantile.
+
+ Parameters
+ ----------
+ p : xr.DataArray
+ Distribution parameters returned by the `fit` function.
+ The array should have dimension `dparams` storing the distribution parameters,
+ and attribute `scipy_dist`, storing the name of the distribution.
+ q : float or Sequence of float
+ Quantile to compute, which must be between `0` and `1`, inclusive.
+ dist: str, rv_continuous instance, optional
+ The distribution name or instance if the `scipy_dist` attribute is not available on `p`.
+
+ Returns
+ -------
+ xarray.DataArray
+ An array of parametric quantiles estimated from the distribution parameters.
+
+ Notes
+ -----
+ When all quantiles are above 0.5, the `isf` method is used instead of `ppf` because accuracy is sometimes better.
+ """
+ q = np.atleast_1d(q)
+
+ dist = get_dist(dist or p.attrs["scipy_dist"])
+
+ # Create a lambda function to facilitate passing arguments to dask. There is probably a better way to do this.
+ if np.all(q > 0.5):
+
+ def func(x):
+ return dist.isf(1 - q, *x)
+
+ else:
+
+ def func(x):
+ return dist.ppf(q, *x)
+
+ data = xr.apply_ufunc(
+ func,
+ p,
+ input_core_dims=[["dparams"]],
+ output_core_dims=[["quantile"]],
+ vectorize=True,
+ dask="parallelized",
+ output_dtypes=[float],
+ keep_attrs=True,
+ dask_gufunc_kwargs={"output_sizes": {"quantile": len(q)}},
+ )
+
+ # Assign quantile coordinates and transpose to preserve original dimension order
+ dims = [d if d != "dparams" else "quantile" for d in p.dims]
+ out = data.assign_coords(quantile=q).transpose(*dims)
+ out.attrs = unprefix_attrs(p.attrs, ["units", "standard_name"], "original_")
+
+ attrs = dict(
+ long_name=f"{dist.name} quantiles",
+ description=f"Quantiles estimated by the {dist.name} distribution",
+ cell_methods="dparams: ppf",
+ history=update_history(
+ "Compute parametric quantiles from distribution parameters",
+ new_name="parametric_quantile",
+ parameters=p,
+ ),
+ )
+ out.attrs.update(attrs)
+ return out
+
+
+def parametric_cdf(
+ p: xr.DataArray,
+ v: float | Sequence[float],
+ dist: str | scipy.stats.rv_continuous | None = None,
+) -> xr.DataArray:
+ """Return the cumulative distribution function corresponding to the given distribution parameters and value.
+
+ Parameters
+ ----------
+ p : xr.DataArray
+ Distribution parameters returned by the `fit` function.
+ The array should have dimension `dparams` storing the distribution parameters,
+ and attribute `scipy_dist`, storing the name of the distribution.
+ v : float or Sequence of float
+ Value to compute the CDF.
+ dist: str, rv_continuous instance, optional
+ The distribution name or instance is the `scipy_dist` attribute is not available on `p`.
+
+ Returns
+ -------
+ xarray.DataArray
+ An array of parametric CDF values estimated from the distribution parameters.
+ """
+ v = np.atleast_1d(v)
+
+ dist = get_dist(dist or p.attrs["scipy_dist"])
+
+ # Create a lambda function to facilitate passing arguments to dask. There is probably a better way to do this.
+ def func(x):
+ return dist.cdf(v, *x)
+
+ data = xr.apply_ufunc(
+ func,
+ p,
+ input_core_dims=[["dparams"]],
+ output_core_dims=[["cdf"]],
+ vectorize=True,
+ dask="parallelized",
+ output_dtypes=[float],
+ keep_attrs=True,
+ dask_gufunc_kwargs={"output_sizes": {"cdf": len(v)}},
+ )
+
+ # Assign quantile coordinates and transpose to preserve original dimension order
+ dims = [d if d != "dparams" else "cdf" for d in p.dims]
+ out = data.assign_coords(cdf=v).transpose(*dims)
+ out.attrs = unprefix_attrs(p.attrs, ["units", "standard_name"], "original_")
+
+ attrs = dict(
+ long_name=f"{dist.name} cdf",
+ description=f"CDF estimated by the {dist.name} distribution",
+ cell_methods="dparams: cdf",
+ history=update_history(
+ "Compute parametric cdf from distribution parameters",
+ new_name="parametric_cdf",
+ parameters=p,
+ ),
+ )
+ out.attrs.update(attrs)
+ return out
+
+
+def fa(
+ da: xr.DataArray,
+ t: int | Sequence,
+ dist: str | scipy.stats.rv_continuous = "norm",
+ mode: str = "max",
+ method: str = "ML",
+) -> xr.DataArray:
+ """Return the value corresponding to the given return period.
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ Maximized/minimized input data with a `time` dimension.
+ t : int or Sequence of int
+ Return period. The period depends on the resolution of the input data. If the input array's resolution is
+ yearly, then the return period is in years.
+ dist : str or rv_continuous instance
+ Name of the univariate distribution, such as:
+ `beta`, `expon`, `genextreme`, `gamma`, `gumbel_r`, `lognorm`, `norm`
+ Or the distribution instance itself.
+ mode : {'min', 'max}
+ Whether we are looking for a probability of exceedance (max) or a probability of non-exceedance (min).
+ method : {"ML", "MLE", "MOM", "PWM", "APP"}
+ Fitting method, either maximum likelihood (ML or MLE), method of moments (MOM) or approximate method (APP).
+ Also accepts probability weighted moments (PWM), also called L-Moments, if `dist` is an instance from the lmoments3 library.
+ The PWM method is usually more robust to outliers.
+
+ Returns
+ -------
+ xarray.DataArray
+ An array of values with a 1/t probability of exceedance (if mode=='max').
+
+ See Also
+ --------
+ scipy.stats : For descriptions of univariate distribution types.
+ """
+ # Fit the parameters of the distribution
+ p = fit(da, dist, method=method)
+ t = np.atleast_1d(t)
+
+ if mode in ["max", "high"]:
+ q = 1 - 1.0 / t
+
+ elif mode in ["min", "low"]:
+ q = 1.0 / t
+
+ else:
+ raise ValueError(f"Mode `{mode}` should be either 'max' or 'min'.")
+
+ # Compute the quantiles
+ out = (
+ parametric_quantile(p, q, dist)
+ .rename({"quantile": "return_period"})
+ .assign_coords(return_period=t)
+ )
+ out.attrs["mode"] = mode
+ return out
+
+
+def frequency_analysis(
+ da: xr.DataArray,
+ mode: str,
+ t: int | Sequence[int],
+ dist: str | scipy.stats.rv_continuous,
+ window: int = 1,
+ freq: str | None = None,
+ method: str = "ML",
+ **indexer: int | float | str,
+) -> xr.DataArray:
+ r"""Return the value corresponding to a return period.
+
+ Parameters
+ ----------
+ da : xarray.DataArray
+ Input data.
+ mode : {'min', 'max'}
+ Whether we are looking for a probability of exceedance (high) or a probability of non-exceedance (low).
+ t : int or sequence
+ Return period. The period depends on the resolution of the input data. If the input array's resolution is
+ yearly, then the return period is in years.
+ dist : str or rv_continuous
+ Name of the univariate distribution, e.g. `beta`, `expon`, `genextreme`, `gamma`, `gumbel_r`, `lognorm`, `norm`.
+ Or an instance of the distribution.
+ window : int
+ Averaging window length (days).
+ freq : str, optional
+ Resampling frequency. If None, the frequency is assumed to be 'YS' unless the indexer is season='DJF',
+ in which case `freq` would be set to `YS-DEC`.
+ method : {"ML" or "MLE", "MOM", "PWM", "APP"}
+ Fitting method, either maximum likelihood (ML or MLE), method of moments (MOM) or approximate method (APP).
+ Also accepts probability weighted moments (PWM), also called L-Moments, if `dist` is an instance from the lmoments3 library.
+ The PWM method is usually more robust to outliers.
+ \*\*indexer
+ Time attribute and values over which to subset the array. For example, use season='DJF' to select winter values,
+ month=1 to select January, or month=[6,7,8] to select summer months. If indexer is not provided, all values are
+ considered.
+
+ Returns
+ -------
+ xarray.DataArray
+ An array of values with a 1/t probability of exceedance or non-exceedance when mode is high or low respectively.
+
+ See Also
+ --------
+ scipy.stats : For descriptions of univariate distribution types.
+ """
+ # Apply rolling average
+ attrs = da.attrs.copy()
+ if window > 1:
+ da = da.rolling(time=window).mean(skipna=False)
+ da.attrs.update(attrs)
+
+ # Assign default resampling frequency if not provided
+ freq = freq or generic.default_freq(**indexer)
+
+ # Extract the time series of min or max over the period
+ sel = generic.select_resample_op(da, op=mode, freq=freq, **indexer)
+
+ if uses_dask(sel):
+ sel = sel.chunk({"time": -1})
+ # Frequency analysis
+ return fa(sel, t, dist=dist, mode=mode, method=method)
+
+
+def get_dist(dist: str | scipy.stats.rv_continuous):
+ """Return a distribution object from `scipy.stats`."""
+ if isinstance(dist, scipy.stats.rv_continuous):
+ return dist
+
+ dc = getattr(scipy.stats, dist, None)
+ if dc is None:
+ e = f"Statistical distribution `{dist}` is not found in scipy.stats."
+ raise ValueError(e)
+ return dc
+
+
+def _fit_start(x, dist: str, **fitkwargs: Any) -> tuple[tuple, dict]:
+ r"""Return initial values for distribution parameters.
+
+ Providing the ML fit method initial values can help the optimizer find the global optimum.
+
+ Parameters
+ ----------
+ x : array-like
+ Input data.
+ dist : str
+ Name of the univariate distribution, e.g. `beta`, `expon`, `genextreme`, `gamma`, `gumbel_r`, `lognorm`, `norm`.
+ (see :py:mod:scipy.stats). Only `genextreme` and `weibull_exp` distributions are supported.
+ \*\*fitkwargs
+ Kwargs passed to fit.
+
+ Returns
+ -------
+ tuple, dict
+
+ References
+ ----------
+ :cite:cts:`coles_introduction_2001,cohen_parameter_2019, thom_1958, cooke_1979, muralidhar_1992`
+
+ """
+ x = np.asarray(x)
+ m = x.mean()
+ v = x.var()
+
+ if dist == "genextreme":
+ s = np.sqrt(6 * v) / np.pi
+ return (0.1,), {"loc": m - 0.57722 * s, "scale": s}
+
+ if dist == "genpareto" and "floc" in fitkwargs:
+ # Taken from julia' Extremes. Case for when "mu/loc" is known.
+ t = fitkwargs["floc"]
+ if not np.isclose(t, 0):
+ m = (x - t).mean()
+ v = (x - t).var()
+
+ c = 0.5 * (1 - m**2 / v)
+ scale = (1 - c) * m
+ return (c,), {"scale": scale}
+
+ if dist in "weibull_min":
+ s = x.std()
+ loc = x.min() - 0.01 * s
+ chat = np.pi / np.sqrt(6) / (np.log(x - loc)).std()
+ scale = ((x - loc) ** chat).mean() ** (1 / chat)
+ return (chat,), {"loc": loc, "scale": scale}
+
+ if dist in ["gamma"]:
+ if "floc" in fitkwargs:
+ loc0 = fitkwargs["floc"]
+ else:
+ xs = sorted(x)
+ x1, x2, xn = xs[0], xs[1], xs[-1]
+ # muralidhar_1992 would suggest the following, but it seems more unstable
+ # using cooke_1979 for now
+ # n = len(x)
+ # cv = x.std() / x.mean()
+ # p = (0.48265 + 0.32967 * cv) * n ** (-0.2984 * cv)
+ # xp = xs[int(p/100*n)]
+ xp = x2
+ loc0 = (x1 * xn - xp**2) / (x1 + xn - 2 * xp)
+ loc0 = loc0 if loc0 < x1 else (0.9999 * x1 if x1 > 0 else 1.0001 * x1)
+ x_pos = x - loc0
+ x_pos = x_pos[x_pos > 0]
+ m = x_pos.mean()
+ log_of_mean = np.log(m)
+ mean_of_logs = np.log(x_pos).mean()
+ A = log_of_mean - mean_of_logs
+ a0 = (1 + np.sqrt(1 + 4 * A / 3)) / (4 * A)
+ scale0 = m / a0
+ kwargs = {"scale": scale0, "loc": loc0}
+ return (a0,), kwargs
+
+ if dist in ["fisk"]:
+ if "floc" in fitkwargs:
+ loc0 = fitkwargs["floc"]
+ else:
+ xs = sorted(x)
+ x1, x2, xn = xs[0], xs[1], xs[-1]
+ loc0 = (x1 * xn - x2**2) / (x1 + xn - 2 * x2)
+ loc0 = loc0 if loc0 < x1 else (0.9999 * x1 if x1 > 0 else 1.0001 * x1)
+ x_pos = x - loc0
+ x_pos = x_pos[x_pos > 0]
+ # method of moments:
+ # LHS is computed analytically with the two-parameters log-logistic distribution
+ # and depends on alpha,beta
+ # RHS is from the sample
+ # = m
+ # / ^2 = m2/m**2
+ # solving these equations yields
+ m = x_pos.mean()
+ m2 = (x_pos**2).mean()
+ scale0 = 2 * m**3 / (m2 + m**2)
+ c0 = np.pi * m / np.sqrt(3) / np.sqrt(m2 - m**2)
+ kwargs = {"scale": scale0, "loc": loc0}
+ return (c0,), kwargs
+ return (), {}
+
+
+def _dist_method_1D( # noqa: N802
+ *args, dist: str | scipy.stats.rv_continuous, function: str, **kwargs: Any
+) -> xr.DataArray:
+ r"""Statistical function for given argument on given distribution initialized with params.
+
+ See :py:ref:`scipy:scipy.stats.rv_continuous` for all available functions and their arguments.
+ Every method where `"*args"` are the distribution parameters can be wrapped.
+
+ Parameters
+ ----------
+ \*args
+ The arguments for the requested scipy function.
+ dist : str
+ The scipy name of the distribution.
+ function : str
+ The name of the function to call.
+ \*\*kwargs
+ Other parameters to pass to the function call.
+
+ Returns
+ -------
+ array_like
+ """
+ dist = get_dist(dist)
+ return getattr(dist, function)(*args, **kwargs)
+
+
+def dist_method(
+ function: str,
+ fit_params: xr.DataArray,
+ arg: xr.DataArray | None = None,
+ dist: str | scipy.stats.rv_continuous | None = None,
+ **kwargs: Any,
+) -> xr.DataArray:
+ r"""Vectorized statistical function for given argument on given distribution initialized with params.
+
+ Methods where `"*args"` are the distribution parameters can be wrapped, except those that reduce dimensions (
+ e.g. `nnlf`) or create new dimensions (eg: 'rvs' with size != 1, 'stats' with more than one moment, 'interval',
+ 'support').
+
+ Parameters
+ ----------
+ function : str
+ The name of the function to call.
+ fit_params : xr.DataArray
+ Distribution parameters are along `dparams`, in the same order as given by :py:func:`fit`.
+ arg : array_like, optional
+ The first argument for the requested function if different from `fit_params`.
+ dist : str pr rv_continuous, optional
+ The distribution name or instance. Defaults to the `scipy_dist` attribute or `fit_params`.
+ \*\*kwargs
+ Other parameters to pass to the function call.
+
+ Returns
+ -------
+ array_like
+ Same shape as arg.
+
+ See Also
+ --------
+ scipy:scipy.stats.rv_continuous : for all available functions and their arguments.
+ """
+ # Typically the data to be transformed
+ arg = [arg] if arg is not None else []
+ if function == "nnlf":
+ raise ValueError(
+ "This method is not supported because it reduces the dimensionality of the data."
+ )
+
+ # We don't need to set `input_core_dims` because we're explicitly splitting the parameters here.
+ args = arg + [fit_params.sel(dparams=dp) for dp in fit_params.dparams.values]
+
+ return xr.apply_ufunc(
+ _dist_method_1D,
+ *args,
+ kwargs={
+ "dist": dist or fit_params.attrs["scipy_dist"],
+ "function": function,
+ **kwargs,
+ },
+ output_dtypes=[float],
+ dask="parallelized",
+ )
diff --git a/tests/test_properties.py b/tests/test_properties.py
new file mode 100644
index 0000000..5b693bf
--- /dev/null
+++ b/tests/test_properties.py
@@ -0,0 +1,577 @@
+from __future__ import annotations
+
+import numpy as np
+import pandas as pd
+import pytest
+import xarray as xr
+from xarray import set_options
+
+from xsdba import properties
+from xsdba.units import convert_units_to
+
+
+class TestProperties:
+ def test_mean(self, open_dataset):
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1980"), location="Vancouver")
+ .pr
+ ).load()
+
+ out_year = properties.mean(sim)
+ np.testing.assert_array_almost_equal(out_year.values, [3.0016028e-05])
+
+ out_season = properties.mean(sim, group="time.season")
+ np.testing.assert_array_almost_equal(
+ out_season.values,
+ [4.6115547e-05, 1.7220482e-05, 2.8805329e-05, 2.825359e-05],
+ )
+
+ assert out_season.long_name.startswith("Mean")
+
+ def test_var(self, open_dataset):
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1980"), location="Vancouver")
+ .pr
+ ).load()
+
+ out_year = properties.var(sim)
+ np.testing.assert_array_almost_equal(out_year.values, [2.5884779e-09])
+
+ out_season = properties.var(sim, group="time.season")
+ np.testing.assert_array_almost_equal(
+ out_season.values,
+ [3.9270796e-09, 1.2538864e-09, 1.9057025e-09, 2.8776632e-09],
+ )
+ assert out_season.long_name.startswith("Variance")
+ assert out_season.units == "kg2 m-4 s-2"
+
+ def test_std(self, open_dataset):
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1980"), location="Vancouver")
+ .pr
+ ).load()
+
+ out_year = properties.std(sim)
+ np.testing.assert_array_almost_equal(out_year.values, [5.08770208398345e-05])
+
+ out_season = properties.std(sim, group="time.season")
+ np.testing.assert_array_almost_equal(
+ out_season.values,
+ [6.2666411e-05, 3.5410259e-05, 4.3654352e-05, 5.3643853e-05],
+ )
+ assert out_season.long_name.startswith("Standard deviation")
+ assert out_season.units == "kg m-2 s-1"
+
+ def test_skewness(self, open_dataset):
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1980"), location="Vancouver")
+ .pr
+ ).load()
+
+ out_year = properties.skewness(sim)
+ np.testing.assert_array_almost_equal(out_year.values, [2.8497460898513745])
+
+ out_season = properties.skewness(sim, group="time.season")
+ np.testing.assert_array_almost_equal(
+ out_season.values,
+ [
+ 2.036650744163691,
+ 3.7909534745807147,
+ 2.416590445325826,
+ 3.3521301798559566,
+ ],
+ )
+ assert out_season.long_name.startswith("Skewness")
+ assert out_season.units == ""
+
+ def test_quantile(self, open_dataset):
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1980"), location="Vancouver")
+ .pr
+ ).load()
+
+ out_year = properties.quantile(sim, q=0.2)
+ np.testing.assert_array_almost_equal(out_year.values, [2.8109431013945154e-07])
+
+ out_season = properties.quantile(sim, group="time.season", q=0.2)
+ np.testing.assert_array_almost_equal(
+ out_season.values,
+ [
+ 1.5171653330980917e-06,
+ 9.822543773907455e-08,
+ 1.8135805248675763e-07,
+ 4.135342521749408e-07,
+ ],
+ )
+ assert out_season.long_name.startswith("Quantile 0.2")
+
+ # TODO: test theshold_count? it's the same a test_spell_length_distribution
+ def test_spell_length_distribution(self, open_dataset):
+ ds = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1952"), location="Vancouver")
+ .load()
+ )
+
+ # test pr, with amount method
+ sim = ds.pr
+ kws = {"op": "<", "group": "time.month", "thresh": "1.157e-05 kg/m/m/s"}
+ outd = {
+ stat: properties.spell_length_distribution(da=sim, **kws, stat=stat)
+ .sel(month=1)
+ .values
+ for stat in ["mean", "max", "min"]
+ }
+ np.testing.assert_array_almost_equal(
+ [outd[k] for k in ["mean", "max", "min"]], [2.44127, 10, 1]
+ )
+
+ # test tasmax, with quantile method
+ simt = ds.tasmax
+ kws = {"thresh": 0.9, "op": ">=", "method": "quantile", "group": "time.month"}
+ outd = {
+ stat: properties.spell_length_distribution(da=simt, **kws, stat=stat).sel(
+ month=6
+ )
+ for stat in ["mean", "max", "min"]
+ }
+ np.testing.assert_array_almost_equal(
+ [outd[k].values for k in ["mean", "max", "min"]], [3.0, 6, 1]
+ )
+
+ # test varia
+ with pytest.raises(
+ ValueError,
+ match="percentile is not a valid method. Choose 'amount' or 'quantile'.",
+ ):
+ properties.spell_length_distribution(simt, method="percentile")
+
+ assert (
+ outd["mean"].long_name
+ == "Average of spell length distribution when the variable is >= the quantile 0.9 for 1 consecutive day(s)."
+ )
+
+ def test_spell_length_distribution_mixed_stat(self, open_dataset):
+
+ time = pd.date_range("2000-01-01", periods=2 * 365, freq="D")
+ tas = xr.DataArray(
+ np.array([0] * 365 + [40] * 365),
+ dims=("time"),
+ coords={"time": time},
+ attrs={"units": "degC"},
+ )
+
+ kws_sum = dict(
+ thresh="30 degC", op=">=", stat="sum", stat_resample="sum", group="time"
+ )
+ out_sum = properties.spell_length_distribution(tas, **kws_sum).values
+ kws_mixed = dict(
+ thresh="30 degC", op=">=", stat="mean", stat_resample="sum", group="time"
+ )
+ out_mixed = properties.spell_length_distribution(tas, **kws_mixed).values
+
+ assert out_sum == 365
+ assert out_mixed == 182.5
+
+ @pytest.mark.parametrize(
+ "window,expected_amount,expected_quantile",
+ [
+ (1, [2.333333, 4, 1], [3, 6, 1]),
+ (3, [1.333333, 4, 0], [2, 6, 0]),
+ ],
+ )
+ def test_bivariate_spell_length_distribution(
+ self, open_dataset, window, expected_amount, expected_quantile
+ ):
+ ds = (
+ open_dataset("sdba/CanESM2_1950-2100.nc").sel(
+ time=slice("1950", "1952"), location="Vancouver"
+ )
+ ).load()
+ tx = ds.tasmax
+ with set_options(keep_attrs=True):
+ tn = tx - 5
+
+ # test with amount method
+ kws = {
+ "thresh1": "0 degC",
+ "thresh2": "0 degC",
+ "op1": ">",
+ "op2": "<=",
+ "group": "time.month",
+ "window": window,
+ }
+ outd = {
+ stat: properties.bivariate_spell_length_distribution(
+ da1=tx, da2=tn, **kws, stat=stat
+ )
+ .sel(month=1)
+ .values
+ for stat in ["mean", "max", "min"]
+ }
+ np.testing.assert_array_almost_equal(
+ [outd[k] for k in ["mean", "max", "min"]], expected_amount
+ )
+
+ # test with quantile method
+ kws = {
+ "thresh1": 0.9,
+ "thresh2": 0.9,
+ "op1": ">",
+ "op2": ">",
+ "method1": "quantile",
+ "method2": "quantile",
+ "group": "time.month",
+ "window": window,
+ }
+ outd = {
+ stat: properties.bivariate_spell_length_distribution(
+ da1=tx, da2=tn, **kws, stat=stat
+ )
+ .sel(month=6)
+ .values
+ for stat in ["mean", "max", "min"]
+ }
+ np.testing.assert_array_almost_equal(
+ [outd[k] for k in ["mean", "max", "min"]], expected_quantile
+ )
+
+ def test_acf(self, open_dataset):
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1952"), location="Vancouver")
+ .pr
+ ).load()
+
+ out = properties.acf(sim, lag=1, group="time.month").sel(month=1)
+ np.testing.assert_array_almost_equal(out.values, [0.11242357313756905])
+
+ # FIXME
+ # with pytest.raises(ValueError, match="Grouping period year is not allowed for"):
+ # properties.acf(sim, group="time")
+
+ assert out.long_name.startswith("Lag-1 autocorrelation")
+ assert out.units == ""
+
+ def test_annual_cycle(self, open_dataset):
+ simt = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1952"), location="Vancouver")
+ .tasmax
+ ).load()
+
+ amp = properties.annual_cycle_amplitude(simt)
+ relamp = properties.relative_annual_cycle_amplitude(simt)
+ phase = properties.annual_cycle_phase(simt)
+
+ np.testing.assert_allclose(
+ [amp.values, relamp.values, phase.values],
+ [16.74645996, 5.802083, 167],
+ rtol=1e-6,
+ )
+ # FIXME
+ # with pytest.raises(
+ # ValueError,
+ # match="Grouping period season is not allowed for property",
+ # ):
+ # properties.annual_cycle_amplitude(simt, group="time.season")
+
+ # with pytest.raises(
+ # ValueError,
+ # match="Grouping period month is not allowed for property",
+ # ):
+ # properties.annual_cycle_phase(simt, group="time.month")
+
+ assert amp.long_name.startswith("Absolute amplitude of the annual cycle")
+ assert phase.long_name.startswith("Phase of the annual cycle")
+ assert amp.units == "delta_degC"
+ assert relamp.units == "%"
+ assert phase.units == ""
+
+ def test_annual_range(self, open_dataset):
+ simt = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1952"), location="Vancouver")
+ .tasmax
+ ).load()
+
+ # Initial annual cycle was this with window = 1
+ amp = properties.mean_annual_range(simt, window=1)
+ relamp = properties.mean_annual_relative_range(simt, window=1)
+ phase = properties.mean_annual_phase(simt, window=1)
+
+ np.testing.assert_allclose(
+ [amp.values, relamp.values, phase.values],
+ [34.039806, 11.793684020675501, 165.33333333333334],
+ )
+
+ amp = properties.mean_annual_range(simt)
+ relamp = properties.mean_annual_relative_range(simt)
+ phase = properties.mean_annual_phase(simt)
+
+ np.testing.assert_array_almost_equal(
+ [amp.values, relamp.values, phase.values],
+ [18.715261, 6.480101, 181.6666667],
+ )
+ # FIXME
+ # with pytest.raises(
+ # ValueError,
+ # match="Grouping period season is not allowed for property",
+ # ):
+ # properties.mean_annual_range(simt, group="time.season")
+
+ # with pytest.raises(
+ # ValueError,
+ # match="Grouping period month is not allowed for property",
+ # ):
+ # properties.mean_annual_phase(simt, group="time.month")
+
+ assert amp.long_name.startswith("Average annual absolute amplitude")
+ assert phase.long_name.startswith("Average annual phase")
+ assert amp.units == "delta_degC"
+ assert relamp.units == "%"
+ assert phase.units == ""
+
+ def test_corr_btw_var(self, open_dataset):
+ simt = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1952"), location="Vancouver")
+ .tasmax
+ ).load()
+
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1952"), location="Vancouver")
+ .pr
+ ).load()
+
+ pc = properties.corr_btw_var(simt, sim, corr_type="Pearson")
+ pp = properties.corr_btw_var(
+ simt, sim, corr_type="Pearson", output="pvalue"
+ ).values
+ sc = properties.corr_btw_var(simt, sim).values
+ sp = properties.corr_btw_var(simt, sim, output="pvalue").values
+ sc_jan = (
+ properties.corr_btw_var(simt, sim, group="time.month").sel(month=1).values
+ )
+ sim[0] = np.nan
+ pc_nan = properties.corr_btw_var(sim, simt, corr_type="Pearson").values
+
+ np.testing.assert_array_almost_equal(
+ [pc.values, pp, sc, sp, sc_jan, pc_nan],
+ [
+ -0.20849051347480407,
+ 3.2160438749049577e-12,
+ -0.3449358561881698,
+ 5.97619379511559e-32,
+ 0.28329503745038936,
+ -0.2090292,
+ ],
+ )
+ assert pc.long_name == "Pearson correlation coefficient"
+ assert pc.units == ""
+
+ with pytest.raises(
+ ValueError,
+ match="pear is not a valid type. Choose 'Pearson' or 'Spearman'.",
+ ):
+ properties.corr_btw_var(sim, simt, group="time", corr_type="pear")
+
+ def test_relative_frequency(self, open_dataset):
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1952"), location="Vancouver")
+ .pr
+ ).load()
+
+ test = properties.relative_frequency(sim, thresh="2.8925e-04 kg/m^2/s", op=">=")
+ testjan = (
+ properties.relative_frequency(
+ sim, thresh="2.8925e-04 kg/m^2/s", op=">=", group="time.month"
+ )
+ .sel(month=1)
+ .values
+ )
+ np.testing.assert_array_almost_equal(
+ [test.values, testjan], [0.0045662100456621, 0.010752688172043012]
+ )
+ assert test.long_name == "Relative frequency of values >= 2.8925e-04 kg/m^2/s."
+ assert test.units == ""
+
+ def test_transition(self, open_dataset):
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1952"), location="Vancouver")
+ .pr
+ ).load()
+
+ test = properties.transition_probability(
+ da=sim, initial_op="<", final_op=">=", thresh="1.157e-05 kg/m^2/s"
+ )
+
+ np.testing.assert_array_almost_equal([test.values], [0.14076782449725778])
+ assert (
+ test.long_name
+ == "Transition probability of values < 1.157e-05 kg/m^2/s to values >= 1.157e-05 kg/m^2/s."
+ )
+ assert test.units == ""
+
+ def test_trend(self, open_dataset):
+ simt = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "1952"), location="Vancouver")
+ .tasmax
+ ).load()
+
+ slope = properties.trend(simt).values
+ intercept = properties.trend(simt, output="intercept").values
+ rvalue = properties.trend(simt, output="rvalue").values
+ pvalue = properties.trend(simt, output="pvalue").values
+ stderr = properties.trend(simt, output="stderr").values
+ intercept_stderr = properties.trend(simt, output="intercept_stderr").values
+
+ np.testing.assert_array_almost_equal(
+ [slope, intercept, rvalue, pvalue, stderr, intercept_stderr],
+ [
+ -0.133711111111111,
+ 288.762132222222222,
+ -0.9706433333333333,
+ 0.1546344444444444,
+ 0.033135555555555,
+ 0.042776666666666,
+ ],
+ 4,
+ )
+
+ slope = properties.trend(simt, group="time.month").sel(month=1)
+ intercept = (
+ properties.trend(simt, output="intercept", group="time.month")
+ .sel(month=1)
+ .values
+ )
+ rvalue = (
+ properties.trend(simt, output="rvalue", group="time.month")
+ .sel(month=1)
+ .values
+ )
+ pvalue = (
+ properties.trend(simt, output="pvalue", group="time.month")
+ .sel(month=1)
+ .values
+ )
+ stderr = (
+ properties.trend(simt, output="stderr", group="time.month")
+ .sel(month=1)
+ .values
+ )
+ intercept_stderr = (
+ properties.trend(simt, output="intercept_stderr", group="time.month")
+ .sel(month=1)
+ .values
+ )
+
+ np.testing.assert_array_almost_equal(
+ [slope.values, intercept, rvalue, pvalue, stderr, intercept_stderr],
+ [
+ 0.8254511111111111,
+ 281.76353222222222,
+ 0.576843333333333,
+ 0.6085644444444444,
+ 1.1689105555555555,
+ 1.509056666666666,
+ ],
+ 4,
+ )
+
+ assert slope.long_name.startswith("Slope of the interannual linear trend")
+ assert slope.units == "K/year"
+
+ def test_return_value(self, open_dataset):
+ simt = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1950", "2010"), location="Vancouver")
+ .tasmax
+ ).load()
+
+ out_y = properties.return_value(simt)
+
+ out_djf = (
+ properties.return_value(simt, op="min", group="time.season")
+ .sel(season="DJF")
+ .values
+ )
+
+ np.testing.assert_array_almost_equal(
+ [out_y.values, out_djf], [313.154, 278.072], 3
+ )
+ assert out_y.long_name.startswith("20-year maximal return level")
+
+ @pytest.mark.slow
+ def test_spatial_correlogram(self, open_dataset):
+ # This also tests sdba.utils._pairwise_spearman and sdba.nbutils._pairwise_haversine_and_bins
+ # Test 1, does it work with 1D data?
+ sim = (
+ open_dataset("sdba/CanESM2_1950-2100.nc")
+ .sel(time=slice("1981", "2010"))
+ .tasmax
+ ).load()
+
+ out = properties.spatial_correlogram(sim, dims=["location"], bins=3)
+ np.testing.assert_allclose(out, [-1, np.nan, 0], atol=1e-6)
+
+ # Test 2, not very exhaustive, this is more of a detect-if-we-break-it test.
+ sim = open_dataset("NRCANdaily/nrcan_canada_daily_tasmax_1990.nc").tasmax
+ out = properties.spatial_correlogram(
+ sim.isel(lon=slice(0, 50)), dims=["lon", "lat"], bins=20
+ )
+ np.testing.assert_allclose(
+ out[:5],
+ [0.95099902, 0.83028772, 0.66874473, 0.48893958, 0.30915054],
+ )
+ np.testing.assert_allclose(
+ out.distance[:5],
+ [26.543199, 67.716227, 108.889254, 150.062282, 191.23531],
+ rtol=5e-07,
+ )
+
+ @pytest.mark.slow
+ def test_decorrelation_length(self, open_dataset):
+ sim = (
+ open_dataset("NRCANdaily/nrcan_canada_daily_tasmax_1990.nc")
+ .tasmax.isel(lon=slice(0, 5), lat=slice(0, 1))
+ .load()
+ )
+
+ out = properties.decorrelation_length(
+ sim, dims=["lat", "lon"], bins=10, radius=30
+ )
+ np.testing.assert_allclose(
+ out[0],
+ [4.5, 4.5, 4.5, 4.5, 10.5],
+ )
+
+ # ADAPT? The plan was not to allow mm/d -> kg m-2 s-1 in xsdba
+ # def test_get_measure(self, open_dataset):
+ # sim = (
+ # open_dataset("sdba/CanESM2_1950-2100.nc")
+ # .sel(time=slice("1981", "2010"), location="Vancouver")
+ # .pr
+ # ).load()
+
+ # ref = (
+ # open_dataset("sdba/ahccd_1950-2013.nc")
+ # .sel(time=slice("1981", "2010"), location="Vancouver")
+ # .pr
+ # ).load()
+
+ # sim = convert_units_to(sim, ref)
+ # sim_var = properties.var(sim)
+ # ref_var = properties.var(ref)
+
+ # meas = properties.var.get_measure()(sim_var, ref_var)
+ # np.testing.assert_allclose(meas, [0.408327], rtol=1e-3)