diff --git a/xwavelet/wavelet.py b/xwavelet/wavelet.py index 4b9b819..b99007f 100644 --- a/xwavelet/wavelet.py +++ b/xwavelet/wavelet.py @@ -57,7 +57,7 @@ def _delta(da, dim, spacing_tol): return delta_x -def _morlet(xo, ntheta, a, s, y, x, dim): +def _morlet(xo, ntheta, a, s, y, x, **kwargs): r""" Define @@ -72,11 +72,17 @@ def _morlet(xo, ntheta, a, s, y, x, dim): Units of :math:`a` are :math:`L^{-2}`. :math:`k_0` is defaulted to :math:`1/x_0` in the zonal direction. """ + ko = 1.0 / xo # compute morlet wavelet th = np.arange(int(ntheta / 2)) * 2.0 * np.pi / ntheta - th = xr.DataArray(th, dims=["angle"], coords={"angle": th}).chunk({"angle": 1}) + th = xr.DataArray(th, dims=["angle"], coords={"angle": th}) + + if "angle" in kwargs: + for k, v in kwargs.items(): + chunk = {k: v} + th = th.chunk(chunk) # rotated positions yp = np.sin(th) * s**-1 * (y - y.mean()) @@ -89,7 +95,17 @@ def _morlet(xo, ntheta, a, s, y, x, dim): return m, th -def dwvlt(da, s, spacing_tol=1e-3, dim=None, xo=50e3, a=1.0, ntheta=16, wtype="morlet"): +def dwvlt( + da, + s, + spacing_tol=1e-3, + dim=None, + xo=50e3, + a=1.0, + ntheta=16, + wtype="morlet", + **kwargs +): r""" Compute discrete wavelet transform of da. Default is the Morlet wavelet. Scale :math:`s` is dimensionless. @@ -150,7 +166,7 @@ def dwvlt(da, s, spacing_tol=1e-3, dim=None, xo=50e3, a=1.0, ntheta=16, wtype="m ) if wtype == "morlet": - wavelet, phi = _morlet(xo, ntheta, a, s, y, x, dim) + wavelet, phi = _morlet(xo, ntheta, a, s, y, x, **kwargs) else: raise NotImplementedError("Only the Morlet wavelet is implemented for now.") @@ -171,6 +187,7 @@ def wvlt_power_spectrum( ntheta=16, wtype="morlet", normalize=True, + **kwargs ): r""" Compute discrete wavelet power spectrum of :math:`da`. @@ -212,7 +229,15 @@ def wvlt_power_spectrum( dim = [dim] dawt = dwvlt( - da, s, spacing_tol=spacing_tol, dim=dim, xo=xo, a=a, ntheta=ntheta, wtype=wtype + da, + s, + spacing_tol=spacing_tol, + dim=dim, + xo=xo, + a=a, + ntheta=ntheta, + wtype=wtype, + **kwargs ) if normalize: @@ -262,6 +287,7 @@ def wvlt_cross_spectrum( ntheta=16, wtype="morlet", normalize=True, + **kwargs ): r""" Compute discrete wavelet cross spectrum of :math:`da` and :math:`da1`. @@ -305,10 +331,26 @@ def wvlt_cross_spectrum( dim = [dim] dawt = dwvlt( - da, s, spacing_tol=spacing_tol, dim=dim, xo=xo, a=a, ntheta=ntheta, wtype=wtype + da, + s, + spacing_tol=spacing_tol, + dim=dim, + xo=xo, + a=a, + ntheta=ntheta, + wtype=wtype, + **kwargs ) dawt1 = dwvlt( - da1, s, spacing_tol=spacing_tol, dim=dim, xo=xo, a=a, ntheta=ntheta, wtype=wtype + da1, + s, + spacing_tol=spacing_tol, + dim=dim, + xo=xo, + a=a, + ntheta=ntheta, + wtype=wtype, + **kwargs ) if normalize: