Skip to content

Commit

Permalink
Merge pull request #27 from roxyboy/tolerance
Browse files Browse the repository at this point in the history
Remove spacing_tol input and added more test
  • Loading branch information
roxyboy authored Aug 31, 2023
2 parents 1cacbb0 + 18ea50b commit 04b5ab2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 112 deletions.
54 changes: 47 additions & 7 deletions xwavelet/tests/test_wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@

@pytest.fixture
def sample_da_1d():
t = np.linspace(0, 10, 11)
return xr.DataArray(t, dims=["t"], coords={"t": t})
time = np.arange(0, 360 * 15 * 86400, 5 * 86400)
freq0 = (180 * 86400) ** -1
da = xr.DataArray(
np.sin(2 * np.pi * freq0 * time), dims=["time"], coords={"time": time}
)
return da


@pytest.fixture
Expand All @@ -50,7 +54,7 @@ def test_dimensions(sample_da_3d, sample_da_2d, sample_da_1d, x0=1.0):
coords={"scale": np.linspace(0.1, 1.0, 20)},
)
with pytest.raises(ValueError):
cwvlt(sample_da_2d, s, t0=x0)
cwvlt(sample_da_2d, s, t0=(180 * 86400))
with pytest.raises(ValueError):
cwvlt2(sample_da_1d, s, x0=x0)
with pytest.raises(NotImplementedError):
Expand All @@ -66,14 +70,50 @@ def test_convergence(sample_da_2d, sample_da_1d, x0=1.0):
coords={"scale": np.linspace(0.1, 1.0, 20)},
)

npt.assert_allclose(
npt.assert_almost_equal(
wvlt_power_spectrum(sample_da_2d, s, x0=x0).values,
wvlt_cross_spectrum(sample_da_2d, sample_da_2d, s, x0=x0).values,
)

npt.assert_allclose(
wvlt_power_spectrum(sample_da_1d, s, x0=x0).values,
wvlt_cross_spectrum(sample_da_1d, sample_da_1d, s, x0=x0).values,
npt.assert_almost_equal(
wvlt_power_spectrum(sample_da_1d, s, x0=(180 * 86400)).values,
wvlt_cross_spectrum(sample_da_1d, sample_da_1d, s, x0=(180 * 86400)).values,
)


def test_wtype(sample_da_2d, sample_da_1d, x0=1.0):
s = xr.DataArray(
np.linspace(0.1, 1.0, 20),
dims=["scale"],
coords={"scale": np.linspace(0.1, 1.0, 20)},
)
with pytest.raises(NotImplementedError):
cwvlt2(sample_da_2d, s, x0=x0, wtype=None)
cwvlt2(sample_da_2d, s, x0=x0, wtype="boxcar")
with pytest.raises(NotImplementedError):
cwvlt(sample_da_1d, s, t0=(180 * 86400) ** -1, wtype=None)
cwvlt(sample_da_1d, s, t0=(180 * 86400) ** -1, wtype="boxcar")


def test_frequency(sample_da_1d, t0=(180 * 86400)):
fda = xrft.power_spectrum(
sample_da_1d,
)
s = (
xr.DataArray(
fda.freq_time[len(sample_da_1d.time) // 2 + 1 :].data ** -1,
dims=["scale"],
coords={
"scale": fda.freq_time[len(sample_da_1d.time) // 2 + 1 :].data ** -1
},
)
/ t0
)
wda = wvlt_power_spectrum(sample_da_1d, s, x0=t0)

npt.assert_equal(
np.sort(wda.values.argsort()[-3:]),
np.sort(fda.values[len(sample_da_1d.time) // 2 + 1 :].argsort()[-3:]),
)


Expand Down
120 changes: 15 additions & 105 deletions xwavelet/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,14 @@ def _diff_coord(coord):
return np.diff(coord)


def _delta(da, dim, spacing_tol):
def _delta(da, dim):
"""Returns the grid spacing"""

delta_x = []
for d in dim:
diff = _diff_coord(da[d])
delta = np.abs(diff[0])
if not np.allclose(diff, diff[0], rtol=spacing_tol):
raise ValueError(
"Can't take wavelet transform because "
"coodinate %s is not evenly spaced" % d
)

if delta == 0.0:
raise ValueError(
"Can't take wavelet transform because spacing in coordinate %s is zero"
Expand Down Expand Up @@ -128,17 +124,7 @@ def _morlet2(x0, ntheta, a, s, y, x, **kwargs):
_xo_warning = "Input argument `xo` will be deprecated in the future versions of xwavelet and be replaced by `x0`"


def dwvlt(
da,
s,
spacing_tol=1e-3,
dim=None,
xo=50e3,
a=1.0,
ntheta=16,
wtype="morlet",
**kwargs
):
def dwvlt(da, s, dim=None, xo=50e3, a=1.0, ntheta=16, wtype="morlet", **kwargs):
"""
Deprecated function. See cwvlt2 doc.
"""
Expand All @@ -148,23 +134,12 @@ def dwvlt(
)
warnings.warn(msg, FutureWarning)

return cwvlt2(
da,
s,
spacing_tol=spacing_tol,
dim=dim,
x0=xo,
a=a,
ntheta=ntheta,
wtype=wtype,
**kwargs
)
return cwvlt2(da, s, dim=dim, x0=xo, a=a, ntheta=ntheta, wtype=wtype, **kwargs)


def cwvlt(
da,
s,
spacing_tol=1e-3,
dim=None,
t0=5 * 365 * 86400,
a=1.0,
Expand All @@ -180,9 +155,6 @@ def cwvlt(
The data to be transformed.
s : `xarray.DataArray`
One-dimensional array with scaling parameter.
spacing_tol : float, optional
Spacing tolerance. Fourier transform should not be applied to uneven grid but
this restriction can be relaxed with this setting. Use caution.
dim : str or sequence of str, optional
The dimensions along which to take the transformation. If `None`, all
dimensions will be transformed. If the inputs are dask arrays, the
Expand Down Expand Up @@ -218,8 +190,7 @@ def cwvlt(

N = [da.shape[n] for n in axis_num]

# verify even spacing of input coordinates
delta_t = _delta(da, dim, spacing_tol)
delta_t = _delta(da, dim)

# grid parameters
if len(dim) == 1:
Expand All @@ -241,17 +212,7 @@ def cwvlt(
return dawt


def cwvlt2(
da,
s,
spacing_tol=1e-3,
dim=None,
x0=50e3,
a=1.0,
ntheta=16,
wtype="morlet",
**kwargs
):
def cwvlt2(da, s, dim=None, x0=50e3, a=1.0, ntheta=16, wtype="morlet", **kwargs):
r"""
Compute continuous two-dimensional wavelet transform of da. Default is the Morlet wavelet.
Scale :math:`s` is dimensionless.
Expand All @@ -262,9 +223,6 @@ def cwvlt2(
The data to be transformed.
s : `xarray.DataArray`
One-dimensional array with scaling parameter.
spacing_tol : float, optional
Spacing tolerance. Fourier transform should not be applied to uneven grid but
this restriction can be relaxed with this setting. Use caution.
dim : str or sequence of str, optional
The dimensions along which to take the transformation. If `None`, all
dimensions will be transformed. If the inputs are dask arrays, the
Expand Down Expand Up @@ -300,8 +258,7 @@ def cwvlt2(

N = [da.shape[n] for n in axis_num]

# verify even spacing of input coordinates
delta_x = _delta(da, dim, spacing_tol)
delta_x = _delta(da, dim)

# grid parameters
if len(dim) == 2:
Expand All @@ -325,16 +282,7 @@ def cwvlt2(


def wvlt_power_spectrum(
da,
s,
spacing_tol=1e-3,
dim=None,
x0=50e3,
a=1.0,
ntheta=16,
wtype="morlet",
normalize=True,
**kwargs
da, s, dim=None, x0=50e3, a=1.0, ntheta=16, wtype="morlet", normalize=True, **kwargs
):
r"""
Compute discrete wavelet power spectrum of :math:`da`.
Expand All @@ -346,10 +294,7 @@ def wvlt_power_spectrum(
The data to have the spectral estimate.
s : `xarray.DataArray`
Non-dimensional scaling parameter. The dimensionalized length scales are
:math:`xo\times s`.
spacing_tol : float, optional
Spacing tolerance. Fourier transform should not be applied to uneven grid but
this restriction can be relaxed with this setting. Use caution.
:math:`x0\times s`.
dim : str or sequence of str, optional
The dimensions along which to take the transformation. If `None`, all
dimensions will be transformed. If the inputs are dask arrays, the
Expand Down Expand Up @@ -383,24 +328,13 @@ def wvlt_power_spectrum(
dawt = cwvlt(
da,
s,
spacing_tol=spacing_tol,
dim=dim,
t0=x0,
a=a,
wtype=wtype,
)
elif len(dim) == 2:
dawt = cwvlt2(
da,
s,
spacing_tol=spacing_tol,
dim=dim,
x0=x0,
a=a,
ntheta=ntheta,
wtype=wtype,
**kwargs
)
dawt = cwvlt2(da, s, dim=dim, x0=x0, a=a, ntheta=ntheta, wtype=wtype, **kwargs)
else:
raise NotImplementedError(
"Transformation for three dimensions and higher is not implemented."
Expand All @@ -409,7 +343,7 @@ def wvlt_power_spectrum(
if normalize:
axis_num = [da.get_axis_num(d) for d in dim]
N = [da.shape[n] for n in axis_num]
delta_x = _delta(da, dim, spacing_tol)
delta_x = _delta(da, dim)

Fdims = []
chunks = dict()
Expand Down Expand Up @@ -460,7 +394,6 @@ def wvlt_cross_spectrum(
da,
da1,
s,
spacing_tol=1e-3,
dim=None,
x0=50e3,
a=1.0,
Expand All @@ -481,10 +414,7 @@ def wvlt_cross_spectrum(
The data to have the cross spectral estimate.
s : `xarray.DataArray`
Non-dimensional scaling parameter. The dimensionalized length scales are
:math:`xo\times s`.
spacing_tol : float, optional
Spacing tolerance. Fourier transform should not be applied to uneven grid but
this restriction can be relaxed with this setting. Use caution.
:math:`x0\times s`.
dim : str or sequence of str, optional
The dimensions along which to take the transformation. If `None`, all
dimensions will be transformed. If the inputs are dask arrays, the
Expand Down Expand Up @@ -518,7 +448,6 @@ def wvlt_cross_spectrum(
dawt = cwvlt(
da,
s,
spacing_tol=spacing_tol,
dim=dim,
t0=x0,
a=a,
Expand All @@ -527,34 +456,15 @@ def wvlt_cross_spectrum(
dawt1 = cwvlt(
da1,
s,
spacing_tol=spacing_tol,
dim=dim,
t0=x0,
a=a,
wtype=wtype,
)
elif len(dim) == 2:
dawt = cwvlt2(
da,
s,
spacing_tol=spacing_tol,
dim=dim,
x0=x0,
a=a,
ntheta=ntheta,
wtype=wtype,
**kwargs
)
dawt = cwvlt2(da, s, dim=dim, x0=x0, a=a, ntheta=ntheta, wtype=wtype, **kwargs)
dawt1 = cwvlt2(
da1,
s,
spacing_tol=spacing_tol,
dim=dim,
x0=x0,
a=a,
ntheta=ntheta,
wtype=wtype,
**kwargs
da1, s, dim=dim, x0=x0, a=a, ntheta=ntheta, wtype=wtype, **kwargs
)
else:
raise NotImplementedError(
Expand All @@ -564,7 +474,7 @@ def wvlt_cross_spectrum(
if normalize:
axis_num = [da.get_axis_num(d) for d in dim]
N = [da.shape[n] for n in axis_num]
delta_x = _delta(da, dim, spacing_tol)
delta_x = _delta(da, dim)

Fdims = []
chunks = dict()
Expand Down

0 comments on commit 04b5ab2

Please sign in to comment.