Skip to content

Commit 91d2a96

Browse files
committed
PASSED: test_units.py (still needs more tests)
1 parent ca8bb75 commit 91d2a96

File tree

6 files changed

+153
-44
lines changed

6 files changed

+153
-44
lines changed

src/xsdba/detrending.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .base import Grouper, ParametrizableWithDataset, map_groups, parse_group
1111
from .loess import loess_smoothing
12-
from .units import check_units
12+
from .units import compare_units
1313
from .utils import ADDITIVE, apply_correction, invert
1414

1515

@@ -91,13 +91,13 @@ def retrend(self, da: xr.DataArray):
9191
raise ValueError("You must call fit() before retrending")
9292
return self._retrend(da, self.ds.trend)
9393

94-
@check_units(["da", "trend"])
94+
@compare_units(["da", "trend"])
9595
def _detrend(self, da, trend):
9696
"""Detrend."""
9797
# Remove trend from series
9898
return apply_correction(da, invert(trend, self.kind), self.kind)
9999

100-
@check_units(["da", "trend"])
100+
@compare_units(["da", "trend"])
101101
def _retrend(self, da, trend):
102102
"""Retrend."""
103103
# Add trend to series

src/xsdba/indicator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
import warnings
107107
import weakref
108108
from collections import OrderedDict, defaultdict
109+
from collections.abc import Sequence
109110
from copy import deepcopy
110111
from dataclasses import asdict, dataclass
111112
from functools import reduce
@@ -117,7 +118,6 @@
117118
from pathlib import Path
118119
from types import ModuleType
119120
from typing import Any, Callable, Optional, Union
120-
from collections.abc import Sequence
121121

122122
import numpy as np
123123
import xarray
@@ -156,7 +156,7 @@
156156
OPTIONS,
157157
)
158158
from .typing import InputKind
159-
from .units import check_units, convert_units_to, units
159+
from .units import compare_units, convert_units_to, units
160160
from .utils import load_module
161161

162162
# Indicators registry
@@ -1149,7 +1149,7 @@ def _translate(cf_attrs, names, var_id=None):
11491149
return attrs
11501150

11511151
@classmethod
1152-
def json(self, args=None):
1152+
def json(cls, args=None):
11531153
"""Return a serializable dictionary representation of the class.
11541154
11551155
Parameters

src/xsdba/measures.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# from xclim.core.units import ensure_delta
2222
from .base import Grouper
2323
from .typing import InputKind
24-
from .units import check_units, ensure_delta
24+
from .units import compare_units, ensure_delta
2525
from .utils import _pairwise_spearman
2626

2727

@@ -48,7 +48,7 @@ def _ensure_correct_parameters(cls, parameters):
4848
)
4949
return super()._ensure_correct_parameters(parameters)
5050

51-
@check_units([{"das": "ref"}, {"das": "sim"}])
51+
@compare_units([{"das": "ref"}, {"das": "sim"}])
5252
def _preprocess_and_checks(self, das, params):
5353
"""Perform parent's checks and also check convert units so that sim matches ref."""
5454
das, params = super()._preprocess_and_checks(das, params)
@@ -110,7 +110,7 @@ def _ensure_correct_parameters(cls, parameters):
110110

111111
return super()._ensure_correct_parameters(parameters)
112112

113-
@check_units([{"das": "ref"}, {"das": "sim"}])
113+
@compare_units([{"das": "ref"}, {"das": "sim"}])
114114
def _preprocess_and_checks(self, das, params):
115115
"""Perform parent's checks and also check convert units so that sim matches ref."""
116116
das, params = super()._preprocess_and_checks(das, params)

src/xsdba/processing.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ._processing import _adapt_freq, _normalize, _reordering
2121
from .base import Grouper
2222
from .nbutils import _escore
23-
from .units import check_units, convert_units_to, harmonize_units
23+
from .units import compare_units, convert_units_to, harmonize_units
2424
from .utils import ADDITIVE, copy_all_attrs
2525

2626
# from xclim.core.units import convert_units_to, infer_context, units
@@ -67,7 +67,7 @@ def adapt_freq(
6767
sim : xr.Dataset
6868
Simulated data, with a "time" dimension.
6969
group : str or Grouper
70-
Grouping information, see base.Grouper
70+
Grouping information, see base.Grouper.
7171
thresh : str
7272
Threshold below which values are considered zero, a quantity with units.
7373
@@ -96,7 +96,6 @@ def adapt_freq(
9696
References
9797
----------
9898
:cite:cts:`sdba-themesl_empirical-statistical_2012`
99-
10099
"""
101100
out = _adapt_freq(xr.Dataset(dict(sim=sim, ref=ref)), group=group, thresh=thresh)
102101

@@ -134,7 +133,7 @@ def jitter_under_thresh(x: xr.DataArray, thresh: str) -> xr.DataArray:
134133
135134
Returns
136135
-------
137-
xr.DataArray
136+
xr.DataArray.
138137
139138
Notes
140139
-----
@@ -162,12 +161,11 @@ def jitter_over_thresh(x: xr.DataArray, thresh: str, upper_bnd: str) -> xr.DataA
162161
163162
Returns
164163
-------
165-
xr.DataArray
164+
xr.DataArray.
166165
167166
Notes
168167
-----
169168
If thresh is low, this will change the mean value of x.
170-
171169
"""
172170
j: xr.DataArray = jitter(
173171
x, lower=None, upper=thresh, minimum=None, maximum=upper_bnd
@@ -353,29 +351,28 @@ def unstandardize(da: xr.DataArray, mean: xr.DataArray, std: xr.DataArray):
353351

354352
@update_xsdba_history
355353
def reordering(ref: xr.DataArray, sim: xr.DataArray, group: str = "time") -> xr.Dataset:
356-
"""Reorders data in `sim` following the order of ref.
354+
"""Reorder data in `sim` following the order of ref.
357355
358356
The rank structure of `ref` is used to reorder the elements of `sim` along dimension "time", optionally doing the
359357
operation group-wise.
360358
361359
Parameters
362360
----------
363-
sim : xr.DataArray
364-
Array to reorder.
365361
ref : xr.DataArray
366362
Array whose rank order sim should replicate.
363+
sim : xr.DataArray
364+
Array to reorder.
367365
group : str
368366
Grouping information. See :py:class:`xclim.sdba.base.Grouper` for details.
369367
370368
Returns
371369
-------
372370
xr.Dataset
373-
sim reordered according to ref's rank order.
371+
Sim reordered according to ref's rank order.
374372
375373
References
376374
----------
377-
:cite:cts:`sdba-cannon_multivariate_2018`
378-
375+
:cite:cts:`sdba-cannon_multivariate_2018`.
379376
"""
380377
ds = xr.Dataset({"sim": sim, "ref": ref})
381378
out: xr.Dataset = _reordering(ds, group=group).reordered
@@ -414,7 +411,7 @@ def escore(
414411
Returns
415412
-------
416413
xr.DataArray
417-
e-score with dimensions not in `dims`.
414+
Return e-score with dimensions not in `dims`.
418415
419416
Notes
420417
-----
@@ -441,8 +438,7 @@ def escore(
441438
442439
References
443440
----------
444-
:cite:cts:`sdba-baringhaus_new_2004,sdba-cannon_multivariate_2018,sdba-cannon_mbc_2020,sdba-szekely_testing_2004`
445-
441+
:cite:cts:`sdba-baringhaus_new_2004,sdba-cannon_multivariate_2018,sdba-cannon_mbc_2020,sdba-szekely_testing_2004`.
446442
"""
447443
pts_dim, obs_dim = dims
448444

@@ -563,14 +559,14 @@ def to_additive_space(
563559
564560
See Also
565561
--------
562+
Related functions
566563
from_additive_space : for the inverse transformation.
567564
jitter_under_thresh : Remove values exactly equal to the lower bound.
568565
jitter_over_thresh : Remove values exactly equal to the upper bound.
569566
570567
References
571568
----------
572-
:cite:cts:`sdba-alavoine_distinct_2022`
573-
569+
:cite:cts:`sdba-alavoine_distinct_2022`.
574570
"""
575571
# with units.context(infer_context(data.attrs.get("standard_name"))):
576572
lower_bound_array = np.array(lower_bound).astype(float)
@@ -666,7 +662,7 @@ def from_additive_space(
666662
667663
References
668664
----------
669-
:cite:cts:`sdba-alavoine_distinct_2022`
665+
:cite:cts:`sdba-alavoine_distinct_2022`.
670666
671667
"""
672668
if trans is None and lower_bound is None and units is None:
@@ -748,7 +744,6 @@ def stack_variables(ds: xr.Dataset, rechunk: bool = True, dim: str = "multivar")
748744
`sdba_transform_upper` are also set if the requested bounds are different from the defaults.
749745
750746
Array with variables stacked along `dim` dimension. Units are set to "".
751-
752747
"""
753748
# Store original arrays' attributes
754749
attrs: dict = {}
@@ -825,7 +820,7 @@ def grouped_time_indexes(times, group):
825820
times : xr.DataArray
826821
Time dimension in the dataset of interest.
827822
group : str or Grouper
828-
Grouping information, see base.Grouper
823+
Grouping information, see base.Grouper.
829824
830825
Returns
831826
-------

src/xsdba/units.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
import numpy as np
2525
import xarray as xr
2626

27-
from .calendar import parse_offset
27+
from .calendar import get_calendar, parse_offset
2828
from .typing import Quantified
2929
from .utils import copy_all_attrs
3030

3131
units = pint.get_application_registry()
32-
32+
# Another alias not included by cf_xarray
33+
units.define("@alias percent = pct")
3334

3435
FREQ_UNITS = {
3536
"D": "d",
@@ -66,7 +67,7 @@ def infer_sampling_units(
6667
Returns
6768
-------
6869
int
69-
The magnitude (number of base periods per period)
70+
The magnitude (number of base periods per period).
7071
str
7172
Units as a string, understandable by pint.
7273
"""
@@ -169,7 +170,7 @@ def pint2str(value: units.Quantity | units.Unit) -> str:
169170
Returns
170171
-------
171172
str
172-
Units
173+
Units.
173174
174175
Notes
175176
-----
@@ -213,7 +214,7 @@ def ensure_delta(unit: str) -> str:
213214
Parameters
214215
----------
215216
unit : str
216-
unit to transform in delta (or not)
217+
unit to transform in delta (or not).
217218
"""
218219
u = units2pint(unit)
219220
d = 1 * u
@@ -246,7 +247,7 @@ def extract_units(arg):
246247
return ustr if ustr is None else pint.Quantity(1, ustr).units
247248

248249

249-
def check_units(args_to_check):
250+
def compare_units(args_to_check):
250251
"""Decorator to check that all arguments have the same units (or no units)."""
251252

252253
# if no units are present (DataArray without units attribute or float), then no check is performed
@@ -312,14 +313,6 @@ def convert_units_to( # noqa: C901
312313
The outputted type is always similar to `source` initial type.
313314
Attributes are preserved unless an automatic CF conversion is performed,
314315
in which case only the new `standard_name` appears in the result.
315-
316-
See Also
317-
--------
318-
cf_conversion
319-
amount2rate
320-
rate2amount
321-
amount2lwethickness
322-
lwethickness2amount
323316
"""
324317
# Target units
325318
target_unit = extract_units(target)
@@ -346,6 +339,7 @@ def _add_default_kws(params_dict, params_to_check, func):
346339
return params_dict
347340

348341

342+
# TODO: this changes the type of some variables (e.g. thresh : str -> float). This should probably not be allowed
349343
def harmonize_units(params_to_check):
350344
"""Check that units are compatible with dimensions, otherwise raise a `ValidationError`."""
351345

@@ -460,7 +454,7 @@ def to_agg_units(
460454
461455
>>> degdays = convert_units_to(degdays, "K days")
462456
>>> degdays.units
463-
'K d'
457+
'K d'.
464458
"""
465459
if op in ["amin", "min", "amax", "max", "mean", "sum"]:
466460
out.attrs["units"] = orig.attrs["units"]

0 commit comments

Comments
 (0)