Skip to content

Commit

Permalink
Adding 1D capability
Browse files Browse the repository at this point in the history
  • Loading branch information
Takaya Uchida committed Aug 25, 2023
1 parent aea2b61 commit c2546f7
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 66 deletions.
52 changes: 46 additions & 6 deletions xwavelet/tests/test_wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,33 @@
import xarray.testing as xrt

import xrft
from xwavelet.wavelet import dwvlt, wvlt_power_spectrum
from xwavelet.wavelet import dwvlt, cwvlt, cwvlt2, wvlt_power_spectrum


@pytest.fixture
def sample_da_1d():
t = np.linspace(0, 10, 11)
return xr.DataArray(t, dims=["t"], coords={"t": t})


@pytest.fixture
def sample_da_2d():
x = np.linspace(0, 10, 11)
y = np.linspace(-4, 4, 17)
z = np.arange(11 * 17).reshape(17, 11)
return xr.DataArray(z, dims=["y", "x"], coords={"y": y, "x": x})


def test_dimensions(sample_da_2d, sample_da_1d):
s = xr.DataArray(
np.linspace(0.1, 1.0, 20),
dims=["scale"],
coords={"scale": np.linspace(0.1, 1.0, 20)},
)
with pytest.raises(ValueError):
cwvlt(sample_da_2d, s)
with pytest.raises(ValueError):
cwvlt2(sample_da_1d, s)


def synthetic_field(N, dL, amp, s):
Expand Down Expand Up @@ -112,7 +138,7 @@ def synthetic_field_xr(


@pytest.mark.parametrize("chunk", [False, True])
def test_isotropic_ps_slope(chunk, N=128, dL=1.0, amp=1e0, slope=-3.0, xo=50):
def test_isotropic_ps_slope(chunk, N=256, dL=1.0, amp=1e0, slope=-3.0, xo=50):
"""Test the spectral slope of isotropic power spectrum."""

theta = synthetic_field_xr(
Expand All @@ -133,17 +159,31 @@ def test_isotropic_ps_slope(chunk, N=128, dL=1.0, amp=1e0, slope=-3.0, xo=50):
coords={"scale": np.linspace(0.1, 1.0, 20)},
)

Wtheta = dwvlt(theta, s, dim=["y", "x"], xo=xo)
kwargs = {"angle": 2}

Wtheta = dwvlt(theta, s, dim=["y", "x"], xo=xo, **kwargs)
iso_ps = (np.abs(Wtheta) ** 2).mean(["d0", "angle"]) * (Wtheta.scale) ** -1
npt.assert_almost_equal(np.ma.masked_invalid(iso_ps).mask.sum(), 0.0)
y_fit, a, b = xrft.fit_loglog(
(iso_ps.scale.values[1:-2]) ** -1, iso_ps.values[1:-2]
(iso_ps.scale.values[2:-2]) ** -1, iso_ps.values[2:-2]
)
npt.assert_allclose(a, slope, atol=0.3)

iso_ps = wvlt_power_spectrum(theta, s, dim=["y", "x"], xo=xo).mean(["d0", "angle"])
iso_ps = wvlt_power_spectrum(theta, s, dim=["y", "x"], x0=xo, **kwargs).mean(
["d0", "angle"]
)
npt.assert_almost_equal(np.ma.masked_invalid(iso_ps).mask.sum(), 0.0)
y_fit, a, b = xrft.fit_loglog(
(iso_ps.scale.values[1:-2]) ** -1, iso_ps.values[1:-2]
(iso_ps.scale.values[2:-2]) ** -1, iso_ps.values[2:-2]
)
npt.assert_allclose(a, slope, atol=0.3)


#
# if chunk:
# iso_ps = wvlt_power_spectrum(theta, s, dim=["y", "x"], xo=xo, **kwargs).mean(["d0", "angle"])
# npt.assert_almost_equal(np.ma.masked_invalid(iso_ps).mask.sum(), 0.0)
# y_fit, a, b = xrft.fit_loglog(
# (iso_ps.scale.values[1:-2]) ** -1, iso_ps.values[1:-2]
# )
# npt.assert_allclose(a, slope, atol=0.3)
Loading

0 comments on commit c2546f7

Please sign in to comment.