From c9d55be737cb9d558128e6386b43c3519f19e24a Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Mon, 15 Apr 2024 11:19:54 +0100 Subject: [PATCH 1/3] extract JITable functions from c-backend --- s2wav/__init__.py | 5 + s2wav/transforms/__init__.py | 1 + s2wav/transforms/wavelet.py | 171 ++------------- s2wav/transforms/wavelet_c.py | 290 +++++++++++++++++++++++++ s2wav/transforms/wavelet_precompute.py | 50 ++--- tests/test_gradients.py | 34 ++- tests/test_wavelets.py | 69 ++---- 7 files changed, 368 insertions(+), 252 deletions(-) create mode 100644 s2wav/transforms/wavelet_c.py diff --git a/s2wav/__init__.py b/s2wav/__init__.py index 716a1b9..8120d5c 100644 --- a/s2wav/__init__.py +++ b/s2wav/__init__.py @@ -8,6 +8,11 @@ # JAX recursive transforms from .transforms.wavelet import analysis, synthesis, flm_to_analysis +# C Backend transforms +from .transforms.wavelet_c import analysis as analysis_c +from .transforms.wavelet_c import synthesis as synthesis_c +from .transforms.wavelet_c import flm_to_analysis as flm_to_analysis_c + # Base transforms from .transforms.base import analysis as analysis_base from .transforms.base import synthesis as synthesis_base diff --git a/s2wav/transforms/__init__.py b/s2wav/transforms/__init__.py index c6708cf..b43d918 100644 --- a/s2wav/transforms/__init__.py +++ b/s2wav/transforms/__init__.py @@ -1,5 +1,6 @@ from . import base from . import construct from . import wavelet +from . import wavelet_c from . import wavelet_precompute from . import wavelet_precompute_torch diff --git a/s2wav/transforms/wavelet.py b/s2wav/transforms/wavelet.py index 58fc45b..9410e82 100644 --- a/s2wav/transforms/wavelet.py +++ b/s2wav/transforms/wavelet.py @@ -1,6 +1,5 @@ from jax import jit import jax.numpy as jnp -import numpy as np from functools import partial from typing import Tuple, List import s2fft @@ -8,6 +7,7 @@ from s2wav.transforms import construct +@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9)) def synthesis( f_wav: jnp.ndarray, f_scal: jnp.ndarray, @@ -21,8 +21,6 @@ def synthesis( reality: bool = False, filters: Tuple[jnp.ndarray] = None, precomps: List[List[jnp.ndarray]] = None, - use_c_backend: bool = False, - _ssht_backend: int = 1, ) -> jnp.ndarray: r"""Computes the synthesis directional wavelet transform [1,2]. Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` @@ -61,17 +59,6 @@ def synthesis( precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. - use_c_backend (bool, optional): Execution mode in {"jax" = False, "jax_ssht" = True}. - Defaults to False. - - _ssht_backend (int, optional, experimental): Whether to default to SSHT core - (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental - backend. Use with caution. - - Raises: - AssertionError: Shape of wavelet/scaling coefficients incorrect. - ValueError: If healpix sampling is provided to SSHT C backend. - Returns: jnp.ndarray: Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. @@ -79,58 +66,23 @@ def synthesis( [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the sphere", A&A, vol. 558, p. A128, 2013. [2] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). """ - if precomps == None and not use_c_backend: + if precomps is None: precomps = construct.generate_wigner_precomputes( L, N, J_min, lam, sampling, nside, True, reality ) - if use_c_backend and sampling.lower() == "healpix": - raise ValueError("SSHT C backend does not support healpix sampling.") J = samples.j_max(L, lam) Ls = samples.scal_bandlimit(L, J_min, lam, True) flm = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) - f_scal_lm = ( - s2fft.forward( - f_scal.real if reality else f_scal, - Ls, - spin, - nside, - sampling, - "jax_ssht", - reality, - _ssht_backend=_ssht_backend, - ) - if use_c_backend - else s2fft.forward_jax(f_scal, Ls, spin, nside, sampling, reality) - ) + f_scal_lm = s2fft.forward_jax(f_scal, Ls, spin, nside, sampling, reality) # Sum the all wavelet wigner coefficients for each lmn # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) - temp = ( - s2fft.wigner.forward( - f_wav[j - J_min], - Lj, - Nj, - nside, - sampling, - "jax_ssht", - reality, - _ssht_backend=_ssht_backend, - ) - if use_c_backend - else s2fft.wigner.forward_jax( - f_wav[j - J_min], - Lj, - Nj, - nside, - sampling, - reality, - precomps[j - J_min], - L_lower=L0j, - ) + temp = s2fft.wigner.forward_jax( + f_wav[j - J_min], Lj, Nj, nside, sampling, reality, precomps[j - J_min], L0j ) flm = flm.at[L0j:Lj, L - Lj : L - 1 + Lj].add( jnp.einsum( @@ -146,22 +98,10 @@ def synthesis( flm = flm.at[:Ls, L - Ls : L - 1 + Ls].add( jnp.einsum("lm,l->lm", f_scal_lm, phi, optimize=True) ) - return ( - s2fft.inverse( - flm, - L, - spin, - nside, - sampling, - "jax_ssht", - reality, - _ssht_backend=_ssht_backend, - ) - if use_c_backend - else s2fft.inverse_jax(flm, L, spin, nside, sampling, reality) - ) + return s2fft.inverse_jax(flm, L, spin, nside, sampling, reality) +@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8)) def analysis( f: jnp.ndarray, L: int, @@ -174,8 +114,6 @@ def analysis( reality: bool = False, filters: Tuple[jnp.ndarray] = None, precomps: List[List[jnp.ndarray]] = None, - use_c_backend: bool = False, - _ssht_backend: int = 1, ) -> Tuple[jnp.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. @@ -206,13 +144,6 @@ def analysis( precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. - use_c_backend (bool, optional): Execution mode in {"jax" = False, "jax_ssht" = True}. - Defaults to False. - - _ssht_backend (int, optional, experimental): Whether to default to SSHT core - (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental - backend. Use with caution. - Returns: f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. @@ -220,16 +151,12 @@ def analysis( f_scal (jnp.ndarray): Array of scaling pixel-space coefficients with shape :math:`[n_{\theta}, n_{\phi}]`. """ - if precomps == None and not use_c_backend: + if precomps is None: precomps = construct.generate_wigner_precomputes( L, N, J_min, lam, sampling, nside, False, reality ) J = samples.j_max(L, lam) Ls = samples.scal_bandlimit(L, J_min, lam, True) - - f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True) - f_wav = samples.construct_f_jax(L, J_min, J, lam) - wav_lm = jnp.einsum( "jln, l->jln", jnp.conj(filters[0]), @@ -237,23 +164,12 @@ def analysis( optimize=True, ) - flm = ( - s2fft.forward( - f, - L, - spin, - nside, - sampling, - "jax_ssht", - reality, - _ssht_backend=_ssht_backend, - ) - if use_c_backend - else s2fft.forward_jax(f, L, spin, nside, sampling, reality) - ) + flm = s2fft.forward_jax(f, L, spin, nside, sampling, reality) # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J + f_wav = [] + f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True) for j in range(J_min, J + 1): Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) f_wav_lmn[j - J_min] = ( @@ -269,19 +185,8 @@ def analysis( ) ) - f_wav[j - J_min] = ( - s2fft.wigner.inverse( - f_wav_lmn[j - J_min], - Lj, - Nj, - nside, - sampling, - "jax_ssht", - reality, - _ssht_backend=_ssht_backend, - ) - if use_c_backend - else s2fft.wigner.inverse_jax( + f_wav.append( + s2fft.wigner.inverse_jax( f_wav_lmn[j - J_min], Lj, Nj, @@ -289,7 +194,6 @@ def analysis( sampling, reality, precomps[j - J_min], - False, L0j, ) ) @@ -302,23 +206,11 @@ def analysis( if Ls == 1: f_scal = temp * jnp.sqrt(1 / (4 * jnp.pi)) else: - f_scal = ( - s2fft.inverse( - temp, - Ls, - spin, - nside, - sampling, - "jax_ssht", - reality, - _ssht_backend=_ssht_backend, - ) - if use_c_backend - else s2fft.inverse_jax(temp, Ls, spin, nside, sampling, reality) - ) + f_scal = s2fft.inverse_jax(temp, Ls, spin, nside, sampling, reality) return f_wav, f_scal +@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8)) def flm_to_analysis( flm: jnp.ndarray, L: int, @@ -331,8 +223,6 @@ def flm_to_analysis( reality: bool = False, filters: Tuple[jnp.ndarray] = None, precomps: List[List[jnp.ndarray]] = None, - use_c_backend: bool = False, - _ssht_backend: int = 1, ) -> Tuple[jnp.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. @@ -363,27 +253,16 @@ def flm_to_analysis( precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. - use_c_backend (bool, optional): Execution mode in {"jax" = False, "jax_ssht" = True}. - Defaults to False. - - _ssht_backend (int, optional, experimental): Whether to default to SSHT core - (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental - backend. Use with caution. - Returns: f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. """ - if precomps == None and not use_c_backend: + if precomps is None: precomps = construct.generate_wigner_precomputes( L, N, J_min, lam, sampling, nside, False, reality ) J = J_max if J_max is not None else samples.j_max(L, lam) - - f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True) - f_wav = samples.construct_f_jax(L, J_min, J, lam) - wav_lm = jnp.einsum( "jln, l->jln", jnp.conj(filters), @@ -393,6 +272,8 @@ def flm_to_analysis( # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J + f_wav = [] + f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True) for j in range(J_min, J + 1): Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) f_wav_lmn[j - J_min] = ( @@ -408,19 +289,8 @@ def flm_to_analysis( ) ) - f_wav[j - J_min] = jnp.array( - s2fft.wigner.inverse( - jnp.array(f_wav_lmn[j - J_min]), - Lj, - Nj, - nside, - sampling, - "jax_ssht", - reality, - _ssht_backend=_ssht_backend, - ) - if use_c_backend - else s2fft.wigner.inverse_jax( + f_wav.append( + s2fft.wigner.inverse_jax( f_wav_lmn[j - J_min], Lj, Nj, @@ -428,7 +298,6 @@ def flm_to_analysis( sampling, reality, precomps[j - J_min], - False, L0j, ) ) diff --git a/s2wav/transforms/wavelet_c.py b/s2wav/transforms/wavelet_c.py new file mode 100644 index 0000000..ad5238d --- /dev/null +++ b/s2wav/transforms/wavelet_c.py @@ -0,0 +1,290 @@ +from jax import jit +import jax.numpy as jnp +from typing import Tuple, List +from functools import partial +import s2fft +from s2fft.transforms.c_backend_spherical import ssht_forward, ssht_inverse +from s2wav import samples + + +def synthesis( + f_wav: jnp.ndarray, + f_scal: jnp.ndarray, + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + spin: int = 0, + sampling: str = "mw", + reality: bool = False, + filters: Tuple[jnp.ndarray] = None, +) -> jnp.ndarray: + r"""Computes the synthesis directional wavelet transform [1,2]. + Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` + by summing the contributions from wavelet and scaling coefficients in harmonic space, + see equation 27 from `[2] `_. + + Args: + f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients + with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. + + f_scal (jnp.ndarray): Array of scaling pixel-space coefficients + with shape :math:`[n_{\theta}, n_{\phi}]`. + + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) of input signal. Defaults to 0. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", + "healpix"}. Defaults to "mw". + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. + + Raises: + ValueError: If healpix sampling is provided to SSHT C backend. + + Returns: + jnp.ndarray: Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. + + Notes: + [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the sphere", A&A, vol. 558, p. A128, 2013. + [2] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). + """ + if sampling.lower() == "healpix": + raise ValueError("SSHT C backend does not support healpix sampling.") + ssht_sampling = ["mw", "mwss", "dh", "gl"].index(sampling.lower()) + + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, True) + + f_scal_lm = ssht_forward( + f_scal.real if reality else f_scal, Ls, spin, reality, ssht_sampling + ) + f_wav_lmn = [] + for j in range(J_min, J + 1): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + f_wav_lmn.append( + s2fft.wigner.forward_jax_ssht( + f_wav[j - J_min], Lj, Nj, L0j, sampling, reality + ) + ) + + flm = _sum_over_wavelet_and_scaling( + f_wav_lmn, f_scal_lm, L, N, J_min, J, lam, filters + ) + return ssht_inverse(flm, L, spin, reality, ssht_sampling) + + +def analysis( + f: jnp.ndarray, + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + spin: int = 0, + sampling: str = "mw", + reality: bool = False, + filters: Tuple[jnp.ndarray] = None, +) -> Tuple[jnp.ndarray]: + r"""Wavelet analysis from pixel space to wavelet space for complex signals. + + Args: + f (jnp.ndarray): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. + + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) of input signal. Defaults to 0. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. + + Returns: + f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients + with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. + + f_scal (jnp.ndarray): Array of scaling pixel-space coefficients + with shape :math:`[n_{\theta}, n_{\phi}]`. + """ + if sampling.lower() == "healpix": + raise ValueError("SSHT C backend does not support healpix sampling.") + ssht_sampling = ["mw", "mwss", "dh", "gl"].index(sampling.lower()) + + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, True) + + flm = ssht_forward(f, L, spin, reality, ssht_sampling) + f_wav_lmn = _generate_and_apply_wavelets(flm, L, N, J_min, J, lam, filters[0]) + flm_scal = _generate_and_apply_scaling(flm, L, J_min, lam, filters[1]) + + f_wav = [] + for j in range(J_min, J + 1): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + f_wav.append( + s2fft.wigner.inverse_jax_ssht( + f_wav_lmn[j - J_min], Lj, Nj, L0j, sampling, reality + ) + ) + + if Ls == 1: + f_scal = flm_scal * jnp.sqrt(1 / (4 * jnp.pi)) + else: + f_scal = ssht_inverse(flm_scal, Ls, spin, reality, ssht_sampling) + return f_wav, f_scal + + +def flm_to_analysis( + flm: jnp.ndarray, + L: int, + N: int = 1, + J_min: int = 0, + J_max: int = None, + lam: float = 2.0, + sampling: str = "mw", + reality: bool = False, + filters: Tuple[jnp.ndarray] = None, +) -> Tuple[jnp.ndarray]: + r"""Wavelet analysis from pixel space to wavelet space for complex signals. + + Args: + f (jnp.ndarray): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. + + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + J_max (int, optional): Highest frequency wavelet scale to be used. Defaults to None. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + filters (jnp.ndarray, optional): Precomputed wavelet filters. Defaults to None. + + Returns: + f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients + with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. + """ + J = J_max if J_max is not None else samples.j_max(L, lam) + f_wav_lmn = _generate_and_apply_wavelets(flm, L, N, J_min, J, lam, filters) + + f_wav = [] + for j in range(J_min, J + 1): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + f_wav.append( + s2fft.wigner.inverse_jax_ssht( + f_wav_lmn[j - J_min], Lj, Nj, L0j, sampling, reality + ) + ) + + return f_wav + + +@partial(jit, static_argnums=(1, 2, 3, 4, 5)) +def _generate_and_apply_wavelets( + flm: jnp.ndarray, + L: int, + N: int, + J_min: int, + J: int, + lam: float = 2.0, + filters: jnp.ndarray = None, +) -> jnp.ndarray: + """Private internal function which generates and applies wavelet filters.""" + # f_wav = samples.construct_f_jax(L, J_min, J, lam) + f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True) + + wav_lm = jnp.einsum( + "jln, l->jln", + jnp.conj(filters), + 8 * jnp.pi**2 / (2 * jnp.arange(L) + 1), + optimize=True, + ) + + for j in range(J_min, J + 1): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + f_wav_lmn[j - J_min] = ( + f_wav_lmn[j - J_min] + .at[::2, L0j:] + .add( + jnp.einsum( + "lm,ln->nlm", + flm[L0j:Lj, L - Lj : L - 1 + Lj], + wav_lm[j, L0j:Lj, L - Nj : L - 1 + Nj : 2], + optimize=True, + ) + ) + ) + return f_wav_lmn + + +@partial(jit, static_argnums=(1, 2, 3)) +def _generate_and_apply_scaling( + flm: jnp.ndarray, + L: int, + J_min: int = 0, + lam: float = 2.0, + filters: jnp.ndarray = None, +) -> jnp.ndarray: + """Private internal function which generates and applies scaling filter.""" + Ls = samples.scal_bandlimit(L, J_min, lam, True) + phi = filters[:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1)) + return jnp.einsum("lm,l->lm", flm[:Ls, L - Ls : L - 1 + Ls], phi, optimize=True) + + +@partial(jit, static_argnums=(2, 3, 4, 5, 6)) +def _sum_over_wavelet_and_scaling( + f_wav_lmn: jnp.ndarray, + f_scal_lm: jnp.ndarray, + L: int, + N: int, + J_min: int, + J: int, + lam: float = 2.0, + filters: Tuple[jnp.ndarray] = None, +) -> jnp.ndarray: + """Private internal function which sums over wavelet and scaling coefficients.""" + Ls = samples.scal_bandlimit(L, J_min, lam, True) + flm = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) + for j in range(J_min, J + 1): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + flm = flm.at[L0j:Lj, L - Lj : L - 1 + Lj].add( + jnp.einsum( + "ln,nlm->lm", + filters[0][j, L0j:Lj, L - Nj : L - 1 + Nj : 2], + f_wav_lmn[j - J_min][::2, L0j:, :], + optimize=True, + ) + ) + + phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1)) + flm = flm.at[:Ls, L - Ls : L - 1 + Ls].add( + jnp.einsum("lm,l->lm", f_scal_lm, phi, optimize=True) + ) + return flm diff --git a/s2wav/transforms/wavelet_precompute.py b/s2wav/transforms/wavelet_precompute.py index 072e47e..a9d1487 100644 --- a/s2wav/transforms/wavelet_precompute.py +++ b/s2wav/transforms/wavelet_precompute.py @@ -166,10 +166,6 @@ def analysis( J = samples.j_max(L, lam) Ls = samples.scal_bandlimit(L, J_min, lam, True) - - f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True) - f_wav = samples.construct_f_jax(L, J_min, J, lam) - wav_lm = jnp.einsum( "jln, l->jln", jnp.conj(filters[0]), @@ -182,6 +178,8 @@ def analysis( ) # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J + f_wav = [] + f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True) for j in range(J_min, J + 1): Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) f_wav_lmn[j - J_min] = ( @@ -197,14 +195,16 @@ def analysis( ) ) shift = 0 if j < J else -1 - f_wav[j - J_min] = wigner.inverse_transform_jax( - f_wav_lmn[j - J_min], - precomps[2][j - J_min + shift], - Lj, - Nj, - sampling, - reality, - nside, + f_wav.append( + wigner.inverse_transform_jax( + f_wav_lmn[j - J_min], + precomps[2][j - J_min + shift], + Lj, + Nj, + sampling, + reality, + nside, + ) ) # Project all harmonic coefficients for each lm onto scaling coefficients @@ -263,7 +263,7 @@ def flm_to_analysis( precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. - + _precomp_shift (bool, optional): Whether or not the duplicated highest wavelet scale precomputes are provided or not. @@ -278,10 +278,6 @@ def flm_to_analysis( raise ValueError("Must provide precomputed kernels for this transform!") J = J_max if J_max is not None else samples.j_max(L, lam) - - f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True) - f_wav = samples.construct_f_jax(L, J_min, J, lam) - wav_lm = jnp.einsum( "jln, l->jln", jnp.conj(filters), @@ -291,6 +287,8 @@ def flm_to_analysis( # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J + f_wav = [] + f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True) for j in range(J_min, J + 1): Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) f_wav_lmn[j - J_min] = ( @@ -308,13 +306,15 @@ def flm_to_analysis( shift = 0 if j < J else -1 shift = shift if _precomp_shift else 0 - f_wav[j - J_min] = wigner.inverse_transform_jax( - f_wav_lmn[j - J_min], - precomps[2][j - J_min + shift], - Lj, - Nj, - sampling, - reality, - nside, + f_wav.append( + wigner.inverse_transform_jax( + f_wav_lmn[j - J_min], + precomps[2][j - J_min + shift], + Lj, + Nj, + sampling, + reality, + nside, + ) ) return f_wav diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 7fd860b..882a1bb 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from jax.test_util import check_grads import s2fft -from s2wav.transforms import wavelet, wavelet_precompute, construct +from s2wav.transforms import wavelet, wavelet_c, wavelet_precompute, construct from s2wav import filters, samples L_to_test = [8] @@ -14,7 +14,6 @@ reality = [False, True] recursive_transform = [False, True] using_c_backend = [False, True] -_ssht_backends = [0, 1] @pytest.mark.parametrize("L", L_to_test) @@ -23,7 +22,6 @@ @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) @pytest.mark.parametrize("using_c_backend", using_c_backend) -@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_jax_synthesis_gradients( wavelet_generator, L: int, @@ -32,7 +30,6 @@ def test_jax_synthesis_gradients( reality: bool, recursive: bool, using_c_backend: bool, - _ssht_backend: int, ): J = samples.j_max(L) @@ -60,7 +57,11 @@ def test_jax_synthesis_gradients( else construct.generate_full_precomputes ) ) - synthesis = wavelet.synthesis if recursive else wavelet_precompute.synthesis + synthesis = ( + (wavelet_c.synthesis if using_c_backend else wavelet.synthesis) + if recursive + else wavelet_precompute.synthesis + ) precomps = ( None @@ -68,11 +69,7 @@ def test_jax_synthesis_gradients( else generator(L, N, J_min, 2, forward=True, reality=reality) ) - args = ( - {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} - if using_c_backend - else {} - ) + args = {"precomps": precomps} if not using_c_backend else {} def func(f_wav, f_scal): f = synthesis( @@ -83,7 +80,6 @@ def func(f_wav, f_scal): J_min, reality=reality, filters=filter, - precomps=precomps, **args, ) return jnp.sum(jnp.abs(f) ** 2) @@ -97,7 +93,6 @@ def func(f_wav, f_scal): @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) @pytest.mark.parametrize("using_c_backend", using_c_backend) -@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_jax_analysis_gradients( flm_generator, wavelet_generator, @@ -107,7 +102,6 @@ def test_jax_analysis_gradients( reality: bool, recursive: bool, using_c_backend: bool, - _ssht_backend: int, ): J = samples.j_max(L) if J_min >= J: @@ -137,22 +131,22 @@ def test_jax_analysis_gradients( else construct.generate_full_precomputes ) ) - analysis = wavelet.analysis if recursive else wavelet_precompute.analysis + analysis = ( + (wavelet_c.analysis if using_c_backend else wavelet.analysis) + if recursive + else wavelet_precompute.analysis + ) precomps = ( None if using_c_backend else generator(L, N, J_min, forward=False, reality=reality) ) - args = ( - {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} - if using_c_backend - else {} - ) + args = {"precomps": precomps} if not using_c_backend else {} def func(f): f_wav, f_scal = analysis( - f, L, N, J_min, reality=reality, filters=filter, precomps=precomps, **args + f, L, N, J_min, reality=reality, filters=filter, **args ) loss = jnp.sum(jnp.abs(f_scal - f_scal_target) ** 2) for j in range(J - J_min): diff --git a/tests/test_wavelets.py b/tests/test_wavelets.py index fe7121f..ea8725b 100644 --- a/tests/test_wavelets.py +++ b/tests/test_wavelets.py @@ -8,6 +8,7 @@ from s2fft import base_transforms as sht_base from s2wav.transforms import ( wavelet, + wavelet_c, wavelet_precompute, wavelet_precompute_torch, construct, @@ -23,7 +24,6 @@ recursive_transform = [False, True] using_torch_frontend = [False, True] using_c_backend = [False, True] -_ssht_backends = [0, 1] @pytest.mark.parametrize("L", L_to_test) @@ -34,7 +34,6 @@ @pytest.mark.parametrize("recursive", recursive_transform) @pytest.mark.parametrize("using_torch", using_torch_frontend) @pytest.mark.parametrize("using_c_backend", using_c_backend) -@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_synthesis( wavelet_generator, L: int, @@ -45,7 +44,6 @@ def test_synthesis( recursive: bool, using_torch: bool, using_c_backend: bool, - _ssht_backend: int, ): J = samples.j_max(L, lam) @@ -85,7 +83,7 @@ def test_synthesis( ) ) synthesis = ( - wavelet.synthesis + (wavelet_c.synthesis if using_c_backend else wavelet.synthesis) if recursive else ( wavelet_precompute_torch.synthesis @@ -100,23 +98,10 @@ def test_synthesis( L, N, J_min, lam, forward=True, reality=reality, using_torch=using_torch ) ) - args = ( - {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} - if using_c_backend - else {} - ) + args = {"precomps": precomps} if not using_c_backend else {} f_check = synthesis( - f_wav, - f_scal, - L, - N, - J_min, - lam, - reality=reality, - filters=filter, - precomps=precomps, - **args, + f_wav, f_scal, L, N, J_min, lam, reality=reality, filters=filter, **args ) if using_torch: @@ -134,7 +119,6 @@ def test_synthesis( @pytest.mark.parametrize("recursive", recursive_transform) @pytest.mark.parametrize("using_torch", using_torch_frontend) @pytest.mark.parametrize("using_c_backend", using_c_backend) -@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_analysis( flm_generator, f_wav_converter, @@ -146,7 +130,6 @@ def test_analysis( recursive: bool, using_torch: bool, using_c_backend: bool, - _ssht_backend: int, ): J = samples.j_max(L, lam) @@ -177,7 +160,7 @@ def test_analysis( ) ) analysis = ( - wavelet.analysis + (wavelet_c.analysis if using_c_backend else wavelet.analysis) if recursive else ( wavelet_precompute_torch.analysis @@ -193,11 +176,7 @@ def test_analysis( ) ) - args = ( - {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} - if using_c_backend - else {} - ) + args = {"precomps": precomps} if not using_c_backend else {} f_wav_check, f_scal_check = analysis( torch.from_numpy(f) if using_torch else f, @@ -207,7 +186,6 @@ def test_analysis( lam, reality=reality, filters=filter, - precomps=precomps, **args, ) @@ -230,7 +208,6 @@ def test_analysis( @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("sampling", sampling_to_test) @pytest.mark.parametrize("using_c_backend", using_c_backend) -@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_round_trip( flm_generator, L: int, @@ -240,7 +217,6 @@ def test_round_trip( reality: bool, sampling: str, using_c_backend: bool, - _ssht_backend: int, ): J = samples.j_max(L, lam) @@ -252,24 +228,13 @@ def test_round_trip( f = sht_base.spherical.inverse(flm, L, reality=reality, sampling=sampling) filter = filters.filters_directional_vectorised(L, N, J_min, lam) - args = ( - {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} - if using_c_backend - else {} - ) + analysis_func = wavelet_c.analysis if using_c_backend else wavelet.analysis + synthesis_func = wavelet_c.synthesis if using_c_backend else wavelet.synthesis - f_wav, f_scal = wavelet.analysis( - f, - L, - N, - J_min, - lam, - reality=reality, - sampling=sampling, - filters=filter, - **args, + f_wav, f_scal = analysis_func( + f, L, N, J_min, lam, reality=reality, sampling=sampling, filters=filter ) - f_check = wavelet.synthesis( + f_check = synthesis_func( f_wav, f_scal, L, @@ -279,7 +244,6 @@ def test_round_trip( sampling=sampling, reality=reality, filters=filter, - **args, ) np.testing.assert_allclose(f, f_check, atol=1e-14) @@ -293,7 +257,6 @@ def test_round_trip( @pytest.mark.parametrize("recursive", recursive_transform) @pytest.mark.parametrize("using_torch", using_torch_frontend) @pytest.mark.parametrize("using_c_backend", using_c_backend) -@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_flm_to_analysis( flm_generator, f_wav_converter, @@ -305,7 +268,6 @@ def test_flm_to_analysis( recursive: bool, using_torch: bool, using_c_backend: bool, - _ssht_backend: int, ): J = samples.j_max(L, lam) @@ -337,7 +299,7 @@ def test_flm_to_analysis( ) ) analysis = ( - wavelet.flm_to_analysis + (wavelet_c.flm_to_analysis if using_c_backend else wavelet.flm_to_analysis) if recursive else ( wavelet_precompute_torch.flm_to_analysis @@ -353,11 +315,7 @@ def test_flm_to_analysis( ) ) - args = ( - {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} - if using_c_backend - else {} - ) + args = {"precomps": precomps} if not using_c_backend else {} f_wav_check = analysis( torch.from_numpy(flm) if using_torch else flm, @@ -368,7 +326,6 @@ def test_flm_to_analysis( lam, reality=reality, filters=filter, - precomps=precomps, **args, ) From fb8e16ebd3646726e296aceb1fca2ba37060151e Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Mon, 15 Apr 2024 11:26:25 +0100 Subject: [PATCH 2/3] increment version number to v1.0.4 --- docs/conf.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d664689..e6426b1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,9 +25,9 @@ author = "Matthew Price, Jason McEwen, Jessica Whitney, Alicja Polanska" # The short X.Y version -version = "1.0.3" +version = "1.0.4" # The full version, including alpha/beta/rc tags -release = "1.0.3" +release = "1.0.4" # -- General configuration --------------------------------------------------- diff --git a/setup.py b/setup.py index ad0d9e9..c32fb9e 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ "Intended Audience :: Science/Research", ], name="s2wav", - version="1.0.3", + version="1.0.4", url="https://github.com/astro-informatics/s2wav", author="Authors & Contributors", license="GNU General Public License v3 (GPLv3)", From dd4337ae44179c15cc2d5283960f7c1f912e2d8f Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Mon, 15 Apr 2024 11:26:36 +0100 Subject: [PATCH 3/3] update s2fft version requirement --- requirements/requirements-core.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-core.txt b/requirements/requirements-core.txt index 7fe0b81..6ab5691 100644 --- a/requirements/requirements-core.txt +++ b/requirements/requirements-core.txt @@ -5,4 +5,4 @@ pyyaml==6.0 scipy # For spherical transforms -s2fft >= 1.1.0 +s2fft >= 1.1.1