From 45ec1b6ee6f7b8695c0df921b037bcbf1739ffeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Thu, 1 Aug 2024 17:18:04 -0400 Subject: [PATCH] add MBCn tests --- src/xsdba/__init__.py | 1 + src/xsdba/calendar.py | 38 +++++- src/xsdba/locales.py | 2 +- src/xsdba/nbutils.py | 16 +-- src/xsdba/units.py | 1 + src/xsdba/utils.py | 256 +++++++++++++-------------------------- tests/test_adjustment.py | 215 +++++++++++++++++++------------- 7 files changed, 266 insertions(+), 263 deletions(-) diff --git a/src/xsdba/__init__.py b/src/xsdba/__init__.py index c1c626e..81088dc 100644 --- a/src/xsdba/__init__.py +++ b/src/xsdba/__init__.py @@ -36,6 +36,7 @@ from .adjustment import * from .base import Grouper from .options import set_options +from .processing import stack_variables, unstack_variables # from .processing import stack_variables, unstack_variables diff --git a/src/xsdba/calendar.py b/src/xsdba/calendar.py index e32cdcd..c3909d5 100644 --- a/src/xsdba/calendar.py +++ b/src/xsdba/calendar.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd import xarray as xr +from boltons.funcutils import wraps from xarray.coding.cftime_offsets import to_cftime_datetime from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import dtypes @@ -43,6 +44,7 @@ "doy_from_string", "doy_to_days_since", "ensure_cftime_array", + "ensure_longest_doy", "get_calendar", "interp_calendar", "is_offset_divisor", @@ -554,7 +556,9 @@ def compare_offsets(freqA: str, op: str, freqB: str) -> bool: bool freqA op freqB """ - from ..indices.generic import get_op # pylint: disable=import-outside-toplevel + from .xclim_submodules.generic import ( # pylint: disable=import-outside-toplevel + get_op, + ) # Get multiplier and base frequency t_a, b_a, _, _ = parse_offset(freqA) @@ -704,6 +708,38 @@ def is_offset_divisor(divisor: str, offset: str): return all(offAs.is_on_offset(d) for d in tB) +def ensure_longest_doy(func: Callable) -> Callable: + """Ensure that selected day is the longest day of year for x and y dims.""" + + @wraps(func) + def _ensure_longest_doy(x, y, *args, **kwargs): + if ( + hasattr(x, "dims") + and hasattr(y, "dims") + and "dayofyear" in x.dims + and "dayofyear" in y.dims + and x.dayofyear.max() != y.dayofyear.max() + ): + warn( + ( + "get_correction received inputs defined on different dayofyear ranges. " + "Interpolating to the longest range. Results could be strange." + ), + stacklevel=4, + ) + if x.dayofyear.max() < y.dayofyear.max(): + x = _interpolate_doy_calendar( + x, int(y.dayofyear.max()), int(y.dayofyear.min()) + ) + else: + y = _interpolate_doy_calendar( + y, int(x.dayofyear.max()), int(x.dayofyear.min()) + ) + return func(x, y, *args, **kwargs) + + return _ensure_longest_doy + + def _interpolate_doy_calendar( source: xr.DataArray, doy_max: int, doy_min: int = 1 ) -> xr.DataArray: diff --git a/src/xsdba/locales.py b/src/xsdba/locales.py index 8ff876e..e3947b1 100644 --- a/src/xsdba/locales.py +++ b/src/xsdba/locales.py @@ -288,7 +288,7 @@ def generate_local_dict(locale: str, init_english: bool = False) -> dict: If True, fills the initial dictionary with the english versions of the attributes. Defaults to False. """ - from ..core.indicator import registry # pylint: disable=import-outside-toplevel + from .indicator import registry # pylint: disable=import-outside-toplevel if locale in _LOCALES: _, attrs = get_local_dict(locale) diff --git a/src/xsdba/nbutils.py b/src/xsdba/nbutils.py index d5b2864..88fc3d4 100644 --- a/src/xsdba/nbutils.py +++ b/src/xsdba/nbutils.py @@ -26,11 +26,10 @@ nogil=True, cache=False, ) -def _get_indexes( # noqa: PR07 +def _get_indexes( arr: np.array, virtual_indexes: np.array, valid_values_count: np.array ) -> tuple[np.array, np.array]: - """ - Get the valid indexes of arr neighbouring virtual_indexes. + """Get the valid indexes of arr neighbouring virtual_indexes. Parameters ---------- @@ -41,7 +40,7 @@ def _get_indexes( # noqa: PR07 Returns ------- array-like, array-like - A tuple of virtual_indexes neighbouring indexes (previous and next). + A tuple of virtual_indexes neighbouring indexes (previous and next) Notes ----- @@ -210,7 +209,8 @@ def _wrapper_quantile1d(arr, q): return out -def _quantile(arr, q, nreduce): +def _quantile(arr, q, nreduce=None): + nreduce = nreduce or arr.ndim if arr.ndim == nreduce: out = _nan_quantile_1d(arr.flatten(), q) else: @@ -277,7 +277,7 @@ def quantile(da: DataArray, q: np.ndarray, dim: str | Sequence[Hashable]) -> Dat nogil=True, cache=False, ) -def remove_NaNs(x): # noqa: N802 +def remove_NaNs(x): # noqa """Remove NaN values from series.""" remove = np.zeros_like(x[0, :], dtype=boolean) for i in range(x.shape[0]): @@ -386,7 +386,9 @@ def _first_and_last_nonnull(arr): nogil=True, cache=False, ) -def _extrapolate_on_quantiles(interp, oldx, oldg, oldy, newx, newg, method="constant"): +def _extrapolate_on_quantiles( + interp, oldx, oldg, oldy, newx, newg, method="constant" +): # noqa """Apply extrapolation to the output of interpolation on quantiles with a given grouping. Arguments are the same as _interp_on_quantiles_2D. diff --git a/src/xsdba/units.py b/src/xsdba/units.py index 2342ad4..1a0bac4 100644 --- a/src/xsdba/units.py +++ b/src/xsdba/units.py @@ -6,6 +6,7 @@ import inspect from copy import deepcopy from functools import wraps +from typing import Any import pint diff --git a/src/xsdba/utils.py b/src/xsdba/utils.py index 34dcdfb..5c6264c 100644 --- a/src/xsdba/utils.py +++ b/src/xsdba/utils.py @@ -9,15 +9,16 @@ from typing import Callable from warnings import warn +import bottleneck as bn import numpy as np import xarray as xr -from boltons.funcutils import wraps from dask import array as dsk from scipy.interpolate import griddata, interp1d from scipy.stats import spearmanr from xarray.core.utils import get_temp_dimname from .base import Grouper, parse_group, uses_dask +from .calendar import ensure_longest_doy from .nbutils import _extrapolate_on_quantiles MULTIPLICATIVE = "*" @@ -49,18 +50,17 @@ def map_cdf( Parameters ---------- ds : xr.Dataset - Variables: - x : Values from which to pick. - y : Reference values giving the ranking. + Variables: x, Values from which to pick, + y, Reference values giving the ranking y_value : float, array - Value within the support of `y`. + Value within the support of `y`. dim : str - Dimension along which to compute quantile. + Dimension along which to compute quantile. Returns ------- array - Quantile of `x` with the same CDF as `y_value` in `y`. + Quantile of `x` with the same CDF as `y_value` in `y`. """ return xr.apply_ufunc( map_cdf_1d, @@ -95,135 +95,6 @@ def ecdf(x: xr.DataArray, value: float, dim: str = "time") -> xr.DataArray: return (x <= value).sum(dim) / x.notnull().sum(dim) -# XC -def ensure_chunk_size(da: xr.DataArray, **minchunks: int) -> xr.DataArray: - r"""Ensure that the input DataArray has chunks of at least the given size. - - If only one chunk is too small, it is merged with an adjacent chunk. - If many chunks are too small, they are grouped together by merging adjacent chunks. - - Parameters - ---------- - da : xr.DataArray - The input DataArray, with or without the dask backend. Does nothing when passed a non-dask array. - \*\*minchunks : dict[str, int] - A kwarg mapping from dimension name to minimum chunk size. - Pass -1 to force a single chunk along that dimension. - - Returns - ------- - xr.DataArray - """ - if not uses_dask(da): - return da - - all_chunks = dict(zip(da.dims, da.chunks)) - chunking = {} - for dim, minchunk in minchunks.items(): - chunks = all_chunks[dim] - if minchunk == -1 and len(chunks) > 1: - # Rechunk to single chunk only if it's not already one - chunking[dim] = -1 - - toosmall = np.array(chunks) < minchunk # Chunks that are too small - if toosmall.sum() > 1: - # Many chunks are too small, merge them by groups - fac = np.ceil(minchunk / min(chunks)).astype(int) - chunking[dim] = tuple( - sum(chunks[i : i + fac]) for i in range(0, len(chunks), fac) - ) - # Reset counter is case the last chunks are still too small - chunks = chunking[dim] - toosmall = np.array(chunks) < minchunk - if toosmall.sum() == 1: - # Only one, merge it with adjacent chunk - ind = np.where(toosmall)[0][0] - new_chunks = list(chunks) - sml = new_chunks.pop(ind) - new_chunks[max(ind - 1, 0)] += sml - chunking[dim] = tuple(new_chunks) - - if chunking: - return da.chunk(chunks=chunking) - return da - - -# XC -def _interpolate_doy_calendar( - source: xr.DataArray, doy_max: int, doy_min: int = 1 -) -> xr.DataArray: - """Interpolate from one set of dayofyear range to another. - - Interpolate an array defined over a `dayofyear` range (say 1 to 360) to another `dayofyear` range (say 1 - to 365). - - Parameters - ---------- - source : xr.DataArray - Array with `dayofyear` coordinates. - doy_max : int - The largest day of the year allowed by calendar. - doy_min : int - The smallest day of the year in the output. - This parameter is necessary when the target time series does not span over a full year (e.g. JJA season). - Default is 1. - - Returns - ------- - xr.DataArray - Interpolated source array over coordinates spanning the target `dayofyear` range. - """ - if "dayofyear" not in source.coords.keys(): - raise AttributeError("Source should have `dayofyear` coordinates.") - - # Interpolate to fill na values - da = source - if uses_dask(source): - # interpolate_na cannot run on chunked dayofyear. - da = source.chunk(dict(dayofyear=-1)) - filled_na = da.interpolate_na(dim="dayofyear") - - # Interpolate to target dayofyear range - filled_na.coords["dayofyear"] = np.linspace( - start=doy_min, stop=doy_max, num=len(filled_na.coords["dayofyear"]) - ) - - return filled_na.interp(dayofyear=range(doy_min, doy_max + 1)) - - -# XC -def ensure_longest_doy(func: Callable) -> Callable: - """Ensure that selected day is the longest day of year for x and y dims.""" - - @wraps(func) - def _ensure_longest_doy(x, y, *args, **kwargs): - if ( - hasattr(x, "dims") - and hasattr(y, "dims") - and "dayofyear" in x.dims - and "dayofyear" in y.dims - and x.dayofyear.max() != y.dayofyear.max() - ): - warn( - ( - "get_correction received inputs defined on different dayofyear ranges. " - "Interpolating to the longest range. Results could be strange." - ), - stacklevel=4, - ) - if x.dayofyear.max() < y.dayofyear.max(): - x = _interpolate_doy_calendar( - x, int(y.dayofyear.max()), int(y.dayofyear.min()) - ) - else: - y = _interpolate_doy_calendar( - y, int(x.dayofyear.max()), int(x.dayofyear.min()) - ) - return func(x, y, *args, **kwargs) - - return _ensure_longest_doy - - @ensure_longest_doy def get_correction(x: xr.DataArray, y: xr.DataArray, kind: str) -> xr.DataArray: """Return the additive or multiplicative correction/adjustment factors.""" @@ -405,6 +276,39 @@ def add_cyclic_bounds( return ensure_chunk_size(qmf, **{att: -1}) +def _interp_on_quantiles_1D_multi(newxs, oldx, oldy, method, extrap): # noqa: N802 + # Perform multiple interpolations with a single call of interp1d. + # This should be used when `oldx` is common for many data arrays (`newxs`) + # that we want to interpolate on. For instance, with QuantileDeltaMapping, we simply + # interpolate on quantiles that always remain the same. + if len(newxs.shape) == 1: + return _interp_on_quantiles_1D(newxs, oldx, oldy, method, extrap) + mask_old = np.isnan(oldy) | np.isnan(oldx) + if extrap == "constant": + fill_value = ( + oldy[~np.isnan(oldy)][0], + oldy[~np.isnan(oldy)][-1], + ) + else: # extrap == 'nan' + fill_value = np.NaN + + finterp1d = interp1d( + oldx[~mask_old], + oldy[~mask_old], + kind=method, + bounds_error=False, + fill_value=fill_value, + ) + + out = np.zeros_like(newxs) + for ii in range(newxs.shape[0]): + mask_new = np.isnan(newxs[ii, :]) + y1 = newxs[ii, :].copy() * np.NaN + y1[~mask_new] = finterp1d(newxs[ii, ~mask_new]) + out[ii, :] = y1.flatten() + return out + + def _interp_on_quantiles_1D(newx, oldx, oldy, method, extrap): # noqa: N802 mask_new = np.isnan(newx) mask_old = np.isnan(oldy) | np.isnan(oldx) @@ -627,6 +531,14 @@ def rank( return rnk +def _rank_bn(arr, axis=None): + """Ranking on a specific axis""" + rnk = bn.nanrankdata(arr, axis=axis) + rnk = rnk / np.nanmax(rnk, axis=axis, keepdims=True) + mx, mn = 1, np.nanmin(rnk, axis=axis, keepdims=True) + return mx * (rnk - mn) / (mx - mn) + + def pc_matrix(arr: np.ndarray | dsk.Array) -> np.ndarray | dsk.Array: """Construct a Principal Component matrix. @@ -678,17 +590,17 @@ def best_pc_orientation_simple( Parameters ---------- R : np.ndarray - MxM Matrix defining the final transformation. + MxM Matrix defining the final transformation. Hinv : np.ndarray - MxM Matrix defining the (inverse) first transformation. + MxM Matrix defining the (inverse) first transformation. val : float - The coordinate of the test point (same for all axes). It should be much - greater than the largest furthest point in the array used to define B. + The coordinate of the test point (same for all axes). It should be much + greater than the largest furthest point in the array used to define B. Returns ------- np.ndarray - Mx1 vector of orientation correction (1 or -1). + Mx1 vector of orientation correction (1 or -1). See Also -------- @@ -728,20 +640,20 @@ def best_pc_orientation_full( Parameters ---------- R : np.ndarray - MxM Matrix defining the final transformation. + MxM Matrix defining the final transformation. Hinv : np.ndarray - MxM Matrix defining the (inverse) first transformation. + MxM Matrix defining the (inverse) first transformation. Rmean : np.ndarray - M vector defining the target distribution center point. + M vector defining the target distribution center point. Hmean : np.ndarray - M vector defining the original distribution center point. + M vector defining the original distribution center point. hist : np.ndarray - MxN matrix of all training observations of the M variables/sites. + MxN matrix of all training observations of the M variables/sites. Returns ------- np.ndarray - M vector of orientation correction (1 or -1). + M vector of orientation correction (1 or -1). References ---------- @@ -830,27 +742,27 @@ def get_clusters(data: xr.DataArray, u1, u2, dim: str = "time") -> xr.Dataset: Parameters ---------- - data : 1D ndarray - Values to get clusters from. + data: 1D ndarray + Values to get clusters from. u1 : float - Extreme value threshold, at least one value in the cluster must exceed this. + Extreme value threshold, at least one value in the cluster must exceed this. u2 : float - Cluster threshold, values above this can be part of a cluster. + Cluster threshold, values above this can be part of a cluster. dim : str - Dimension name. + Dimension name. Returns ------- xr.Dataset - With variables, - - `nclusters` : Number of clusters for each point (with `dim` reduced), int - - `start` : First index in the cluster (`dim` reduced, new `cluster`), int - - `end` : Last index in the cluster, inclusive (`dim` reduced, new `cluster`), int - - `maxpos` : Index of the maximal value within the cluster (`dim` reduced, new `cluster`), int - - `maximum` : Maximal value within the cluster (`dim` reduced, new `cluster`), same dtype as data. - - For `start`, `end` and `maxpos`, -1 means NaN and should always correspond to a `NaN` in `maximum`. - The length along `cluster` is half the size of "dim", the maximal theoretical number of clusters. + With variables, + - `nclusters` : Number of clusters for each point (with `dim` reduced), int + - `start` : First index in the cluster (`dim` reduced, new `cluster`), int + - `end` : Last index in the cluster, inclusive (`dim` reduced, new `cluster`), int + - `maxpos` : Index of the maximal value within the cluster (`dim` reduced, new `cluster`), int + - `maximum` : Maximal value within the cluster (`dim` reduced, new `cluster`), same dtype as data. + + For `start`, `end` and `maxpos`, -1 means NaN and should always correspond to a `NaN` in `maximum`. + The length along `cluster` is half the size of "dim", the maximal theoretical number of clusters. """ def _get_clusters(arr, u1, u2, N): @@ -914,19 +826,19 @@ def rand_rot_matrix( Parameters ---------- - crd : xr.DataArray - 1D coordinate DataArray along which the rotation occurs. - The output will be square with the same coordinate replicated, - the second renamed to `new_dim`. + crd: xr.DataArray + 1D coordinate DataArray along which the rotation occurs. + The output will be square with the same coordinate replicated, + the second renamed to `new_dim`. num : int - If larger than 1 (default), the number of matrices to generate, stacked along a "matrices" dimension. + If larger than 1 (default), the number of matrices to generate, stacked along a "matrices" dimension. new_dim : str - Name of the new "prime" dimension, defaults to the same name as `crd` + "_prime". + Name of the new "prime" dimension, defaults to the same name as `crd` + "_prime". Returns ------- xr.DataArray - Data of type float, NxN if num = 1, numxNxN otherwise, where N is the length of crd. + float, NxN if num = 1, numxNxN otherwise, where N is the length of crd. References ---------- @@ -950,9 +862,11 @@ def rand_rot_matrix( num = np.diag(R) denum = np.abs(num) lam = np.diag(num / denum) # "lambda" - return xr.DataArray( - Q @ lam, dims=(dim, new_dim), coords={dim: crd, new_dim: crd2} - ).astype("float32") + return ( + xr.DataArray(Q @ lam, dims=(dim, new_dim), coords={dim: crd, new_dim: crd2}) + .astype("float32") + .assign_attrs({"crd_dim": dim, "new_dim": new_dim}) + ) def _pairwise_spearman(da, dims): diff --git a/tests/test_adjustment.py b/tests/test_adjustment.py index b2aaf1e..cf4e5ab 100644 --- a/tests/test_adjustment.py +++ b/tests/test_adjustment.py @@ -12,11 +12,13 @@ DetrendedQuantileMapping, EmpiricalQuantileMapping, ExtremeValues, + MBCn, PrincipalComponents, QuantileDeltaMapping, Scaling, ) from xsdba.base import Grouper +from xsdba.calendar import stack_periods from xsdba.options import set_options from xsdba.processing import ( jitter_under_thresh, @@ -582,44 +584,44 @@ def test_mon_u( # Test predict np.testing.assert_array_almost_equal(p, ref, 2) - # @pytest.mark.parametrize("use_dask", [True, False]) - # @pytest.mark.filterwarnings("ignore::RuntimeWarning") - # def test_add_dims(self, use_dask, open_dataset): - # with set_options(sdba_encode_cf=use_dask): - # if use_dask: - # chunks = {"location": -1} - # else: - # chunks = None - # ref = ( - # open_dataset( - # "sdba/ahccd_1950-2013.nc", - # chunks=chunks, - # drop_variables=["lat", "lon"], - # ) - # .sel(time=slice("1981", "2010")) - # .tasmax - # ) - # ref = convert_units_to(ref, "K") - # ref = ref.isel(location=1, drop=True).expand_dims(location=["Amos"]) - - # dsim = open_dataset( - # "sdba/CanESM2_1950-2100.nc", - # chunks=chunks, - # drop_variables=["lat", "lon"], - # ).tasmax - # hist = dsim.sel(time=slice("1981", "2010")) - # sim = dsim.sel(time=slice("2041", "2070")) - - # # With add_dims, "does it run" test - # group = Grouper("time.dayofyear", window=5, add_dims=["location"]) - # EQM = EmpiricalQuantileMapping.train(ref, hist, group=group) - # EQM.adjust(sim).load() - - # # Without, sanity test. - # group = Grouper("time.dayofyear", window=5) - # EQM2 = EmpiricalQuantileMapping.train(ref, hist, group=group) - # scen2 = EQM2.adjust(sim).load() - # assert scen2.sel(location=["Kugluktuk", "Vancouver"]).isnull().all() + @pytest.mark.parametrize("use_dask", [True, False]) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_add_dims(self, use_dask, open_dataset): + with set_options(sdba_encode_cf=use_dask): + if use_dask: + chunks = {"location": -1} + else: + chunks = None + ref = ( + open_dataset( + "sdba/ahccd_1950-2013.nc", + chunks=chunks, + drop_variables=["lat", "lon"], + ) + .sel(time=slice("1981", "2010")) + .tasmax + ) + ref = convert_units_to(ref, "K") + ref = ref.isel(location=1, drop=True).expand_dims(location=["Amos"]) + + dsim = open_dataset( + "sdba/CanESM2_1950-2100.nc", + chunks=chunks, + drop_variables=["lat", "lon"], + ).tasmax + hist = dsim.sel(time=slice("1981", "2010")) + sim = dsim.sel(time=slice("2041", "2070")) + + # With add_dims, "does it run" test + group = Grouper("time.dayofyear", window=5, add_dims=["location"]) + EQM = EmpiricalQuantileMapping.train(ref, hist, group=group) + EQM.adjust(sim).load() + + # Without, sanity test. + group = Grouper("time.dayofyear", window=5) + EQM2 = EmpiricalQuantileMapping.train(ref, hist, group=group) + scen2 = EQM2.adjust(sim).load() + assert scen2.sel(location=["Kugluktuk", "Vancouver"]).isnull().all() class TestPrincipalComponents: @@ -666,49 +668,49 @@ def _group_assert(ds, dim): group.apply(_group_assert, {"ref": ref, "sim": sim, "scen": scen}) - # @pytest.mark.parametrize("use_dask", [True, False]) - # @pytest.mark.parametrize("pcorient", ["full", "simple"]) - # def test_real_data(self, atmosds, use_dask, pcorient): - # ref = stack_variables( - # xr.Dataset( - # {"tasmax": atmosds.tasmax, "tasmin": atmosds.tasmin, "tas": atmosds.tas} - # ) - # ).isel(location=3) - # hist = stack_variables( - # xr.Dataset( - # { - # "tasmax": 1.001 * atmosds.tasmax, - # "tasmin": atmosds.tasmin - 0.25, - # "tas": atmosds.tas + 1, - # } - # ) - # ).isel(location=3) - # with xr.set_options(keep_attrs=True): - # sim = hist + 5 - # sim["time"] = sim.time + np.timedelta64(10, "Y").astype(" kg m-2 d-1 + ref["pr"] = pint_multiply(ref["pr"], "1000 kg/m^3") + dsim = dsim.sel(time=slice("1981", None)) + sim = (stack_periods(dsim).isel(period=slice(1, 2))).isel( + time=slice(365 * 4) + ) + + ref, hist, sim = (stack_variables(ds) for ds in [ref, hist, sim]) + + MBCN = MBCn.train( + ref, + hist, + base_kws=dict(nquantiles=50, group=Grouper(group, window)), + adj_kws=dict(interp="linear"), + ) + p = MBCN.adjust(sim=sim, ref=ref, hist=hist, period_dim=period_dim) + # 'does it run' test + p.load() + + # class TestSBCKutils: # @pytest.mark.slow # @pytest.mark.parametrize(