Skip to content

Commit

Permalink
increase codecov by catching edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
CosmoMatt committed Apr 9, 2024
1 parent d7f3633 commit 6394602
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 10 deletions.
22 changes: 12 additions & 10 deletions s2wav/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def filters_directional(
if kappa0[el] != 0:
phi[el] = np.sqrt((2 * el + 1) / (4.0 * np.pi)) * kappa0[el]
if spin0 != 0:
phi[el] *= spin_normalization(el, spin0) * (-1) ** spin0
phi[el] *= _spin_normalization(el, spin0) * (-1) ** spin0

Check warning on line 135 in s2wav/filters.py

View check run for this annotation

Codecov / codecov/patch

s2wav/filters.py#L135

Added line #L135 was not covered by tests

for j in range(J_min, J + 1):
for el in range(el_min, L):
Expand All @@ -146,7 +146,7 @@ def filters_directional(
)
if spin0 != 0:
psi[j, el, L - 1 + m] *= (

Check warning on line 148 in s2wav/filters.py

View check run for this annotation

Codecov / codecov/patch

s2wav/filters.py#L148

Added line #L148 was not covered by tests
spin_normalization(el, spin0) * (-1) ** spin0
_spin_normalization(el, spin0) * (-1) ** spin0
)
if using_torch:
psi = torch.from_numpy(psi)
Expand Down Expand Up @@ -227,7 +227,7 @@ def filters_directional_vectorised(
el_min = max(abs(spin), abs(spin0))

spin_norms = (
(-1) ** spin0 * spin_normalization_vectorised(np.arange(L), spin0)
(-1) ** spin0 * _spin_normalization_vectorised(np.arange(L), spin0)
if spin0 != 0
else 1
)
Expand Down Expand Up @@ -323,7 +323,9 @@ def filters_directional_jax(
el_min = max(abs(spin), abs(spin0))

spin_norms = (
(-1) ** spin0 * spin_normalization_jax(np.arange(L), spin0) if spin0 != 0 else 1
(-1) ** spin0 * _spin_normalization_jax(np.arange(L), spin0)
if spin0 != 0
else 1
)

kappa, kappa0 = filters_axisym_jax(L, J_min, lam)
Expand Down Expand Up @@ -476,7 +478,7 @@ def k_lam(L: int, lam: float = 2.0, quad_iters: int = 300) -> float:


@partial(jit, static_argnums=(2, 3)) # not sure
def part_scaling_fn_jax(a: float, b: float, n: int, lam: float = 2.0) -> float:
def _part_scaling_fn_jax(a: float, b: float, n: int, lam: float = 2.0) -> float:
r"""JAX version of part_scaling_fn. Computes integral used to calculate smoothly
decreasing function :math:`k_{\lambda}`.
Expand Down Expand Up @@ -627,7 +629,7 @@ def tiling_direction(L: int, N: int = 1) -> np.ndarray:
return s_elm


def spin_normalization(el: int, spin: int = 0) -> float:
def _spin_normalization(el: int, spin: int = 0) -> float:
r"""Computes the normalization factor for spin-lowered wavelets, which is
:math:`\sqrt{\frac{(\ell+s)!}{(\ell-s)!}}`.
Expand All @@ -650,8 +652,8 @@ def spin_normalization(el: int, spin: int = 0) -> float:
return np.sqrt(1.0 / factor)

Check warning on line 652 in s2wav/filters.py

View check run for this annotation

Codecov / codecov/patch

s2wav/filters.py#L652

Added line #L652 was not covered by tests


def spin_normalization_vectorised(el: np.ndarray, spin: int = 0) -> float:
r"""Vectorised version of :func:`~spin_normalization`.
def _spin_normalization_vectorised(el: np.ndarray, spin: int = 0) -> float:
r"""Vectorised version of :func:`~_spin_normalization`.
Args:
el (int): Harmonic index :math:`\ell`.
spin (int): Spin of field over which to perform the transform. Defaults to 0.
Expand Down Expand Up @@ -714,8 +716,8 @@ def tiling_direction_jax(L: int, N: int = 1) -> np.ndarray:


@partial(jit, static_argnums=(1))
def spin_normalization_jax(el: np.ndarray, spin: int = 0) -> float:
r"""JAX version of :func:`~spin_normalization`.
def _spin_normalization_jax(el: np.ndarray, spin: int = 0) -> float:
r"""JAX version of :func:`~_spin_normalization`.
Args:
el (int): Harmonic index :math:`\ell`.
spin (int): Spin of field over which to perform the transform. Defaults to 0.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import jax

jax.config.update("jax_enable_x64", True)
import pytest
import numpy as np
from s2wav import filters, samples
Expand Down Expand Up @@ -94,6 +97,18 @@ def test_directional_vectorised(L: int, N: int, J_min: int, lam: int):
np.testing.assert_allclose(f[i], f_vect[i], rtol=1e-14)


@pytest.mark.parametrize("L", L_to_test)
@pytest.mark.parametrize("N", N_to_test)
@pytest.mark.parametrize("J_min", J_min_to_test)
@pytest.mark.parametrize("lam", lam_to_test)
def test_directional_torch(L: int, N: int, J_min: int, lam: int):
f = filters.filters_directional(L, N, J_min, lam)
f_vect = filters.filters_directional_vectorised(L, N, J_min, lam, using_torch=True)

for i in range(2):
np.testing.assert_allclose(f[i], f_vect[i], rtol=1e-14)


@pytest.mark.parametrize("L", L_to_test)
@pytest.mark.parametrize("J_min", J_min_to_test)
@pytest.mark.parametrize("lam", lam_to_test)
Expand All @@ -115,3 +130,15 @@ def test_directional_jax(L: int, N: int, J_min: int, lam: int):

for i in range(2):
np.testing.assert_allclose(f[i], f_jax[i], rtol=1e-13, atol=1e-13)


def test_filter_exceptions():
L = 8
with pytest.raises(ValueError) as e:
filters.filters_axisym(L, 10)

with pytest.raises(ValueError) as e:
filters.filters_axisym_vectorised(L, 10)

with pytest.raises(ValueError) as e:
filters.filters_axisym_jax(L, 10)
3 changes: 3 additions & 0 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import jax

jax.config.update("jax_enable_x64", True)
import pytest
import jax.numpy as jnp
from jax.test_util import check_grads
Expand Down
95 changes: 95 additions & 0 deletions tests/test_wavelets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import jax

jax.config.update("jax_enable_x64", True)
import pytest
import numpy as np
import torch
Expand Down Expand Up @@ -280,3 +283,95 @@ def test_round_trip(
)

np.testing.assert_allclose(f, f_check, atol=1e-14)


@pytest.mark.parametrize("L", L_to_test)
@pytest.mark.parametrize("N", N_to_test)
@pytest.mark.parametrize("J_min", J_min_to_test)
@pytest.mark.parametrize("lam", lam_to_test)
@pytest.mark.parametrize("reality", reality)
@pytest.mark.parametrize("recursive", recursive_transform)
@pytest.mark.parametrize("using_torch", using_torch_frontend)
@pytest.mark.parametrize("using_c_backend", using_c_backend)
@pytest.mark.parametrize("_ssht_backend", _ssht_backends)
def test_flm_to_analysis(
flm_generator,
f_wav_converter,
L: int,
N: int,
J_min: int,
lam: int,
reality: bool,
recursive: bool,
using_torch: bool,
using_c_backend: bool,
_ssht_backend: int,
):
J = samples.j_max(L, lam)

# Exceptions
if J_min >= J:
pytest.skip("J_min larger than J which isn't a valid test case.")
if recursive and using_torch:
pytest.skip("Recursive transform not yet available for torch frontend")
if not recursive and using_c_backend:
pytest.skip("Precompute transform not supported from C backend libraries.")

flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality)
f = sht_base.spherical.inverse(flm, L, reality=reality)

f_wav, _ = s2let.analysis_px2wav(
f.flatten("C").astype(np.complex128), lam, L, J_min, N, spin=0, upsample=False
)
filter = filters.filters_directional_vectorised(
L, N, J_min, lam, using_torch=using_torch
)[0]

generator = (
None
if using_c_backend
else (
construct.generate_wigner_precomputes
if recursive
else construct.generate_full_precomputes
)
)
analysis = (
wavelet.flm_to_analysis
if recursive
else (
wavelet_precompute_torch.flm_to_analysis
if using_torch
else wavelet_precompute.flm_to_analysis
)
)
precomps = (
None
if using_c_backend
else generator(
L, N, J_min, lam, forward=False, reality=reality, using_torch=using_torch
)
)

args = (
{"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend}
if using_c_backend
else {}
)

f_wav_check = analysis(
torch.from_numpy(flm) if using_torch else flm,
L,
N,
J_min,
None,
lam,
reality=reality,
filters=filter,
precomps=precomps,
**args,
)

f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, using_torch)

np.testing.assert_allclose(f_wav, f_wav_check, atol=1e-14)
3 changes: 3 additions & 0 deletions tests/test_wavelets_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import jax

jax.config.update("jax_enable_x64", True)
import pytest
import numpy as np
import pys2let as s2let
Expand Down

0 comments on commit 6394602

Please sign in to comment.