diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9a5f05e..4f0f97c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,5 +1,8 @@ name: xsdba Testing Suite +env: + XCLIM_TESTDATA_BRANCH: v2023.12.14 + on: push: branches: diff --git a/pyproject.toml b/pyproject.toml index c7f89f3..108a79d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ dev = [ "coverage >=7.5.0", "coveralls >=4.0.0", "mypy", + "netcdf4", + "h5", "numpydoc >=1.7.0", "pytest >=8.2.2", "pytest-cov >=5.0.0", @@ -239,8 +241,11 @@ checks = [ "GL01", "GL08", "PR01", + "PR02", # + "PR04", # "PR07", "PR08", + "PR10", # "RT01", "RT03", "SA01", @@ -303,7 +308,8 @@ ignore = [ "N803", "N806", "PTH123", - "S310" + "S310", + "PERF401" # don't force list comprehensions ] preview = true select = [ diff --git a/src/xsdba/__init__.py b/src/xsdba/__init__.py index 108a53d..9554388 100644 --- a/src/xsdba/__init__.py +++ b/src/xsdba/__init__.py @@ -20,7 +20,7 @@ from __future__ import annotations -from . import base, detrending, processing, units, utils +from . import adjustment, base, detrending, processing, testing, units, utils # , adjustment # from . import adjustment, base, detrending, measures, processing, properties, utils diff --git a/src/xsdba/_adjustment.py b/src/xsdba/_adjustment.py new file mode 100644 index 0000000..a038845 --- /dev/null +++ b/src/xsdba/_adjustment.py @@ -0,0 +1,948 @@ +# pylint: disable=no-value-for-parameter +"""# noqa: SS01 +Adjustment Algorithms +===================== + +This file defines the different steps, to be wrapped into the Adjustment objects. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Callable + +import numpy as np +import xarray as xr + +from . import nbutils as nbu +from . import utils as u +from ._processing import _adapt_freq +from .base import Grouper, map_blocks, map_groups +from .detrending import PolyDetrend +from .options import set_options +from .processing import escore, jitter_under_thresh, reordering, standardize +from .units import convert_units_to, units + +# from xclim.indices.stats import _fitfunc_1d + + +def _adapt_freq_hist(ds: xr.Dataset, adapt_freq_thresh: str): + """Adapt frequency of null values of `hist` in order to match `ref`.""" + # ADAPT: Drop context altogether? + # with units.context(infer_context(ds.ref.attrs.get("standard_name"))): + thresh = convert_units_to(adapt_freq_thresh, ds.ref) + dim = ["time"] + ["window"] * ("window" in ds.hist.dims) + return _adapt_freq.func( + xr.Dataset(dict(sim=ds.hist, ref=ds.ref)), thresh=thresh, dim=dim + ).sim_ad + + +@map_groups( + af=[Grouper.PROP, "quantiles"], + hist_q=[Grouper.PROP, "quantiles"], + scaling=[Grouper.PROP], +) +def dqm_train( + ds: xr.Dataset, + *, + dim: str, + kind: str, + quantiles: np.ndarray, + adapt_freq_thresh: str | None = None, + jitter_under_thresh_value: str | None = None, +) -> xr.Dataset: + """Train step on one group. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + ref : training target + hist : training data + dim : str + The dimension along which to compute the quantiles. + kind : str + The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. + quantiles : array-like + The quantiles to compute. + adapt_freq_thresh : str, optional + Threshold for frequency adaptation. See :py:class:`xclim.sdba.processing.adapt_freq` for details. + Default is None, meaning that frequency adaptation is not performed. + jitter_under_thresh_value : str, optional + Threshold under which to add uniform random noise to values, a quantity with units. + Default is None, meaning that jitter under thresh is not performed. + + Returns + ------- + xr.Dataset + The dataset containing the adjustment factors, the quantiles over the training data, and the scaling factor. + """ + ds["hist"] = ( + jitter_under_thresh(ds.hist, jitter_under_thresh_value) + if jitter_under_thresh_value + else ds.hist + ) + ds["hist"] = ( + _adapt_freq_hist(ds, adapt_freq_thresh) if adapt_freq_thresh else ds.hist + ) + + refn = u.apply_correction(ds.ref, u.invert(ds.ref.mean(dim), kind), kind) + histn = u.apply_correction(ds.hist, u.invert(ds.hist.mean(dim), kind), kind) + + ref_q = nbu.quantile(refn, quantiles, dim) + hist_q = nbu.quantile(histn, quantiles, dim) + + af = u.get_correction(hist_q, ref_q, kind) + mu_ref = ds.ref.mean(dim) + mu_hist = ds.hist.mean(dim) + scaling = u.get_correction(mu_hist, mu_ref, kind=kind) + + return xr.Dataset(data_vars=dict(af=af, hist_q=hist_q, scaling=scaling)) + + +@map_groups( + af=[Grouper.PROP, "quantiles"], + hist_q=[Grouper.PROP, "quantiles"], +) +def eqm_train( + ds: xr.Dataset, + *, + dim: str, + kind: str, + quantiles: np.ndarray, + adapt_freq_thresh: str | None = None, + jitter_under_thresh_value: str | None = None, +) -> xr.Dataset: + """EQM: Train step on one group. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + ref : training target + hist : training data + dim : str + The dimension along which to compute the quantiles. + kind : str + The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. + quantiles : array-like + The quantiles to compute. + adapt_freq_thresh : str, optional + Threshold for frequency adaptation. See :py:class:`xclim.sdba.processing.adapt_freq` for details. + Default is None, meaning that frequency adaptation is not performed. + jitter_under_thresh_value : str, optional + Threshold under which to add uniform random noise to values, a quantity with units. + Default is None, meaning that jitter under thresh is not performed. + + Returns + ------- + xr.Dataset + The dataset containing the adjustment factors and the quantiles over the training data. + """ + ds["hist"] = ( + jitter_under_thresh(ds.hist, jitter_under_thresh_value) + if jitter_under_thresh_value + else ds.hist + ) + ds["hist"] = ( + _adapt_freq_hist(ds, adapt_freq_thresh) if adapt_freq_thresh else ds.hist + ) + ref_q = nbu.quantile(ds.ref, quantiles, dim) + hist_q = nbu.quantile(ds.hist, quantiles, dim) + + af = u.get_correction(hist_q, ref_q, kind) + + return xr.Dataset(data_vars=dict(af=af, hist_q=hist_q)) + + +def _npdft_train(ref, hist, rots, quantiles, method, extrap, n_escore, standardize): + r"""Npdf transform to correct a source `hist` into target `ref`. + + Perform a rotation, bias correct `hist` into `ref` with QuantileDeltaMapping, and rotate back. + Do this iteratively over all rotations `rots` and conserve adjustment factors `af_q` in each iteration. + + Notes + ----- + This function expects numpy inputs. The input arrays `ref,hist` are expected to be 2-dimensional arrays with shape: + `(len(nfeature), len(time))`, where `nfeature` is the dimension which is mixed by the multivariate bias adjustment + (e.g. a `multivar` dimension), i.e. `pts_dims[0]` in :py:func:`mbcn_train`. `rots` are rotation matrices with shape + `(len(iterations), len(nfeature), len(nfeature))`. + """ + if standardize: + ref = (ref - np.nanmean(ref, axis=-1, keepdims=True)) / ( + np.nanstd(ref, axis=-1, keepdims=True) + ) + hist = (hist - np.nanmean(hist, axis=-1, keepdims=True)) / ( + np.nanstd(hist, axis=-1, keepdims=True) + ) + af_q = np.zeros((len(rots), ref.shape[0], len(quantiles))) + escores = np.zeros(len(rots)) * np.NaN + if n_escore > 0: + ref_step, hist_step = ( + int(np.ceil(arr.shape[1] / n_escore)) for arr in [ref, hist] + ) + for ii in range(len(rots)): + rot = rots[0] if ii == 0 else rots[ii] @ rots[ii - 1].T + ref, hist = rot @ ref, rot @ hist + # loop over variables + for iv in range(ref.shape[0]): + ref_q, hist_q = nbu._quantile(ref[iv], quantiles), nbu._quantile( + hist[iv], quantiles + ) + af_q[ii, iv] = ref_q - hist_q + af = u._interp_on_quantiles_1D( + u._rank_bn(hist[iv]), + quantiles, + af_q[ii, iv], + method=method, + extrap=extrap, + ) + hist[iv] = hist[iv] + af + if n_escore > 0: + escores[ii] = nbu._escore(ref[:, ::ref_step], hist[:, ::hist_step]) + hist = rots[-1].T @ hist + return af_q, escores + + +def mbcn_train( + ds: xr.Dataset, + rot_matrices: xr.DataArray, + pts_dims: Sequence[str], + quantiles: np.ndarray, + gw_idxs: xr.DataArray, + interp: str, + extrapolation: str, + n_escore: int, +) -> xr.Dataset: + """Npdf transform training. + + Adjusting factors obtained for each rotation in the npdf transform and conserved to be applied in + the adjusting step in :py:func:`mcbn_adjust`. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + ref : training target + hist : training data + rot_matrices : xr.DataArray + The rotation matrices as a 3D array ('iterations', , ), with shape (n_iter, , ). + pts_dims : sequence of str + The name of the "multivariate" dimension and its primed counterpart. Defaults to "multivar", which + is the normal case when using :py:func:`xclim.sdba.base.stack_variables`, and "multivar_prime". + quantiles : array-like + The quantiles to compute. + gw_idxs : xr.DataArray + Indices of the times in each windowed time group. + interp : str + The interpolation method to use. + extrapolation : str + The extrapolation method to use. + n_escore : int + Number of elements to include in the e_score test (0 for all, < 0 to skip). + + Returns + ------- + xr.Dataset + The dataset containing the adjustment factors and the quantiles over the training data + (only the npdf transform of mbcn). + """ + # unpack data + ref = ds.ref + hist = ds.hist + gr_dim = gw_idxs.attrs["group_dim"] + + # npdf training core + af_q_l = [] + escores_l = [] + + # loop over time blocks + for ib in range(gw_idxs[gr_dim].size): + # indices in a given time block + indices = gw_idxs[{gr_dim: ib}].fillna(-1).astype(int).values + ind = indices[indices >= 0] + + # npdft training : multiple rotations on standardized datasets + # keep track of adjustment factors in each rotation for later use + af_q, escores = xr.apply_ufunc( + _npdft_train, + ref[{"time": ind}], + hist[{"time": ind}], + rot_matrices, + quantiles, + input_core_dims=[ + [pts_dims[0], "time"], + [pts_dims[0], "time"], + ["iterations", pts_dims[1], pts_dims[0]], + ["quantiles"], + ], + output_core_dims=[ + ["iterations", pts_dims[1], "quantiles"], + ["iterations"], + ], + dask="parallelized", + output_dtypes=[hist.dtype, hist.dtype], + kwargs={ + "method": interp, + "extrap": extrapolation, + "n_escore": n_escore, + "standardize": True, + }, + vectorize=True, + ) + af_q_l.append(af_q.expand_dims({gr_dim: [ib]})) + escores_l.append(escores.expand_dims({gr_dim: [ib]})) + af_q = xr.concat(af_q_l, dim=gr_dim) + escores = xr.concat(escores_l, dim=gr_dim) + out = xr.Dataset(dict(af_q=af_q, escores=escores)).assign_coords( + {"quantiles": quantiles, gr_dim: gw_idxs[gr_dim].values} + ) + return out + + +def _npdft_adjust(sim, af_q, rots, quantiles, method, extrap): + """Npdf transform adjusting. + + Adjusting factors `af_q` obtained in the training step are applied on the simulated data `sim` at each iterated + rotation, see :py:func:`_npdft_train`. + + This function expects numpy inputs. `sim` can be a 2-d array with shape: `(len(nfeature), len(time))`, or + a 3-d array with shape: `(len(period), len(nfeature), len(time))`, allowing to adjust multiple climatological periods + all at once. `nfeature` is the dimension which is mixed by the multivariate bias adjustment + (e.g. a `multivar` dimension), i.e. `pts_dims[0]` in :py:func:`mbcn_train`. `rots` are rotation matrices with shape + `(len(iterations), len(nfeature), len(nfeature))`. + """ + # add dummy dim if period_dim absent to uniformize the function below + # This could be done at higher level, not sure where is best + if dummy_dim_added := (len(sim.shape) == 2): + sim = sim[:, np.newaxis, :] + + # adjust npdft + for ii in range(len(rots)): + rot = rots[0] if ii == 0 else rots[ii] @ rots[ii - 1].T + sim = np.einsum("ij,j...->i...", rot, sim) + # loop over variables + for iv in range(sim.shape[0]): + af = u._interp_on_quantiles_1D_multi( + u._rank_bn(sim[iv], axis=-1), + quantiles, + af_q[ii, iv], + method=method, + extrap=extrap, + ) + sim[iv] = sim[iv] + af + + rot = rots[-1].T + sim = np.einsum("ij,j...->i...", rot, sim) + if dummy_dim_added: + sim = sim[:, 0, :] + + return sim + + +def mbcn_adjust( + ref: xr.Dataset, + hist: xr.Dataset, + sim: xr.Dataset, + ds: xr.Dataset, + pts_dims: Sequence[str], + interp: str, + extrapolation: str, + base: Callable, + base_kws_vars: dict, + adj_kws: dict, + period_dim: str | None, +) -> xr.DataArray: + """Perform the adjustment portion MBCn multivariate bias correction technique. + + The function :py:func:`mbcn_train` pre-computes the adjustment factors for each rotation + in the npdf portion of the MBCn algorithm. The rest of adjustment is performed here + in `mbcn_adjust``. + + Parameters + ---------- + ref : xr.DataArray + training target. + hist : xr.DataArray + training data. + sim : xr.DataArray + data to adjust (stacked with multivariate dimension). + ds : xr.Dataset + Dataset variables: + rot_matrices : Rotation matrices used in the training step. + af_q : Adjustment factors obtained in the training step for the npdf transform + g_idxs : Indices of the times in each time group + gw_idxs: Indices of the times in each windowed time group + pts_dims : [str, str] + The name of the "multivariate" dimension and its primed counterpart. Defaults to "multivar", which + is the normal case when using :py:func:`xclim.sdba.base.stack_variables`, and "multivar_prime". + interp : str + Interpolation method for the npdf transform (same as in the training step). + extrapolation : str + Extrapolation method for the npdf transform (same as in the training step). + base : BaseAdjustment + Bias-adjustment class used for the univariate bias correction. + base_kws_vars : Dict + Options for univariate training for the scenario that is reordered with the output of npdf transform. + The arguments are those expected by TrainAdjust classes along with + - kinds : Dict of correction kinds for each variable (e.g. {"pr":"*", "tasmax":"+"}). + adj_kws : Dict + Options for univariate adjust for the scenario that is reordered with the output of npdf transform. + period_dim : str, optional + Name of the period dimension used when stacking time periods of `sim` using :py:func:`xclim.core.calendar.stack_periods`. + If specified, the interpolation of the npdf transform is performed only once and applied on all periods simultaneously. + This should be more performant, but also more memory intensive. Defaults to `None`: No optimization will be attempted. + + Returns + ------- + xr.Dataset + The adjusted data. + """ + # unpacking training parameters + rot_matrices = ds.rot_matrices + af_q = ds.af_q + quantiles = af_q.quantiles + g_idxs = ds.g_idxs + gw_idxs = ds.gw_idxs + gr_dim = gw_idxs.attrs["group_dim"] + win = gw_idxs.attrs["group"][1] + + # this way of handling was letting open the possibility to perform + # interpolation for multiple periods in the simulation all at once + # in principle, avoiding redundancy. Need to test this on small data + # to confirm it works, and on big data to check performance. + dims = ["time"] if period_dim is None else [period_dim, "time"] + + # mbcn core + scen_mbcn = xr.zeros_like(sim) + for ib in range(gw_idxs[gr_dim].size): + # indices in a given time block (with and without the window) + indices_gw = gw_idxs[{gr_dim: ib}].fillna(-1).astype(int).values + ind_gw = indices_gw[indices_gw >= 0] + indices_g = g_idxs[{gr_dim: ib}].fillna(-1).astype(int).values + ind_g = indices_g[indices_g >= 0] + + # 1. univariate adjustment of sim -> scen + # the kind may differ depending on the variables + scen_block = xr.zeros_like(sim[{"time": ind_gw}]) + for iv, v in enumerate(sim[pts_dims[0]].values): + sl = {"time": ind_gw, pts_dims[0]: iv} + with set_options(sdba_extra_output=False): + ADJ = base.train( + ref[sl], hist[sl], **base_kws_vars[v], skip_input_checks=True + ) + scen_block[{pts_dims[0]: iv}] = ADJ.adjust( + sim[sl], **adj_kws, skip_input_checks=True + ) + + # 2. npdft adjustment of sim + npdft_block = xr.apply_ufunc( + _npdft_adjust, + standardize(sim[{"time": ind_gw}].copy(), dim="time")[0], + af_q[{gr_dim: ib}], + rot_matrices, + quantiles, + input_core_dims=[ + [pts_dims[0]] + dims, + ["iterations", pts_dims[1], "quantiles"], + ["iterations", pts_dims[1], pts_dims[0]], + ["quantiles"], + ], + output_core_dims=[ + [pts_dims[0]] + dims, + ], + dask="parallelized", + output_dtypes=[sim.dtype], + kwargs={"method": interp, "extrap": extrapolation}, + vectorize=True, + ) + + # 3. reorder scen according to npdft results + reordered = reordering(ref=npdft_block, sim=scen_block) + if win > 1: + # keep central value of window (intersecting indices in gw_idxs and g_idxs) + scen_mbcn[{"time": ind_g}] = reordered[{"time": np.in1d(ind_gw, ind_g)}] + else: + scen_mbcn[{"time": ind_g}] = reordered + + return scen_mbcn.to_dataset(name="scen") + + +@map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[]) +def qm_adjust( + ds: xr.Dataset, *, group: Grouper, interp: str, extrapolation: str, kind: str +) -> xr.Dataset: + """QM (DQM and EQM): Adjust step on one block. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + af : Adjustment factors + hist_q : Quantiles over the training data + sim : Data to adjust. + group : Grouper + The grouper object. + interp : str + The interpolation method to use. + extrapolation : str + The extrapolation method to use. + kind : str + The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. + + Returns + ------- + xr.Dataset + The adjusted data. + """ + af = u.interp_on_quantiles( + ds.sim, + ds.hist_q, + ds.af, + group=group, + method=interp, + extrapolation=extrapolation, + ) + + scen: xr.DataArray = u.apply_correction(ds.sim, af, kind).rename("scen") + out = scen.to_dataset() + return out + + +@map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[], trend=[]) +def dqm_adjust( + ds: xr.Dataset, + *, + group: Grouper, + interp: str, + kind: str, + extrapolation: str, + detrend: int | PolyDetrend, +) -> xr.Dataset: + """DQM adjustment on one block. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + scaling : Scaling factor between ref and hist + af : Adjustment factors + hist_q : Quantiles over the training data + sim : Data to adjust + group : Grouper + The grouper object. + interp : str + The interpolation method to use. + kind : str + The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. + extrapolation : str + The extrapolation method to use. + detrend : int | PolyDetrend + The degree of the polynomial detrending to apply. If 0, no detrending is applied. + + Returns + ------- + xr.Dataset + The adjusted data and the trend. + """ + scaled_sim = u.apply_correction( + ds.sim, + u.broadcast( + ds.scaling, + ds.sim, + group=group, + interp=interp if group.prop != "dayofyear" else "nearest", + ), + kind, + ).assign_attrs({"units": ds.sim.units}) + + if isinstance(detrend, int): + detrending = PolyDetrend(degree=detrend, kind=kind, group=group) + else: + detrending = detrend + + detrending = detrending.fit(scaled_sim) + ds["sim"] = detrending.detrend(scaled_sim) + scen = qm_adjust.func( + ds, + group=group, + interp=interp, + extrapolation=extrapolation, + kind=kind, + ).scen + scen = detrending.retrend(scen) + + out = xr.Dataset({"scen": scen, "trend": detrending.ds.trend}) + return out + + +@map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[], sim_q=[]) +def qdm_adjust(ds: xr.Dataset, *, group, interp, extrapolation, kind) -> xr.Dataset: + """QDM: Adjust process on one block. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + af : Adjustment factors + hist_q : Quantiles over the training data + sim : Data to adjust. + """ + sim_q = group.apply(u.rank, ds.sim, main_only=True, pct=True) + af = u.interp_on_quantiles( + sim_q, + ds.quantiles, + ds.af, + group=group, + method=interp, + extrapolation=extrapolation, + ) + scen = u.apply_correction(ds.sim, af, kind) + return xr.Dataset(dict(scen=scen, sim_q=sim_q)) + + +@map_blocks( + reduces=[Grouper.ADD_DIMS, Grouper.DIM], + af=[Grouper.PROP], + hist_thresh=[Grouper.PROP], +) +def loci_train(ds: xr.Dataset, *, group, thresh) -> xr.Dataset: + """LOCI: Train on one block. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + ref : training target + hist : training data + """ + s_thresh = group.apply( + u.map_cdf, ds.rename(hist="x", ref="y"), y_value=thresh + ).isel(x=0) + sth = u.broadcast(s_thresh, ds.hist, group=group) + ws = xr.where(ds.hist >= sth, ds.hist, np.nan) + wo = xr.where(ds.ref >= thresh, ds.ref, np.nan) + + ms = group.apply("mean", ws, skipna=True) + mo = group.apply("mean", wo, skipna=True) + + # Adjustment factor + af = u.get_correction(ms - s_thresh, mo - thresh, u.MULTIPLICATIVE) + return xr.Dataset({"af": af, "hist_thresh": s_thresh}) + + +@map_blocks(reduces=[Grouper.PROP], scen=[]) +def loci_adjust(ds: xr.Dataset, *, group, thresh, interp) -> xr.Dataset: + """LOCI: Adjust on one block. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + hist_thresh : Hist's equivalent thresh from ref + sim : Data to adjust + """ + sth = u.broadcast(ds.hist_thresh, ds.sim, group=group, interp=interp) + factor = u.broadcast(ds.af, ds.sim, group=group, interp=interp) + with xr.set_options(keep_attrs=True): + scen: xr.DataArray = ( + (factor * (ds.sim - sth) + thresh).clip(min=0).rename("scen") + ) + out = scen.to_dataset() + return out + + +@map_groups(af=[Grouper.PROP]) +def scaling_train(ds: xr.Dataset, *, dim, kind) -> xr.Dataset: + """Scaling: Train on one group. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + ref : training target + hist : training data + """ + mhist = ds.hist.mean(dim) + mref = ds.ref.mean(dim) + af: xr.DataArray = u.get_correction(mhist, mref, kind).rename("af") + out = af.to_dataset() + return out + + +@map_blocks(reduces=[Grouper.PROP], scen=[]) +def scaling_adjust(ds: xr.Dataset, *, group, interp, kind) -> xr.Dataset: + """Scaling: Adjust on one block. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + af : Adjustment factors. + sim : Data to adjust. + """ + af = u.broadcast(ds.af, ds.sim, group=group, interp=interp) + scen: xr.DataArray = u.apply_correction(ds.sim, af, kind).rename("scen") + out = scen.to_dataset() + return out + + +def npdf_transform(ds: xr.Dataset, **kwargs) -> xr.Dataset: + r"""N-pdf transform : Iterative univariate adjustment in random rotated spaces. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + ref : Reference multivariate timeseries + hist : simulated timeseries on the reference period + sim : Simulated timeseries on the projected period. + rot_matrices : Random rotation matrices. + \*\*kwargs + pts_dim : multivariate dimension name + base : Adjustment class + base_kws : Kwargs for initialising the adjustment object + adj_kws : Kwargs of the `adjust` call + n_escore : Number of elements to include in the e_score test (0 for all, < 0 to skip). + + Returns + ------- + xr.Dataset + Dataset with `scenh`, `scens` and `escores` DataArrays, where `scenh` and `scens` are `hist` and `sim` + respectively after adjustment according to `ref`. If `n_escore` is negative, `escores` will be filled with NaNs. + """ + ref = ds.ref.rename(time_hist="time") + hist = ds.hist.rename(time_hist="time") + sim = ds.sim + dim = kwargs["pts_dim"] + + escores = [] + for i, R in enumerate(ds.rot_matrices.transpose("iterations", ...)): + # @ operator stands for matrix multiplication (along named dimensions): x@R = R@x + # @R rotates an array defined over dimension x unto new dimension x'. x@R = x' + refp = ref @ R + histp = hist @ R + simp = sim @ R + + # Perform univariate adjustment in rotated space (x') + ADJ = kwargs["base"].train( + refp, histp, **kwargs["base_kws"], skip_input_checks=True + ) + scenhp = ADJ.adjust(histp, **kwargs["adj_kws"], skip_input_checks=True) + scensp = ADJ.adjust(simp, **kwargs["adj_kws"], skip_input_checks=True) + + # Rotate back to original dimension x'@R = x + # Note that x'@R is a back rotation because the matrix multiplication is now done along x' due to xarray + # operating along named dimensions. + # In normal linear algebra, this is equivalent to taking @R.T, the back rotation. + hist = scenhp @ R + sim = scensp @ R + + # Compute score + if kwargs["n_escore"] >= 0: + escores.append( + escore( + ref, + hist, + dims=(dim, "time"), + N=kwargs["n_escore"], + scale=True, + ).expand_dims(iterations=[i]) + ) + + if kwargs["n_escore"] >= 0: + escores = xr.concat(escores, "iterations") + else: + # All NaN, but with the proper shape. + escores = ( + ref.isel({dim: 0, "time": 0}) * hist.isel({dim: 0, "time": 0}) + ).expand_dims(iterations=ds.iterations) * np.NaN + + return xr.Dataset( + data_vars={ + "scenh": hist.rename(time="time_hist").transpose(*ds.hist.dims), + "scen": sim.transpose(*ds.sim.dims), + "escores": escores, + } + ) + + +# TODO: incorporate xclim.stats +# def _fit_on_cluster(data, thresh, dist, cluster_thresh): +# """Extract clusters on 1D data and fit "dist" on the maximums.""" +# _, _, _, maximums = u.get_clusters_1d(data, thresh, cluster_thresh) +# params = list( +# _fitfunc_1d(maximums - thresh, dist=dist, floc=0, nparams=3, method="ML") +# ) +# # We forced 0, put back thresh. +# params[-2] = thresh +# return params + + +# def _extremes_train_1d(ref, hist, ref_params, *, q_thresh, cluster_thresh, dist, N): +# """Train for method ExtremeValues, only for 1D input along time.""" +# # Find quantile q_thresh +# thresh = ( +# np.quantile(ref[ref >= cluster_thresh], q_thresh) +# + np.quantile(hist[hist >= cluster_thresh], q_thresh) +# ) / 2 + +# # Fit genpareto on cluster maximums on ref (if needed) and hist. +# if np.isnan(ref_params).all(): +# ref_params = _fit_on_cluster(ref, thresh, dist, cluster_thresh) + +# hist_params = _fit_on_cluster(hist, thresh, dist, cluster_thresh) + +# # Find probabilities of extremes according to fitted dist +# Px_ref = dist.cdf(ref[ref >= thresh], *ref_params) +# hist = hist[hist >= thresh] +# Px_hist = dist.cdf(hist, *hist_params) + +# # Find common probabilities range. +# Pmax = min(Px_ref.max(), Px_hist.max()) +# Pmin = max(Px_ref.min(), Px_hist.min()) +# Pcommon = (Px_hist <= Pmax) & (Px_hist >= Pmin) +# Px_hist = Px_hist[Pcommon] + +# # Find values of hist extremes if they followed ref's distribution. +# hist_in_ref = dist.ppf(Px_hist, *ref_params) + +# # Adjustment factors, unsorted +# af = hist_in_ref / hist[Pcommon] +# # sort them in Px order, and pad to have N values. +# order = np.argsort(Px_hist) +# px_hist = np.pad(Px_hist[order], ((0, N - af.size),), constant_values=np.NaN) +# af = np.pad(af[order], ((0, N - af.size),), constant_values=np.NaN) + +# return px_hist, af, thresh + + +# @map_blocks( +# reduces=["time"], px_hist=["quantiles"], af=["quantiles"], thresh=[Grouper.PROP] +# ) +# def extremes_train( +# ds: xr.Dataset, +# *, +# group: Grouper, +# q_thresh: float, +# cluster_thresh: float, +# dist, +# quantiles: np.ndarray, +# ) -> xr.Dataset: +# """Train extremes for a given variable series. + +# Parameters +# ---------- +# ds : xr.Dataset +# Dataset containing the reference and historical data. +# group : Grouper +# The grouper object. +# q_thresh : float +# The quantile threshold to use. +# cluster_thresh : float +# The threshold for clustering. +# dist : Any +# The distribution to fit. +# quantiles : array-like +# The quantiles to compute. + +# Returns +# ------- +# xr.Dataset +# The dataset containing the quantiles, the adjustment factors, and the threshold. +# """ +# px_hist, af, thresh = xr.apply_ufunc( +# _extremes_train_1d, +# ds.ref, +# ds.hist, +# ds.ref_params or np.NaN, +# input_core_dims=[("time",), ("time",), ()], +# output_core_dims=[("quantiles",), ("quantiles",), ()], +# vectorize=True, +# kwargs={ +# "q_thresh": q_thresh, +# "cluster_thresh": cluster_thresh, +# "dist": dist, +# "N": len(quantiles), +# }, +# ) +# # Outputs of map_blocks must have dimensions. +# if not isinstance(thresh, xr.DataArray): +# thresh = xr.DataArray(thresh) +# thresh = thresh.expand_dims(group=[1]) +# return xr.Dataset( +# {"px_hist": px_hist, "af": af, "thresh": thresh}, +# coords={"quantiles": quantiles}, +# ) + + +# def _fit_cluster_and_cdf(data, thresh, dist, cluster_thresh): +# """Fit 1D cluster maximums and immediately compute CDF.""" +# fut_params = _fit_on_cluster(data, thresh, dist, cluster_thresh) +# return dist.cdf(data, *fut_params) + + +# @map_blocks(reduces=["quantiles", Grouper.PROP], scen=[]) +# def extremes_adjust( +# ds: xr.Dataset, +# *, +# group: Grouper, +# frac: float, +# power: float, +# dist, +# interp: str, +# extrapolation: str, +# cluster_thresh: float, +# ) -> xr.Dataset: +# """Adjust extremes to reflect many distribution factors. + +# Parameters +# ---------- +# ds : xr.Dataset +# Dataset containing the reference and historical data. +# group : Grouper +# The grouper object. +# frac : float +# The fraction of the transition function. +# power : float +# The power of the transition function. +# dist : Any +# The distribution to fit. +# interp : str +# The interpolation method to use. +# extrapolation : str +# The extrapolation method to use. +# cluster_thresh : float +# The threshold for clustering. + +# Returns +# ------- +# xr.Dataset +# The dataset containing the adjusted data. +# """ +# # Find probabilities of extremes of fut according to its own cluster-fitted dist. +# px_fut = xr.apply_ufunc( +# _fit_cluster_and_cdf, +# ds.sim, +# ds.thresh, +# input_core_dims=[["time"], []], +# output_core_dims=[["time"]], +# kwargs={"dist": dist, "cluster_thresh": cluster_thresh}, +# vectorize=True, +# ) + +# # Find factors by interpolating from hist probs to fut probs. apply them. +# af = u.interp_on_quantiles( +# px_fut, ds.px_hist, ds.af, method=interp, extrapolation=extrapolation +# ) +# scen = u.apply_correction(ds.sim, af, "*") + +# # Smooth transition function between simulation and scenario. +# transition = ( +# ((ds.sim - ds.thresh) / ((ds.sim.max("time")) - ds.thresh)) / frac +# ) ** power +# transition = transition.clip(0, 1) + +# adjusted: xr.DataArray = (transition * scen) + ((1 - transition) * ds.scen) +# out = adjusted.rename("scen").squeeze("group", drop=True).to_dataset() +# return out diff --git a/src/xsdba/adjustment.py b/src/xsdba/adjustment.py new file mode 100644 index 0000000..f39daca --- /dev/null +++ b/src/xsdba/adjustment.py @@ -0,0 +1,1642 @@ +# pylint: disable=missing-kwoa +"""# noqa: SS01 +Adjustment Methods +================== +""" +from __future__ import annotations + +from inspect import signature +from typing import Any +from warnings import warn + +import numpy as np +import xarray as xr +from xarray.core.dataarray import DataArray + +from xsdba.base import get_calendar +from xsdba.formatting import gen_call_string, update_history +from xsdba.options import OPTIONS, SDBA_EXTRA_OUTPUT, set_options +from xsdba.units import convert_units_to +from xsdba.utils import uses_dask + +from ._adjustment import ( # extremes_adjust,; extremes_train, + dqm_adjust, + dqm_train, + eqm_train, + loci_adjust, + loci_train, + mbcn_adjust, + mbcn_train, + npdf_transform, + qdm_adjust, + qm_adjust, + scaling_adjust, + scaling_train, +) +from .base import Grouper, ParametrizableWithDataset, parse_group +from .processing import grouped_time_indexes +from .utils import ( + ADDITIVE, + best_pc_orientation_full, + best_pc_orientation_simple, + equally_spaced_nodes, + pc_matrix, + rand_rot_matrix, +) + +# from xclim.indices import stats + + +__all__ = [ + "LOCI", + "BaseAdjustment", + "DetrendedQuantileMapping", + "EmpiricalQuantileMapping", + # "ExtremeValues", + "MBCn", + "NpdfTransform", + "PrincipalComponents", + "QuantileDeltaMapping", + "Scaling", +] + + +class BaseAdjustment(ParametrizableWithDataset): + """Base class for adjustment objects. + + Children classes should implement the `train` and / or the `adjust` method. + + This base class defined the basic input and output checks. It should only be used for a real adjustment + if neither `TrainAdjust` nor `Adjust` fit the algorithm. + """ + + _allow_diff_calendars = True + _attribute = "_xclim_adjustment" + + def __init__(self, *args, _trained=False, **kwargs): + if _trained: + super().__init__(*args, **kwargs) + else: + raise ValueError( + "As of xclim 0.29, Adjustment object should be initialized through their `train` or `adjust` methods." + ) + + @classmethod + def _check_inputs(cls, *inputs, group): + """Raise an error if there are chunks along the main dimension. + + Also raises if :py:attr:`BaseAdjustment._allow_diff_calendars` is False and calendars differ. + """ + for inda in inputs: + if uses_dask(inda) and len(inda.chunks[inda.get_axis_num(group.dim)]) > 1: + raise ValueError( + f"Multiple chunks along the main adjustment dimension {group.dim} is not supported." + ) + + # All calendars used by the inputs + calendars = {get_calendar(inda, group.dim) for inda in inputs} + if not cls._allow_diff_calendars and len(calendars) > 1: + raise ValueError( + "Inputs are defined on different calendars," + f" this is not supported for {cls.__name__} adjustment." + ) + + # Check multivariate dimensions + mvcrds = [] + for inda in inputs: + for crd in inda.coords.values(): + if crd.attrs.get("is_variables", False): + mvcrds.append(crd) + if mvcrds and ( + not all(mvcrds[0].equals(mv) for mv in mvcrds[1:]) + or len(mvcrds) != len(inputs) + ): + coords = {mv.name for mv in mvcrds} + raise ValueError( + f"Inputs have different multivariate coordinates: {', '.join(coords)}." + ) + + if group.prop == "dayofyear" and ( + "default" in calendars or "standard" in calendars + ): + warn( + "Strange results could be returned when using `dayofyear` grouping " + "on data defined in the 'proleptic_gregorian' calendar." + ) + + @classmethod + def _harmonize_units(cls, *inputs, target: dict[str] | str | None = None): + """Convert all inputs to the same units. + + If the target unit is not given, the units of the first input are used. + + Returns the converted inputs and the target units. + """ + + def _harmonize_units_multivariate( + *inputs, dim, target: dict[str] | None = None + ): + def _convert_units_to(inda, dim, target): + varss = inda[dim].values + input_units = { + v: inda[dim].attrs["_units"][iv] for iv, v in enumerate(varss) + } + if input_units == target: + return inda + input_standard_names = { + v: inda[dim].attrs["_standard_name"][iv] + for iv, v in enumerate(varss) + } + for iv, v in enumerate(varss): + inda.attrs["units"] = input_units[v] + inda.attrs["standard_name"] = input_standard_names[v] + inda[{dim: iv}] = convert_units_to(inda[{dim: iv}], target[v]) + inda[dim].attrs["_units"][iv] = target[v] + inda.attrs["units"] = "" + inda.attrs.pop("standard_name") + return inda + + if target is None: + if "_units" not in inputs[0][dim].attrs or any( + [u is None for u in inputs[0][dim].attrs["_units"]] + ): + error_msg = ( + "Units are missing in some or all of the stacked variables." + "The dataset stacked with `stack_variables` given as input should include units for every variable." + ) + raise ValueError(error_msg) + + target = { + v: inputs[0][dim].attrs["_units"][iv] + for iv, v in enumerate(inputs[0][dim].values) + } + return ( + _convert_units_to(inda, dim=dim, target=target) for inda in inputs + ), target + + for _dim, _crd in inputs[0].coords.items(): + if _crd.attrs.get("is_variables"): + return _harmonize_units_multivariate(*inputs, dim=_dim, target=target) + + if target is None: + target = inputs[0].units + + return (convert_units_to(inda, target) for inda in inputs), target + + @classmethod + def _train(cls, ref, hist, **kwargs): + raise NotImplementedError() + + def _adjust(self, sim, *args, **kwargs): + raise NotImplementedError() + + +class TrainAdjust(BaseAdjustment): + """Base class for adjustment objects obeying the train-adjust scheme. + + Children classes should implement these methods: + + - ``_train(ref, hist, **kwargs)``, classmethod receiving the training target and data, + returning a training dataset and parameters to store in the object. + + - ``_adjust(sim, **kwargs)``, receiving the projected data and some arguments, + returning the `scen` DataArray. + """ + + _allow_diff_calendars = True + _attribute = "_xclim_adjustment" + _repr_hide_params = ["hist_calendar", "train_units"] + + @classmethod + def train(cls, ref: DataArray, hist: DataArray, **kwargs) -> TrainAdjust: + r"""Train the adjustment object. + + Refer to the class documentation for the algorithm details. + + Parameters + ---------- + ref : DataArray + Training target, usually a reference time series drawn from observations. + hist : DataArray + Training data, usually a model output whose biases are to be adjusted. + \*\*kwargs + Algorithm-specific keyword arguments, see class doc. + """ + kwargs = parse_group(cls._train, kwargs) + skip_checks = kwargs.pop("skip_input_checks", False) + + if not skip_checks: + if "group" in kwargs: + cls._check_inputs(ref, hist, group=kwargs["group"]) + + (ref, hist), train_units = cls._harmonize_units(ref, hist) + else: + train_units = "" + + ds, params = cls._train(ref, hist, **kwargs) + obj = cls( + _trained=True, + hist_calendar=get_calendar(hist), + train_units=train_units, + **params, + ) + obj.set_dataset(ds) + return obj + + def adjust(self, sim: DataArray, *args, **kwargs): + r"""Return bias-adjusted data. + + Refer to the class documentation for the algorithm details. + + Parameters + ---------- + sim : DataArray + Time series to be bias-adjusted, usually a model output. + \*args : xr.DataArray + Other DataArrays needed for the adjustment (usually none). + \*\*kwargs + Algorithm-specific keyword arguments, see class doc. + """ + skip_checks = kwargs.pop("skip_input_checks", False) + if not skip_checks: + if "group" in self: + self._check_inputs(sim, *args, group=self.group) + + (sim, *args), _ = self._harmonize_units(sim, *args, target=self.train_units) + + out = self._adjust(sim, *args, **kwargs) + + if isinstance(out, xr.DataArray): + out = out.rename("scen").to_dataset() + + scen = out.scen + + # Keep attrs + scen.attrs.update(sim.attrs) + for name, crd in sim.coords.items(): + if name in scen.coords: + scen[name].attrs.update(crd.attrs) + params = gen_call_string("", **kwargs)[1:-1] # indexing to remove added ( ) + infostr = f"{self!s}.adjust(sim, {params})" + scen.attrs["history"] = update_history(f"Bias-adjusted with {infostr}", sim) + scen.attrs["bias_adjustment"] = infostr + + _is_multivariate = any( + [_crd.attrs.get("is_variables") for _crd in sim.coords.values()] + ) + if _is_multivariate is False: + scen.attrs["units"] = self.train_units + + if OPTIONS[SDBA_EXTRA_OUTPUT]: + return out + return scen + + def set_dataset(self, ds: xr.Dataset): + """Store an xarray dataset in the `ds` attribute. + + Useful with custom object initialization or if some external processing was performed. + """ + super().set_dataset(ds) + self.ds.attrs["adj_params"] = str(self) + + @classmethod + def _train(cls, ref: DataArray, hist: DataArray, *kwargs): + raise NotImplementedError() + + def _adjust(self, sim, **kwargs): + raise NotImplementedError() + + +class Adjust(BaseAdjustment): + """Adjustment with no intermediate trained object. + + Children classes should implement a `_adjust` classmethod taking as input the three DataArrays + and returning the scen dataset/array. + """ + + @classmethod + def adjust( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + sim: xr.DataArray, + **kwargs, + ) -> xr.Dataset: + r"""Return bias-adjusted data. Refer to the class documentation for the algorithm details. + + Parameters + ---------- + ref : DataArray + Training target, usually a reference time series drawn from observations. + hist : DataArray + Training data, usually a model output whose biases are to be adjusted. + sim : DataArray + Time series to be bias-adjusted, usually a model output. + \*\*kwargs + Algorithm-specific keyword arguments, see class doc. + + Returns + ------- + xr.Dataset + The bias-adjusted Dataset. + """ + kwargs = parse_group(cls._adjust, kwargs) + skip_checks = kwargs.pop("skip_input_checks", False) + + if not skip_checks: + if "group" in kwargs: + cls._check_inputs(ref, hist, sim, group=kwargs["group"]) + + (ref, hist, sim), _ = cls._harmonize_units(ref, hist, sim) + + out: xr.Dataset | xr.DataArray = cls._adjust(ref, hist, sim, **kwargs) + + if isinstance(out, xr.DataArray): + out = out.rename("scen").to_dataset() + + scen = out.scen + + params = ", ".join([f"{k}={v!r}" for k, v in kwargs.items()]) + infostr = f"{cls.__name__}.adjust(ref, hist, sim, {params})" + scen.attrs["history"] = update_history(f"Bias-adjusted with {infostr}", sim) + scen.attrs["bias_adjustment"] = infostr + + _is_multivariate = any( + [_crd.attrs.get("is_variables") for _crd in sim.coords.values()] + ) + if _is_multivariate is False: + scen.attrs["units"] = ref.units + + if OPTIONS[SDBA_EXTRA_OUTPUT]: + return out + return scen + + +class EmpiricalQuantileMapping(TrainAdjust): + """Empirical Quantile Mapping bias-adjustment. + + Adjustment factors are computed between the quantiles of `ref` and `sim`. + Values of `sim` are matched to the corresponding quantiles of `hist` and corrected accordingly. + + .. math:: + + F^{-1}_{ref} (F_{hist}(sim)) + + where :math:`F` is the cumulative distribution function (CDF) and `mod` stands for model data. + + Attributes + ---------- + Train step + + nquantiles : int or 1d array of floats + The number of quantiles to use. Two endpoints at 1e-6 and 1 - 1e-6 will be added. + An array of quantiles [0, 1] can also be passed. Defaults to 20 quantiles. + kind : {'+', '*'} + The adjustment kind, either additive or multiplicative. Defaults to "+". + group : Union[str, Grouper] + The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. + Default is "time", meaning an single adjustment group along dimension "time". + adapt_freq_thresh : str | None + Threshold for frequency adaptation. See :py:class:`xclim.sdba.processing.adapt_freq` for details. + Default is None, meaning that frequency adaptation is not performed. + + Adjust step: + + interp : {'nearest', 'linear', 'cubic'} + The interpolation method to use when interpolating the adjustment factors. Defaults to "nearest". + extrapolation : {'constant', 'nan'} + The type of extrapolation to use. See :py:func:`xclim.sdba.utils.extrapolate_qm` for details. Defaults to "constant". + + References + ---------- + :cite:cts:`sdba-deque_frequency_2007` + """ + + _allow_diff_calendars = False + + @classmethod + def _train( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + *, + nquantiles: int | np.ndarray = 20, + kind: str = ADDITIVE, + group: str | Grouper = "time", + adapt_freq_thresh: str | None = None, + jitter_under_thresh_value: str | None = None, + ) -> tuple[xr.Dataset, dict[str, Any]]: + if np.isscalar(nquantiles): + quantiles = equally_spaced_nodes(nquantiles).astype(ref.dtype) + else: + quantiles = nquantiles.astype(ref.dtype) + + ds = eqm_train( + xr.Dataset({"ref": ref, "hist": hist}), + group=group, + kind=kind, + quantiles=quantiles, + adapt_freq_thresh=adapt_freq_thresh, + jitter_under_thresh_value=jitter_under_thresh_value, + ) + + ds.af.attrs.update( + standard_name="Adjustment factors", + long_name="Quantile mapping adjustment factors", + ) + ds.hist_q.attrs.update( + standard_name="Model quantiles", + long_name="Quantiles of model on the reference period", + ) + return ds, {"group": group, "kind": kind} + + def _adjust(self, sim, interp="nearest", extrapolation="constant"): + return qm_adjust( + xr.Dataset({"af": self.ds.af, "hist_q": self.ds.hist_q, "sim": sim}), + group=self.group, + interp=interp, + extrapolation=extrapolation, + kind=self.kind, + ).scen + + +class DetrendedQuantileMapping(TrainAdjust): + r"""Detrended Quantile Mapping bias-adjustment. + + The algorithm follows these steps, 1-3 being the 'train' and 4-6, the 'adjust' steps. + + 1. A scaling factor that would make the mean of `hist` match the mean of `ref` is computed. + 2. `ref` and `hist` are normalized by removing the "dayofyear" mean. + 3. Adjustment factors are computed between the quantiles of the normalized `ref` and `hist`. + 4. `sim` is corrected by the scaling factor, and either normalized by "dayofyear" and detrended group-wise + or directly detrended per "dayofyear", using a linear fit (modifiable). + 5. Values of detrended `sim` are matched to the corresponding quantiles of normalized `hist` and corrected accordingly. + 6. The trend is put back on the result. + + .. math:: + + F^{-1}_{ref}\left\{F_{hist}\left[\frac{\overline{hist}\cdot sim}{\overline{sim}}\right]\right\}\frac{\overline{sim}}{\overline{hist}} + + where :math:`F` is the cumulative distribution function (CDF) and :math:`\overline{xyz}` is the linear trend of the data. + This equation is valid for multiplicative adjustment. Based on the DQM method of :cite:p:`sdba-cannon_bias_2015`. + + Parameters + ---------- + Train step: + + nquantiles : int or 1d array of floats + The number of quantiles to use. See :py:func:`~xclim.sdba.utils.equally_spaced_nodes`. + An array of quantiles [0, 1] can also be passed. Defaults to 20 quantiles. + kind : {'+', '*'} + The adjustment kind, either additive or multiplicative. Defaults to "+". + group : Union[str, Grouper] + The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. + Default is "time", meaning a single adjustment group along dimension "time". + adapt_freq_thresh : str | None + Threshold for frequency adaptation. See :py:class:`xclim.sdba.processing.adapt_freq` for details. + Default is None, meaning that frequency adaptation is not performed. + + Adjust step: + + interp : {'nearest', 'linear', 'cubic'} + The interpolation method to use when interpolating the adjustment factors. Defaults to "nearest". + detrend : int or BaseDetrend instance + The method to use when detrending. If an int is passed, it is understood as a PolyDetrend (polynomial detrending) degree. + Defaults to 1 (linear detrending). + extrapolation : {'constant', 'nan'} + The type of extrapolation to use. See :py:func:`xclim.sdba.utils.extrapolate_qm` for details. Defaults to "constant". + + References + ---------- + :cite:cts:`sdba-cannon_bias_2015` + """ + + _allow_diff_calendars = False + + @classmethod + def _train( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + *, + nquantiles: int | np.ndarray = 20, + kind: str = ADDITIVE, + group: str | Grouper = "time", + adapt_freq_thresh: str | None = None, + jitter_under_thresh_value: str | None = None, + ): + if group.prop not in ["group", "dayofyear"]: + warn( + f"Using DQM with a grouping other than 'dayofyear' is not recommended (received {group.name})." + ) + + if np.isscalar(nquantiles): + quantiles = equally_spaced_nodes(nquantiles).astype(ref.dtype) + else: + quantiles = nquantiles.astype(ref.dtype) + + ds = dqm_train( + xr.Dataset({"ref": ref, "hist": hist}), + group=group, + quantiles=quantiles, + kind=kind, + adapt_freq_thresh=adapt_freq_thresh, + jitter_under_thresh_value=jitter_under_thresh_value, + ) + + ds.af.attrs.update( + standard_name="Adjustment factors", + long_name="Quantile mapping adjustment factors", + ) + ds.hist_q.attrs.update( + standard_name="Model quantiles", + long_name="Quantiles of model on the reference period", + ) + ds.scaling.attrs.update( + standard_name="Scaling factor", + description="Scaling factor making the mean of hist match the one of hist.", + ) + return ds, {"group": group, "kind": kind} + + def _adjust( + self, + sim, + interp="nearest", + extrapolation="constant", + detrend=1, + ): + scen = dqm_adjust( + self.ds.assign(sim=sim), + interp=interp, + extrapolation=extrapolation, + detrend=detrend, + group=self.group, + kind=self.kind, + ).scen + # Detrending needs units. + scen.attrs["units"] = sim.units + return scen + + +class QuantileDeltaMapping(EmpiricalQuantileMapping): + r"""Quantile Delta Mapping bias-adjustment. + + Adjustment factors are computed between the quantiles of `ref` and `hist`. + Quantiles of `sim` are matched to the corresponding quantiles of `hist` and corrected accordingly. + + .. math:: + + sim\frac{F^{-1}_{ref}\left[F_{sim}(sim)\right]}{F^{-1}_{hist}\left[F_{sim}(sim)\right]} + + where :math:`F` is the cumulative distribution function (CDF). This equation is valid for multiplicative adjustment. + The algorithm is based on the "QDM" method of :cite:p:`sdba-cannon_bias_2015`. + + Parameters + ---------- + Train step: + + nquantiles : int or 1d array of floats + The number of quantiles to use. See :py:func:`~xclim.sdba.utils.equally_spaced_nodes`. + An array of quantiles [0, 1] can also be passed. Defaults to 20 quantiles. + kind : {'+', '*'} + The adjustment kind, either additive or multiplicative. Defaults to "+". + group : Union[str, Grouper] + The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. + Default is "time", meaning a single adjustment group along dimension "time". + + Adjust step: + + interp : {'nearest', 'linear', 'cubic'} + The interpolation method to use when interpolating the adjustment factors. Defaults to "nearest". + extrapolation : {'constant', 'nan'} + The type of extrapolation to use. See :py:func:`xclim.sdba.utils.extrapolate_qm` for details. Defaults to "constant". + + Extra diagnostics + ----------------- + In adjustment: + + quantiles : The quantile of each value of `sim`. The adjustment factor is interpolated using this as the "quantile" axis on `ds.af`. + + References + ---------- + :cite:cts:`sdba-cannon_bias_2015` + """ + + def _adjust(self, sim, interp="nearest", extrapolation="constant"): + out = qdm_adjust( + xr.Dataset({"sim": sim, "af": self.ds.af, "hist_q": self.ds.hist_q}), + group=self.group, + interp=interp, + extrapolation=extrapolation, + kind=self.kind, + ) + if OPTIONS[SDBA_EXTRA_OUTPUT]: + out.sim_q.attrs.update(long_name="Group-wise quantiles of `sim`.") + return out + return out.scen + + +# class ExtremeValues(TrainAdjust): +# r"""Adjustment correction for extreme values. + +# The tail of the distribution of adjusted data is corrected according to the bias between the parametric Generalized +# Pareto distributions of the simulated and reference data :cite:p:`sdba-roy_extremeprecip_2023`. The distributions are composed of the +# maximal values of clusters of "large" values. With "large" values being those above `cluster_thresh`. Only extreme +# values, whose quantile within the pool of large values are above `q_thresh`, are re-adjusted. See `Notes`. + +# This adjustment method should be considered experimental and used with care. + +# Parameters +# ---------- +# Train step : + +# cluster_thresh : Quantity (str with units) +# The threshold value for defining clusters. +# q_thresh : float +# The quantile of "extreme" values, [0, 1[. Defaults to 0.95. +# ref_params : xr.DataArray, optional +# Distribution parameters to use instead of fitting a GenPareto distribution on `ref`. + +# Adjust step: + +# scen : DataArray +# This is a second-order adjustment, so the adjust method needs the first-order +# adjusted timeseries in addition to the raw "sim". +# interp : {'nearest', 'linear', 'cubic'} +# The interpolation method to use when interpolating the adjustment factors. Defaults to "linear". +# extrapolation : {'constant', 'nan'} +# The type of extrapolation to use. See :py:func:`~xclim.sdba.utils.extrapolate_qm` for details. Defaults to "constant". +# frac : float +# Fraction where the cutoff happens between the original scen and the corrected one. +# See Notes, ]0, 1]. Defaults to 0.25. +# power : float +# Shape of the correction strength, see Notes. Defaults to 1.0. + +# Notes +# ----- +# Extreme values are extracted from `ref`, `hist` and `sim` by finding all "clusters", i.e. runs of consecutive values +# above `cluster_thresh`. The `q_thresh`th percentile of these values is taken on `ref` and `hist` and becomes +# `thresh`, the extreme value threshold. The maximal value of each cluster, if it exceeds that new threshold, is taken +# and Generalized Pareto distributions are fitted to them, for both `ref` and `hist`. The probabilities associated +# with each of these extremes in `hist` is used to find the corresponding value according to `ref`'s distribution. +# Adjustment factors are computed as the bias between those new extremes and the original ones. + +# In the adjust step, a Generalized Pareto distributions is fitted on the cluster-maximums of `sim` and it is used to +# associate a probability to each extreme, values over the `thresh` compute in the training, without the clustering. +# The adjustment factors are computed by interpolating the trained ones using these probabilities and the +# probabilities computed from `hist`. + +# Finally, the adjusted values (:math:`C_i`) are mixed with the pre-adjusted ones (`scen`, :math:`D_i`) using the +# following transition function: + +# .. math:: + +# V_i = C_i * \tau + D_i * (1 - \tau) + +# Where :math:`\tau` is a function of sim's extreme values (unadjusted, :math:`S_i`) +# and of arguments ``frac`` (:math:`f`) and ``power`` (:math:`p`): + +# .. math:: + +# \tau = \left(\frac{1}{f}\frac{S - min(S)}{max(S) - min(S)}\right)^p + +# Code based on an internal Matlab source and partly ib the `biascorrect_extremes` function of the julia package +# "ClimateTools.jl" :cite:p:`sdba-roy_juliaclimateclimatetoolsjl_2021`. + +# Because of limitations imposed by the lazy computing nature of the dask backend, it +# is not possible to know the number of cluster extremes in `ref` and `hist` at the +# moment the output data structure is created. This is why the code tries to estimate +# that number and usually overestimates it. In the training dataset, this translated +# into a `quantile` dimension that is too large and variables `af` and `px_hist` are +# assigned NaNs on extra elements. This has no incidence on the calculations +# themselves but requires more memory than is useful. + +# References +# ---------- +# :cite:cts:`sdba-roy_juliaclimateclimatetoolsjl_2021` +# :cite:cts:`sdba-roy_extremeprecip_2023` +# """ + +# @classmethod +# def _train( +# cls, +# ref: xr.DataArray, +# hist: xr.DataArray, +# *, +# cluster_thresh: str, +# ref_params: xr.Dataset | None = None, +# q_thresh: float = 0.95, +# ): +# cluster_thresh = convert_units_to(cluster_thresh, ref, context="infer") + +# # Approximation of how many "quantiles" values we will get: +# N = (1 - q_thresh) * ref.time.size * 1.05 # extra padding for safety + +# # ref_params: cast nan to f32 not to interfere with map_blocks dtype parsing +# # ref and hist are f32, we want to have f32 in the output. +# ds = extremes_train( +# xr.Dataset( +# { +# "ref": ref, +# "hist": hist, +# "ref_params": ref_params or np.float32(np.NaN), +# } +# ), +# q_thresh=q_thresh, +# cluster_thresh=cluster_thresh, +# dist=stats.get_dist("genpareto"), +# quantiles=np.arange(int(N)), +# group="time", +# ) + +# ds.px_hist.attrs.update( +# long_name="Probability of extremes in hist", +# description="Parametric probabilities of extremes in the common domain of hist and ref.", +# ) +# ds.af.attrs.update( +# long_name="Extremes adjustment factor", +# description="Multiplicative adjustment factor of extremes from hist to ref.", +# ) +# ds.thresh.attrs.update( +# long_name=f"{q_thresh * 100}th percentile extreme value threshold", +# description=f"Mean of the {q_thresh * 100}th percentile of large values (x > {cluster_thresh}) of ref and hist.", +# ) + +# return ds.drop_vars(["quantiles"]), {"cluster_thresh": cluster_thresh} + +# def _adjust( +# self, +# sim: xr.DataArray, +# scen: xr.DataArray, +# *, +# frac: float = 0.25, +# power: float = 1.0, +# interp: str = "linear", +# extrapolation: str = "constant", +# ): +# # Quantiles coord : cheat and assign 0 - 1, so we can use `extrapolate_qm`. +# ds = self.ds.assign( +# quantiles=(np.arange(self.ds.quantiles.size) + 1) +# / (self.ds.quantiles.size + 1) +# ) + +# scen = extremes_adjust( +# ds.assign(sim=sim, scen=scen), +# cluster_thresh=self.cluster_thresh, +# dist=stats.get_dist("genpareto"), +# frac=frac, +# power=power, +# interp=interp, +# extrapolation=extrapolation, +# group="time", +# ) + +# return scen + + +class LOCI(TrainAdjust): + r"""Local Intensity Scaling (LOCI) bias-adjustment. + + This bias adjustment method is designed to correct daily precipitation time series by considering wet and dry days + separately :cite:p:`sdba-schmidli_downscaling_2006`. + + Multiplicative adjustment factors are computed such that the mean of `hist` matches the mean of `ref` for values + above a threshold. + + The threshold on the training target `ref` is first mapped to `hist` by finding the quantile in `hist` having the same + exceedance probability as thresh in `ref`. The adjustment factor is then given by + + .. math:: + + s = \frac{\left \langle ref: ref \geq t_{ref} \right\rangle - t_{ref}}{\left \langle hist : hist \geq t_{hist} \right\rangle - t_{hist}} + + In the case of precipitations, the adjustment factor is the ratio of wet-days intensity. + + For an adjustment factor `s`, the bias-adjustment of `sim` is: + + .. math:: + + sim(t) = \max\left(t_{ref} + s \cdot (hist(t) - t_{hist}), 0\right) + + Attributes + ---------- + Train step: + + group : Union[str, Grouper] + The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. + Default is "time", meaning a single adjustment group along dimension "time". + thresh : str + The threshold in `ref` above which the values are scaled. + + Adjust step: + + interp : {'nearest', 'linear', 'cubic'} + The interpolation method to use then interpolating the adjustment factors. Defaults to "linear". + + References + ---------- + :cite:cts:`sdba-schmidli_downscaling_2006` + """ + + _allow_diff_calendars = False + + @classmethod + def _train( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + *, + thresh: str, + group: str | Grouper = "time", + ): + thresh = convert_units_to(thresh, ref) + ds = loci_train( + xr.Dataset({"ref": ref, "hist": hist}), group=group, thresh=thresh + ) + ds.af.attrs.update(long_name="LOCI adjustment factors") + ds.hist_thresh.attrs.update(long_name="Threshold over modeled data") + return ds, {"group": group, "thresh": thresh} + + def _adjust(self, sim, interp="linear"): + return loci_adjust( + xr.Dataset( + {"hist_thresh": self.ds.hist_thresh, "af": self.ds.af, "sim": sim} + ), + group=self.group, + thresh=self.thresh, + interp=interp, + ).scen + + +class Scaling(TrainAdjust): + """Scaling bias-adjustment. + + Simple bias-adjustment method scaling variables by an additive or multiplicative factor so that the mean of `hist` + matches the mean of `ref`. + + Parameters + ---------- + Train step: + + group : Union[str, Grouper] + The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. + Default is "time", meaning an single adjustment group along dimension "time". + kind : {'+', '*'} + The adjustment kind, either additive or multiplicative. Defaults to "+". + + Adjust step: + + interp : {'nearest', 'linear', 'cubic'} + The interpolation method to use then interpolating the adjustment factors. Defaults to "nearest". + """ + + _allow_diff_calendars = False + + @classmethod + def _train( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + *, + group: str | Grouper = "time", + kind: str = ADDITIVE, + ): + ds = scaling_train( + xr.Dataset({"ref": ref, "hist": hist}), group=group, kind=kind + ) + ds.af.attrs.update(long_name="Scaling adjustment factors") + return ds, {"group": group, "kind": kind} + + def _adjust(self, sim, interp="nearest"): + return scaling_adjust( + xr.Dataset({"sim": sim, "af": self.ds.af}), + group=self.group, + interp=interp, + kind=self.kind, + ).scen + + +class PrincipalComponents(TrainAdjust): + r"""Principal component adjustment. + + This bias-correction method maps model simulation values to the observation space through principal components + :cite:p:`sdba-hnilica_multisite_2017`. Values in the simulation space (multiple variables, or multiple sites) can be + thought of as coordinate along axes, such as variable, temperature, etc. Principal components (PC) are a + linear combinations of the original variables where the coefficients are the eigenvectors of the covariance matrix. + Values can then be expressed as coordinates along the PC axes. The method makes the assumption that bias-corrected + values have the same coordinates along the PC axes of the observations. By converting from the observation PC space + to the original space, we get bias corrected values. See `Notes` for a mathematical explanation. + + Warnings + -------- + Be aware that *principal components* is meant here as the algebraic operation defining a coordinate system + based on the eigenvectors, not statistical principal component analysis. + + Attributes + ---------- + group : Union[str, Grouper] + The main dimension and grouping information. See Notes. + See :py:class:`xclim.sdba.base.Grouper` for details. + The adjustment will be performed on each group independently. + Default is "time", meaning a single adjustment group along dimension "time". + best_orientation : {'simple', 'full'} + Which method to use when searching for the best principal component orientation. + See :py:func:`~xclim.sdba.utils.best_pc_orientation_simple` and + :py:func:`~xclim.sdba.utils.best_pc_orientation_full`. + "full" is more precise, but it is much slower. + crd_dim : str + The data dimension along which the multiple simulation space dimensions are taken. + For a multivariate adjustment, this usually is "multivar", as returned by `sdba.stack_variables`. + For a multisite adjustment, this should be the spatial dimension. + The training algorithm currently doesn't support any chunking + along either `crd_dim`. `group.dim` and `group.add_dims`. + + Notes + ----- + The input data is understood as a set of N points in a :math:`M`-dimensional space. + + - :math:`M` is taken along `crd_dim`. + + - :math:`N` is taken along the dimensions given through `group` : (the main `dim` but also, if requested, the `add_dims` and `window`). + + The principal components (PC) of `hist` and `ref` are used to defined new coordinate systems, centered on their + respective means. The training step creates a matrix defining the transformation from `hist` to `ref`: + + .. math:: + + scen = e_{R} + \mathrm{\mathbf{T}}(sim - e_{H}) + + Where: + + .. math:: + + \mathrm{\mathbf{T}} = \mathrm{\mathbf{R}}\mathrm{\mathbf{H}}^{-1} + + :math:`\mathrm{\mathbf{R}}` is the matrix transforming from the PC coordinates computed on `ref` to the data + coordinates. Similarly, :math:`\mathrm{\mathbf{H}}` is transform from the `hist` PC to the data coordinates + (:math:`\mathrm{\mathbf{H}}` is the inverse transformation). :math:`e_R` and :math:`e_H` are the centroids of the + `ref` and `hist` distributions respectively. Upon running the `adjust` step, one may decide to use :math:`e_S`, + the centroid of the `sim` distribution, instead of :math:`e_H`. + + References + ---------- + :cite:cts:`sdba-hnilica_multisite_2017,sdba-alavoine_distinct_2022` + """ + + @classmethod + def _train( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + *, + crd_dim: str, + best_orientation: str = "simple", + group: str | Grouper = "time", + ): + all_dims = set(ref.dims + hist.dims) + + # Dimension name for the "points" + lblP = xr.core.utils.get_temp_dimname(all_dims, "points") + + # Rename coord on ref, multiindex do not like conflicting coordinates names + lblM = crd_dim + lblR = xr.core.utils.get_temp_dimname(ref.dims, lblM + "_out") + ref = ref.rename({lblM: lblR}) + + # The real thing, acting on 2D numpy arrays + def _compute_transform_matrix(reference, historical): + """Return the transformation matrix converting simulation coordinates to observation coordinates.""" + # Get transformation matrix from PC coords to ref, dropping points with a NaN coord. + ref_na = np.isnan(reference).any(axis=0) + R = pc_matrix(reference[:, ~ref_na]) + # Get transformation matrix from PC coords to hist, dropping points with a NaN coord. + hist_na = np.isnan(historical).any(axis=0) + H = pc_matrix(historical[:, ~hist_na]) + # This step needs vectorize with dask, but vectorize doesn't work with dask, argh. + # Invert to get transformation matrix from hist to PC coords. + Hinv = np.linalg.inv(H) + # Fancy tricks to choose the best orientation on each axis. + # (using eigenvectors, the output axes orientation is undefined) + if best_orientation == "simple": + orient = best_pc_orientation_simple(R, Hinv) + elif best_orientation == "full": + orient = best_pc_orientation_full( + R, Hinv, reference.mean(axis=1), historical.mean(axis=1), historical + ) + else: + raise ValueError( + f"Unknown `best_orientation` method: {best_orientation}." + ) + # Get transformation matrix + return (R * orient) @ Hinv + + # The group wrapper + def _compute_transform_matrices(ds, dim): + """Apply `_compute_transform_matrix` along dimensions other than time and the variables to map.""" + # The multiple PC-space dimensions are along lblR and lblM + # Matrix multiplication in xarray behaves as a dot product across + # same-name dimensions, instead of reducing according to the dimension order, + # as in numpy or normal maths. + if len(dim) > 1: + reference = ds.ref.stack({lblP: dim}) + historical = ds.hist.stack({lblP: dim}) + else: + reference = ds.ref.rename({dim[0]: lblP}) + historical = ds.hist.rename({dim[0]: lblP}) + transformation = xr.apply_ufunc( + _compute_transform_matrix, + reference, + historical, + input_core_dims=[[lblR, lblP], [lblM, lblP]], + output_core_dims=[[lblR, lblM]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + ) + return transformation + + # Transformation matrix, from model coords to ref coords. + trans = group.apply(_compute_transform_matrices, {"ref": ref, "hist": hist}) + trans.attrs.update(long_name="Transformation from training to target spaces.") + + ref_mean = group.apply("mean", ref) # Centroids of ref + ref_mean.attrs.update(long_name="Centroid point of target.") + + hist_mean = group.apply("mean", hist) # Centroids of hist + hist_mean.attrs.update(long_name="Centroid point of training.") + + ds = xr.Dataset(dict(trans=trans, ref_mean=ref_mean, hist_mean=hist_mean)) + + ds.attrs["_reference_coord"] = lblR + ds.attrs["_model_coord"] = lblM + return ds, {"group": group} + + def _adjust(self, sim): + lblR = self.ds.attrs["_reference_coord"] + lblM = self.ds.attrs["_model_coord"] + + vmean = self.group.apply("mean", sim) + + def _compute_adjust(ds, dim): + """Apply the mapping transformation.""" + scenario = ds.ref_mean + ds.trans.dot((ds.sim - ds.vmean), [lblM]) + return scenario + + scen = ( + self.group.apply( + _compute_adjust, + { + "ref_mean": self.ds.ref_mean, + "trans": self.ds.trans, + "sim": sim, + "vmean": vmean, + }, + main_only=True, + ) + .rename({lblR: lblM}) + .rename("scen") + ) + return scen + + +class NpdfTransform(Adjust): + r"""N-dimensional probability density function transform. + + This adjustment object combines both training and adjust steps in the `adjust` class method. + + A multivariate bias-adjustment algorithm described by :cite:t:`sdba-cannon_multivariate_2018`, as part of the MBCn + algorithm, based on a color-correction algorithm described by :cite:t:`sdba-pitie_n-dimensional_2005`. + + This algorithm in itself, when used with QuantileDeltaMapping, is NOT trend-preserving. + The full MBCn algorithm includes a reordering step provided here by :py:func:`xclim.sdba.processing.reordering`. + + See notes for an explanation of the algorithm. + + Parameters + ---------- + base : BaseAdjustment + An univariate bias-adjustment class. This is untested for anything else than QuantileDeltaMapping. + base_kws : dict, optional + Arguments passed to the training of the univariate adjustment. + n_escore : int + The number of elements to send to the escore function. The default, 0, means all elements are included. + Pass -1 to skip computing the escore completely. + Small numbers result in less significant scores, but the execution time goes up quickly with large values. + n_iter : int + The number of iterations to perform. Defaults to 20. + pts_dim : str + The name of the "multivariate" dimension. Defaults to "multivar", which is the + normal case when using :py:func:`xclim.sdba.base.stack_variables`. + adj_kws : dict, optional + Dictionary of arguments to pass to the adjust method of the univariate adjustment. + rot_matrices : xr.DataArray, optional + The rotation matrices as a 3D array ('iterations', , ), with shape (n_iter, , ). + If left empty, random rotation matrices will be automatically generated. + + Notes + ----- + The historical reference (:math:`T`, for "target"), simulated historical (:math:`H`) and simulated projected (:math:`S`) + datasets are constructed by stacking the timeseries of N variables together. The algorithm is broken into the + following steps: + + 1. Rotate the datasets in the N-dimensional variable space with :math:`\mathbf{R}`, a random rotation NxN matrix. + + .. math:: + + \tilde{\mathbf{T}} = \mathbf{T}\mathbf{R} \ + \tilde{\mathbf{H}} = \mathbf{H}\mathbf{R} \ + \tilde{\mathbf{S}} = \mathbf{S}\mathbf{R} + + 2. An univariate bias-adjustment :math:`\mathcal{F}` is used on the rotated datasets. + The adjustments are made in additive mode, for each variable :math:`i`. + + .. math:: + + \hat{\mathbf{H}}_i, \hat{\mathbf{S}}_i = \mathcal{F}\left(\tilde{\mathbf{T}}_i, \tilde{\mathbf{H}}_i, \tilde{\mathbf{S}}_i\right) + + 3. The bias-adjusted datasets are rotated back. + + .. math:: + + \mathbf{H}' = \hat{\mathbf{H}}\mathbf{R} \\ + \mathbf{S}' = \hat{\mathbf{S}}\mathbf{R} + + These three steps are repeated a certain number of times, prescribed by argument ``n_iter``. At each + iteration, a new random rotation matrix is generated. + + The original algorithm :cite:p:`sdba-pitie_n-dimensional_2005`, stops the iteration when some distance score converges. + Following cite:t:`sdba-cannon_multivariate_2018` and the MBCn implementation in :cite:t:`sdba-cannon_mbc_2020`, we + instead fix the number of iterations. + + As done by cite:t:`sdba-cannon_multivariate_2018`, the distance score chosen is the "Energy distance" from + :cite:t:`sdba-szekely_testing_2004`. (see: :py:func:`xclim.sdba.processing.escore`). + + The random matrices are generated following a method laid out by :cite:t:`sdba-mezzadri_how_2007`. + + This is only part of the full MBCn algorithm, see :ref:`notebooks/sdba:Statistical Downscaling and Bias-Adjustment` + for an example on how to replicate the full method with xclim. This includes a standardization of the simulated data + beforehand, an initial univariate adjustment and the reordering of those adjusted series according to the rank + structure of the output of this algorithm. + + References + ---------- + :cite:cts:`sdba-cannon_multivariate_2018,sdba-cannon_mbc_2020,sdba-pitie_n-dimensional_2005,sdba-mezzadri_how_2007,sdba-szekely_testing_2004` + """ + + @classmethod + def _adjust( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + sim: xr.DataArray, + *, + base: TrainAdjust = QuantileDeltaMapping, + base_kws: dict[str, Any] | None = None, + n_escore: int = 0, + n_iter: int = 20, + pts_dim: str = "multivar", + adj_kws: dict[str, Any] | None = None, + rot_matrices: xr.DataArray | None = None, + ) -> xr.Dataset: + if base_kws is None: + base_kws = {} + if "kind" in base_kws: + warn( + f'The adjustment kind cannot be controlled when using {cls.__name__}, it defaults to "+".' + ) + base_kws.setdefault("kind", "+") + + # Assuming sim has the same coords as hist + # We get the safest new name of the rotated dim. + rot_dim = xr.core.utils.get_temp_dimname( + set(ref.dims).union(hist.dims).union(sim.dims), pts_dim + "_prime" + ) + + # Get the rotation matrices + rot_matrices = rot_matrices or rand_rot_matrix( + ref[pts_dim], num=n_iter, new_dim=rot_dim + ).rename(matrices="iterations") + + # Call a map_blocks on the iterative function + # Sadly, this is a bit too complicated for map_blocks, we'll do it by hand. + escores_tmpl = xr.broadcast( + ref.isel({pts_dim: 0, "time": 0}), + hist.isel({pts_dim: 0, "time": 0}), + )[0].expand_dims(iterations=rot_matrices.iterations) + + template = xr.Dataset( + data_vars={ + "scenh": xr.full_like(hist, np.NaN).rename(time="time_hist"), + "scen": xr.full_like(sim, np.NaN), + "escores": escores_tmpl, + } + ) + + # Input data, rename time dim on sim since it can't be aligned with ref or hist. + ds = xr.Dataset( + data_vars={ + "ref": ref.rename(time="time_hist"), + "hist": hist.rename(time="time_hist"), + "sim": sim, + "rot_matrices": rot_matrices, + } + ) + + kwargs = { + "base": base, + "base_kws": base_kws, + "n_escore": n_escore, + "n_iter": n_iter, + "pts_dim": pts_dim, + "adj_kws": adj_kws or {}, + } + + with set_options(sdba_extra_output=False): + out = ds.map_blocks(npdf_transform, template=template, kwargs=kwargs) + + out = out.assign(rotation_matrices=rot_matrices) + out.scenh.attrs["units"] = hist.units + return out + + +class MBCn(TrainAdjust): + r"""Multivariate bias correction function using the N-dimensional probability density function transform. + + A multivariate bias-adjustment algorithm described by :cite:t:`sdba-cannon_multivariate_2018` + based on a color-correction algorithm described by :cite:t:`sdba-pitie_n-dimensional_2005`. + + This algorithm in itself, when used with QuantileDeltaMapping, is NOT trend-preserving. + The full MBCn algorithm includes a reordering step provided here by :py:func:`xclim.sdba.processing.reordering`. + + See notes for an explanation of the algorithm. + + Attributes + ---------- + Train step + + ref : xr.DataArray + Reference dataset. + hist : xr.DataArray + Historical dataset. + base_kws : dict, optional + Arguments passed to the training in the npdf transform. + adj_kws : dict, optional + Arguments passed to the adjusting in the npdf transform. + n_escore : int + The number of elements to send to the escore function. The default, 0, means all elements are included. + Pass -1 to skip computing the escore completely. + Small numbers result in less significant scores, but the execution time goes up quickly with large values. + n_iter : int + The number of iterations to perform. Defaults to 20. + pts_dim : str + The name of the "multivariate" dimension. Defaults to "multivar", which is the + normal case when using :py:func:`xclim.sdba.base.stack_variables`. + rot_matrices: xr.DataArray, optional + The rotation matrices as a 3D array ('iterations', , ), with shape (n_iter, , ). + If left empty, random rotation matrices will be automatically generated. + + Adjust step + + ref : xr.DataArray + Target reference dataset also needed for univariate bias correction preceding npdf transform + hist: xr.DataArray + Source dataset also needed for univariate bias correction preceding npdf transform + sim : xr.DataArray + Source dataset to adjust. + base : BaseAdjustment + Bias-adjustment class used for the univariate bias correction. + base_kws : dict, optional + Arguments passed to the training in the univariate bias correction + adj_kws : dict, optional + Arguments passed to the adjusting in the univariate bias correction + period_dim : str, optional + Name of the period dimension used when stacking time periods of `sim` using :py:func:`xclim.core.calendar.stack_periods`. + If specified, the interpolation of the npdf transform is performed only once and applied on all periods simultaneously. + This should be more performant, but also more memory intensive. + + Training (only npdf transform training) + + 1. Standardize `ref` and `hist` (see ``xclim.sdba.processing.standardize``.) + + 2. Rotate the datasets in the N-dimensional variable space with :math:`\mathbf{R}`, a random rotation NxN matrix. + + .. math:: + + \tilde{\mathbf{T}} = \mathbf{T}\mathbf{R} \ + \tilde{\mathbf{H}} = \mathbf{H}\mathbf{R} + + 3. QuantileDeltaMapping is used to perform bias adjustment :math:`\mathcal{F}` on the rotated datasets. + The adjustment factor is conserved for later use in the adjusting step. The adjustments are made in additive mode, + for each variable :math:`i`. + + .. math:: + + \hat{\mathbf{H}}_i, \hat{\mathbf{S}}_i = \mathcal{F}\left(\tilde{\mathbf{T}}_i, \tilde{\mathbf{H}}_i, \tilde{\mathbf{S}}_i\right) + + 4. The bias-adjusted datasets are rotated back. + + .. math:: + + \mathbf{H}' = \hat{\mathbf{H}}\mathbf{R} \\ + \mathbf{S}' = \hat{\mathbf{S}}\mathbf{R} + + 5. Repeat steps 2,3,4 three steps ``n_iter`` times, i.e. the number of randomly generated rotation matrices. + + Adjusting + + 1. Perform the same steps as in training, with `ref, hist` replaced with `sim`. Step 3. of the training is modified, here we + simply reuse the adjustment factors previously found in the training step to bias correct the standardized `sim` directly. + + 2. Using the original (unstandardized) `ref,hist, sim`, perform a univariate bias adjustment using the ``base_scen`` class + on `sim`. + + 3. Reorder the dataset found in step 2. according to the ranks of the dataset found in step 1. + + The original algorithm :cite:p:`sdba-pitie_n-dimensional_2005`, stops the iteration when some distance score converges. + Following cite:t:`sdba-cannon_multivariate_2018` and the MBCn implementation in :cite:t:`sdba-cannon_mbc_2020`, we + instead fix the number of iterations. + + As done by cite:t:`sdba-cannon_multivariate_2018`, the distance score chosen is the "Energy distance" from + :cite:t:`sdba-szekely_testing_2004`. (see: :py:func:`xclim.sdba.processing.escore`). + + The random matrices are generated following a method laid out by :cite:t:`sdba-mezzadri_how_2007`. + + References + ---------- + :cite:cts:`sdba-cannon_multivariate_2018,sdba-cannon_mbc_2020,sdba-pitie_n-dimensional_2005,sdba-mezzadri_how_2007,sdba-szekely_testing_2004` + + Notes + ----- + * Only "time" and "time.dayofyear" (with a suitable window) are implemented as possible values for `group`. + * The historical reference (:math:`T`, for "target"), simulated historical (:math:`H`) and simulated projected (:math:`S`) + datasets are constructed by stacking the timeseries of N variables together using ``xsdba.base.stack_variables``. + """ + + @classmethod + def _train( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + *, + base_kws: dict[str, Any] | None = None, + adj_kws: dict[str, Any] | None = None, + n_escore: int = -1, + n_iter: int = 20, + pts_dim: str = "multivar", + rot_matrices: xr.DataArray | None = None, + ): + # set default values for non-specified parameters + base_kws = base_kws if base_kws is not None else {} + adj_kws = adj_kws if adj_kws is not None else {} + base_kws.setdefault("nquantiles", 20) + base_kws.setdefault("group", Grouper("time", 1)) + adj_kws.setdefault("interp", "nearest") + adj_kws.setdefault("extrapolation", "constant") + + if np.isscalar(base_kws["nquantiles"]): + base_kws["nquantiles"] = equally_spaced_nodes(base_kws["nquantiles"]) + if isinstance(base_kws["group"], str): + base_kws["group"] = Grouper(base_kws["group"], 1) + if base_kws["group"].name == "time.month": + NotImplementedError( + "Received `group==time.month` in `base_kws`. Monthly grouping is not currently supported in the MBCn class." + ) + # stack variables and prepare rotations + if rot_matrices is not None: + if pts_dim != rot_matrices.attrs["crd_dim"]: + raise ValueError( + f"`crd_dim` attribute of `rot_matrices` ({rot_matrices.attrs['crd_dim']}) does not correspond to `pts_dim` ({pts_dim})." + ) + else: + rot_dim = xr.core.utils.get_temp_dimname( + set(ref.dims).union(hist.dims), pts_dim + "_prime" + ) + rot_matrices = rand_rot_matrix( + ref[pts_dim], num=n_iter, new_dim=rot_dim + ).rename(matrices="iterations") + pts_dims = [rot_matrices.attrs[d] for d in ["crd_dim", "new_dim"]] + + # time indices corresponding to group and windowed group + # used to divide datasets as map_blocks or groupby would do + _, gw_idxs = grouped_time_indexes(ref.time, base_kws["group"]) + + # training, obtain adjustment factors of the npdf transform + ds = xr.Dataset(dict(ref=ref, hist=hist)) + params = { + "quantiles": base_kws["nquantiles"], + "interp": adj_kws["interp"], + "extrapolation": adj_kws["extrapolation"], + "pts_dims": pts_dims, + "n_escore": n_escore, + } + out = mbcn_train(ds, rot_matrices=rot_matrices, gw_idxs=gw_idxs, **params) + params["group"] = base_kws["group"] + + # postprocess + out["rot_matrices"] = rot_matrices + + out.af_q.attrs.update( + standard_name="Adjustment factors", + long_name="Quantile mapping adjustment factors", + ) + return out, params + + def _adjust( + self, + sim: xr.DataArray, + ref: xr.DataArray, + hist: xr.DataArray, + *, + base: TrainAdjust = QuantileDeltaMapping, + base_kws_vars: dict[str, Any] | None = None, + adj_kws: dict[str, Any] | None = None, + period_dim=None, + ): + # set default values for non-specified parameters + base_kws_vars = base_kws_vars or {} + pts_dim = self.pts_dims[0] + for v in sim[pts_dim].values: + base_kws_vars.setdefault(v, {}) + base_kws_vars[v].setdefault("group", self.group) + if isinstance(base_kws_vars[v]["group"], str): + base_kws_vars[v]["group"] = Grouper(base_kws_vars[v]["group"], 1) + if base_kws_vars[v]["group"] != self.group: + raise ValueError( + f"`group` input in _train and _adjust must be the same." + f"Got {self.group} and {base_kws_vars[v]['group']}" + ) + base_kws_vars[v].pop("group") + + base_kws_vars[v].setdefault("nquantiles", self.ds.af_q.quantiles.values) + if np.isscalar(base_kws_vars[v]["nquantiles"]): + base_kws_vars[v]["nquantiles"] = equally_spaced_nodes( + base_kws_vars[v]["nquantiles"] + ) + if "is_variables" in sim[pts_dim].attrs: + if self.train_units == "": + _, units = self._harmonize_units(sim) + else: + units = self.train_units + + if "jitter_under_thresh_value" in base_kws_vars[v]: + base_kws_vars[v]["jitter_under_thresh_value"] = str( + convert_units_to( + base_kws_vars[v]["jitter_under_thresh_value"], + units[v], + ) + ) + if "adapt_freq_thresh" in base_kws_vars[v]: + base_kws_vars[v]["adapt_freq_thresh"] = str( + convert_units_to( + base_kws_vars[v]["adapt_freq_thresh"], + units[v], + ) + ) + + adj_kws = adj_kws or {} + adj_kws.setdefault("interp", self.interp) + adj_kws.setdefault("extrapolation", self.extrapolation) + + g_idxs, gw_idxs = grouped_time_indexes(ref.time, self.group) + ds = self.ds.copy() + ds["g_idxs"] = g_idxs + ds["gw_idxs"] = gw_idxs + + # adjust (adjust for npft transform, train/adjust for univariate bias correction) + out = mbcn_adjust( + ref=ref, + hist=hist, + sim=sim, + ds=ds, + pts_dims=self.pts_dims, + interp=self.interp, + extrapolation=self.extrapolation, + base=base, + base_kws_vars=base_kws_vars, + adj_kws=adj_kws, + period_dim=period_dim, + ) + + return out + + +try: + import SBCK +except ImportError: # noqa: S110 + # SBCK is not installed, we will not generate the SBCK classes. + pass +else: + + class _SBCKAdjust(Adjust): + sbck = None # The method + + @classmethod + def _adjust(cls, ref, hist, sim, *, multi_dim=None, **kwargs): + # Check inputs + fit_needs_sim = "X1" in signature(cls.sbck.fit).parameters + for k, v in signature(cls.sbck.__init__).parameters.items(): + if ( + v.default == v.empty + and v.kind != v.VAR_KEYWORD + and k != "self" + and k not in kwargs + ): + raise ValueError( + f"Argument {k} is not optional for SBCK method {cls.sbck.__name__}." + ) + + ref = ref.rename(time="time_cal") + hist = hist.rename(time="time_cal") + sim = sim.rename(time="time_tgt") + + if multi_dim: + input_core_dims = [ + ("time_cal", multi_dim), + ("time_cal", multi_dim), + ("time_tgt", multi_dim), + ] + else: + input_core_dims = [("time_cal",), ("time_cal",), ("time_tgt",)] + + return xr.apply_ufunc( + cls._apply_sbck, + ref, + hist, + sim, + input_core_dims=input_core_dims, + kwargs={"method": cls.sbck, "fit_needs_sim": fit_needs_sim, **kwargs}, + vectorize=True, + keep_attrs=True, + dask="parallelized", + output_core_dims=[input_core_dims[-1]], + output_dtypes=[sim.dtype], + ).rename(time_tgt="time") + + @staticmethod + def _apply_sbck(ref, hist, sim, method, fit_needs_sim, **kwargs): + obj = method(**kwargs) + if fit_needs_sim: + obj.fit(ref, hist, sim) + else: + obj.fit(ref, hist) + scen = obj.predict(sim) + if sim.ndim == 1: + return scen[:, 0] + return scen + + def _parse_sbck_doc(cls): + def _parse(s): + s = s.replace("\t", " ") + n = min(len(line) - len(line.lstrip()) for line in s.split("\n") if line) + lines = [] + for line in s.split("\n"): + line = line[n:] if line else line + if set(line).issubset({"=", " "}): + line = line.replace("=", "-") + elif set(line).issubset({"-", " "}): + line = line.replace("-", "~") + lines.append(line) + return lines + + return "\n".join( + [ + f"SBCK_{cls.__name__}", + "=" * (5 + len(cls.__name__)), + ( + f"This Adjustment object was auto-generated from the {cls.__name__} " + " object of package SBCK. See :ref:`Experimental wrap of SBCK`." + ), + "", + ( + "The adjust method accepts ref, hist, sim and all arguments listed " + 'below in "Parameters". It also accepts a `multi_dim` argument ' + "specifying the dimension across which to take the 'features' and " + "is valid for multivariate methods only. See :py:func:`xclim.sdba.stack_variables`." + "In the description below, `n_features` is the size of the `multi_dim` " + "dimension. There is no way of specifying parameters across other " + "dimensions for the moment." + ), + "", + *_parse(cls.__doc__), + *_parse(cls.__init__.__doc__), + " Copyright(c) 2021 Yoann Robin.", + ] + ) + + def _generate_SBCK_classes(): # noqa: N802 + classes = [] + for clsname in dir(SBCK): + cls = getattr(SBCK, clsname) + if ( + not clsname.startswith("_") + and isinstance(cls, type) + and hasattr(cls, "fit") + and hasattr(cls, "predict") + ): + doc = _parse_sbck_doc(cls) + classes.append( + type( + f"SBCK_{clsname}", (_SBCKAdjust,), {"sbck": cls, "__doc__": doc} + ) + ) + return classes diff --git a/src/xsdba/base.py b/src/xsdba/base.py index 9734d22..fce999a 100644 --- a/src/xsdba/base.py +++ b/src/xsdba/base.py @@ -8,6 +8,7 @@ import datetime as pydt import itertools from collections.abc import Sequence +from enum import IntEnum from inspect import _empty, signature from typing import Any, Callable, NewType, TypeVar @@ -118,6 +119,99 @@ def set_dataset(self, ds: xr.Dataset) -> None: # XC +class InputKind(IntEnum): + """Constants for input parameter kinds. + + For use by external parses to determine what kind of data the indicator expects. + On the creation of an indicator, the appropriate constant is stored in + :py:attr:`xclim.core.indicator.Indicator.parameters`. The integer value is what gets stored in the output + of :py:meth:`xclim.core.indicator.Indicator.json`. + + For developers : for each constant, the docstring specifies the annotation a parameter of an indice function + should use in order to be picked up by the indicator constructor. Notice that we are using the annotation format + as described in `PEP 604 `_, i.e. with '|' indicating a union and without import + objects from `typing`. + """ + + VARIABLE = 0 + """A data variable (DataArray or variable name). + + Annotation : ``xr.DataArray``. + """ + OPTIONAL_VARIABLE = 1 + """An optional data variable (DataArray or variable name). + + Annotation : ``xr.DataArray | None``. The default should be None. + """ + QUANTIFIED = 2 + """A quantity with units, either as a string (scalar), a pint.Quantity (scalar) or a DataArray (with units set). + + Annotation : ``xclim.core.utils.Quantified`` and an entry in the :py:func:`xclim.core.units.declare_units` + decorator. "Quantified" translates to ``str | xr.DataArray | pint.util.Quantity``. + """ + FREQ_STR = 3 + """A string representing an "offset alias", as defined by pandas. + + See the Pandas documentation on :ref:`timeseries.offset_aliases` for a list of valid aliases. + + Annotation : ``str`` + ``freq`` as the parameter name. + """ + NUMBER = 4 + """A number. + + Annotation : ``int``, ``float`` and unions thereof, potentially optional. + """ + STRING = 5 + """A simple string. + + Annotation : ``str`` or ``str | None``. In most cases, this kind of parameter makes sense + with choices indicated in the docstring's version of the annotation with curly braces. + See :ref:`notebooks/extendxclim:Defining new indices`. + """ + DAY_OF_YEAR = 6 + """A date, but without a year, in the MM-DD format. + + Annotation : :py:obj:`xclim.core.utils.DayOfYearStr` (may be optional). + """ + DATE = 7 + """A date in the YYYY-MM-DD format, may include a time. + + Annotation : :py:obj:`xclim.core.utils.DateStr` (may be optional). + """ + NUMBER_SEQUENCE = 8 + """A sequence of numbers + + Annotation : ``Sequence[int]``, ``Sequence[float]`` and unions thereof, may include single ``int`` and ``float``, + may be optional. + """ + BOOL = 9 + """A boolean flag. + + Annotation : ``bool``, may be optional. + """ + DICT = 10 + """A dictionary. + + Annotation : ``dict`` or ``dict | None``, may be optional. + """ + KWARGS = 50 + """A mapping from argument name to value. + + Developers : maps the ``**kwargs``. Please use as little as possible. + """ + DATASET = 70 + """An xarray dataset. + + Developers : as indices only accept DataArrays, this should only be added on the indicator's constructor. + """ + OTHER_PARAMETER = 99 + """An object that fits None of the previous kinds. + + Developers : This is the fallback kind, it will raise an error in xclim's unit tests if used. + """ + + +# XC def copy_all_attrs(ds: xr.Dataset | xr.DataArray, ref: xr.Dataset | xr.DataArray): """Copy all attributes of ds to ref, including attributes of shared coordinates, and variables in the case of Datasets.""" ds.attrs.update(ref.attrs) @@ -194,7 +288,6 @@ def parse_offset(freq: str) -> tuple[int, str, bool, str | None]: anchor : str, optional Anchor date for bases Y or Q. As xarray doesn't support "W", neither does xclim (anchor information is lost when given). - """ # Useful to raise on invalid freqs, convert Y to A and get default anchor (A, Q) offset = pd.tseries.frequencies.to_offset(freq) @@ -254,49 +347,6 @@ def get_calendar(obj: Any, dim: str = "time") -> str: raise ValueError(f"Calendar could not be inferred from object of type {type(obj)}.") -# XC -def gen_call_string(funcname: str, *args, **kwargs) -> str: - r"""Generate a signature string for use in the history attribute. - - DataArrays and Dataset are replaced with their name, while Nones, floats, ints and strings are printed directly. - All other objects have their type printed between < >. - - Arguments given through positional arguments are printed positionnally and those - given through keywords are printed prefixed by their name. - - Parameters - ---------- - funcname : str - Name of the function - \*args, \*\*kwargs - Arguments given to the function. - - Example - ------- - >>> A = xr.DataArray([1], dims=("x",), name="A") - >>> gen_call_string("func", A, b=2.0, c="3", d=[10] * 100) - "func(A, b=2.0, c='3', d=)" - """ - elements = [] - chain = itertools.chain(zip([None] * len(args), args), kwargs.items()) - for name, val in chain: - if isinstance(val, xr.DataArray): - rep = val.name or "" - elif isinstance(val, (int, float, str, bool)) or val is None: - rep = repr(val) - else: - rep = repr(val) - if len(rep) > 50: - rep = f"<{type(val).__name__}>" - - if name is not None: - rep = f"{name}={rep}" - - elements.append(rep) - - return f"{funcname}({', '.join(elements)})" - - class Grouper(Parametrizable): """Grouper inherited class for parameterizable classes.""" diff --git a/src/xsdba/formatting.py b/src/xsdba/formatting.py index 89f3bbe..de5bbfb 100644 --- a/src/xsdba/formatting.py +++ b/src/xsdba/formatting.py @@ -1,4 +1,4 @@ -""" +"""# noqa: SS01 Formatting Utilities =================================== """ @@ -29,9 +29,9 @@ def merge_attributes( ---------- attribute : str The attribute to merge. - inputs_list : xr.DataArray or xr.Dataset + \*inputs_list : xr.DataArray or xr.Dataset The datasets or variables that were used to produce the new object. - Inputs given that way will be prefixed by their `name` attribute if available. + Inputs given that way will be prefixed by their "name" attribute if available. new_line : str The character to put between each instance of the attributes. Usually, in CF-conventions, the history attributes uses '\\n' while cell_methods uses ' '. @@ -47,9 +47,7 @@ def merge_attributes( str The new attribute made from the combination of the ones from all the inputs. """ - inputs = [] - for in_ds in inputs_list: - inputs.append((getattr(in_ds, "name", None), in_ds)) + inputs = [getattr(in_ds, "name", None) for in_ds in inputs_list] inputs += list(inputs_kws.items()) merged_attr = "" @@ -165,7 +163,11 @@ def _call_and_add_history(*args, **kwargs): # XC -def gen_call_string(funcname: str, *args, **kwargs) -> str: +def gen_call_string( + funcname: str, + *args, + **kwargs, +) -> str: r"""Generate a signature string for use in the history attribute. DataArrays and Dataset are replaced with their name, while Nones, floats, ints and strings are printed directly. @@ -177,9 +179,7 @@ def gen_call_string(funcname: str, *args, **kwargs) -> str: Parameters ---------- funcname : str - Name of the function - \*args, \*\*kwargs - Arguments given to the function. + Name of the function. Example ------- diff --git a/src/xsdba/testing.py b/src/xsdba/testing.py index e96eda3..18fc108 100644 --- a/src/xsdba/testing.py +++ b/src/xsdba/testing.py @@ -16,8 +16,11 @@ import pandas as pd import xarray as xr from platformdirs import user_cache_dir +from scipy.stats import gamma from xarray import open_dataset as _open_dataset +from xsdba.utils import equally_spaced_nodes + __all__ = ["test_timelonlatseries", "test_timeseries"] # keeping xclim-testdata for now, since it's still this on gitHub @@ -51,6 +54,32 @@ SocketBlockedError = None +def test_cannon_2015_dist(): # noqa: D103 + # ref ~ gamma(k=4, theta=7.5) mu: 30, sigma: 15 + ref = gamma(4, scale=7.5) + + # hist ~ gamma(k=8.15, theta=3.68) mu: 30, sigma: 10.5 + hist = gamma(8.15, scale=3.68) + + # sim ~ gamma(k=16, theta=2.63) mu: 42, sigma: 10.5 + sim = gamma(16, scale=2.63) + + return ref, hist, sim + + +def test_cannon_2015_rvs(n, random=True): # noqa: D103 + # Frozen distributions + fd = test_cannon_2015_dist() + + if random: + r = [d.rvs(n) for d in fd] + else: + u = equally_spaced_nodes(n, None) + r = [d.ppf(u) for d in fd] + + return map(lambda x: test_timelonlatseries(x, attrs={"units": "kg/m/m/s"}), r) + + def test_timelonlatseries(values, attrs=None, start="2000-01-01"): """Create a DataArray with time, lon and lat dimensions.""" attrs = {} if attrs is None else attrs diff --git a/src/xsdba/units.py b/src/xsdba/units.py index 08c303e..634d9a0 100644 --- a/src/xsdba/units.py +++ b/src/xsdba/units.py @@ -7,21 +7,24 @@ from copy import deepcopy from functools import wraps +import pint + # this dependency is "necessary" for convert_units_to # if we only do checks, we could get rid of it -import cf_xarray.units + + +try: + # allows to use cf units + import cf_xarray.units +except ImportError: # noqa: S110 + # cf-xarray is not installed, this will not be used + pass import numpy as np -import pint import xarray as xr from .base import Quantified, copy_all_attrs -# shamelessly adapted from `cf-xarray` (which adopted it from MetPy and xclim itself) -units = deepcopy(cf_xarray.units.units) -# Switch this flag back to False. Not sure what that implies, but it breaks some tests. -units.force_ndarray_like = False # noqa: F841 -# Another alias not included by cf_xarray -units.define("@alias percent = pct") +units = pint.get_application_registry() # XC @@ -120,13 +123,17 @@ def str2pint(val: str) -> pint.Quantity: def extract_units(arg): """Extract units from a string, DataArray, or scalar.""" - if not (isinstance(arg, (str, xr.DataArray)) or np.isscalar(arg)): + if not ( + isinstance(arg, (str, xr.DataArray, pint.Unit, units.Unit)) or np.isscalar(arg) + ): print(arg) raise TypeError( f"Argument must be a str, DataArray, or scalar. Got {type(arg)}" ) elif isinstance(arg, xr.DataArray): ustr = None if "units" not in arg.attrs else arg.attrs["units"] + elif isinstance(arg, pint.Unit | units.Unit): + ustr = f"{arg:cf}" # XC: from pint2cfunits elif isinstance(arg, str): ustr = str2pint(arg).units else: # (scalar case) @@ -219,7 +226,7 @@ def convert_units_to( # noqa: C901 out = source.copy(data=units.convert(source.data, source_unit, target_unit)) out = out.assign_attrs(units=target_unit) else: - out = str2pint(source).to(target_unit) + out = str2pint(source).to(target_unit).m return out diff --git a/tests/conftest.py b/tests/conftest.py index 4be3140..2897a0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,12 @@ from xsdba.testing import TESTDATA_BRANCH from xsdba.testing import open_dataset as _open_dataset -from xsdba.testing import test_timelonlatseries, test_timeseries +from xsdba.testing import ( + test_cannon_2015_dist, + test_cannon_2015_rvs, + test_timelonlatseries, + test_timeseries, +) from xsdba.utils import apply_correction, equally_spaced_nodes # import xclim @@ -65,6 +70,49 @@ # ) +@pytest.fixture +def cannon_2015_rvs(): + return test_cannon_2015_rvs + + +@pytest.fixture +def cannon_2015_dist(): + return test_cannon_2015_dist + + +# @pytest.fixture +# def ref_hist_sim_tuto(socket_enabled): # noqa: F841 +# """Return ref, hist, sim time series of air temperature. + +# socket_enabled is a fixture that enables the use of the internet to download the tutorial dataset while the +# `--disable-socket` flag has been called. This fixture will crash if the `air_temperature` tutorial file is +# not on disk while the internet is unavailable. +# """ + +# def _ref_hist_sim_tuto(sim_offset=3, delta=0.1, smth_win=3, trend=True): +# ds = xr.tutorial.open_dataset("air_temperature") +# ref = ds.air.resample(time="D").mean(keep_attrs=True) +# hist = ref.rolling(time=smth_win, min_periods=1).mean(keep_attrs=True) + delta +# hist.attrs["units"] = ref.attrs["units"] +# sim_time = hist.time + np.timedelta64(730 + sim_offset * 365, "D").astype( +# " np.random.Generator: return np.random.default_rng(seed=list(map(ord, "𝕽𝔞𝖓𝔡𝖔𝔪"))) @@ -102,17 +150,17 @@ def mon_triangular(): # XC (name changed) @pytest.fixture -def mon_timelonlatseries(series, mon_triangular): - def _mon_timelonlatseries(values, name): +def mon_timelonlatseries(timelonlatseries, mon_triangular): + def _mon_timelonlatseries(values, attrs): """Random time series whose mean varies over a monthly cycle.""" - x = timelonlatseries(values, name) + x = timelonlatseries(values, attrs) m = mon_triangular - factor = timelonlatseriesseries(m[x.time.dt.month - 1], name) + factor = timelonlatseries(m[x.time.dt.month - 1], attrs) with xr.set_options(keep_attrs=True): return apply_correction(x, factor, x.kind) - return _mon_series + return _mon_timelonlatseries @pytest.fixture @@ -230,14 +278,14 @@ def _is_matplotlib_installed(): # ADAPT or REMOVE? -# @pytest.fixture(scope="function") -# def atmosds(threadsafe_data_dir) -> xr.Dataset: -# return _open_dataset( -# threadsafe_data_dir.joinpath("atmosds.nc"), -# cache_dir=threadsafe_data_dir, -# branch=helpers.TESTDATA_BRANCH, -# engine="h5netcdf", -# ).load() +@pytest.fixture(scope="function") +def atmosds(threadsafe_data_dir) -> xr.Dataset: + return _open_dataset( + threadsafe_data_dir.joinpath("atmosds.nc"), + cache_dir=threadsafe_data_dir, + branch=TESTDATA_BRANCH, + engine="h5netcdf", + ).load() # @pytest.fixture(scope="function") diff --git a/tests/test_adjustment.py b/tests/test_adjustment.py new file mode 100644 index 0000000..ff9479b --- /dev/null +++ b/tests/test_adjustment.py @@ -0,0 +1,884 @@ +# pylint: disable=no-member +from __future__ import annotations + +import numpy as np +import pytest +import xarray as xr +from scipy.stats import genpareto, norm, uniform + +from xsdba import adjustment +from xsdba.adjustment import ( # ExtremeValues, + LOCI, + DetrendedQuantileMapping, + EmpiricalQuantileMapping, + PrincipalComponents, + QuantileDeltaMapping, + Scaling, +) +from xsdba.base import Grouper +from xsdba.options import set_options +from xsdba.processing import ( + jitter_under_thresh, + stack_variables, + uniform_noise_like, + unstack_variables, +) +from xsdba.testing import nancov +from xsdba.units import convert_units_to +from xsdba.utils import ( + ADDITIVE, + MULTIPLICATIVE, + apply_correction, + get_correction, + invert, +) + + +class TestLoci: + @pytest.mark.parametrize("group,dec", (["time", 2], ["time.month", 1])) + def test_time_and_from_ds(self, timelonlatseries, group, dec, tmp_path, random): + n = 10000 + u = random.random(n) + + xd = uniform(loc=0, scale=3) + x = xd.ppf(u) + + attrs = {"units": "kg m-2 s-1", "kind": MULTIPLICATIVE} + + hist = sim = timelonlatseries(x, attrs={"units": "kg m-2 s-1"}) + y = x * 2 + thresh = 2 + ref_fit = timelonlatseries(y, attrs={"units": "kg m-2 s-1"}).where( + y > thresh, 0.1 + ) + ref = timelonlatseries(y, attrs={"units": "kg m-2 s-1"}) + + loci = LOCI.train(ref_fit, hist, group=group, thresh=f"{thresh} kg m-2 s-1") + np.testing.assert_array_almost_equal(loci.ds.hist_thresh, 1, dec) + np.testing.assert_array_almost_equal(loci.ds.af, 2, dec) + + p = loci.adjust(sim) + np.testing.assert_array_almost_equal(p, ref, dec) + + assert "history" in p.attrs + assert "Bias-adjusted with LOCI(" in p.attrs["history"] + + file = tmp_path / "test_loci.nc" + loci.ds.to_netcdf(file) + + ds = xr.open_dataset(file) + loci2 = LOCI.from_dataset(ds) + + xr.testing.assert_equal(loci.ds, loci2.ds) + + p2 = loci2.adjust(sim) + np.testing.assert_array_equal(p, p2) + + # @pytest.mark.requires_internet + # def test_reduce_dims(self, ref_hist_sim_tuto): + # ref, hist, _sim = ref_hist_sim_tuto() + # hist = hist.expand_dims(member=[0, 1]) + # ref = ref.expand_dims(member=hist.member) + # LOCI.train(ref, hist, group="time", thresh="283 K", add_dims=["member"]) + + +@pytest.mark.slow +class TestScaling: + @pytest.mark.parametrize( + "kind,units", [(ADDITIVE, "K"), (MULTIPLICATIVE, "kg m-2 s-1")] + ) + def test_time(self, kind, units, timelonlatseries, random): + n = 10000 + u = random.random(n) + + xd = uniform(loc=2, scale=1) + x = xd.ppf(u) + + attrs = {"units": units, "kind": kind} + + hist = sim = timelonlatseries(x, attrs=attrs) + ref = timelonlatseries(apply_correction(x, 2, kind), attrs=attrs) + if kind == ADDITIVE: + ref = convert_units_to(ref, "degC") + + scaling = Scaling.train(ref, hist, group="time", kind=kind) + np.testing.assert_array_almost_equal(scaling.ds.af, 2) + + p = scaling.adjust(sim) + np.testing.assert_array_almost_equal(p, ref) + + @pytest.mark.parametrize( + "kind,units", [(ADDITIVE, "K"), (MULTIPLICATIVE, "kg m-2 s-1")] + ) + def test_mon_u( + self, + mon_timelonlatseries, + timelonlatseries, + mon_triangular, + kind, + units, + random, + ): + n = 10000 + u = random.random(n) + + xd = uniform(loc=2, scale=1) + x = xd.ppf(u) + + attrs = {"units": units, "kind": kind} + + hist = sim = timelonlatseries(x, attrs=attrs) + ref = mon_timelonlatseries(apply_correction(x, 2, kind), attrs=attrs) + + # Test train + scaling = Scaling.train(ref, hist, group="time.month", kind=kind) + expected = apply_correction(mon_triangular, 2, kind) + np.testing.assert_array_almost_equal(scaling.ds.af, expected) + + # Test predict + p = scaling.adjust(sim) + np.testing.assert_array_almost_equal(p, ref) + + def test_add_dim(self, timelonlatseries, mon_timelonlatseries, random): + n = 10000 + u = random.random((n, 4)) + + xd = uniform(loc=2, scale=1) + x = xd.ppf(u) + units, kind = "K", ADDITIVE + attrs = {"units": units, "kind": kind} + + hist = sim = timelonlatseries(x, attrs=attrs) + ref = mon_timelonlatseries(apply_correction(x, 2, "+"), attrs=attrs) + + group = Grouper("time.month", add_dims=["lon"]) + + scaling = Scaling.train(ref, hist, group=group, kind="+") + assert "lon" not in scaling.ds + p = scaling.adjust(sim) + assert "lon" in p.dims + np.testing.assert_array_almost_equal(p.transpose(*ref.dims), ref) + + +@pytest.mark.slow +class TestDQM: + @pytest.mark.parametrize( + "kind,units", [(ADDITIVE, "K"), (MULTIPLICATIVE, "kg m-2 s-1")] + ) + def test_quantiles(self, timelonlatseries, kind, units, random): + """Train on + hist: U + ref: Normal + + Predict on hist to get ref + """ + ns = 10000 + u = random.random(ns) + + # Define distributions + xd = uniform(loc=10, scale=1) + yd = norm(loc=12, scale=1) + + # Generate random numbers with u so we get exact results for comparison + x = xd.ppf(u) + y = yd.ppf(u) + + # Test train + attrs = {"units": units, "kind": kind} + + hist = sim = timelonlatseries(x, attrs=attrs) + ref = timelonlatseries(y, attrs=attrs) + + DQM = DetrendedQuantileMapping.train( + ref, + hist, + kind=kind, + group="time", + nquantiles=50, + ) + p = DQM.adjust(sim, interp="linear") + + q = DQM.ds.quantiles + ex = apply_correction(xd.ppf(q), invert(xd.mean(), kind), kind) + ey = apply_correction(yd.ppf(q), invert(yd.mean(), kind), kind) + expected = get_correction(ex, ey, kind) + + # Results are not so good at the endpoints + np.testing.assert_array_almost_equal( + DQM.ds.af[:, 2:-2], expected[np.newaxis, 2:-2], 1 + ) + + # Test predict + # Accept discrepancies near extremes + middle = (x > 1e-2) * (x < 0.99) + np.testing.assert_array_almost_equal(p[middle], ref[middle], 1) + + # PB 13-01-21 : This seems the same as the next test. + # Test with sim not equal to hist + # ff = series(np.ones(ns) * 1.1, name) + # sim2 = apply_correction(sim, ff, kind) + # ref2 = apply_correction(ref, ff, kind) + + # p2 = DQM.adjust(sim2, interp="linear") + + # np.testing.assert_array_almost_equal(p2[middle], ref2[middle], 1) + + # Test with actual trend in sim + attrs = {"units": units, "kind": kind} + + trend = timelonlatseries( + np.linspace(-0.2, 0.2, ns) + (1 if kind == MULTIPLICATIVE else 0), + attrs=attrs, + ) + sim3 = apply_correction(sim, trend, kind) + ref3 = apply_correction(ref, trend, kind) + p3 = DQM.adjust(sim3, interp="linear") + np.testing.assert_array_almost_equal(p3[middle], ref3[middle], 1) + + @pytest.mark.parametrize( + "kind,units", [(ADDITIVE, "K"), (MULTIPLICATIVE, "kg m-2 s-1")] + ) + @pytest.mark.parametrize("add_dims", [True, False]) + def test_mon_u( + self, mon_timelonlatseries, timelonlatseries, kind, units, add_dims, random + ): + """ + Train on + hist: U + ref: U + monthly cycle + + Predict on hist to get ref + """ + n = 5000 + u = random.random(n) + + # Define distributions + xd = uniform(loc=2, scale=0.1) + yd = uniform(loc=4, scale=0.1) + noise = uniform(loc=0, scale=1e-7) + + # Generate random numbers + x = xd.ppf(u) + y = yd.ppf(u) + noise.ppf(u) + attrs = {"units": units, "kind": kind} + # Test train + hist, ref = timelonlatseries(x, attrs=attrs), mon_timelonlatseries( + y, attrs=attrs + ) + + trend = np.linspace(-0.2, 0.2, n) + int(kind == MULTIPLICATIVE) + ref_t = mon_timelonlatseries(apply_correction(y, trend, kind), attrs=attrs) + sim = timelonlatseries(apply_correction(x, trend, kind), attrs=attrs) + + if add_dims: + ref = ref.expand_dims(lat=[0, 1, 2]).chunk({"lat": 1}) + hist = hist.expand_dims(lat=[0, 1, 2]).chunk({"lat": 1}) + sim = sim.expand_dims(lat=[0, 1, 2]).chunk({"lat": 1}) + ref_t = ref_t.expand_dims(lat=[0, 1, 2]) + + DQM = DetrendedQuantileMapping.train( + ref, hist, kind=kind, group="time.month", nquantiles=5 + ) + mqm = DQM.ds.af.mean(dim="quantiles") + p = DQM.adjust(sim) + + if add_dims: + mqm = mqm.isel(lat=0) + np.testing.assert_array_almost_equal(mqm, int(kind == MULTIPLICATIVE), 1) + np.testing.assert_allclose(p.transpose(..., "time"), ref_t, rtol=0.1, atol=0.5) + + # def test_cannon_and_from_ds(self, cannon_2015_rvs, tmp_path, random): + # ref, hist, sim = cannon_2015_rvs(15000, random=random) + + # DQM = DetrendedQuantileMapping.train(ref, hist, kind="*", group="time") + # p = DQM.adjust(sim) + + # np.testing.assert_almost_equal(p.mean(), 41.6, 0) + # np.testing.assert_almost_equal(p.std(), 15.0, 0) + + # file = tmp_path / "test_dqm.nc" + # DQM.ds.to_netcdf(file) + + # ds = xr.open_dataset(file) + # DQM2 = DetrendedQuantileMapping.from_dataset(ds) + + # xr.testing.assert_equal(DQM.ds, DQM2.ds) + + # p2 = DQM2.adjust(sim) + # np.testing.assert_array_equal(p, p2) + + +@pytest.mark.slow +class TestQDM: + @pytest.mark.parametrize( + "kind,units", [(ADDITIVE, "K"), (MULTIPLICATIVE, "kg m-2 s-1")] + ) + def test_quantiles(self, timelonlatseries, kind, units, random): + """Train on + x : U(1,1) + y : U(1,2) + + """ + u = random.random(10000) + + # Define distributions + xd = uniform(loc=1, scale=1) + yd = uniform(loc=2, scale=4) + + # Generate random numbers with u so we get exact results for comparison + x = xd.ppf(u) + y = yd.ppf(u) + + # Test train + attrs = {"units": units, "kind": kind} + hist = sim = timelonlatseries(x, attrs=attrs) + ref = timelonlatseries(y, attrs=attrs) + + QDM = QuantileDeltaMapping.train( + ref.astype("float32"), + hist.astype("float32"), + kind=kind, + group="time", + nquantiles=10, + ) + p = QDM.adjust(sim.astype("float32"), interp="linear") + + q = QDM.ds.coords["quantiles"] + expected = get_correction(xd.ppf(q), yd.ppf(q), kind)[np.newaxis, :] + + # Results are not so good at the endpoints + np.testing.assert_array_almost_equal(QDM.ds.af, expected, 1) + + # Test predict + # Accept discrepancies near extremes + middle = (u > 1e-2) * (u < 0.99) + np.testing.assert_array_almost_equal(p[middle], ref[middle], 1) + + # Test dtype control of map_blocks + assert QDM.ds.af.dtype == "float32" + assert p.dtype == "float32" + + @pytest.mark.parametrize("use_dask", [True, False]) + @pytest.mark.parametrize( + "kind,units", [(ADDITIVE, "K"), (MULTIPLICATIVE, "kg m-2 s-1")] + ) + @pytest.mark.parametrize("add_dims", [True, False]) + def test_mon_u( + self, + mon_timelonlatseries, + timelonlatseries, + mon_triangular, + add_dims, + kind, + units, + use_dask, + random, + ): + """ + Train on + hist: U + ref: U + monthly cycle + + Predict on hist to get ref + """ + u = random.random(10000) + + # Define distributions + xd = uniform(loc=1, scale=1) + yd = uniform(loc=2, scale=2) + noise = uniform(loc=0, scale=1e-7) + + # Generate random numbers + x = xd.ppf(u) + y = yd.ppf(u) + noise.ppf(u) + + # Test train + attrs = {"units": units, "kind": kind} + + ref = mon_timelonlatseries(y, attrs=attrs) + hist = sim = timelonlatseries(x, attrs=attrs) + if use_dask: + ref = ref.chunk({"time": -1}) + hist = hist.chunk({"time": -1}) + sim = sim.chunk({"time": -1}) + if add_dims: + ref = ref.expand_dims(site=[0, 1, 2, 3, 4]).drop_vars("site") + hist = hist.expand_dims(site=[0, 1, 2, 3, 4]).drop_vars("site") + sim = sim.expand_dims(site=[0, 1, 2, 3, 4]).drop_vars("site") + sel = {"site": 0} + else: + sel = {} + + QDM = QuantileDeltaMapping.train( + ref, hist, kind=kind, group="time.month", nquantiles=40 + ) + p = QDM.adjust(sim, interp="linear" if kind == "+" else "nearest") + + q = QDM.ds.coords["quantiles"] + expected = get_correction(xd.ppf(q), yd.ppf(q), kind) + + expected = apply_correction( + mon_triangular[:, np.newaxis], expected[np.newaxis, :], kind + ) + np.testing.assert_array_almost_equal( + QDM.ds.af.sel(quantiles=q, **sel), expected, 1 + ) + + # Test predict + np.testing.assert_allclose(p, ref.transpose(*p.dims), rtol=0.1, atol=0.2) + + def test_seasonal(self, timelonlatseries, random): + u = random.random(10000) + kind = "+" + units = "K" + # Define distributions + xd = uniform(loc=1, scale=1) + yd = uniform(loc=2, scale=4) + + # Generate random numbers with u so we get exact results for comparison + x = xd.ppf(u) + y = yd.ppf(u) + + # Test train + attrs = {"units": units, "kind": kind} + + hist = sim = timelonlatseries(x, attrs=attrs) + ref = timelonlatseries(y, attrs=attrs) + + QDM = QuantileDeltaMapping.train( + ref.astype("float32"), + hist.astype("float32"), + kind=kind, + group="time.season", + nquantiles=10, + ) + p = QDM.adjust(sim.astype("float32"), interp="linear") + + # Test predict + # Accept discrepancies near extremes + middle = (u > 1e-2) * (u < 0.99) + np.testing.assert_array_almost_equal(p[middle], ref[middle], 1) + + def test_cannon_and_diagnostics(self, cannon_2015_dist, cannon_2015_rvs): + ref, hist, sim = cannon_2015_rvs(15000, random=False) + + # Quantile mapping + with set_options(sdba_extra_output=True): + QDM = QuantileDeltaMapping.train( + ref, hist, kind="*", group="time", nquantiles=50 + ) + scends = QDM.adjust(sim) + + assert isinstance(scends, xr.Dataset) + + # Theoretical results + # ref, hist, sim = cannon_2015_dist + # u1 = equally_spaced_nodes(1001, None) + # u = np.convolve(u1, [0.5, 0.5], mode="valid") + # pu = ref.ppf(u) * sim.ppf(u) / hist.ppf(u) + # pu1 = ref.ppf(u1) * sim.ppf(u1) / hist.ppf(u1) + # pdf = np.diff(u1) / np.diff(pu1) + + # mean = np.trapz(pdf * pu, pu) + # mom2 = np.trapz(pdf * pu ** 2, pu) + # std = np.sqrt(mom2 - mean ** 2) + bc_sim = scends.scen + np.testing.assert_almost_equal(bc_sim.mean(), 41.5, 1) + np.testing.assert_almost_equal(bc_sim.std(), 16.7, 0) + + +@pytest.mark.slow +class TestQM: + @pytest.mark.parametrize( + "kind,units", [(ADDITIVE, "K"), (MULTIPLICATIVE, "kg m-2 s-1")] + ) + def test_quantiles(self, timelonlatseries, kind, units, random): + """Train on + hist: U + ref: Normal + + Predict on hist to get ref + """ + u = random.random(10000) + + # Define distributions + xd = uniform(loc=10, scale=1) + yd = norm(loc=12, scale=1) + + # Generate random numbers with u so we get exact results for comparison + x = xd.ppf(u) + y = yd.ppf(u) + + # Test train + attrs = {"units": units, "kind": kind} + + hist = sim = timelonlatseries(x, attrs={"units": units}) + ref = timelonlatseries(y, attrs={"units": units}) + + QM = EmpiricalQuantileMapping.train( + ref, + hist, + kind=kind, + group="time", + nquantiles=50, + ) + p = QM.adjust(sim, interp="linear") + + q = QM.ds.coords["quantiles"] + expected = get_correction(xd.ppf(q), yd.ppf(q), kind)[np.newaxis, :] + # Results are not so good at the endpoints + np.testing.assert_array_almost_equal(QM.ds.af[:, 2:-2], expected[:, 2:-2], 1) + + # Test predict + # Accept discrepancies near extremes + middle = (x > 1e-2) * (x < 0.99) + np.testing.assert_array_almost_equal(p[middle], ref[middle], 1) + + @pytest.mark.parametrize( + "kind,units", [(ADDITIVE, "K"), (MULTIPLICATIVE, "kg m-2 s-1")] + ) + def test_mon_u( + self, + mon_timelonlatseries, + timelonlatseries, + mon_triangular, + kind, + units, + random, + ): + """ + Train on + hist: U + ref: U + monthly cycle + + Predict on hist to get ref + """ + u = random.random(10000) + + # Define distributions + xd = uniform(loc=2, scale=0.1) + yd = uniform(loc=4, scale=0.1) + noise = uniform(loc=0, scale=1e-7) + + # Generate random numbers + x = xd.ppf(u) + y = yd.ppf(u) + noise.ppf(u) + + # Test train + attrs = {"units": units, "kind": kind} + + hist = sim = timelonlatseries(x, attrs=attrs) + ref = mon_timelonlatseries(y, attrs=attrs) + + QM = EmpiricalQuantileMapping.train( + ref, hist, kind=kind, group="time.month", nquantiles=5 + ) + p = QM.adjust(sim) + mqm = QM.ds.af.mean(dim="quantiles") + expected = apply_correction(mon_triangular, 2, kind) + np.testing.assert_array_almost_equal(mqm, expected, 1) + + # Test predict + np.testing.assert_array_almost_equal(p, ref, 2) + + # @pytest.mark.parametrize("use_dask", [True, False]) + # @pytest.mark.filterwarnings("ignore::RuntimeWarning") + # def test_add_dims(self, use_dask, open_dataset): + # with set_options(sdba_encode_cf=use_dask): + # if use_dask: + # chunks = {"location": -1} + # else: + # chunks = None + # ref = ( + # open_dataset( + # "sdba/ahccd_1950-2013.nc", + # chunks=chunks, + # drop_variables=["lat", "lon"], + # ) + # .sel(time=slice("1981", "2010")) + # .tasmax + # ) + # ref = convert_units_to(ref, "K") + # ref = ref.isel(location=1, drop=True).expand_dims(location=["Amos"]) + + # dsim = open_dataset( + # "sdba/CanESM2_1950-2100.nc", + # chunks=chunks, + # drop_variables=["lat", "lon"], + # ).tasmax + # hist = dsim.sel(time=slice("1981", "2010")) + # sim = dsim.sel(time=slice("2041", "2070")) + + # # With add_dims, "does it run" test + # group = Grouper("time.dayofyear", window=5, add_dims=["location"]) + # EQM = EmpiricalQuantileMapping.train(ref, hist, group=group) + # EQM.adjust(sim).load() + + # # Without, sanity test. + # group = Grouper("time.dayofyear", window=5) + # EQM2 = EmpiricalQuantileMapping.train(ref, hist, group=group) + # scen2 = EQM2.adjust(sim).load() + # assert scen2.sel(location=["Kugluktuk", "Vancouver"]).isnull().all() + + +class TestPrincipalComponents: + @pytest.mark.parametrize( + "group", (Grouper("time.month"), Grouper("time", add_dims=["lon"])) + ) + def test_simple(self, group, random): + n = 15 * 365 + m = 2 # A dummy dimension to test vectorizing. + ref_y = norm.rvs(loc=10, scale=1, size=(m, n), random_state=random) + ref_x = norm.rvs(loc=3, scale=2, size=(m, n), random_state=random) + sim_x = norm.rvs(loc=4, scale=2, size=(m, n), random_state=random) + sim_y = sim_x + norm.rvs(loc=1, scale=1, size=(m, n), random_state=random) + + ref = xr.DataArray( + [ref_x, ref_y], dims=("lat", "lon", "time"), attrs={"units": "degC"} + ) + ref["time"] = xr.cftime_range("1990-01-01", periods=n, calendar="noleap") + sim = xr.DataArray( + [sim_x, sim_y], dims=("lat", "lon", "time"), attrs={"units": "degC"} + ) + sim["time"] = ref["time"] + + PCA = PrincipalComponents.train(ref, sim, group=group, crd_dim="lat") + scen = PCA.adjust(sim) + + def _assert(ds): + cov_ref = nancov(ds.ref.transpose("lat", "pt")) + cov_sim = nancov(ds.sim.transpose("lat", "pt")) + cov_scen = nancov(ds.scen.transpose("lat", "pt")) + + # PC adjustment makes the covariance of scen match the one of ref. + np.testing.assert_allclose(cov_ref - cov_scen, 0, atol=1e-6) + with pytest.raises(AssertionError): + np.testing.assert_allclose(cov_ref - cov_sim, 0, atol=1e-6) + + def _group_assert(ds, dim): + if "lon" not in dim: + for lon in ds.lon: + _assert(ds.sel(lon=lon).stack(pt=dim)) + else: + _assert(ds.stack(pt=dim)) + return ds + + group.apply(_group_assert, {"ref": ref, "sim": sim, "scen": scen}) + + # @pytest.mark.parametrize("use_dask", [True, False]) + # @pytest.mark.parametrize("pcorient", ["full", "simple"]) + # def test_real_data(self, atmosds, use_dask, pcorient): + # ref = stack_variables( + # xr.Dataset( + # {"tasmax": atmosds.tasmax, "tasmin": atmosds.tasmin, "tas": atmosds.tas} + # ) + # ).isel(location=3) + # hist = stack_variables( + # xr.Dataset( + # { + # "tasmax": 1.001 * atmosds.tasmax, + # "tasmin": atmosds.tasmin - 0.25, + # "tas": atmosds.tas + 1, + # } + # ) + # ).isel(location=3) + # with xr.set_options(keep_attrs=True): + # sim = hist + 5 + # sim["time"] = sim.time + np.timedelta64(10, "Y").astype(" 1], q_thresh) +# base[base > qv] = genpareto.rvs( +# c, loc=qv, scale=s, size=base[base > qv].shape, random_state=random +# ) +# return xr.DataArray( +# base, +# dims=("time",), +# coords={ +# "time": xr.cftime_range("1990-01-01", periods=n, calendar="noleap") +# }, +# attrs={"units": "mm/day", "thresh": qv}, +# ) + +# ref = jitter_under_thresh(gen_testdata(-0.1, 2), "1e-3 mm/d") +# hist = jitter_under_thresh(gen_testdata(-0.1, 2), "1e-3 mm/d") +# sim = gen_testdata(-0.15, 2.5) + +# EQM = EmpiricalQuantileMapping.train( +# ref, hist, group="time.dayofyear", nquantiles=15, kind="*" +# ) + +# scen = EQM.adjust(sim) + +# EX = ExtremeValues.train(ref, hist, cluster_thresh=c_thresh, q_thresh=q_thresh) +# qv = (ref.thresh + hist.thresh) / 2 +# np.testing.assert_allclose(EX.ds.thresh, qv, atol=0.15, rtol=0.01) + +# scen2 = EX.adjust(scen, sim, frac=frac, power=power) + +# # What to test??? +# # Test if extreme values of sim are still extreme +# exval = sim > EX.ds.thresh +# assert (scen2.where(exval) > EX.ds.thresh).sum() > ( +# scen.where(exval) > EX.ds.thresh +# ).sum() + +# @pytest.mark.slow +# def test_real_data(self, open_dataset): +# dsim = open_dataset("sdba/CanESM2_1950-2100.nc").chunk() +# dref = open_dataset("sdba/ahccd_1950-2013.nc").chunk() + +# ref = convert_units_to( +# dref.sel(time=slice("1950", "2009")).pr, "mm/d", context="hydro" +# ) +# hist = convert_units_to( +# dsim.sel(time=slice("1950", "2009")).pr, "mm/d", context="hydro" +# ) + +# quantiles = np.linspace(0.01, 0.99, num=50) + +# with xr.set_options(keep_attrs=True): +# ref = ref + uniform_noise_like(ref, low=1e-6, high=1e-3) +# hist = hist + uniform_noise_like(hist, low=1e-6, high=1e-3) + +# EQM = EmpiricalQuantileMapping.train( +# ref, hist, group=Grouper("time.dayofyear", window=31), nquantiles=quantiles +# ) + +# scen = EQM.adjust(hist, interp="linear", extrapolation="constant") + +# EX = ExtremeValues.train(ref, hist, cluster_thresh="1 mm/day", q_thresh=0.97) +# new_scen = EX.adjust(scen, hist, frac=0.000000001) +# new_scen.load() + + +def test_raise_on_multiple_chunks(timelonlatseries): + ref = timelonlatseries(np.arange(730).astype(float)).chunk({"time": 365}) + with pytest.raises(ValueError): + EmpiricalQuantileMapping.train(ref, ref, group=Grouper("time.month")) + + +def test_default_grouper_understood(timelonlatseries): + attrs = {"units": "K", "kind": ADDITIVE} + + ref = timelonlatseries(np.arange(730).astype(float), attrs={"units": "K"}) + + EQM = EmpiricalQuantileMapping.train(ref, ref) + EQM.adjust(ref) + assert EQM.group.dim == "time" + + +class TestSBCKutils: + @pytest.mark.slow + @pytest.mark.parametrize( + "method", + [m for m in dir(adjustment) if m.startswith("SBCK_")], + ) + @pytest.mark.parametrize("use_dask", [True]) # do we gain testing both? + def test_sbck(self, method, use_dask, random): + SBCK = pytest.importorskip("SBCK", minversion="0.4.0") + + n = 10 * 365 + m = 2 # A dummy dimension to test vectorization. + ref_y = norm.rvs(loc=10, scale=1, size=(m, n), random_state=random) + ref_x = norm.rvs(loc=3, scale=2, size=(m, n), random_state=random) + hist_x = norm.rvs(loc=11, scale=1.2, size=(m, n), random_state=random) + hist_y = norm.rvs(loc=4, scale=2.2, size=(m, n), random_state=random) + sim_x = norm.rvs(loc=12, scale=2, size=(m, n), random_state=random) + sim_y = norm.rvs(loc=3, scale=1.8, size=(m, n), random_state=random) + + ref = xr.Dataset( + { + "tasmin": xr.DataArray( + ref_x, dims=("lon", "time"), attrs={"units": "degC"} + ), + "tasmax": xr.DataArray( + ref_y, dims=("lon", "time"), attrs={"units": "degC"} + ), + } + ) + ref["time"] = xr.cftime_range("1990-01-01", periods=n, calendar="noleap") + + hist = xr.Dataset( + { + "tasmin": xr.DataArray( + hist_x, dims=("lon", "time"), attrs={"units": "degC"} + ), + "tasmax": xr.DataArray( + hist_y, dims=("lon", "time"), attrs={"units": "degC"} + ), + } + ) + hist["time"] = ref["time"] + + sim = xr.Dataset( + { + "tasmin": xr.DataArray( + sim_x, dims=("lon", "time"), attrs={"units": "degC"} + ), + "tasmax": xr.DataArray( + sim_y, dims=("lon", "time"), attrs={"units": "degC"} + ), + } + ) + sim["time"] = xr.cftime_range("2090-01-01", periods=n, calendar="noleap") + + if use_dask: + ref = ref.chunk({"lon": 1}) + hist = hist.chunk({"lon": 1}) + sim = sim.chunk({"lon": 1}) + + if "TSMBC" in method: + kws = {"lag": 1} + elif "MBCn" in method: + kws = {"metric": SBCK.metrics.energy} + else: + kws = {} + + scen = getattr(adjustment, method).adjust( + stack_variables(ref), + stack_variables(hist), + stack_variables(sim), + multi_dim="multivar", + **kws, + ) + unstack_variables(scen).load()