Skip to content

Commit

Permalink
add MBCn tests
Browse files Browse the repository at this point in the history
  • Loading branch information
coxipi committed Aug 1, 2024
1 parent 52546bc commit 45ec1b6
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 263 deletions.
1 change: 1 addition & 0 deletions src/xsdba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .adjustment import *
from .base import Grouper
from .options import set_options
from .processing import stack_variables, unstack_variables

# from .processing import stack_variables, unstack_variables

Expand Down
38 changes: 37 additions & 1 deletion src/xsdba/calendar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import pandas as pd
import xarray as xr
from boltons.funcutils import wraps
from xarray.coding.cftime_offsets import to_cftime_datetime
from xarray.coding.cftimeindex import CFTimeIndex
from xarray.core import dtypes
Expand Down Expand Up @@ -43,6 +44,7 @@
"doy_from_string",
"doy_to_days_since",
"ensure_cftime_array",
"ensure_longest_doy",
"get_calendar",
"interp_calendar",
"is_offset_divisor",
Expand Down Expand Up @@ -554,7 +556,9 @@ def compare_offsets(freqA: str, op: str, freqB: str) -> bool:
bool
freqA op freqB
"""
from ..indices.generic import get_op # pylint: disable=import-outside-toplevel
from .xclim_submodules.generic import ( # pylint: disable=import-outside-toplevel
get_op,
)

# Get multiplier and base frequency
t_a, b_a, _, _ = parse_offset(freqA)
Expand Down Expand Up @@ -704,6 +708,38 @@ def is_offset_divisor(divisor: str, offset: str):
return all(offAs.is_on_offset(d) for d in tB)


def ensure_longest_doy(func: Callable) -> Callable:
"""Ensure that selected day is the longest day of year for x and y dims."""

@wraps(func)
def _ensure_longest_doy(x, y, *args, **kwargs):
if (
hasattr(x, "dims")
and hasattr(y, "dims")
and "dayofyear" in x.dims
and "dayofyear" in y.dims
and x.dayofyear.max() != y.dayofyear.max()
):
warn(
(
"get_correction received inputs defined on different dayofyear ranges. "
"Interpolating to the longest range. Results could be strange."
),
stacklevel=4,
)
if x.dayofyear.max() < y.dayofyear.max():
x = _interpolate_doy_calendar(
x, int(y.dayofyear.max()), int(y.dayofyear.min())
)
else:
y = _interpolate_doy_calendar(
y, int(x.dayofyear.max()), int(x.dayofyear.min())
)
return func(x, y, *args, **kwargs)

return _ensure_longest_doy


def _interpolate_doy_calendar(
source: xr.DataArray, doy_max: int, doy_min: int = 1
) -> xr.DataArray:
Expand Down
2 changes: 1 addition & 1 deletion src/xsdba/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def generate_local_dict(locale: str, init_english: bool = False) -> dict:
If True, fills the initial dictionary with the english versions of the attributes.
Defaults to False.
"""
from ..core.indicator import registry # pylint: disable=import-outside-toplevel
from .indicator import registry # pylint: disable=import-outside-toplevel

if locale in _LOCALES:
_, attrs = get_local_dict(locale)
Expand Down
16 changes: 9 additions & 7 deletions src/xsdba/nbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
nogil=True,
cache=False,
)
def _get_indexes( # noqa: PR07
def _get_indexes(
arr: np.array, virtual_indexes: np.array, valid_values_count: np.array
) -> tuple[np.array, np.array]:
"""
Get the valid indexes of arr neighbouring virtual_indexes.
"""Get the valid indexes of arr neighbouring virtual_indexes.
Parameters
----------
Expand All @@ -41,7 +40,7 @@ def _get_indexes( # noqa: PR07
Returns
-------
array-like, array-like
A tuple of virtual_indexes neighbouring indexes (previous and next).
A tuple of virtual_indexes neighbouring indexes (previous and next)
Notes
-----
Expand Down Expand Up @@ -210,7 +209,8 @@ def _wrapper_quantile1d(arr, q):
return out


def _quantile(arr, q, nreduce):
def _quantile(arr, q, nreduce=None):
nreduce = nreduce or arr.ndim
if arr.ndim == nreduce:
out = _nan_quantile_1d(arr.flatten(), q)
else:
Expand Down Expand Up @@ -277,7 +277,7 @@ def quantile(da: DataArray, q: np.ndarray, dim: str | Sequence[Hashable]) -> Dat
nogil=True,
cache=False,
)
def remove_NaNs(x): # noqa: N802
def remove_NaNs(x): # noqa
"""Remove NaN values from series."""
remove = np.zeros_like(x[0, :], dtype=boolean)
for i in range(x.shape[0]):
Expand Down Expand Up @@ -386,7 +386,9 @@ def _first_and_last_nonnull(arr):
nogil=True,
cache=False,
)
def _extrapolate_on_quantiles(interp, oldx, oldg, oldy, newx, newg, method="constant"):
def _extrapolate_on_quantiles(
interp, oldx, oldg, oldy, newx, newg, method="constant"
): # noqa
"""Apply extrapolation to the output of interpolation on quantiles with a given grouping.
Arguments are the same as _interp_on_quantiles_2D.
Expand Down
1 change: 1 addition & 0 deletions src/xsdba/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
from copy import deepcopy
from functools import wraps
from typing import Any

import pint

Expand Down
Loading

0 comments on commit 45ec1b6

Please sign in to comment.