diff --git a/xwavelet/tests/test_wavelet.py b/xwavelet/tests/test_wavelet.py index 39482d4..ace3192 100644 --- a/xwavelet/tests/test_wavelet.py +++ b/xwavelet/tests/test_wavelet.py @@ -11,7 +11,33 @@ import xarray.testing as xrt import xrft -from xwavelet.wavelet import dwvlt, wvlt_power_spectrum +from xwavelet.wavelet import dwvlt, cwvlt, cwvlt2, wvlt_power_spectrum + + +@pytest.fixture +def sample_da_1d(): + t = np.linspace(0, 10, 11) + return xr.DataArray(t, dims=["t"], coords={"t": t}) + + +@pytest.fixture +def sample_da_2d(): + x = np.linspace(0, 10, 11) + y = np.linspace(-4, 4, 17) + z = np.arange(11 * 17).reshape(17, 11) + return xr.DataArray(z, dims=["y", "x"], coords={"y": y, "x": x}) + + +def test_dimensions(sample_da_2d, sample_da_1d): + s = xr.DataArray( + np.linspace(0.1, 1.0, 20), + dims=["scale"], + coords={"scale": np.linspace(0.1, 1.0, 20)}, + ) + with pytest.raises(ValueError): + cwvlt(sample_da_2d, s) + with pytest.raises(ValueError): + cwvlt2(sample_da_1d, s) def synthetic_field(N, dL, amp, s): @@ -112,7 +138,7 @@ def synthetic_field_xr( @pytest.mark.parametrize("chunk", [False, True]) -def test_isotropic_ps_slope(chunk, N=128, dL=1.0, amp=1e0, slope=-3.0, xo=50): +def test_isotropic_ps_slope(chunk, N=256, dL=1.0, amp=1e0, slope=-3.0, xo=50): """Test the spectral slope of isotropic power spectrum.""" theta = synthetic_field_xr( @@ -133,17 +159,31 @@ def test_isotropic_ps_slope(chunk, N=128, dL=1.0, amp=1e0, slope=-3.0, xo=50): coords={"scale": np.linspace(0.1, 1.0, 20)}, ) - Wtheta = dwvlt(theta, s, dim=["y", "x"], xo=xo) + kwargs = {"angle": 2} + + Wtheta = dwvlt(theta, s, dim=["y", "x"], xo=xo, **kwargs) iso_ps = (np.abs(Wtheta) ** 2).mean(["d0", "angle"]) * (Wtheta.scale) ** -1 npt.assert_almost_equal(np.ma.masked_invalid(iso_ps).mask.sum(), 0.0) y_fit, a, b = xrft.fit_loglog( - (iso_ps.scale.values[1:-2]) ** -1, iso_ps.values[1:-2] + (iso_ps.scale.values[2:-2]) ** -1, iso_ps.values[2:-2] ) npt.assert_allclose(a, slope, atol=0.3) - iso_ps = wvlt_power_spectrum(theta, s, dim=["y", "x"], xo=xo).mean(["d0", "angle"]) + iso_ps = wvlt_power_spectrum(theta, s, dim=["y", "x"], x0=xo, **kwargs).mean( + ["d0", "angle"] + ) npt.assert_almost_equal(np.ma.masked_invalid(iso_ps).mask.sum(), 0.0) y_fit, a, b = xrft.fit_loglog( - (iso_ps.scale.values[1:-2]) ** -1, iso_ps.values[1:-2] + (iso_ps.scale.values[2:-2]) ** -1, iso_ps.values[2:-2] ) npt.assert_allclose(a, slope, atol=0.3) + + +# +# if chunk: +# iso_ps = wvlt_power_spectrum(theta, s, dim=["y", "x"], xo=xo, **kwargs).mean(["d0", "angle"]) +# npt.assert_almost_equal(np.ma.masked_invalid(iso_ps).mask.sum(), 0.0) +# y_fit, a, b = xrft.fit_loglog( +# (iso_ps.scale.values[1:-2]) ** -1, iso_ps.values[1:-2] +# ) +# npt.assert_allclose(a, slope, atol=0.3) diff --git a/xwavelet/wavelet.py b/xwavelet/wavelet.py index 62c6503..77470ab 100644 --- a/xwavelet/wavelet.py +++ b/xwavelet/wavelet.py @@ -12,6 +12,8 @@ __all__ = [ "dwvlt", + "cwvlt", + "cwvlt2", "wvlt_power_spectrum", "wvlt_cross_spectrum", ] @@ -57,7 +59,35 @@ def _delta(da, dim, spacing_tol): return delta_x -def _morlet(xo, ntheta, a, s, y, x, **kwargs): +def _morlet(t0, a, s, t): + r""" + Define + + .. math:: + \psi = a e^{-2\pi i f_0 t} e^{-\frac{t^2}{2 t_0^2}} + + as the morlet wavelet. Its transform is + + .. math:: + \psi_h = a 2\pi t_0^2 e^{-2 \pi^2 (f-f_0)^2 t_0^2} + + Units of :math:`a` are :math:`T^{-1}`. + :math:`f_0` is defaulted to :math:`1/t_0`. + """ + + f0 = 1.0 / t0 + + # rotated positions + tp = s**-1 * (t - t.mean()) + + arg1 = 2j * np.pi * f0 * tp + arg2 = -((t - t.mean()) ** 2) / 2 / s**2 / t0**2 + m = a * np.exp(arg1) * np.exp(arg2) + + return m, th + + +def _morlet2(x0, ntheta, a, s, y, x, **kwargs): r""" Define @@ -73,7 +103,7 @@ def _morlet(xo, ntheta, a, s, y, x, **kwargs): :math:`k_0` is defaulted to :math:`1/x_0` in the zonal direction. """ - ko = 1.0 / xo + k0 = 1.0 / x0 # compute morlet wavelet th = np.arange(int(ntheta / 2)) * 2.0 * np.pi / ntheta @@ -88,13 +118,16 @@ def _morlet(xo, ntheta, a, s, y, x, **kwargs): yp = np.sin(th) * s**-1 * (y - y.mean()) xp = np.cos(th) * s**-1 * (x - x.mean()) - arg1 = 2j * np.pi * ko * (yp - xp) - arg2 = -((x - x.mean()) ** 2 + (y - y.mean()) ** 2) / 2 / s**2 / xo**2 + arg1 = 2j * np.pi * k0 * (yp - xp) + arg2 = -((x - x.mean()) ** 2 + (y - y.mean()) ** 2) / 2 / s**2 / x0**2 m = a * np.exp(arg1) * np.exp(arg2) return m, th +_xo_warning = "Input argument `xo` will be deprecated in the future versions of xwavelet and be replaced by `x0`" + + def dwvlt( da, s, @@ -105,9 +138,41 @@ def dwvlt( ntheta=16, wtype="morlet", **kwargs +): + """ + Deprecated function. See cwvlt2 doc. + """ + msg = ( + "This function has been renamed and will disappear in the future." + + " Please use `cwvlt2` instead." + ) + warnings.warn(msg, FutureWarning) + + return cwvlt2( + da, + s, + spacing_tol=spacing_tol, + dim=dim, + x0=xo, + a=a, + ntheta=ntheta, + wtype=wtype, + **kwargs + ) + + +def cwvlt( + da, + s, + spacing_tol=1e-3, + dim=None, + t0=5 * 365 * 86400, + a=1.0, + wtype="morlet", + **kwargs ): r""" - Compute discrete wavelet transform of da. Default is the Morlet wavelet. + Compute continuous one-dimensional wavelet transform of da. Default is the Morlet wavelet. Scale :math:`s` is dimensionless. Parameters @@ -123,12 +188,10 @@ def dwvlt( The dimensions along which to take the transformation. If `None`, all dimensions will be transformed. If the inputs are dask arrays, the arrays must not be chunked along these dimensions. - xo : float + t0 : float Length scale. a : float Amplitude of wavelet. - ntheta : int - Number of azimuthal angles the wavelet transform is taken over. wtype : str Type of wavelet. @@ -146,6 +209,91 @@ def dwvlt( if isinstance(dim, str): dim = [dim] + if len(dim) != 1: + raise ValueError("The transformed dimension should be one-dimensional.") + + sdim = s.dims[0] + + # the axes along which to take wavelets + axis_num = [da.get_axis_num(d) for d in dim] + + N = [da.shape[n] for n in axis_num] + + # verify even spacing of input coordinates + delta_t = _delta(da, dim, spacing_tol) + + # grid parameters + if len(dim) == 1: + t = da[da.dims[axis_num[0]]] - da[da.dims[axis_num[0]]].mean() + else: + raise NotImplementedError( + "Only one-dimensional transforms are implemented for now." + ) + + if wtype == "morlet": + wavelet = _morlet(t0, a, s, t) + else: + raise NotImplementedError("Only the Morlet wavelet is implemented for now.") + + dawt = (da * np.conj(wavelet)).sum(dim, skipna=True) * delta_t / np.sqrt(np.abs(s)) + dawt = dawt.drop_vars(sdim) + dawt[sdim] = t0 * s + + return dawt + + +def cwvlt2( + da, + s, + spacing_tol=1e-3, + dim=None, + x0=50e3, + a=1.0, + ntheta=16, + wtype="morlet", + **kwargs +): + r""" + Compute continuous two-dimensional wavelet transform of da. Default is the Morlet wavelet. + Scale :math:`s` is dimensionless. + + Parameters + ---------- + da : `xarray.DataArray` + The data to be transformed. + s : `xarray.DataArray` + One-dimensional array with scaling parameter. + spacing_tol : float, optional + Spacing tolerance. Fourier transform should not be applied to uneven grid but + this restriction can be relaxed with this setting. Use caution. + dim : str or sequence of str, optional + The dimensions along which to take the transformation. If `None`, all + dimensions will be transformed. If the inputs are dask arrays, the + arrays must not be chunked along these dimensions. + x0 : float + Length scale. + a : float + Amplitude of wavelet. + ntheta : int + Number of azimuthal angles the wavelet transform is taken over. + wtype : str + Type of wavelet. + + Returns + ------- + dawt : `xarray.DataArray` + The output of the wavelet transformation, with appropriate dimensions. + """ + + if dim is None: + dim = list(da.dims) + else: + if isinstance(dim, str): + dim = [dim] + + if len(dim) != 2: + raise ValueError("The transformed dimension should be two-dimensional.") + sdim = s.dims[0] # the axes along which to take wavelets @@ -166,13 +314,13 @@ def dwvlt( ) if wtype == "morlet": - wavelet, phi = _morlet(xo, ntheta, a, s, y, x, **kwargs) + wavelet, phi = _morlet2(x0, ntheta, a, s, y, x, **kwargs) else: raise NotImplementedError("Only the Morlet wavelet is implemented for now.") dawt = (da * np.conj(wavelet)).sum(dim, skipna=True) * np.prod(delta_x) / s dawt = dawt.drop_vars(sdim) - dawt[sdim] = xo * s + dawt[sdim] = x0 * s return dawt @@ -182,7 +330,7 @@ def wvlt_power_spectrum( s, spacing_tol=1e-3, dim=None, - xo=50e3, + x0=50e3, a=1.0, ntheta=16, wtype="morlet", @@ -207,7 +355,7 @@ def wvlt_power_spectrum( The dimensions along which to take the transformation. If `None`, all dimensions will be transformed. If the inputs are dask arrays, the arrays must not be chunked along these dimensions. - xo : float + x0 : float Length scale of the mother wavelet. a : float Amplitude of wavelet. @@ -228,28 +376,47 @@ def wvlt_power_spectrum( if isinstance(dim, str): dim = [dim] - dawt = dwvlt( - da, - s, - spacing_tol=spacing_tol, - dim=dim, - xo=xo, - a=a, - ntheta=ntheta, - wtype=wtype, - **kwargs - ) + if "xo" in kwargs: + x0 = kwargs.get("xo") + warnings.warn(_xo_warning, FutureWarning) + + if len(dim) == 1: + dawt = cwvlt( + da, s, spacing_tol=spacing_tol, dim=dim, t0=x0, a=a, wtype=wtype, **kwargs + ) + elif len(dim) == 2: + dawt = cwvlt2( + da, + s, + spacing_tol=spacing_tol, + dim=dim, + x0=x0, + a=a, + ntheta=ntheta, + wtype=wtype, + **kwargs + ) + else: + raise NotImplementedError( + "Transformation for three dimensions and higher is not implemented." + ) if normalize: axis_num = [da.get_axis_num(d) for d in dim] N = [da.shape[n] for n in axis_num] delta_x = _delta(da, dim, spacing_tol) - y = da[da.dims[axis_num[-2]]] - N[-2] / 2.0 * delta_x[-2] - x = da[da.dims[axis_num[-1]]] - N[-1] / 2.0 * delta_x[-1] - if wtype == "morlet": - # mother wavelet - wavelet, phi = _morlet(xo, ntheta, a, 1.0, y, x, **kwargs) + if len(dim) == 1: + t = da[da.dims[axis_num[0]]] - N[0] / 2.0 * delta_x[0] + if wtype == "morlet": + # mother wavelet + wavelet = _morlet(x0, a, 1.0, t) + elif len(dim) == 2: + y = da[da.dims[axis_num[-2]]] - N[-2] / 2.0 * delta_x[-2] + x = da[da.dims[axis_num[-1]]] - N[-1] / 2.0 * delta_x[-1] + if wtype == "morlet": + # mother wavelet + wavelet, phi = _morlet2(x0, ntheta, a, 1.0, y, x, **kwargs) Fdims = [] chunks = dict() @@ -273,7 +440,7 @@ def wvlt_power_spectrum( else: C = 1.0 - return np.abs(dawt) ** 2 * (dawt[s.dims[0]]) ** -1 * xo**2 / C + return np.abs(dawt) ** 2 * (dawt[s.dims[0]]) ** -1 * x0**2 / C def wvlt_cross_spectrum( @@ -282,7 +449,7 @@ def wvlt_cross_spectrum( s, spacing_tol=1e-3, dim=None, - xo=50e3, + x0=50e3, a=1.0, ntheta=16, wtype="morlet", @@ -290,7 +457,7 @@ def wvlt_cross_spectrum( **kwargs ): r""" - Compute discrete wavelet cross spectrum of :math:`da` and :math:`da1`. + Compute continuous wavelet cross spectrum of :math:`da` and :math:`da1`. Scale :math:`s` is dimensionless. Parameters @@ -309,7 +476,7 @@ def wvlt_cross_spectrum( The dimensions along which to take the transformation. If `None`, all dimensions will be transformed. If the inputs are dask arrays, the arrays must not be chunked along these dimensions. - xo : float + x0 : float Length scale of the mother wavelet. a : float Amplitude of wavelet. @@ -330,39 +497,61 @@ def wvlt_cross_spectrum( if isinstance(dim, str): dim = [dim] - dawt = dwvlt( - da, - s, - spacing_tol=spacing_tol, - dim=dim, - xo=xo, - a=a, - ntheta=ntheta, - wtype=wtype, - **kwargs - ) - dawt1 = dwvlt( - da1, - s, - spacing_tol=spacing_tol, - dim=dim, - xo=xo, - a=a, - ntheta=ntheta, - wtype=wtype, - **kwargs - ) + if "xo" in kwargs: + x0 = kwargs.get("xo") + warnings.warn(_xo_warning, FutureWarning) + + if len(dim) == 1: + dawt = cwvlt( + da, s, spacing_tol=spacing_tol, dim=dim, x0=x0, a=a, wtype=wtype, **kwargs + ) + dawt1 = cwvlt( + da1, s, spacing_tol=spacing_tol, dim=dim, x0=x0, a=a, wtype=wtype, **kwargs + ) + elif len(dim) == 2: + dawt = cwvlt2( + da, + s, + spacing_tol=spacing_tol, + dim=dim, + x0=x0, + a=a, + ntheta=ntheta, + wtype=wtype, + **kwargs + ) + dawt1 = cwvlt2( + da1, + s, + spacing_tol=spacing_tol, + dim=dim, + x0=x0, + a=a, + ntheta=ntheta, + wtype=wtype, + **kwargs + ) + else: + raise NotImplementedError( + "Transformation for three dimensions and higher is not implemented." + ) if normalize: axis_num = [da.get_axis_num(d) for d in dim] N = [da.shape[n] for n in axis_num] delta_x = _delta(da, dim, spacing_tol) - y = da[da.dims[axis_num[-2]]] - N[-2] / 2.0 * delta_x[-2] - x = da[da.dims[axis_num[-1]]] - N[-1] / 2.0 * delta_x[-1] - if wtype == "morlet": - # mother wavelet - wavelet, phi = _morlet(xo, ntheta, a, 1.0, y, x, **kwargs) + if len(dim) == 1: + t = da[da.dims[axis_num[0]]] - N[0] / 2.0 * delta_x[0] + if wtype == "morlet": + # mother wavelet + wavelet = _morlet(x0, a, 1.0, t) + elif len(dim) == 2: + y = da[da.dims[axis_num[-2]]] - N[-2] / 2.0 * delta_x[-2] + x = da[da.dims[axis_num[-1]]] - N[-1] / 2.0 * delta_x[-1] + if wtype == "morlet": + # mother wavelet + wavelet, phi = _morlet2(x0, ntheta, a, 1.0, y, x, **kwargs) Fdims = [] chunks = dict() @@ -386,4 +575,4 @@ def wvlt_cross_spectrum( else: C = 1.0 - return (dawt * np.conj(dawt1)).real * (dawt[s.dims[0]]) ** -1 * xo**2 / C + return (dawt * np.conj(dawt1)).real * (dawt[s.dims[0]]) ** -1 * x0**2 / C