From 96e03828614677a2e212cb278facf6e58842006a Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Fri, 8 Mar 2024 14:00:54 +0000 Subject: [PATCH 1/8] restructure code and sunset non-multiresolution support --- requirements/requirements-core.txt | 4 +- s2wav/__init__.py | 4 +- s2wav/filter_factory/__init__.py | 3 - s2wav/filter_factory/filters.py | 332 -------- s2wav/filter_factory/kernels.py | 236 ------ s2wav/filter_factory/tiling.py | 161 ---- s2wav/filters.py | 711 ++++++++++++++++++ s2wav/{utils/shapes.py => samples.py} | 39 +- s2wav/transforms/__init__.py | 6 +- .../transforms/{numpy_wavelets.py => base.py} | 45 +- s2wav/transforms/construct.py | 113 +++ ..._wavelets_precompute.py => pre_wav_jax.py} | 158 +--- .../{jax_wavelets.py => rec_wav_jax.py} | 148 +--- s2wav/utils/__init__.py | 2 - s2wav/utils/math_functions.py | 41 - setup.py | 4 +- tests/conftest.py | 29 +- tests/test_filters.py | 9 +- tests/test_gradients.py | 47 +- tests/test_wavelets.py | 89 +-- tests/test_wavelets_base.py | 44 +- tests/test_wavelets_precompute.py | 149 ---- 22 files changed, 1053 insertions(+), 1321 deletions(-) delete mode 100644 s2wav/filter_factory/__init__.py delete mode 100644 s2wav/filter_factory/filters.py delete mode 100644 s2wav/filter_factory/kernels.py delete mode 100644 s2wav/filter_factory/tiling.py create mode 100644 s2wav/filters.py rename s2wav/{utils/shapes.py => samples.py} (95%) rename s2wav/transforms/{numpy_wavelets.py => base.py} (92%) create mode 100644 s2wav/transforms/construct.py rename s2wav/transforms/{jax_wavelets_precompute.py => pre_wav_jax.py} (64%) rename s2wav/transforms/{jax_wavelets.py => rec_wav_jax.py} (66%) delete mode 100644 s2wav/utils/__init__.py delete mode 100644 s2wav/utils/math_functions.py delete mode 100644 tests/test_wavelets_precompute.py diff --git a/requirements/requirements-core.txt b/requirements/requirements-core.txt index 84b70ac..16485d2 100644 --- a/requirements/requirements-core.txt +++ b/requirements/requirements-core.txt @@ -4,5 +4,5 @@ colorlog pyyaml==6.0 scipy -# Switch to pypi install when available -s2fft @ git+https://github.com/astro-informatics/s2fft.git@main#egg=s2fft +# For spherical transforms +s2fft >= 1.0.2 diff --git a/s2wav/__init__.py b/s2wav/__init__.py index 54ed041..58d40fc 100644 --- a/s2wav/__init__.py +++ b/s2wav/__init__.py @@ -1 +1,3 @@ -from .transforms.jax_wavelets import analysis, synthesis, flm_to_analysis +from . import filters +from . import samples +from .transforms.rec_wav_jax import analysis, synthesis, flm_to_analysis diff --git a/s2wav/filter_factory/__init__.py b/s2wav/filter_factory/__init__.py deleted file mode 100644 index 1a380a6..0000000 --- a/s2wav/filter_factory/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import tiling -from . import filters -from . import kernels diff --git a/s2wav/filter_factory/filters.py b/s2wav/filter_factory/filters.py deleted file mode 100644 index 699f897..0000000 --- a/s2wav/filter_factory/filters.py +++ /dev/null @@ -1,332 +0,0 @@ -from jax import jit, config - -config.update("jax_enable_x64", True) - -import jax.numpy as jnp -import numpy as np -from s2wav.filter_factory import tiling, kernels -from s2wav.utils import shapes -from typing import Tuple -from functools import partial - - -def filters_axisym( - L: int, J_min: int = 0, lam: float = 2.0 -) -> Tuple[np.ndarray, np.ndarray]: - r"""Computes wavelet kernels :math:`\Psi^j_{\ell m}` and scaling kernel :math:`\Phi_{\ell m}` in harmonic space. - - Specifically, these kernels are derived in `[1] `_, where the wavelet kernels are defined (15) for scale :math:`j` to be - - .. math:: - - \Psi^j_{\ell m} \equiv \sqrt{\frac{2\ell+1}{4\pi}} \kappa_{\lambda}(\frac{\ell}{\lambda^j})\delta_{m0}, - - where :math:`\kappa_{\lambda} = \sqrt{k_{\lambda}(t/\lambda) - k_{\lambda}(t)}` for :math:`k_{\lambda}` given in :func:`~k_lam`. Similarly, the scaling kernel is defined (16) as - - .. math:: - - \Phi_{\ell m} \equiv \sqrt{\frac{2\ell+1}{4\pi}} \nu_{\lambda} (\frac{\ell}{\lambda^{J_0}})\delta_{m0}, - - where :math:`\nu_{\lambda} = \sqrt{k_{\lambda}(t)}` for :math:`k_{\lambda}` given in :func:`~k_lam`. Notice that :math:`\delta_{m0}` enforces that these kernels are axisymmetric, i.e. coefficients for :math:`m \not = \ell` are zero. In this implementation the normalisation constant has been omitted as it is nulled in subsequent functions. - - Args: - L (int): Harmonic band-limit. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - - lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. - Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - - Raises: - ValueError: J_min is negative or greater than J. - - Returns: - Tuple[np.ndarray, np.ndarray]: Unnormalised wavelet kernels :math:`\Psi^j_{\ell m}` with shape :math:`[(J+1)L]`, and scaling kernel :math:`\Phi_{\el m}` with shape :math:`[L]` in harmonic space. - - Note: - [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the sphere", A&A, vol. 558, p. A128, 2013. - """ - J = shapes.j_max(L, lam) - - if J_min >= J or J_min < 0: - raise ValueError( - "J_min must be non-negative and less than J= " - + str(J) - + " for given L and lam." - ) - - previoustemp = 0.0 - k = kernels.k_lam(L, lam) - psi = np.zeros((J + 1, L), np.float64) - phi = np.zeros(L, np.float64) - for l in range(L): - phi[l] = np.sqrt(k[J_min, l]) - - for j in range(J_min, J + 1): - for l in range(L): - diff = k[j + 1, l] - k[j, l] - if diff < 0: - psi[j, l] = previoustemp - else: - temp = np.sqrt(diff) - psi[j, l] = temp - previoustemp = temp - - return psi, phi - - -def filters_directional( - L: int, - N: int = 1, - J_min: int = 0, - lam: float = 2.0, - spin: int = 0, - spin0: int = 0, -) -> Tuple[np.ndarray, np.ndarray]: - r"""Generates the harmonic coefficients for the directional tiling wavelets. - - This implementation is based on equation 36 in the wavelet computation paper `[1] `_. - - Args: - L (int): Harmonic band-limit. - - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - - lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. - Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - - spin (int, optional): Spin (integer) to perform the transform. Defaults to 0. - - spin0 (int, optional): Spin number the wavelet was lowered from. Defaults to 0. - - Returns: - Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels (:math:`\Psi^j_{\ell n}`, :math:`\Phi_{\ell m}`) - psi (np.ndarray): Harmonic coefficients of directional wavelets with shape :math:`[L^2(J+1)]`. - - phi (np.ndarray): Harmonic coefficients of scaling function with shape :math:`[L]`. - - Notes: - [1] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). - """ - J = shapes.j_max(L, lam) - el_min = max(abs(spin), abs(spin0)) - - phi = np.zeros(L, dtype=np.float64) - psi = np.zeros((J + 1, L, 2 * L - 1), dtype=np.complex128) - - kappa, kappa0 = filters_axisym(L, J_min, lam) - s_elm = tiling.tiling_direction(L, N) - - for el in range(el_min, L): - if kappa0[el] != 0: - phi[el] = np.sqrt((2 * el + 1) / (4.0 * np.pi)) * kappa0[el] - if spin0 != 0: - phi[el] *= tiling.spin_normalization(el, spin0) * (-1) ** spin0 - - for j in range(J_min, J + 1): - for el in range(el_min, L): - if kappa[j, el] != 0: - for m in range(-el, el + 1): - if s_elm[el, L - 1 + m] != 0: - psi[j, el, L - 1 + m] = ( - np.sqrt((2 * el + 1) / (8.0 * np.pi * np.pi)) - * kappa[j, el] - * s_elm[el, L - 1 + m] - ) - if spin0 != 0: - psi[j, el, L - 1 + m] *= ( - tiling.spin_normalization(el, spin0) * (-1) ** spin0 - ) - - return psi, phi - - -def filters_axisym_vectorised( - L: int, J_min: int = 0, lam: float = 2.0 -) -> Tuple[np.ndarray, np.ndarray]: - r"""Vectorised version of :func:`~filters_axisym`. - - Args: - L (int): Harmonic band-limit. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - - lam (float, optional): Wavelet parameter which determines the scale factor - between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates - dyadic wavelets. Defaults to 2. - - Raises: - ValueError: J_min is negative or greater than J. - - Returns: - Tuple[np.ndarray, np.ndarray]: Unnormalised wavelet kernels :math:`\Psi^j_{\ell m}` - with shape :math:`[(J+1)L], and scaling kernel :math:`\Phi_{\ell m}` with shape - :math:`[L]` in harmonic space. - """ - J = shapes.j_max(L, lam) - - if J_min >= J or J_min < 0: - raise ValueError( - "J_min must be non-negative and less than J= " - + str(J) - + " for given L and lam." - ) - - k = kernels.k_lam(L, lam) - diff = (np.roll(k, -1, axis=0) - k)[:-1] - diff[diff < 0] = 0 - return np.sqrt(diff), np.sqrt(k[J_min]) - - -def filters_directional_vectorised( - L: int, - N: int = 1, - J_min: int = 0, - lam: float = 2.0, - spin: int = 0, - spin0: int = 0, -) -> Tuple[np.ndarray, np.ndarray]: - r"""Vectorised version of :func:`~filters_directional`. - - Args: - L (int): Harmonic band-limit. - - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - - lam (float, optional): Wavelet parameter which determines the scale factor between - consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic - wavelets. Defaults to 2. - - spin (int, optional): Spin (integer) to perform the transform. Defaults to 0. - - spin0 (int, optional): Spin number the wavelet was lowered from. Defaults to 0. - - Returns: - Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels (:math:`\Psi^j_{\ell n}`, :math:`\Phi_{\ell m}`) - psi (np.ndarray): Harmonic coefficients of directional wavelets with shape :math:`[L^2(J+1)]`. - - phi (np.ndarray): Harmonic coefficients of scaling function with shape :math:`[L]`. - """ - el_min = max(abs(spin), abs(spin0)) - - spin_norms = ( - (-1) ** spin0 * tiling.spin_normalization_vectorised(np.arange(L), spin0) - if spin0 != 0 - else 1 - ) - - kappa, kappa0 = filters_axisym_vectorised(L, J_min, lam) - s_elm = tiling.tiling_direction(L, N) - - kappa0 *= np.sqrt((2 * np.arange(L) + 1) / (4.0 * np.pi)) - kappa0 = kappa0 * spin_norms if spin0 != 0 else kappa0 - - kappa *= np.sqrt((2 * np.arange(L) + 1) / 8.0) / np.pi - kappa = np.einsum("ij,jk->ijk", kappa, s_elm) - kappa = np.einsum("ijk,j->ijk", kappa, spin_norms) if spin0 != 0 else kappa - - kappa0[:el_min] = 0 - kappa[:, :el_min, :] = 0 - return kappa, kappa0 - - -@partial(jit, static_argnums=(0, 1, 2)) -def filters_axisym_jax( - L: int, J_min: int = 0, lam: float = 2.0 -) -> Tuple[jnp.ndarray, jnp.ndarray]: - r"""JAX version of :func:`~filters_axisym_vectorised`. - - Args: - L (int): Harmonic band-limit. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - - lam (float, optional): Wavelet parameter which determines the scale factor - between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates - dyadic wavelets. Defaults to 2. - - Raises: - ValueError: J_min is negative or greater than J. - - Returns: - Tuple[np.ndarray, np.ndarray]: Unnormalised wavelet kernels :math:`\Psi^j_{\ell m}` - with shape :math:`[(J+1)L], and scaling kernel :math:`\Phi_{\ell m}` with shape - :math:`[L]` in harmonic space. - """ - J = shapes.j_max(L, lam) - - if J_min >= J or J_min < 0: - raise ValueError( - "J_min must be non-negative and less than J= " - + str(J) - + " for given L and lam." - ) - - k = kernels.k_lam_jax(L, lam) - diff = (jnp.roll(k, -1, axis=0) - k)[:-1] - diff = jnp.where(diff < 0, jnp.zeros((J + 1, L)), diff) - return jnp.sqrt(diff), jnp.sqrt(k[J_min]) - - -@partial(jit, static_argnums=(0, 1, 2, 3, 4, 5)) -def filters_directional_jax( - L: int, - N: int = 1, - J_min: int = 0, - lam: float = 2.0, - spin: int = 0, - spin0: int = 0, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - r"""JAX version of :func:`~filters_directional`. - - Args: - L (int): Harmonic band-limit. - - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - - lam (float, optional): Wavelet parameter which determines the scale factor between - consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic - wavelets. Defaults to 2. - - spin (int, optional): Spin (integer) to perform the transform. Defaults to 0. - - spin0 (int, optional): Spin number the wavelet was lowered from. Defaults to 0. - - Returns: - Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels (:math:`\Psi^j_{\ell n}`, :math:`\Phi_{\ell m}`) - psi (np.ndarray): Harmonic coefficients of directional wavelets with shape :math:`[L^2(J+1)]`. - - phi (np.ndarray): Harmonic coefficients of scaling function with shape :math:`[L]`. - """ - el_min = max(abs(spin), abs(spin0)) - - spin_norms = ( - (-1) ** spin0 * tiling.spin_normalization_jax(np.arange(L), spin0) - if spin0 != 0 - else 1 - ) - - kappa, kappa0 = filters_axisym_jax(L, J_min, lam) - s_elm = tiling.tiling_direction_jax(L, N) - - kappa0 *= jnp.sqrt((2 * jnp.arange(L) + 1) / (4.0 * jnp.pi)) - kappa0 = kappa0 * spin_norms if spin0 != 0 else kappa0 - - kappa *= jnp.sqrt((2 * jnp.arange(L) + 1) / 8.0) / np.pi - kappa = jnp.einsum("ij,jk->ijk", kappa, s_elm, optimize=True) - kappa = ( - jnp.einsum("ijk,j->ijk", kappa, spin_norms, optimize=True) - if spin0 != 0 - else kappa - ) - - kappa0 = kappa0.at[:el_min].set(0) - kappa = kappa.at[:, :el_min, :].set(0) - - return kappa, kappa0 diff --git a/s2wav/filter_factory/kernels.py b/s2wav/filter_factory/kernels.py deleted file mode 100644 index a38bc8d..0000000 --- a/s2wav/filter_factory/kernels.py +++ /dev/null @@ -1,236 +0,0 @@ -from jax import jit, config - -config.update("jax_enable_x64", True) - -import jax.numpy as jnp -import numpy as np -from s2wav.utils.shapes import j_max -from functools import partial - - -def tiling_integrand(t: float, lam: float = 2.0) -> float: - r"""Tiling integrand for scale-discretised wavelets `[1] `_. - - Intermediate step used to compute the wavelet and scaling function generating - functions. One of the basic mathematical functions needed to carry out the tiling of - the harmonic space. - - Args: - t (float): Real argument over which we integrate. - - lam (float, optional): Wavelet parameter which determines the scale factor - between consecutive wavelet scales.Note that :math:`\lambda = 2` indicates - dyadic wavelets. Defaults to 2. - - Returns: - float: Value of tiling integrand for given :math:`t` and scaling factor. - - Note: - [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on - the sphere", A&A, vol. 558, p. A128, 2013. - """ - s_arg = (t - (1.0 / lam)) * (2.0 * lam / (lam - 1.0)) - 1.0 - - integrand = np.exp(-2.0 / (1.0 - s_arg**2.0)) / t - - return integrand - - -def part_scaling_fn(a: float, b: float, n: int, lam: float = 2.0) -> float: - r"""Computes integral used to calculate smoothly decreasing function :math:`k_{\lambda}`. - - Intermediate step used to compute the wavelet and scaling function generating - functions. Uses the trapezium method to integrate :func:`~tiling_integrand` in the - limits from :math:`a \rightarrow b` with scaling parameter :math:`\lambda`. One of - the basic mathematical functions needed to carry out the tiling of the harmonic - space. - - Args: - a (float): Lower limit of the numerical integration. - - b (float): Upper limit of the numerical integration. - - n (int): Number of steps to be performed during integration. - - lam (float, optional): Wavelet parameter which determines the scale factor - between consecutive wavelet scales.Note that :math:`\lambda = 2` indicates - dyadic wavelets. Defaults to 2. - - Returns: - float: Integral of the tiling integrand from :math:`a \rightarrow b`. - """ - sum = 0.0 - h = (b - a) / n - - if a == b: - return 0 - - for i in range(n): - if a + i * h not in [1 / lam, 1.0] and a + (i + 1) * h not in [ - 1 / lam, - 1.0, - ]: - f1 = tiling_integrand(a + i * h, lam) - f2 = tiling_integrand(a + (i + 1) * h, lam) - - sum += ((f1 + f2) * h) / 2 - - return sum - - -def k_lam(L: int, lam: float = 2.0, quad_iters: int = 300) -> float: - r"""Compute function :math:`k_{\lambda}` used as a wavelet generating function. - - Specifically, this function is derived in [1] and is given by - - .. math:: - - k_{\lambda} \equiv \frac{ \int_t^1 \frac{\text{d}t^{\prime}}{t^{\prime}} - s_{\lambda}^2(t^{\prime})}{ \int_{\frac{1}{\lambda}}^1 - \frac{\text{d}t^{\prime}}{t^{\prime}} s_{\lambda}^2(t^{\prime})}, - - where the integrand is defined to be - - .. math:: - - s_{\lambda} \equiv s \Big ( \frac{2\lambda}{\lambda - 1}(t-\frac{1}{\lambda}) - - 1 \Big ), - - for infinitely differentiable Cauchy-Schwartz function :math:`s(t) \in C^{\infty}`. - - Args: - L (int): Harmonic band-limit. - - lam (float, optional): Wavelet parameter which determines the scale factor - between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates - dyadic wavelets. Defaults to 2. - - quad_iters (int, optional): Total number of iterations for quadrature - integration. Defaults to 300. - - Returns: - (np.ndarray): Value of :math:`k_{\lambda}` computed for values between - :math:`\frac{1}{\lambda}` and 1, parametrised by :math:`\ell` as required to - compute the axisymmetric filters in :func:`~tiling_axisym`. - - Note: - [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the - sphere", A&A, vol. 558, p. A128, 2013. - """ - - J = j_max(L, lam) - - normalisation = part_scaling_fn(1.0 / lam, 1.0, quad_iters, lam) - k = np.zeros((J + 2, L)) - - for j in range(J + 2): - for l in range(L): - if l < lam ** (j - 1): - k[j, l] = 1 - elif l > lam**j: - k[j, l] = 0 - else: - k[j, l] = ( - part_scaling_fn(l / lam**j, 1.0, quad_iters, lam) / normalisation - ) - - return k - - -@partial(jit, static_argnums=(2, 3)) # not sure -def part_scaling_fn_jax(a: float, b: float, n: int, lam: float = 2.0) -> float: - r"""JAX version of part_scaling_fn. Computes integral used to calculate smoothly decreasing function :math:`k_{\lambda}`. - - Intermediate step used to compute the wavelet and scaling function generating - functions. Uses the trapezium method to integrate :func:`~tiling_integrand` in the - limits from :math:`a \rightarrow b` with scaling parameter :math:`\lambda`. One of - the basic mathematical functions needed to carry out the tiling of the harmonic - space. - - Args: - a (float): Lower limit of the numerical integration. - - b (float): Upper limit of the numerical integration. - - n (int): Number of steps to be performed during integration. - - lam (float, optional): Wavelet parameter which determines the scale factor - between consecutive wavelet scales.Note that :math:`\lambda = 2` indicates - dyadic wavelets. Defaults to 2. - - Returns: - float: Integral of the tiling integrand from :math:`a \rightarrow b`. - """ - - h = (b - a) / n - - x = jnp.linspace(a, b, num=n + 1) - s_arg = (x - (1.0 / lam)) * (2.0 * lam / (lam - 1.0)) - 1.0 - value = jnp.where( - (x[:-1] == 1.0 / lam) | (x[:-1] == 1.0) | (x[1:] == 1.0 / lam) | (x[1:] == 1.0), - jnp.zeros(n), - (jnp.exp(-2.0 / (1.0 - jnp.square(s_arg))) / x)[:-1] - + (jnp.exp(-2.0 / (1.0 - jnp.square(s_arg))) / x)[1:], - ) - - return jnp.sum(value * h / 2) - - -@partial(jit, static_argnums=(0, 1, 2)) -def k_lam_jax(L: int, lam: float = 2.0, quad_iters: int = 300) -> float: - r"""JAX version of k_lam. Compute function :math:`k_{\lambda}` used as a wavelet generating function. - - Specifically, this function is derived in [1] and is given by - - .. math:: - - k_{\lambda} \equiv \frac{ \int_t^1 \frac{\text{d}t^{\prime}}{t^{\prime}} - s_{\lambda}^2(t^{\prime})}{ \int_{\frac{1}{\lambda}}^1 - \frac{\text{d}t^{\prime}}{t^{\prime}} s_{\lambda}^2(t^{\prime})}, - - where the integrand is defined to be - - .. math:: - - s_{\lambda} \equiv s \Big ( \frac{2\lambda}{\lambda - 1}(t-\frac{1}{\lambda}) - - 1 \Big ), - - for infinitely differentiable Cauchy-Schwartz function :math:`s(t) \in C^{\infty}`. - - Args: - L (int): Harmonic band-limit. - - lam (float, optional): Wavelet parameter which determines the scale factor - between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates - dyadic wavelets. Defaults to 2. - - quad_iters (int, optional): Total number of iterations for quadrature - integration. Defaults to 300. - - Returns: - (np.ndarray): Value of :math:`k_{\lambda}` computed for values between - :math:`\frac{1}{\lambda}` and 1, parametrised by :math:`\ell` as required to - compute the axisymmetric filters in :func:`~tiling_axisym`. - - Note: - [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the - sphere", A&A, vol. 558, p. A128, 2013. - """ - - J = j_max(L, lam) - - normalisation = part_scaling_fn(1.0 / lam, 1.0, quad_iters, lam) - k = jnp.zeros((J + 2, L)) - - for j in range(J + 2): - for l in range(L): - if l < lam ** (j - 1): - k = k.at[j, l].set(1.0) - elif l > lam**j: - k = k.at[j, l].set(0.0) - else: - k = k.at[j, l].set( - part_scaling_fn(l / lam**j, 1.0, quad_iters, lam) / normalisation - ) - - return k diff --git a/s2wav/filter_factory/tiling.py b/s2wav/filter_factory/tiling.py deleted file mode 100644 index c6f780a..0000000 --- a/s2wav/filter_factory/tiling.py +++ /dev/null @@ -1,161 +0,0 @@ -from jax import jit, config - -config.update("jax_enable_x64", True) - -import jax.numpy as jnp -import numpy as np -from s2wav.utils.math_functions import ( - binomial_coefficient, - binomial_coefficient_jax, -) -from functools import partial - - -def tiling_direction(L: int, N: int = 1) -> np.ndarray: - r"""Generates the harmonic coefficients for the directionality component of the - tiling functions. - - Formally, this function implements the follow equation - - .. math:: - - _{s}\eta_{\el m} = \nu \vu \sqrt{\frac{1}{2^{\gamma}} \big ( \binom{\gamma}{ - (\gamma - m)/2} \big )} - - which was first derived in `[1] `_. - - Args: - L (int): Harmonic band-limit. - - N (int, optional): Upper orientational band-limit. Defaults to 1. - - Returns: - np.ndarray: Harmonic coefficients of directionality components - :math:`_{s}\eta_{\el m}`. - - Notes: - [1] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint - arXiv:1509.06749 (2015). - """ - if N % 2: - nu = 1 - else: - nu = 1j - - s_elm = np.zeros((L, 2 * L - 1), dtype=np.complex128) - - for el in range(1, L): - if (N + el) % 2: - gamma = min(N - 1, el) - else: - gamma = min(N - 1, el - 1) - - for m in range(-el, el + 1): - if abs(m) < N and (N + m) % 2: - s_elm[el, L - 1 + m] = nu * np.sqrt( - (binomial_coefficient(gamma, ((gamma - m) / 2))) / (2**gamma) - ) - else: - s_elm[el, L - 1 + m] = 0.0 - - return s_elm - - -def spin_normalization(el: int, spin: int = 0) -> float: - r"""Computes the normalization factor for spin-lowered wavelets, which is - :math:`\sqrt{\frac{(\ell+s)!}{(\ell-s)!}}`. - - Args: - el (int): Harmonic index :math:`\ell`. - - spin (int): Spin of field over which to perform the transform. Defaults to 0. - - Returns: - float: Normalization factor for spin-lowered wavelets. - """ - factor = 1.0 - - for s in range(-abs(spin) + 1, abs(spin) + 1): - factor *= el + s - - if spin > 0: - return np.sqrt(factor) - else: - return np.sqrt(1.0 / factor) - - -def spin_normalization_vectorised(el: np.ndarray, spin: int = 0) -> float: - r"""Vectorised version of :func:`~spin_normalization`. - Args: - el (int): Harmonic index :math:`\ell`. - spin (int): Spin of field over which to perform the transform. Defaults to 0. - Returns: - float: Normalization factor for spin-lowered wavelets. - """ - factor = np.arange(-abs(spin) + 1, abs(spin) + 1).reshape(1, 2 * abs(spin) + 1) - factor = el.reshape(len(el), 1).dot(factor) - return np.sqrt(np.prod(factor, axis=1) ** (np.sign(spin))) - - -@partial(jit, static_argnums=(0, 1)) -def tiling_direction_jax(L: int, N: int = 1) -> np.ndarray: - r"""JAX version of tiling_direction. Generates the harmonic coefficients for the directionality component of the - tiling functions. - - Formally, this function implements the follow equation - - .. math:: - - _{s}\eta_{\ell m} = \nu \vu \sqrt{\frac{1}{2^{\gamma}} \big ( \binom{\gamma}{ - (\gamma - m)/2} \big )} - - which was first derived in `[1] `_. - - Args: - L (int): Harmonic band-limit. - - N (int, optional): Upper orientational band-limit. Defaults to 1. - - Returns: - np.ndarray: Harmonic coefficients of directionality components - :math:`_{s}\eta_{\ell m}`. - - Notes: - [1] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint - arXiv:1509.06749 (2015). - """ - - nu = (N % 2 - 1) ** 2 * 1j + (N % 2) - - s_elm = jnp.zeros((L, 2 * L - 1), dtype=np.complex128) - - for el in range(1, L): - gamma = min(N - 1, el - 1 + (N + el) % 2) - - ms = jnp.arange(-el, el + 1) - val = nu * jnp.sqrt( - (binomial_coefficient_jax(gamma, ((gamma - ms) / 2))) / (2**gamma) - ) - - val = jnp.where( - (ms < N) & (ms > -N) & ((N + ms) % 2 == 1), - val, - jnp.zeros(2 * el + 1), - ) - s_elm = s_elm.at[el, L - 1 - el : L + el].set(val) - - return s_elm - - -@partial(jit, static_argnums=(1)) -def spin_normalization_jax(el: np.ndarray, spin: int = 0) -> float: - r"""JAX version of :func:`~spin_normalization`. - Args: - el (int): Harmonic index :math:`\ell`. - spin (int): Spin of field over which to perform the transform. Defaults to 0. - Returns: - float: Normalization factor for spin-lowered wavelets. - """ - factor = jnp.arange(-abs(spin) + 1, abs(spin) + 1).reshape(1, 2 * abs(spin) + 1) - factor = el.reshape(len(el), 1).dot(factor) - return jnp.sqrt(jnp.prod(factor, axis=1) ** (jnp.sign(spin))) diff --git a/s2wav/filters.py b/s2wav/filters.py new file mode 100644 index 0000000..1462cc9 --- /dev/null +++ b/s2wav/filters.py @@ -0,0 +1,711 @@ +from jax import jit +import jax.numpy as jnp +import numpy as np +from typing import Tuple +from functools import partial +from s2wav import samples + + +def filters_axisym( + L: int, J_min: int = 0, lam: float = 2.0 +) -> Tuple[np.ndarray, np.ndarray]: + r"""Computes wavelet kernels :math:`\Psi^j_{\ell m}` and scaling kernel + :math:`\Phi_{\ell m}` in harmonic space. + + Specifically, these kernels are derived in `[1] `_, + where the wavelet kernels are defined (15) for scale :math:`j` to be + + .. math:: + + \Psi^j_{\ell m} \equiv \sqrt{\frac{2\ell+1}{4\pi}} \kappa_{\lambda}(\frac{\ell}{\lambda^j})\delta_{m0}, + + where :math:`\kappa_{\lambda} = \sqrt{k_{\lambda}(t/\lambda) - k_{\lambda}(t)}` for :math:`k_{\lambda}` + given in :func:`~k_lam`. Similarly, the scaling kernel is defined (16) as + + .. math:: + + \Phi_{\ell m} \equiv \sqrt{\frac{2\ell+1}{4\pi}} \nu_{\lambda} (\frac{\ell}{\lambda^{J_0}})\delta_{m0}, + + where :math:`\nu_{\lambda} = \sqrt{k_{\lambda}(t)}` for :math:`k_{\lambda}` given in :func:`~k_lam`. + Notice that :math:`\delta_{m0}` enforces that these kernels are axisymmetric, i.e. coefficients + for :math:`m \not = \ell` are zero. In this implementation the normalisation constant has been + omitted as it is nulled in subsequent functions. + + Args: + L (int): Harmonic band-limit. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between + consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic + wavelets. Defaults to 2. + + Raises: + ValueError: J_min is negative or greater than J. + + Returns: + Tuple[np.ndarray, np.ndarray]: Unnormalised wavelet kernels :math:`\Psi^j_{\ell m}` + with shape :math:`[(J+1)L]`, and scaling kernel :math:`\Phi_{\el m}` with shape + :math:`[L]` in harmonic space. + + Note: + [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the sphere", A&A, vol. 558, p. A128, 2013. + """ + J = samples.j_max(L, lam) + + if J_min >= J or J_min < 0: + raise ValueError( + "J_min must be non-negative and less than J= " + + str(J) + + " for given L and lam." + ) + + previoustemp = 0.0 + k = k_lam(L, lam) + psi = np.zeros((J + 1, L), np.float64) + phi = np.zeros(L, np.float64) + for l in range(L): + phi[l] = np.sqrt(k[J_min, l]) + + for j in range(J_min, J + 1): + for l in range(L): + diff = k[j + 1, l] - k[j, l] + if diff < 0: + psi[j, l] = previoustemp + else: + temp = np.sqrt(diff) + psi[j, l] = temp + previoustemp = temp + + return psi, phi + + +def filters_directional( + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + spin: int = 0, + spin0: int = 0, +) -> Tuple[np.ndarray, np.ndarray]: + r"""Generates the harmonic coefficients for the directional tiling wavelets. + + This implementation is based on equation 36 in the wavelet computation paper + `[1] `_. + + Args: + L (int): Harmonic band-limit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between + consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic + wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) to perform the transform. Defaults to 0. + + spin0 (int, optional): Spin number the wavelet was lowered from. Defaults to 0. + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels + (:math:`\Psi^j_{\ell n}`, :math:`\Phi_{\ell m}`) + + Notes: + [1] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). + """ + J = samples.j_max(L, lam) + el_min = max(abs(spin), abs(spin0)) + + phi = np.zeros(L, dtype=np.float64) + psi = np.zeros((J + 1, L, 2 * L - 1), dtype=np.complex128) + + kappa, kappa0 = filters_axisym(L, J_min, lam) + s_elm = tiling_direction(L, N) + + for el in range(el_min, L): + if kappa0[el] != 0: + phi[el] = np.sqrt((2 * el + 1) / (4.0 * np.pi)) * kappa0[el] + if spin0 != 0: + phi[el] *= spin_normalization(el, spin0) * (-1) ** spin0 + + for j in range(J_min, J + 1): + for el in range(el_min, L): + if kappa[j, el] != 0: + for m in range(-el, el + 1): + if s_elm[el, L - 1 + m] != 0: + psi[j, el, L - 1 + m] = ( + np.sqrt((2 * el + 1) / (8.0 * np.pi * np.pi)) + * kappa[j, el] + * s_elm[el, L - 1 + m] + ) + if spin0 != 0: + psi[j, el, L - 1 + m] *= ( + spin_normalization(el, spin0) * (-1) ** spin0 + ) + + return psi, phi + + +def filters_axisym_vectorised( + L: int, J_min: int = 0, lam: float = 2.0 +) -> Tuple[np.ndarray, np.ndarray]: + r"""Vectorised version of :func:`~filters_axisym`. + + Args: + L (int): Harmonic band-limit. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates + dyadic wavelets. Defaults to 2. + + Raises: + ValueError: J_min is negative or greater than J. + + Returns: + Tuple[np.ndarray, np.ndarray]: Unnormalised wavelet kernels :math:`\Psi^j_{\ell m}` + with shape :math:`[(J+1)L], and scaling kernel :math:`\Phi_{\ell m}` with shape + :math:`[L]` in harmonic space. + """ + J = samples.j_max(L, lam) + + if J_min >= J or J_min < 0: + raise ValueError( + "J_min must be non-negative and less than J= " + + str(J) + + " for given L and lam." + ) + + k = k_lam(L, lam) + diff = (np.roll(k, -1, axis=0) - k)[:-1] + diff[diff < 0] = 0 + return np.sqrt(diff), np.sqrt(k[J_min]) + + +def filters_directional_vectorised( + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + spin: int = 0, + spin0: int = 0, +) -> Tuple[np.ndarray, np.ndarray]: + r"""Vectorised version of :func:`~filters_directional`. + + Args: + L (int): Harmonic band-limit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between + consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic + wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) to perform the transform. Defaults to 0. + + spin0 (int, optional): Spin number the wavelet was lowered from. Defaults to 0. + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels + (:math:`\Psi^j_{\ell n}`, :math:`\Phi_{\ell m}`). + """ + el_min = max(abs(spin), abs(spin0)) + + spin_norms = ( + (-1) ** spin0 * spin_normalization_vectorised(np.arange(L), spin0) + if spin0 != 0 + else 1 + ) + + kappa, kappa0 = filters_axisym_vectorised(L, J_min, lam) + s_elm = tiling_direction(L, N) + + kappa0 *= np.sqrt((2 * np.arange(L) + 1) / (4.0 * np.pi)) + kappa0 = kappa0 * spin_norms if spin0 != 0 else kappa0 + + kappa *= np.sqrt((2 * np.arange(L) + 1) / 8.0) / np.pi + kappa = np.einsum("ij,jk->ijk", kappa, s_elm) + kappa = np.einsum("ijk,j->ijk", kappa, spin_norms) if spin0 != 0 else kappa + + kappa0[:el_min] = 0 + kappa[:, :el_min, :] = 0 + return kappa, kappa0 + + +@partial(jit, static_argnums=(0, 1, 2)) +def filters_axisym_jax( + L: int, J_min: int = 0, lam: float = 2.0 +) -> Tuple[jnp.ndarray, jnp.ndarray]: + r"""JAX version of :func:`~filters_axisym_vectorised`. + + Args: + L (int): Harmonic band-limit. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates + dyadic wavelets. Defaults to 2. + + Raises: + ValueError: J_min is negative or greater than J. + + Returns: + Tuple[np.ndarray, np.ndarray]: Unnormalised wavelet kernels :math:`\Psi^j_{\ell m}` + with shape :math:`[(J+1)L], and scaling kernel :math:`\Phi_{\ell m}` with shape + :math:`[L]` in harmonic space. + """ + J = samples.j_max(L, lam) + + if J_min >= J or J_min < 0: + raise ValueError( + "J_min must be non-negative and less than J= " + + str(J) + + " for given L and lam." + ) + + k = k_lam_jax(L, lam) + diff = (jnp.roll(k, -1, axis=0) - k)[:-1] + diff = jnp.where(diff < 0, jnp.zeros((J + 1, L)), diff) + return jnp.sqrt(diff), jnp.sqrt(k[J_min]) + + +@partial(jit, static_argnums=(0, 1, 2, 3, 4, 5)) +def filters_directional_jax( + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + spin: int = 0, + spin0: int = 0, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + r"""JAX version of :func:`~filters_directional`. + + Args: + L (int): Harmonic band-limit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between + consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic + wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) to perform the transform. Defaults to 0. + + spin0 (int, optional): Spin number the wavelet was lowered from. Defaults to 0. + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels + (:math:`\Psi^j_{\ell n}`, :math:`\Phi_{\ell m}`). + """ + el_min = max(abs(spin), abs(spin0)) + + spin_norms = ( + (-1) ** spin0 * spin_normalization_jax(np.arange(L), spin0) + if spin0 != 0 + else 1 + ) + + kappa, kappa0 = filters_axisym_jax(L, J_min, lam) + s_elm = tiling_direction_jax(L, N) + + kappa0 *= jnp.sqrt((2 * jnp.arange(L) + 1) / (4.0 * jnp.pi)) + kappa0 = kappa0 * spin_norms if spin0 != 0 else kappa0 + + kappa *= jnp.sqrt((2 * jnp.arange(L) + 1) / 8.0) / np.pi + kappa = jnp.einsum("ij,jk->ijk", kappa, s_elm, optimize=True) + kappa = ( + jnp.einsum("ijk,j->ijk", kappa, spin_norms, optimize=True) + if spin0 != 0 + else kappa + ) + + kappa0 = kappa0.at[:el_min].set(0) + kappa = kappa.at[:, :el_min, :].set(0) + + return kappa, kappa0 + +def tiling_integrand(t: float, lam: float = 2.0) -> float: + r"""Tiling integrand for scale-discretised wavelets `[1] `_. + + Intermediate step used to compute the wavelet and scaling function generating + functions. One of the basic mathematical functions needed to carry out the tiling of + the harmonic space. + + Args: + t (float): Real argument over which we integrate. + + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales.Note that :math:`\lambda = 2` indicates + dyadic wavelets. Defaults to 2. + + Returns: + float: Value of tiling integrand for given :math:`t` and scaling factor. + + Note: + [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on + the sphere", A&A, vol. 558, p. A128, 2013. + """ + s_arg = (t - (1.0 / lam)) * (2.0 * lam / (lam - 1.0)) - 1.0 + + integrand = np.exp(-2.0 / (1.0 - s_arg**2.0)) / t + + return integrand + + +def part_scaling_fn(a: float, b: float, n: int, lam: float = 2.0) -> float: + r"""Computes integral used to calculate smoothly decreasing function :math:`k_{\lambda}`. + + Intermediate step used to compute the wavelet and scaling function generating + functions. Uses the trapezium method to integrate :func:`~tiling_integrand` in the + limits from :math:`a \rightarrow b` with scaling parameter :math:`\lambda`. One of + the basic mathematical functions needed to carry out the tiling of the harmonic + space. + + Args: + a (float): Lower limit of the numerical integration. + + b (float): Upper limit of the numerical integration. + + n (int): Number of steps to be performed during integration. + + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales.Note that :math:`\lambda = 2` indicates + dyadic wavelets. Defaults to 2. + + Returns: + float: Integral of the tiling integrand from :math:`a \rightarrow b`. + """ + sum = 0.0 + h = (b - a) / n + + if a == b: + return 0 + + for i in range(n): + if a + i * h not in [1 / lam, 1.0] and a + (i + 1) * h not in [ + 1 / lam, + 1.0, + ]: + f1 = tiling_integrand(a + i * h, lam) + f2 = tiling_integrand(a + (i + 1) * h, lam) + + sum += ((f1 + f2) * h) / 2 + + return sum + + +def k_lam(L: int, lam: float = 2.0, quad_iters: int = 300) -> float: + r"""Compute function :math:`k_{\lambda}` used as a wavelet generating function. + + Specifically, this function is derived in [1] and is given by + + .. math:: + + k_{\lambda} \equiv \frac{ \int_t^1 \frac{\text{d}t^{\prime}}{t^{\prime}} + s_{\lambda}^2(t^{\prime})}{ \int_{\frac{1}{\lambda}}^1 + \frac{\text{d}t^{\prime}}{t^{\prime}} s_{\lambda}^2(t^{\prime})}, + + where the integrand is defined to be + + .. math:: + + s_{\lambda} \equiv s \Big ( \frac{2\lambda}{\lambda - 1}(t-\frac{1}{\lambda}) + - 1 \Big ), + + for infinitely differentiable Cauchy-Schwartz function :math:`s(t) \in C^{\infty}`. + + Args: + L (int): Harmonic band-limit. + + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates + dyadic wavelets. Defaults to 2. + + quad_iters (int, optional): Total number of iterations for quadrature + integration. Defaults to 300. + + Returns: + (np.ndarray): Value of :math:`k_{\lambda}` computed for values between + :math:`\frac{1}{\lambda}` and 1, parametrised by :math:`\ell` as required to + compute the axisymmetric filters in :func:`~tiling_axisym`. + + Note: + [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the + sphere", A&A, vol. 558, p. A128, 2013. + """ + + J = samples.j_max(L, lam) + + normalisation = part_scaling_fn(1.0 / lam, 1.0, quad_iters, lam) + k = np.zeros((J + 2, L)) + + for j in range(J + 2): + for l in range(L): + if l < lam ** (j - 1): + k[j, l] = 1 + elif l > lam**j: + k[j, l] = 0 + else: + k[j, l] = ( + part_scaling_fn(l / lam**j, 1.0, quad_iters, lam) / normalisation + ) + + return k + + +@partial(jit, static_argnums=(2, 3)) # not sure +def part_scaling_fn_jax(a: float, b: float, n: int, lam: float = 2.0) -> float: + r"""JAX version of part_scaling_fn. Computes integral used to calculate smoothly + decreasing function :math:`k_{\lambda}`. + + Intermediate step used to compute the wavelet and scaling function generating + functions. Uses the trapezium method to integrate :func:`~tiling_integrand` in the + limits from :math:`a \rightarrow b` with scaling parameter :math:`\lambda`. One of + the basic mathematical functions needed to carry out the tiling of the harmonic + space. + + Args: + a (float): Lower limit of the numerical integration. + + b (float): Upper limit of the numerical integration. + + n (int): Number of steps to be performed during integration. + + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales.Note that :math:`\lambda = 2` indicates + dyadic wavelets. Defaults to 2. + + Returns: + float: Integral of the tiling integrand from :math:`a \rightarrow b`. + """ + + h = (b - a) / n + + x = jnp.linspace(a, b, num=n + 1) + s_arg = (x - (1.0 / lam)) * (2.0 * lam / (lam - 1.0)) - 1.0 + value = jnp.where( + (x[:-1] == 1.0 / lam) | (x[:-1] == 1.0) | (x[1:] == 1.0 / lam) | (x[1:] == 1.0), + jnp.zeros(n), + (jnp.exp(-2.0 / (1.0 - jnp.square(s_arg))) / x)[:-1] + + (jnp.exp(-2.0 / (1.0 - jnp.square(s_arg))) / x)[1:], + ) + + return jnp.sum(value * h / 2) + + +@partial(jit, static_argnums=(0, 1, 2)) +def k_lam_jax(L: int, lam: float = 2.0, quad_iters: int = 300) -> float: + r"""JAX version of k_lam. Compute function :math:`k_{\lambda}` used as a wavelet + generating function. + + Specifically, this function is derived in [1] and is given by + + .. math:: + + k_{\lambda} \equiv \frac{ \int_t^1 \frac{\text{d}t^{\prime}}{t^{\prime}} + s_{\lambda}^2(t^{\prime})}{ \int_{\frac{1}{\lambda}}^1 + \frac{\text{d}t^{\prime}}{t^{\prime}} s_{\lambda}^2(t^{\prime})}, + + where the integrand is defined to be + + .. math:: + + s_{\lambda} \equiv s \Big ( \frac{2\lambda}{\lambda - 1}(t-\frac{1}{\lambda}) + - 1 \Big ), + + for infinitely differentiable Cauchy-Schwartz function :math:`s(t) \in C^{\infty}`. + + Args: + L (int): Harmonic band-limit. + + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates + dyadic wavelets. Defaults to 2. + + quad_iters (int, optional): Total number of iterations for quadrature + integration. Defaults to 300. + + Returns: + (np.ndarray): Value of :math:`k_{\lambda}` computed for values between + :math:`\frac{1}{\lambda}` and 1, parametrised by :math:`\ell` as required to + compute the axisymmetric filters in :func:`~tiling_axisym`. + + Note: + [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the + sphere", A&A, vol. 558, p. A128, 2013. + """ + + J = samples.j_max(L, lam) + + normalisation = part_scaling_fn(1.0 / lam, 1.0, quad_iters, lam) + k = jnp.zeros((J + 2, L)) + + for j in range(J + 2): + for l in range(L): + if l < lam ** (j - 1): + k = k.at[j, l].set(1.0) + elif l > lam**j: + k = k.at[j, l].set(0.0) + else: + k = k.at[j, l].set( + part_scaling_fn(l / lam**j, 1.0, quad_iters, lam) / normalisation + ) + + return k + +def tiling_direction(L: int, N: int = 1) -> np.ndarray: + r"""Generates the harmonic coefficients for the directionality component of the + tiling functions. + + Formally, this function implements the follow equation + + .. math:: + + _{s}\eta_{\el m} = \nu \vu \sqrt{\frac{1}{2^{\gamma}} \big ( \binom{\gamma}{ + (\gamma - m)/2} \big )} + + which was first derived in `[1] `_. + + Args: + L (int): Harmonic band-limit. + + N (int, optional): Upper orientational band-limit. Defaults to 1. + + Returns: + np.ndarray: Harmonic coefficients of directionality components + :math:`_{s}\eta_{\el m}`. + + Notes: + [1] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint + arXiv:1509.06749 (2015). + """ + if N % 2: + nu = 1 + else: + nu = 1j + + s_elm = np.zeros((L, 2 * L - 1), dtype=np.complex128) + + for el in range(1, L): + if (N + el) % 2: + gamma = min(N - 1, el) + else: + gamma = min(N - 1, el - 1) + + for m in range(-el, el + 1): + if abs(m) < N and (N + m) % 2: + s_elm[el, L - 1 + m] = nu * np.sqrt( + (samples.binomial_coefficient(gamma, ((gamma - m) / 2))) / (2**gamma) + ) + else: + s_elm[el, L - 1 + m] = 0.0 + + return s_elm + + +def spin_normalization(el: int, spin: int = 0) -> float: + r"""Computes the normalization factor for spin-lowered wavelets, which is + :math:`\sqrt{\frac{(\ell+s)!}{(\ell-s)!}}`. + + Args: + el (int): Harmonic index :math:`\ell`. + + spin (int): Spin of field over which to perform the transform. Defaults to 0. + + Returns: + float: Normalization factor for spin-lowered wavelets. + """ + factor = 1.0 + + for s in range(-abs(spin) + 1, abs(spin) + 1): + factor *= el + s + + if spin > 0: + return np.sqrt(factor) + else: + return np.sqrt(1.0 / factor) + + +def spin_normalization_vectorised(el: np.ndarray, spin: int = 0) -> float: + r"""Vectorised version of :func:`~spin_normalization`. + Args: + el (int): Harmonic index :math:`\ell`. + spin (int): Spin of field over which to perform the transform. Defaults to 0. + Returns: + float: Normalization factor for spin-lowered wavelets. + """ + factor = np.arange(-abs(spin) + 1, abs(spin) + 1).reshape(1, 2 * abs(spin) + 1) + factor = el.reshape(len(el), 1).dot(factor) + return np.sqrt(np.prod(factor, axis=1) ** (np.sign(spin))) + + +@partial(jit, static_argnums=(0, 1)) +def tiling_direction_jax(L: int, N: int = 1) -> np.ndarray: + r"""JAX version of tiling_direction. Generates the harmonic coefficients for the + directionality component of the tiling functions. + + Formally, this function implements the follow equation + + .. math:: + + _{s}\eta_{\ell m} = \nu \vu \sqrt{\frac{1}{2^{\gamma}} \big ( \binom{\gamma}{ + (\gamma - m)/2} \big )} + + which was first derived in `[1] `_. + + Args: + L (int): Harmonic band-limit. + + N (int, optional): Upper orientational band-limit. Defaults to 1. + + Returns: + np.ndarray: Harmonic coefficients of directionality components + :math:`_{s}\eta_{\ell m}`. + + Notes: + [1] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint + arXiv:1509.06749 (2015). + """ + + nu = (N % 2 - 1) ** 2 * 1j + (N % 2) + + s_elm = jnp.zeros((L, 2 * L - 1), dtype=np.complex128) + + for el in range(1, L): + gamma = min(N - 1, el - 1 + (N + el) % 2) + + ms = jnp.arange(-el, el + 1) + val = nu * jnp.sqrt( + (samples.binomial_coefficient_jax(gamma, ((gamma - ms) / 2))) / (2**gamma) + ) + + val = jnp.where( + (ms < N) & (ms > -N) & ((N + ms) % 2 == 1), + val, + jnp.zeros(2 * el + 1), + ) + s_elm = s_elm.at[el, L - 1 - el : L + el].set(val) + + return s_elm + + +@partial(jit, static_argnums=(1)) +def spin_normalization_jax(el: np.ndarray, spin: int = 0) -> float: + r"""JAX version of :func:`~spin_normalization`. + Args: + el (int): Harmonic index :math:`\ell`. + spin (int): Spin of field over which to perform the transform. Defaults to 0. + Returns: + float: Normalization factor for spin-lowered wavelets. + """ + factor = jnp.arange(-abs(spin) + 1, abs(spin) + 1).reshape(1, 2 * abs(spin) + 1) + factor = el.reshape(len(el), 1).dot(factor) + return jnp.sqrt(jnp.prod(factor, axis=1) ** (jnp.sign(spin))) \ No newline at end of file diff --git a/s2wav/utils/shapes.py b/s2wav/samples.py similarity index 95% rename from s2wav/utils/shapes.py rename to s2wav/samples.py index 4d76461..0c55264 100644 --- a/s2wav/utils/shapes.py +++ b/s2wav/samples.py @@ -1,14 +1,12 @@ -from jax import jit, config - -config.update("jax_enable_x64", True) - +from jax import jit import jax.numpy as jnp import numpy as np import math from functools import partial from typing import Tuple from s2fft.sampling import s2_samples, so3_samples - +from scipy.special import loggamma +from jax.scipy.special import gammaln as jax_gammaln def f_scal( L: int, @@ -589,3 +587,34 @@ def wavelet_shape_check( assert f_w[j - J_min].shape == f_wav_j( L, j, N, lam, sampling, nside, multiresolution ) + +def binomial_coefficient(n: int, k: int) -> int: + r"""Computes the binomial coefficient :math:`\binom{n}{k}`. + + Args: + n (int): Number of elements to choose from. + + k (int): Number of elements to pick. + + Returns: + (int): Number of possible subsets. + """ + return np.floor( + 0.5 + np.exp(loggamma(n + 1) - loggamma(k + 1) - loggamma(n - k + 1)) + ) + + +def binomial_coefficient_jax(n: int, k: int) -> int: + r"""Computes the binomial coefficient :math:`\binom{n}{k}`. + + Args: + n (int): Number of elements to choose from. + + k (int): Number of elements to pick. + + Returns: + (int): Number of possible subsets. + """ + return jnp.floor( + 0.5 + jnp.exp(jax_gammaln(n + 1) - jax_gammaln(k + 1) - jax_gammaln(n - k + 1)) + ) \ No newline at end of file diff --git a/s2wav/transforms/__init__.py b/s2wav/transforms/__init__.py index 6810dbc..f109544 100644 --- a/s2wav/transforms/__init__.py +++ b/s2wav/transforms/__init__.py @@ -1,3 +1,3 @@ -from . import numpy_wavelets -from . import jax_wavelets -from . import jax_wavelets_precompute +from . import base +from . import construct +from . import rec_wav_jax, pre_wav_jax diff --git a/s2wav/transforms/numpy_wavelets.py b/s2wav/transforms/base.py similarity index 92% rename from s2wav/transforms/numpy_wavelets.py rename to s2wav/transforms/base.py index 76e8818..d8e4849 100644 --- a/s2wav/transforms/numpy_wavelets.py +++ b/s2wav/transforms/base.py @@ -1,8 +1,7 @@ import numpy as np from typing import Tuple -from s2wav.utils import shapes -from s2wav.filter_factory import filters from s2fft import base_transforms as base +from s2wav import samples, filters def synthesis_looped( @@ -49,11 +48,11 @@ def synthesis_looped( [2] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). """ - shapes.wavelet_shape_check( + samples.wavelet_shape_check( f_wav, f_scal, L, N, J_min, lam, sampling, nside, multiresolution ) - J = shapes.j_max(L, lam) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, multiresolution) flm = np.zeros((L, 2 * L - 1), dtype=np.complex128) f_scal_lm = base.spherical.forward(f_scal, Ls, spin, sampling, nside, reality) @@ -63,7 +62,7 @@ def synthesis_looped( # Sum the all wavelet wigner coefficients for each lmn # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, multiresolution) temp = base.wigner.forward( f_wav[j - J_min], Lj, Nj, L0j, sampling, reality, nside ) @@ -132,12 +131,12 @@ def synthesis( [2] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). """ - shapes.wavelet_shape_check( + samples.wavelet_shape_check( f_wav, f_scal, L, N, J_min, lam, sampling, nside, multiresolution ) - J = shapes.j_max(L, lam) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, multiresolution) flm = np.zeros((L, 2 * L - 1), dtype=np.complex128) f_scal_lm = base.spherical.forward(f_scal, Ls, spin, sampling, nside, reality) @@ -149,7 +148,7 @@ def synthesis( # Sum the all wavelet wigner coefficients for each lmn # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, multiresolution) temp = base.wigner.forward( f_wav[j - J_min], Lj, Nj, L0j, sampling, reality, nside ) @@ -215,17 +214,17 @@ def analysis_looped( f_scal (np.ndarray): Array of scaling pixel-space coefficients with shape :math:`[n_{\theta}, n_{\phi}]`. """ - J = shapes.j_max(L, lam) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, multiresolution) - f_scal_lm = shapes.construct_flm(L, J_min, lam, multiresolution) - f_wav_lmn = shapes.construct_flmn(L, N, J_min, lam, multiresolution) + f_scal_lm = samples.construct_flm(L, J_min, lam, multiresolution) + f_wav_lmn = samples.construct_flmn(L, N, J_min, lam, multiresolution) wav_lm, scal_l = filters.filters_directional(L, N, J_min, lam, spin, spin0) flm = base.spherical.forward(f, L, spin, sampling, nside, reality) for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, multiresolution) for n in range(-Nj + 1, Nj, 2): for el in range(max(abs(spin), abs(n), L0j), Lj): psi = np.conj(wav_lm[j, el, L - 1 + n]) @@ -259,9 +258,9 @@ def analysis_looped( else: f_scal_lm[el, Ls - 1 - m] = flm[el, L - 1 - m] * phi - f_wav = shapes.construct_f(L, N, J_min, lam, sampling, nside, multiresolution) + f_wav = samples.construct_f(L, N, J_min, lam, sampling, nside, multiresolution) for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, multiresolution) f_wav[j - J_min] = base.wigner.inverse( f_wav_lmn[j - J_min], Lj, Nj, L0j, sampling, reality, nside ) @@ -323,12 +322,12 @@ def analysis( f_scal (np.ndarray): Array of scaling pixel-space coefficients with shape :math:`[n_{\theta}, n_{\phi}]`. """ - J = shapes.j_max(L, lam) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, multiresolution) - f_scal_lm = shapes.construct_flm(L, J_min, lam, multiresolution) - f_wav_lmn = shapes.construct_flmn(L, N, J_min, lam, multiresolution) - f_wav = shapes.construct_f(L, N, J_min, lam, sampling, multiresolution) + f_scal_lm = samples.construct_flm(L, J_min, lam, multiresolution) + f_wav_lmn = samples.construct_flmn(L, N, J_min, lam, multiresolution) + f_wav = samples.construct_f(L, N, J_min, lam, sampling, multiresolution) # Generate the directional wavelet kernels wav_lm, scal_l = filters.filters_directional_vectorised( @@ -342,7 +341,7 @@ def analysis( # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, multiresolution) f_wav_lmn[j - J_min][::2, L0j:] = np.einsum( "lm,ln->nlm", flm[L0j:Lj, L - Lj : L - 1 + Lj], diff --git a/s2wav/transforms/construct.py b/s2wav/transforms/construct.py new file mode 100644 index 0000000..5a5ad01 --- /dev/null +++ b/s2wav/transforms/construct.py @@ -0,0 +1,113 @@ +import jax.numpy as jnp +from typing import List +import s2fft +from s2fft.precompute_transforms.construct import ( + wigner_kernel_jax, + spin_spherical_kernel_jax, +) +from s2wav import samples + +def generate_full_precomputes( + L: int, + N: int, + J_min: int = 0, + lam: float = 2.0, + sampling: str = "mw", + nside: int = None, + forward: bool = False, + reality: bool = False, + nospherical: bool = False, +) -> List[jnp.ndarray]: + r"""Generates a list of precompute arrays associated with the underlying Wigner + transforms. + + Args: + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", + "healpix"}. Defaults to "mw". + + nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults + to None. + + forward (bool, optional): _description_. Defaults to False. + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + nospherical (bool, optional): Whether to only compute Wigner precomputes. + Defaults to False. + + Returns: + List[jnp.ndarray]: Precomputed recursion arrays for underlying Wigner transforms. + """ + precomps = [] + J = samples.j_max(L, lam) + for j in range(J_min, J): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + precomps.append(wigner_kernel_jax(Lj, Nj, reality, sampling, nside, forward)) + Ls = samples.scal_bandlimit(L, J_min, lam, True) + if nospherical: + return [], [], precomps + precompute_scaling = spin_spherical_kernel_jax( + Ls, 0, reality, sampling, nside, forward + ) + precompute_full = spin_spherical_kernel_jax( + L, 0, reality, sampling, nside, not forward + ) + return precompute_full, precompute_scaling, precomps + +def generate_wigner_precomputes( + L: int, + N: int, + J_min: int = 0, + lam: float = 2.0, + sampling: str = "mw", + nside: int = None, + forward: bool = False, + reality: bool = False +) -> List[jnp.ndarray]: + r"""Generates a list of precompute arrays associated with the underlying Wigner + transforms. + + Args: + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", + "healpix"}. Defaults to "mw". + + nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults + to None. + + forward (bool, optional): _description_. Defaults to False. + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + Returns: + List[jnp.ndarray]: Precomputed recursion arrays for underlying Wigner transforms. + """ + precomps = [] + J = samples.j_max(L, lam) + for j in range(J_min, J + 1): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + precomps.append( + s2fft.generate_precomputes_wigner_jax( + Lj, Nj, sampling, nside, forward, reality, L0j + ) + ) + return precomps \ No newline at end of file diff --git a/s2wav/transforms/jax_wavelets_precompute.py b/s2wav/transforms/pre_wav_jax.py similarity index 64% rename from s2wav/transforms/jax_wavelets_precompute.py rename to s2wav/transforms/pre_wav_jax.py index 693c303..7750b51 100644 --- a/s2wav/transforms/jax_wavelets_precompute.py +++ b/s2wav/transforms/pre_wav_jax.py @@ -1,78 +1,11 @@ from jax import jit import jax.numpy as jnp -from s2wav.utils import shapes from functools import partial from typing import Tuple, List -from s2fft.precompute_transforms.construct import ( - wigner_kernel_jax, - spin_spherical_kernel_jax, -) from s2fft.precompute_transforms import wigner, spherical +from s2wav import samples - -def generate_precomputes( - L: int, - N: int, - J_min: int = 0, - lam: float = 2.0, - sampling: str = "mw", - nside: int = None, - forward: bool = False, - reality: bool = False, - multiresolution: bool = False, - nospherical: bool = False, -) -> List[jnp.ndarray]: - r"""Generates a list of precompute arrays associated with the underlying Wigner - transforms. - - Args: - L (int): Harmonic bandlimit. - - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. - - lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. - Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", - "healpix"}. Defaults to "mw". - - nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults - to None. - - forward (bool, optional): _description_. Defaults to False. - - reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits - conjugate symmetry of harmonic coefficients. Defaults to False. - - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - - nospherical (bool, optional): Whether to only compute Wigner precomputes. - Defaults to False. - - Returns: - List[jnp.ndarray]: Precomputed recursion arrays for underlying Wigner transforms. - """ - precomps = [] - J = shapes.j_max(L, lam) - for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) - precomps.append(wigner_kernel_jax(Lj, Nj, reality, sampling, nside, forward)) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) - if nospherical: - return [], [], precomps - precompute_scaling = spin_spherical_kernel_jax( - Ls, 0, reality, sampling, nside, forward - ) - precompute_full = spin_spherical_kernel_jax( - L, 0, reality, sampling, nside, not forward - ) - return precompute_full, precompute_scaling, precomps - - -@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13)) +@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9)) def synthesis( f_wav: jnp.ndarray, f_scal: jnp.ndarray, @@ -81,17 +14,16 @@ def synthesis( J_min: int = 0, lam: float = 2.0, spin: int = 0, - spin0: int = 0, sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, filters: Tuple[jnp.ndarray] = None, - spmd: bool = False, precomps: List[List[jnp.ndarray]] = None, ) -> jnp.ndarray: r"""Computes the synthesis directional wavelet transform [1,2]. - Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` by summing the contributions from wavelet and scaling coefficients in harmonic space, see equation 27 from `[2] `_. + Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` + by summing the contributions from wavelet and scaling coefficients in harmonic space, + see equation 27 from `[2] `_. Args: f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. @@ -105,13 +37,12 @@ def synthesis( J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. - lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. - Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates + dyadic wavelets. Defaults to 2. spin (int, optional): Spin (integer) of input signal. Defaults to 0. - spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". @@ -121,15 +52,8 @@ def synthesis( reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. - spmd (bool, optional): Whether to map compute over multiple devices. Currently this - only maps over all available devices, and is only valid for JAX implementations. - Defaults to False. - precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. @@ -146,20 +70,20 @@ def synthesis( if precomps == None: raise ValueError("Must provide precomputed kernels for this transform!") - J = shapes.j_max(L, lam) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, True) flm = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) f_scal_lm = spherical.forward_transform_jax( f_scal, precomps[1], Ls, sampling, reality, spin, nside ) # Sum the all wavelet wigner coefficients for each lmn - # Note that almost the entire compute is concentrated at the highest J + # Note that almost the entire compute is concentrated at the highest two scales. for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) - spmd_iter = spmd if N == Nj else False + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + shift = 0 if j < J else -1 temp = wigner.forward_transform_jax( - f_wav[j - J_min], precomps[2][j - J_min], Lj, Nj, sampling, reality, nside + f_wav[j - J_min], precomps[2][j-J_min+shift], Lj, Nj, sampling, reality, nside ) flm = flm.at[L0j:Lj, L - Lj : L - 1 + Lj].add( jnp.einsum( @@ -180,7 +104,7 @@ def synthesis( ) -@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12)) +@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8)) def analysis( f: jnp.ndarray, L: int, @@ -188,13 +112,10 @@ def analysis( J_min: int = 0, lam: float = 2.0, spin: int = 0, - spin0: int = 0, sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, filters: Tuple[jnp.ndarray] = None, - spmd: bool = False, precomps: List[List[jnp.ndarray]] = None, ) -> Tuple[jnp.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. @@ -213,8 +134,6 @@ def analysis( spin (int, optional): Spin (integer) of input signal. Defaults to 0. - spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults @@ -223,15 +142,8 @@ def analysis( reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. - spmd (bool, optional): Whether to map compute over multiple devices. Currently this - only maps over all available devices, and is only valid for JAX implementations. - Defaults to False. - precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. @@ -245,11 +157,11 @@ def analysis( if precomps == None: raise ValueError("Must provide precomputed kernels for this transform!") - J = shapes.j_max(L, lam) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, True) - f_wav_lmn = shapes.construct_flmn_jax(L, N, J_min, lam, multiresolution) - f_wav = shapes.construct_f_jax(L, N, J_min, lam, sampling, nside, multiresolution) + f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, lam, True) + f_wav = samples.construct_f_jax(L, N, J_min, lam, sampling, nside, True) wav_lm = jnp.einsum( "jln, l->jln", @@ -264,8 +176,7 @@ def analysis( # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) - spmd_iter = spmd if N == Nj else False + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) f_wav_lmn[j - J_min] = ( f_wav_lmn[j - J_min] .at[::2, L0j:] @@ -278,10 +189,10 @@ def analysis( ) ) ) - + shift = 0 if j < J else -1 f_wav[j - J_min] = wigner.inverse_transform_jax( f_wav_lmn[j - J_min], - precomps[2][j - J_min], + precomps[2][j - J_min + shift], Lj, Nj, sampling, @@ -302,7 +213,7 @@ def analysis( return f_wav, f_scal -@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 11)) +@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8)) def flm_to_analysis( flm: jnp.ndarray, L: int, @@ -313,9 +224,7 @@ def flm_to_analysis( sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, filters: Tuple[jnp.ndarray] = None, - spmd: bool = False, precomps: List[List[jnp.ndarray]] = None, ) -> Tuple[jnp.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. @@ -334,8 +243,6 @@ def flm_to_analysis( spin (int, optional): Spin (integer) of input signal. Defaults to 0. - spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults @@ -344,15 +251,8 @@ def flm_to_analysis( reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. - spmd (bool, optional): Whether to map compute over multiple devices. Currently this - only maps over all available devices, and is only valid for JAX implementations. - Defaults to False. - precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. @@ -366,10 +266,10 @@ def flm_to_analysis( if precomps == None: raise ValueError("Must provide precomputed kernels for this transform!") - J = J_max if J_max is not None else shapes.j_max(L, lam) + J = J_max if J_max is not None else samples.j_max(L, lam) - f_wav_lmn = shapes.construct_flmn_jax(L, N, J_min, lam, multiresolution) - f_wav = shapes.construct_f_jax(L, N, J_min, lam, sampling, nside, multiresolution) + f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, lam, True) + f_wav = samples.construct_f_jax(L, N, J_min, lam, sampling, nside, True) wav_lm = jnp.einsum( "jln, l->jln", @@ -381,7 +281,7 @@ def flm_to_analysis( # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) f_wav_lmn[j - J_min] = ( f_wav_lmn[j - J_min] .at[::2, L0j:] @@ -394,10 +294,10 @@ def flm_to_analysis( ) ) ) - + shift = 0 if j < J else -1 f_wav[j - J_min] = wigner.inverse_transform_jax( f_wav_lmn[j - J_min], - precomps[2][j - J_min], + precomps[2][j - J_min + shift], Lj, Nj, sampling, diff --git a/s2wav/transforms/jax_wavelets.py b/s2wav/transforms/rec_wav_jax.py similarity index 66% rename from s2wav/transforms/jax_wavelets.py rename to s2wav/transforms/rec_wav_jax.py index 1492dd2..fdcd920 100644 --- a/s2wav/transforms/jax_wavelets.py +++ b/s2wav/transforms/rec_wav_jax.py @@ -1,69 +1,12 @@ -from jax import jit, config - -config.update("jax_enable_x64", True) - +from jax import jit import jax.numpy as jnp -from s2wav.utils import shapes -import s2fft from functools import partial from typing import Tuple, List +import s2fft +from s2wav import samples +from s2wav.transforms import construct - -@partial(jit, static_argnums=(0, 1, 2, 3, 4, 5, 6, 7, 8)) -def generate_wigner_precomputes( - L: int, - N: int, - J_min: int = 0, - lam: float = 2.0, - sampling: str = "mw", - nside: int = None, - forward: bool = False, - reality: bool = False, - multiresolution: bool = False, -) -> List[jnp.ndarray]: - r"""Generates a list of precompute arrays associated with the underlying Wigner - transforms. - - Args: - L (int): Harmonic bandlimit. - - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. - - lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. - Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", - "healpix"}. Defaults to "mw". - - nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults - to None. - - forward (bool, optional): _description_. Defaults to False. - - reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits - conjugate symmetry of harmonic coefficients. Defaults to False. - - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - - Returns: - List[jnp.ndarray]: Precomputed recursion arrays for underlying Wigner transforms. - """ - precomps = [] - J = shapes.j_max(L, lam) - for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) - precomps.append( - s2fft.generate_precomputes_wigner_jax( - Lj, Nj, sampling, nside, forward, reality, L0j - ) - ) - return precomps - - -# @partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13)) +@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9)) def synthesis( f_wav: jnp.ndarray, f_scal: jnp.ndarray, @@ -72,13 +15,10 @@ def synthesis( J_min: int = 0, lam: float = 2.0, spin: int = 0, - spin0: int = 0, sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, filters: Tuple[jnp.ndarray] = None, - spmd: bool = False, precomps: List[List[jnp.ndarray]] = None, ) -> jnp.ndarray: r"""Computes the synthesis directional wavelet transform [1,2]. @@ -101,8 +41,6 @@ def synthesis( spin (int, optional): Spin (integer) of input signal. Defaults to 0. - spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". @@ -112,15 +50,8 @@ def synthesis( reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. - spmd (bool, optional): Whether to map compute over multiple devices. Currently this - only maps over all available devices, and is only valid for JAX implementations. - Defaults to False. - precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. @@ -135,19 +66,18 @@ def synthesis( [2] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). """ if precomps == None: - precomps = generate_wigner_precomputes( - L, N, J_min, lam, sampling, nside, True, reality, multiresolution + precomps = construct.generate_wigner_precomputes( + L, N, J_min, lam, sampling, nside, True, reality ) - J = shapes.j_max(L, lam) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, True) flm = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) f_scal_lm = s2fft.forward_jax(f_scal, Ls, spin, nside, sampling, reality) # Sum the all wavelet wigner coefficients for each lmn # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) - spmd_iter = spmd if N == Nj else False + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) temp = s2fft.wigner.forward_jax( f_wav[j - J_min], Lj, @@ -156,7 +86,6 @@ def synthesis( sampling, reality, precomps[j - J_min], - spmd_iter, L_lower=L0j, ) flm = flm.at[L0j:Lj, L - Lj : L - 1 + Lj].add( @@ -177,7 +106,7 @@ def synthesis( return s2fft.inverse_jax(flm, L, spin, nside, sampling, reality) -# @partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12)) +@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8)) def analysis( f: jnp.ndarray, L: int, @@ -185,13 +114,10 @@ def analysis( J_min: int = 0, lam: float = 2.0, spin: int = 0, - spin0: int = 0, sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, filters: Tuple[jnp.ndarray] = None, - spmd: bool = False, precomps: List[List[jnp.ndarray]] = None, ) -> Tuple[jnp.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. @@ -210,8 +136,6 @@ def analysis( spin (int, optional): Spin (integer) of input signal. Defaults to 0. - spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults @@ -220,15 +144,8 @@ def analysis( reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. - spmd (bool, optional): Whether to map compute over multiple devices. Currently this - only maps over all available devices, and is only valid for JAX implementations. - Defaults to False. - precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. @@ -240,14 +157,14 @@ def analysis( with shape :math:`[n_{\theta}, n_{\phi}]`. """ if precomps == None: - precomps = generate_wigner_precomputes( - L, N, J_min, lam, sampling, nside, False, reality, multiresolution + precomps = construct.generate_wigner_precomputes( + L, N, J_min, lam, sampling, nside, False, reality ) - J = shapes.j_max(L, lam) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, True) - f_wav_lmn = shapes.construct_flmn_jax(L, N, J_min, lam, multiresolution) - f_wav = shapes.construct_f_jax(L, N, J_min, lam, sampling, nside, multiresolution) + f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, lam, True) + f_wav = samples.construct_f_jax(L, N, J_min, lam, sampling, nside, True) wav_lm = jnp.einsum( "jln, l->jln", @@ -261,8 +178,7 @@ def analysis( # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) - spmd_iter = spmd if N == Nj else False + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) f_wav_lmn[j - J_min] = ( f_wav_lmn[j - J_min] .at[::2, L0j:] @@ -284,7 +200,7 @@ def analysis( sampling, reality, precomps[j - J_min], - spmd_iter, + False, L0j, ) @@ -299,7 +215,7 @@ def analysis( return f_wav, f_scal -@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 11)) +@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8)) def flm_to_analysis( flm: jnp.ndarray, L: int, @@ -310,9 +226,7 @@ def flm_to_analysis( sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, filters: Tuple[jnp.ndarray] = None, - spmd: bool = False, precomps: List[List[jnp.ndarray]] = None, ) -> Tuple[jnp.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. @@ -337,15 +251,8 @@ def flm_to_analysis( reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - filters (jnp.ndarray, optional): Precomputed wavelet filters. Defaults to None. - spmd (bool, optional): Whether to map compute over multiple devices. Currently this - only maps over all available devices, and is only valid for JAX implementations. - Defaults to False. - precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. @@ -354,14 +261,14 @@ def flm_to_analysis( with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. """ if precomps == None: - precomps = generate_wigner_precomputes( - L, N, J_min, lam, sampling, nside, False, reality, multiresolution + precomps = construct.generate_wigner_precomputes( + L, N, J_min, lam, sampling, nside, False, reality ) - J = J_max if J_max is not None else shapes.j_max(L, lam) + J = J_max if J_max is not None else samples.j_max(L, lam) - f_wav_lmn = shapes.construct_flmn_jax(L, N, J_min, lam, multiresolution) - f_wav = shapes.construct_f_jax(L, N, J_min, lam, sampling, nside, multiresolution) + f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, lam, True) + f_wav = samples.construct_f_jax(L, N, J_min, lam, sampling, nside, True) wav_lm = jnp.einsum( "jln, l->jln", @@ -373,8 +280,7 @@ def flm_to_analysis( # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) - spmd_iter = spmd if N == Nj else False + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) f_wav_lmn[j - J_min] = ( f_wav_lmn[j - J_min] .at[::2, L0j:] @@ -396,7 +302,7 @@ def flm_to_analysis( sampling, reality, precomps[j - J_min], - spmd_iter, + False, L0j, ) diff --git a/s2wav/utils/__init__.py b/s2wav/utils/__init__.py deleted file mode 100644 index b3cb75c..0000000 --- a/s2wav/utils/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import math_functions -from . import shapes diff --git a/s2wav/utils/math_functions.py b/s2wav/utils/math_functions.py deleted file mode 100644 index d64c13f..0000000 --- a/s2wav/utils/math_functions.py +++ /dev/null @@ -1,41 +0,0 @@ -from jax import config - -config.update("jax_enable_x64", True) - -import jax.numpy as jnp -import numpy as np -from scipy.special import loggamma -from jax.scipy.special import gammaln as jax_gammaln -from functools import partial - - -def binomial_coefficient(n: int, k: int) -> int: - r"""Computes the binomial coefficient :math:`\binom{n}{k}`. - - Args: - n (int): Number of elements to choose from. - - k (int): Number of elements to pick. - - Returns: - (int): Number of possible subsets. - """ - return np.floor( - 0.5 + np.exp(loggamma(n + 1) - loggamma(k + 1) - loggamma(n - k + 1)) - ) - - -def binomial_coefficient_jax(n: int, k: int) -> int: - r"""Computes the binomial coefficient :math:`\binom{n}{k}`. - - Args: - n (int): Number of elements to choose from. - - k (int): Number of elements to pick. - - Returns: - (int): Number of possible subsets. - """ - return jnp.floor( - 0.5 + jnp.exp(jax_gammaln(n + 1) - jax_gammaln(k + 1) - jax_gammaln(n - k + 1)) - ) diff --git a/setup.py b/setup.py index f61151c..6c84ddb 100644 --- a/setup.py +++ b/setup.py @@ -12,15 +12,15 @@ setup( classifiers=[ - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Operating System :: OS Independent", "Intended Audience :: Developers", "Intended Audience :: Science/Research", ], name="s2wav", - version="0.0.1", + version="0.0.2", url="https://github.com/astro-informatics/s2wav", author="Authors & Contributors", license="GNU General Public License v3 (GPLv3)", diff --git a/tests/conftest.py b/tests/conftest.py index 1763b3f..6d194a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,10 @@ from typing import Tuple import numpy as np import pytest +import s2fft +from s2fft import base_transforms as base +from s2fft.sampling import so3_samples +from s2wav import samples DEFAULT_SEED = 8966433580120847635 @@ -34,17 +38,13 @@ def generate_f_wav_scal( lam: float, sampling: str = "mw", reality: bool = False, - multiresolution: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - from s2wav.utils import shapes - from s2fft import base_transforms as base - - J = shapes.j_max(L, lam) - flmn = shapes.construct_flmn(L, N, J_min, lam, multiresolution) + J = samples.j_max(L, lam) + flmn = samples.construct_flmn(L, N, J_min, lam, True) f_wav = [] for j in range(J_min, J + 1): - Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) for n in range(-Nj + 1, Nj, 2): for el in range(max(abs(n), L0j), Lj): @@ -54,7 +54,7 @@ def generate_f_wav_scal( ) f_wav.append(base.wigner.inverse(flmn[j - J_min], Lj, Nj, 0, sampling, reality)) - L_s = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + L_s = samples.scal_bandlimit(L, J_min, lam, True) flm = np.zeros((L_s, 2 * L_s - 1), dtype=np.complex128) for el in range(L_s): for m in range(-el, el + 1): @@ -65,7 +65,7 @@ def generate_f_wav_scal( return ( f_wav, f_scal, - s2wav_to_s2let(f_wav, L, N, J_min, lam, multiresolution), + s2wav_to_s2let(f_wav, L, N, J_min, lam, True), f_scal.flatten("C"), ) @@ -78,9 +78,8 @@ def s2wav_to_s2let( lam: float = 2.0, multiresolution: bool = False, ) -> int: - from s2wav.utils.shapes import j_max - J = j_max(L, lam) + J = samples.j_max(L, lam) f_wav_s2let = np.zeros( n_wav(L, N, J_min, lam, multiresolution), dtype=np.complex128 ) @@ -100,13 +99,11 @@ def n_wav( multiresolution: bool = False, sampling: str = "mw", ) -> int: - from s2wav.utils import shapes - from s2fft.sampling import so3_samples - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) count = 0 for j in range(J_min, J + 1): - Lj = shapes.wav_j_bandlimit(L, j, lam, multiresolution) + Lj = samples.wav_j_bandlimit(L, j, lam, multiresolution) count += np.prod(list(so3_samples.f_shape(Lj, N, sampling))) return count @@ -135,6 +132,4 @@ def wavelet_generator(rng): def flm_generator(rng): # Import s2fft (and indirectly numpy) locally to avoid # `RuntimeWarning: numpy.ndarray size changed` when importing at module level - import s2fft - return partial(s2fft.utils.signal_generator.generate_flm, rng) diff --git a/tests/test_filters.py b/tests/test_filters.py index 18a6aaa..283d801 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,7 +1,6 @@ import pytest import numpy as np -from s2wav.filter_factory import filters, tiling -from s2wav.utils.shapes import j_max +from s2wav import filters, samples L_to_test = [8, 16] N_to_test = [2, 3] @@ -14,7 +13,7 @@ @pytest.mark.parametrize("lam", lam_to_test) def test_axisym_admissibility(L: int, J_min: int, lam: int): Psi, Phi = filters.filters_axisym(L, J_min, lam) - J = j_max(L, lam) + J = samples.j_max(L, lam) Psi_j_sum = np.zeros_like(Phi) for j in range(J_min, J + 1): for el in range(L): @@ -34,7 +33,7 @@ def test_axisym_admissibility(L: int, J_min: int, lam: int): def test_directional_admissibility(L: int, N: int, J_min: int, lam: int): spin = 0 psi, phi = filters.filters_directional(L, N, J_min, lam) - J = j_max(L, lam) + J = samples.j_max(L, lam) ident = np.zeros(L, dtype=np.complex128) @@ -62,7 +61,7 @@ def test_directional_admissibility(L: int, N: int, J_min: int, lam: int): @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("N", N_to_test) def test_directional_tiling(L: int, N: int): - s_elm = tiling.tiling_direction(L, N) + s_elm = filters.tiling_direction(L, N) for el in range(1, L): temp = 0 for m in range(-el, el + 1): diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 8d3e96b..97aa88e 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,12 +1,9 @@ import pytest -import numpy as np - -from s2wav.transforms import jax_wavelets -from s2wav.filter_factory import filters -from s2wav.utils import shapes import jax.numpy as jnp from jax.test_util import check_grads import s2fft +from s2wav.transforms import rec_wav_jax, pre_wav_jax, construct +from s2wav import filters, samples L_to_test = [8] N_to_test = [3] @@ -14,13 +11,14 @@ multiresolution = [False, True] reality = [False, True] sampling_to_test = ["mw", "mwss", "dh"] - +recursive_transform = [False, True] @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) +@pytest.mark.parametrize("recursive", recursive_transform) def test_jax_synthesis_gradients( flm_generator, L: int, @@ -28,18 +26,29 @@ def test_jax_synthesis_gradients( J_min: int, multiresolution: bool, reality: bool, + recursive: bool ): - J = shapes.j_max(L) + J = samples.j_max(L) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") # Generate wavelet filters filter = filters.filters_directional_vectorised(L, N, J_min) + generator = construct.generate_wigner_precomputes if recursive else construct.generate_full_precomputes + synthesis = rec_wav_jax.synthesis if recursive else pre_wav_jax.synthesis + precomps = generator( + L, + N, + J_min, + forward=True, + reality=reality, + multiresolution=multiresolution, + ) # Generate random signal flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) f = s2fft.inverse_jax(flm, L) - f_wav, f_scal = jax_wavelets.analysis( + f_wav, f_scal = rec_wav_jax.analysis( f, L, N, @@ -54,7 +63,7 @@ def test_jax_synthesis_gradients( f_target = s2fft.inverse_jax(flm_target, L) def func(f_wav, f_scal): - f = jax_wavelets.synthesis( + f = synthesis( f_wav, f_scal, L, @@ -63,6 +72,7 @@ def func(f_wav, f_scal): multiresolution=multiresolution, reality=reality, filters=filter, + precomps=precomps, ) return jnp.sum(jnp.abs(f - f_target) ** 2) @@ -82,6 +92,7 @@ def func(f_wav, f_scal): @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) +@pytest.mark.parametrize("recursive", recursive_transform) def test_jax_analysis_gradients( flm_generator, L: int, @@ -89,13 +100,24 @@ def test_jax_analysis_gradients( J_min: int, multiresolution: bool, reality: bool, + recursive: bool ): - J = shapes.j_max(L) + J = samples.j_max(L) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") # Generate wavelet filters filter = filters.filters_directional_vectorised(L, N, J_min) + generator = construct.generate_wigner_precomputes if recursive else construct.generate_full_precomputes + analysis = rec_wav_jax.analysis if recursive else pre_wav_jax.analysis + precomps = generator( + L, + N, + J_min, + forward=False, + reality=reality, + multiresolution=multiresolution, + ) # Generate random signal flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) @@ -104,7 +126,7 @@ def test_jax_analysis_gradients( # Generate target signal flm_target = flm_generator(L=L, L_lower=0, spin=0, reality=reality) f_target = s2fft.inverse_jax(flm_target, L) - f_wav_target, f_scal_target = jax_wavelets.analysis( + f_wav_target, f_scal_target = rec_wav_jax.analysis( f_target, L, N, @@ -115,7 +137,7 @@ def test_jax_analysis_gradients( ) def func(f): - f_wav, f_scal = jax_wavelets.analysis( + f_wav, f_scal = analysis( f, L, N, @@ -123,6 +145,7 @@ def func(f): multiresolution=multiresolution, reality=reality, filters=filter, + precomps=precomps ) loss = jnp.sum(jnp.abs(f_scal - f_scal_target) ** 2) for j in range(J - J_min): diff --git a/tests/test_wavelets.py b/tests/test_wavelets.py index a5caacf..31eb574 100644 --- a/tests/test_wavelets.py +++ b/tests/test_wavelets.py @@ -1,40 +1,34 @@ import pytest import numpy as np import pys2let as s2let - -from s2wav.transforms import jax_wavelets -from s2wav.filter_factory import filters -from s2wav.utils import shapes -from s2fft import base_transforms as base +from s2fft import base_transforms as sht_base +from s2wav.transforms import rec_wav_jax, pre_wav_jax, construct +from s2wav import filters, samples L_to_test = [8] N_to_test = [2, 3] J_min_to_test = [2] lam_to_test = [2, 3] -multiresolution = [False, True] reality = [False, True] -multiple_gpus = [False] sampling_to_test = ["mw", "mwss", "dh"] - +recursive_transform = [False, True] @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) -@pytest.mark.parametrize("spmd", multiple_gpus) +@pytest.mark.parametrize("recursive", recursive_transform) def test_jax_synthesis( wavelet_generator, L: int, N: int, J_min: int, lam: int, - multiresolution: bool, reality: bool, - spmd: bool, + recursive: bool ): - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") @@ -43,7 +37,6 @@ def test_jax_synthesis( N=N, J_min=J_min, lam=lam, - multiresolution=multiresolution, reality=reality, ) @@ -55,32 +48,31 @@ def test_jax_synthesis( J_min, N, spin=0, - upsample=not multiresolution, + upsample=False, ) - # Precompute some values filter = filters.filters_directional_vectorised(L, N, J_min, lam) - precomps = jax_wavelets.generate_wigner_precomputes( + generator = construct.generate_wigner_precomputes if recursive else construct.generate_full_precomputes + synthesis = rec_wav_jax.synthesis if recursive else pre_wav_jax.synthesis + + precomps = generator( L, N, J_min, lam, forward=True, - reality=reality, - multiresolution=multiresolution, - ) - f_check = jax_wavelets.synthesis( + reality=reality + ) + f_check = synthesis( f_wav, f_scal, L, N, J_min, lam, - multiresolution=multiresolution, reality=reality, filters=filter, - precomps=precomps, - spmd=spmd, + precomps=precomps ) f = np.real(f) if reality else f np.testing.assert_allclose(f, f_check.flatten("C"), atol=1e-14) @@ -90,9 +82,8 @@ def test_jax_synthesis( @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) -@pytest.mark.parametrize("spmd", multiple_gpus) +@pytest.mark.parametrize("recursive", recursive_transform) def test_jax_analysis( flm_generator, f_wav_converter, @@ -100,16 +91,15 @@ def test_jax_analysis( N: int, J_min: int, lam: int, - multiresolution: bool, reality: bool, - spmd: bool, + recursive: bool ): - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f = base.spherical.inverse(flm, L, reality=reality) + f = sht_base.spherical.inverse(flm, L, reality=reality) f_wav, f_scal = s2let.analysis_px2wav( f.flatten("C").astype(np.complex128), @@ -118,32 +108,31 @@ def test_jax_analysis( J_min, N, spin=0, - upsample=not multiresolution, + upsample=False ) filter = filters.filters_directional_vectorised(L, N, J_min, lam) - precomps = jax_wavelets.generate_wigner_precomputes( + generator = construct.generate_wigner_precomputes if recursive else construct.generate_full_precomputes + analysis = rec_wav_jax.analysis if recursive else pre_wav_jax.analysis + precomps = generator( L, N, J_min, lam, forward=False, - reality=reality, - multiresolution=multiresolution, + reality=reality ) - f_wav_check, f_scal_check = jax_wavelets.analysis( + f_wav_check, f_scal_check = analysis( f, L, N, J_min, lam, - multiresolution=multiresolution, reality=reality, filters=filter, - precomps=precomps, - spmd=spmd, + precomps=precomps ) - f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, multiresolution) + f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, True) np.testing.assert_allclose(f_wav, f_wav_check, atol=1e-14) np.testing.assert_allclose(f_scal, f_scal_check.flatten("C"), atol=1e-14) @@ -153,53 +142,45 @@ def test_jax_analysis( @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("sampling", sampling_to_test) -@pytest.mark.parametrize("spmd", multiple_gpus) def test_jax_round_trip( flm_generator, L: int, N: int, J_min: int, lam: int, - multiresolution: bool, reality: bool, - sampling: str, - spmd: bool, + sampling: str ): - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f = base.spherical.inverse(flm, L, reality=reality, sampling=sampling) + f = sht_base.spherical.inverse(flm, L, reality=reality, sampling=sampling) filter = filters.filters_directional_vectorised(L, N, J_min, lam) - f_wav, f_scal = jax_wavelets.analysis( + f_wav, f_scal = rec_wav_jax.analysis( f, L, N, J_min, lam, - multiresolution=multiresolution, reality=reality, sampling=sampling, - filters=filter, - spmd=spmd, + filters=filter ) - f_check = jax_wavelets.synthesis( + f_check = rec_wav_jax.synthesis( f_wav, f_scal, L, N, J_min, lam, - multiresolution=multiresolution, sampling=sampling, reality=reality, - filters=filter, - spmd=spmd, + filters=filter ) np.testing.assert_allclose(f, f_check, atol=1e-14) diff --git a/tests/test_wavelets_base.py b/tests/test_wavelets_base.py index 94182ca..32896cd 100644 --- a/tests/test_wavelets_base.py +++ b/tests/test_wavelets_base.py @@ -1,11 +1,9 @@ import pytest import numpy as np import pys2let as s2let - -from s2wav.transforms import numpy_wavelets -from s2wav.utils import shapes -from s2fft import base_transforms as base - +from s2fft import base_transforms as sht_base +from s2wav.transforms import base as wav_base +from s2wav import samples L_to_test = [8] N_to_test = [2, 3] @@ -31,7 +29,7 @@ def test_synthesis_looped( multiresolution: bool, reality: bool, ): - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") @@ -55,7 +53,7 @@ def test_synthesis_looped( upsample=not multiresolution, ) - f_check = numpy_wavelets.synthesis_looped( + f_check = wav_base.synthesis_looped( f_wav, f_scal, L, N, J_min, lam, multiresolution=multiresolution ) @@ -77,7 +75,7 @@ def test_synthesis_vectorised( multiresolution: bool, reality: bool, ): - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") @@ -101,7 +99,7 @@ def test_synthesis_vectorised( upsample=not multiresolution, ) - f_check = numpy_wavelets.synthesis( + f_check = wav_base.synthesis( f_wav, f_scal, L, @@ -131,12 +129,12 @@ def test_analysis_looped( multiresolution: bool, reality: bool, ): - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f = base.spherical.inverse(flm, L, reality=reality) + f = sht_base.spherical.inverse(flm, L, reality=reality) f_wav, f_scal = s2let.analysis_px2wav( f.flatten("C").astype(np.complex128), @@ -147,7 +145,7 @@ def test_analysis_looped( spin=0, upsample=not multiresolution, ) - f_wav_check, f_scal_check = numpy_wavelets.analysis_looped( + f_wav_check, f_scal_check = wav_base.analysis_looped( f, L, N, J_min, lam, reality=reality, multiresolution=multiresolution ) f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, multiresolution) @@ -171,12 +169,12 @@ def test_analysis_vectorised( multiresolution: bool, reality: bool, ): - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f = base.spherical.inverse(flm, L, reality=reality) + f = sht_base.spherical.inverse(flm, L, reality=reality) f_wav, f_scal = s2let.analysis_px2wav( f.flatten("C").astype(np.complex128), @@ -187,7 +185,7 @@ def test_analysis_vectorised( spin=0, upsample=not multiresolution, ) - f_wav_check, f_scal_check = numpy_wavelets.analysis( + f_wav_check, f_scal_check = wav_base.analysis( f, L, N, J_min, lam, multiresolution=multiresolution, reality=reality ) @@ -213,16 +211,16 @@ def test_looped_round_trip( reality: bool, sampling: str, ): - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") nside = int(L / 2) flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f = base.spherical.inverse(flm, L, reality=reality, sampling=sampling, nside=nside) + f = sht_base.spherical.inverse(flm, L, reality=reality, sampling=sampling, nside=nside) - f_wav, f_scal = numpy_wavelets.analysis_looped( + f_wav, f_scal = wav_base.analysis_looped( f, L, N, @@ -234,7 +232,7 @@ def test_looped_round_trip( nside=nside, ) - f_check = numpy_wavelets.synthesis_looped( + f_check = wav_base.synthesis_looped( f_wav, f_scal, L, @@ -266,14 +264,14 @@ def test_vectorised_round_trip( reality: bool, sampling: str, ): - J = shapes.j_max(L, lam) + J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f = base.spherical.inverse(flm, L, reality=reality, sampling=sampling) + f = sht_base.spherical.inverse(flm, L, reality=reality, sampling=sampling) - f_wav, f_scal = numpy_wavelets.analysis( + f_wav, f_scal = wav_base.analysis( f, L, N, @@ -284,7 +282,7 @@ def test_vectorised_round_trip( sampling=sampling, ) - f_check = numpy_wavelets.synthesis( + f_check = wav_base.synthesis( f_wav, f_scal, L, diff --git a/tests/test_wavelets_precompute.py b/tests/test_wavelets_precompute.py deleted file mode 100644 index e649b23..0000000 --- a/tests/test_wavelets_precompute.py +++ /dev/null @@ -1,149 +0,0 @@ -import pytest -import numpy as np -import pys2let as s2let - -from s2wav.transforms import jax_wavelets_precompute as jax_wavelets -from s2wav.filter_factory import filters -from s2wav.utils import shapes -from s2fft import base_transforms as base - -L_to_test = [8] -N_to_test = [2, 3] -J_min_to_test = [2] -lam_to_test = [2, 3] -multiresolution = [False, True] -reality = [False, True] -multiple_gpus = [False] -sampling_to_test = ["mw", "mwss", "dh"] - - -@pytest.mark.parametrize("L", L_to_test) -@pytest.mark.parametrize("N", N_to_test) -@pytest.mark.parametrize("J_min", J_min_to_test) -@pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) -@pytest.mark.parametrize("reality", reality) -@pytest.mark.parametrize("spmd", multiple_gpus) -def test_jax_synthesis( - wavelet_generator, - L: int, - N: int, - J_min: int, - lam: int, - multiresolution: bool, - reality: bool, - spmd: bool, -): - J = shapes.j_max(L, lam) - if J_min >= J: - pytest.skip("J_min larger than J which isn't a valid test case.") - - f_wav, f_scal, f_wav_s2let, f_scal_s2let = wavelet_generator( - L=L, - N=N, - J_min=J_min, - lam=lam, - multiresolution=multiresolution, - reality=reality, - ) - - f = s2let.synthesis_wav2px( - f_wav_s2let, - f_scal_s2let, - lam, - L, - J_min, - N, - spin=0, - upsample=not multiresolution, - ) - - # Precompute some values - filter = filters.filters_directional_vectorised(L, N, J_min, lam) - precomps = jax_wavelets.generate_precomputes( - L, - N, - J_min, - lam, - forward=True, - reality=reality, - multiresolution=multiresolution, - ) - f_check = jax_wavelets.synthesis( - f_wav, - f_scal, - L, - N, - J_min, - lam, - multiresolution=multiresolution, - reality=reality, - filters=filter, - precomps=precomps, - spmd=spmd, - ) - f = np.real(f) if reality else f - np.testing.assert_allclose(f, f_check.flatten("C"), atol=1e-14) - - -@pytest.mark.parametrize("L", L_to_test) -@pytest.mark.parametrize("N", N_to_test) -@pytest.mark.parametrize("J_min", J_min_to_test) -@pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) -@pytest.mark.parametrize("reality", reality) -@pytest.mark.parametrize("spmd", multiple_gpus) -def test_jax_analysis( - flm_generator, - f_wav_converter, - L: int, - N: int, - J_min: int, - lam: int, - multiresolution: bool, - reality: bool, - spmd: bool, -): - J = shapes.j_max(L, lam) - if J_min >= J: - pytest.skip("J_min larger than J which isn't a valid test case.") - - flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f = base.spherical.inverse(flm, L, reality=reality) - - f_wav, f_scal = s2let.analysis_px2wav( - f.flatten("C").astype(np.complex128), - lam, - L, - J_min, - N, - spin=0, - upsample=not multiresolution, - ) - filter = filters.filters_directional_vectorised(L, N, J_min, lam) - precomps = jax_wavelets.generate_precomputes( - L, - N, - J_min, - lam, - forward=False, - reality=reality, - multiresolution=multiresolution, - ) - f_wav_check, f_scal_check = jax_wavelets.analysis( - f, - L, - N, - J_min, - lam, - multiresolution=multiresolution, - reality=reality, - filters=filter, - precomps=precomps, - spmd=spmd, - ) - - f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, multiresolution) - - np.testing.assert_allclose(f_wav, f_wav_check, atol=1e-14) - np.testing.assert_allclose(f_scal, f_scal_check.flatten("C"), atol=1e-14) From 068cbb476cde0a5cd2ad4ca682c0ca8063556b48 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Fri, 8 Mar 2024 14:15:28 +0000 Subject: [PATCH 2/8] update autodoc api links to new code structure --- docs/api/filter_factory/filters.rst | 4 +- docs/api/filter_factory/index.rst | 32 +- .../kernels.rst => transforms/base.rst} | 4 +- .../tiling.rst => transforms/construct.rst} | 4 +- docs/api/transforms/index.rst | 48 ++- .../{numpy_wavelets.rst => pre_wav_jax.rst} | 4 +- .../{jax_wavelets.rst => rec_wav_jax.rst} | 4 +- docs/api/utility/index.rst | 27 +- docs/api/utility/{shapes.rst => samples.rst} | 2 +- notebooks/latent_emulation.ipynb | 289 ------------------ notebooks/template_notebook.ipynb | 69 ----- 11 files changed, 73 insertions(+), 414 deletions(-) rename docs/api/{filter_factory/kernels.rst => transforms/base.rst} (63%) rename docs/api/{filter_factory/tiling.rst => transforms/construct.rst} (62%) rename docs/api/transforms/{numpy_wavelets.rst => pre_wav_jax.rst} (59%) rename docs/api/transforms/{jax_wavelets.rst => rec_wav_jax.rst} (59%) rename docs/api/utility/{shapes.rst => samples.rst} (78%) delete mode 100644 notebooks/latent_emulation.ipynb delete mode 100644 notebooks/template_notebook.ipynb diff --git a/docs/api/filter_factory/filters.rst b/docs/api/filter_factory/filters.rst index c5c87e9..5bbc2bd 100644 --- a/docs/api/filter_factory/filters.rst +++ b/docs/api/filter_factory/filters.rst @@ -1,7 +1,7 @@ :html_theme.sidebar_secondary.remove: ************************** -Filter Generator +Filter functions ************************** -.. automodule:: s2wav.filter_factory.filters +.. automodule:: s2wav.filters :members: \ No newline at end of file diff --git a/docs/api/filter_factory/index.rst b/docs/api/filter_factory/index.rst index 7a1733b..da10080 100644 --- a/docs/api/filter_factory/index.rst +++ b/docs/api/filter_factory/index.rst @@ -10,14 +10,14 @@ Filter Factory * - Function Name - Description - * - :func:`~s2wav.filter_factory.filters.filters_axisym` + * - :func:`~s2wav.filters.filters_axisym` - Computes wavelet kernels :math:`\Psi^j_{\ell m}` and scaling kernel :math:`\Phi_{\ell m}` in harmonic space. - * - :func:`~s2wav.filter_factory.filters.filters_directional` + * - :func:`~s2wav.filters.filters_directional` - Generates the harmonic coefficients for the directional tiling wavelets in harmonic space. - * - :func:`~s2wav.filter_factory.filters.filters_axisym_vectorised` - - Vectorised implementation of :func:`~s2wav.filter_factory.filters.filters_directional`. - * - :func:`~s2wav.filter_factory.filters.filters_directional_vectorised` - - Vectorised implementation of :func:`~s2wav.filter_factory.filters.filters_directional`. + * - :func:`~s2wav.filters.filters_axisym_vectorised` + - Vectorised implementation of :func:`~s2wav.filters.filters_directional`. + * - :func:`~s2wav.filters.filters_directional_vectorised` + - Vectorised implementation of :func:`~s2wav.filters.filters_directional`. .. list-table:: Wavelet kernel functions. :widths: 25 25 @@ -25,11 +25,11 @@ Filter Factory * - Function Name - Description - * - :func:`~s2wav.filter_factory.kernels.tiling_integrand` + * - :func:`~s2wav.filters.tiling_integrand` - Tiling integrand for scale-discretised wavelets. - * - :func:`~s2wav.filter_factory.kernels.part_scaling_fn` + * - :func:`~s2wav.filters.part_scaling_fn` - Computes integral used to calculate smoothly decreasing function :math:`k_{\lambda}`. - * - :func:`~s2wav.filter_factory.kernels.k_lam` + * - :func:`~s2wav.filters.k_lam` - Compute function :math:`k_{\lambda}` used as a wavelet generating function. .. list-table:: Wavelet tiling functions. @@ -38,18 +38,16 @@ Filter Factory * - Function Name - Description - * - :func:`~s2wav.filter_factory.tiling.tiling_direction` + * - :func:`~s2wav.filters.tiling_direction` - Generates the harmonic coefficients for the directionality component of the tiling functions. - * - :func:`~s2wav.filter_factory.tiling.spin_normalization` + * - :func:`~s2wav.filters.spin_normalization` - Computes the normalization factor for spin-lowered wavelets, which is :math:`\sqrt{\frac{(l+s)!}{(l-s)!}}`. - * - :func:`~s2wav.filter_factory.tiling.spin_normalization_vectorised` - - Vectorised version of :func:`~s2wav.filter_factory.tiling.spin_normalization`. + * - :func:`~s2wav.filters.spin_normalization_vectorised` + - Vectorised version of :func:`~s2wav.filters.spin_normalization`. .. toctree:: :hidden: :maxdepth: 2 - :caption: Wavelet generators + :caption: Filter functions - filters - tiling - kernels \ No newline at end of file + filters \ No newline at end of file diff --git a/docs/api/filter_factory/kernels.rst b/docs/api/transforms/base.rst similarity index 63% rename from docs/api/filter_factory/kernels.rst rename to docs/api/transforms/base.rst index 782b11d..45f9c96 100644 --- a/docs/api/filter_factory/kernels.rst +++ b/docs/api/transforms/base.rst @@ -1,7 +1,7 @@ :html_theme.sidebar_secondary.remove: ************************** -Wavelet Kernels +Numpy Transforms ************************** -.. automodule:: s2wav.filter_factory.kernels +.. automodule:: s2wav.transforms.base :members: \ No newline at end of file diff --git a/docs/api/filter_factory/tiling.rst b/docs/api/transforms/construct.rst similarity index 62% rename from docs/api/filter_factory/tiling.rst rename to docs/api/transforms/construct.rst index 491435c..72e2726 100644 --- a/docs/api/filter_factory/tiling.rst +++ b/docs/api/transforms/construct.rst @@ -1,7 +1,7 @@ :html_theme.sidebar_secondary.remove: ************************** -Tiling Functions +Matrice Precomputes ************************** -.. automodule:: s2wav.filter_factory.tiling +.. automodule:: s2wav.transforms.construct :members: \ No newline at end of file diff --git a/docs/api/transforms/index.rst b/docs/api/transforms/index.rst index dc1c39d..269d8b7 100644 --- a/docs/api/transforms/index.rst +++ b/docs/api/transforms/index.rst @@ -4,35 +4,55 @@ Wavelet Transforms ************************** -.. list-table:: Wavelet transforms +.. list-table:: Numpy transforms :widths: 25 25 :header-rows: 1 * - Function Name - Description - * - :func:`~s2wav.transforms.numpy_wavelets.synthesis_looped` + * - :func:`~s2wav.transforms.base.synthesis_looped` - Loopy implementation of mapping from wavelet to pixel space. - * - :func:`~s2wav.transforms.numpy_wavelets.synthesis` + * - :func:`~s2wav.transforms.base.synthesis` - Vectorised implementation of mapping from wavelet to pixel space. - * - :func:`~s2wav.transforms.numpy_wavelets.analysis_looped` + * - :func:`~s2wav.transforms.base.analysis_looped` - Loopy implementation of mapping from pixel to wavelet space. - * - :func:`~s2wav.transforms.numpy_wavelets.analysis` + * - :func:`~s2wav.transforms.base.analysis` - Vectorised implementation of mapping from pixel to wavelet space. - * - :func:`~s2wav.transforms.jax_wavelets.synthesis` - - JAX implementation of mapping from wavelet to pixel space. - * - :func:`~s2wav.transforms.jax_wavelets.analysis` - - JAX implementation of mapping from pixel to wavelet space. - * - :func:`~s2wav.transforms.jax_wavelets.flm_to_analysis` - - JAX implementation of mapping from harmonic to wavelet space. - * - :func:`~s2wav.transforms.jax_wavelets.generate_wigner_precomputes` +.. list-table:: JAX transforms + :widths: 25 25 + :header-rows: 1 + + * - :func:`~s2wav.transforms.rec_wav_jax.synthesis` + - JAX implementation of mapping from wavelet to pixel space (Recursive). + * - :func:`~s2wav.transforms.rec_wav_jax.analysis` + - JAX implementation of mapping from pixel to wavelet space (Recursive). + * - :func:`~s2wav.transforms.rec_wav_jax.flm_to_analysis` + - JAX implementation of mapping from harmonic to wavelet space (Recursive). + + * - :func:`~s2wav.transforms.pre_wav_jax.synthesis` + - JAX implementation of mapping from wavelet to pixel space (fully precompute). + * - :func:`~s2wav.transforms.pre_wav_jax.analysis` + - JAX implementation of mapping from pixel to wavelet space (fully precompute). + * - :func:`~s2wav.transforms.pre_wav_jax.flm_to_analysis` + - JAX implementation of mapping from harmonic to wavelet space (fully precompute). + + .. list-table:: Matrices precomputations + :widths: 25 25 + :header-rows: 1 + + * - :func:`~s2wav.transforms.construct.generate_wigner_precomputes` - JAX function to generate precompute arrays for underlying Wigner transforms. + * - :func:`~s2wav.transforms.construct.generate_full_precomputes` + - JAX function to generate precompute arrays for fully precompute transforms. .. toctree:: :hidden: :maxdepth: 2 :caption: Wavelet transform - numpy_wavelets - jax_wavelets + base + construct + rec_wav_jax + pre_wav_jax \ No newline at end of file diff --git a/docs/api/transforms/numpy_wavelets.rst b/docs/api/transforms/pre_wav_jax.rst similarity index 59% rename from docs/api/transforms/numpy_wavelets.rst rename to docs/api/transforms/pre_wav_jax.rst index affd9c0..8ba6fd9 100644 --- a/docs/api/transforms/numpy_wavelets.rst +++ b/docs/api/transforms/pre_wav_jax.rst @@ -1,7 +1,7 @@ :html_theme.sidebar_secondary.remove: ************************** -Transforms (numpy) +JAX Transforms (Precompute) ************************** -.. automodule:: s2wav.transforms.numpy_wavelets +.. automodule:: s2wav.transforms.pre_wav_jax :members: \ No newline at end of file diff --git a/docs/api/transforms/jax_wavelets.rst b/docs/api/transforms/rec_wav_jax.rst similarity index 59% rename from docs/api/transforms/jax_wavelets.rst rename to docs/api/transforms/rec_wav_jax.rst index 8d70883..ad496bb 100644 --- a/docs/api/transforms/jax_wavelets.rst +++ b/docs/api/transforms/rec_wav_jax.rst @@ -1,7 +1,7 @@ :html_theme.sidebar_secondary.remove: ************************** -Transforms (JAX) +JAX Transforms (Recursive) ************************** -.. automodule:: s2wav.transforms.jax_wavelets +.. automodule:: s2wav.transforms.rec_wav_jax :members: \ No newline at end of file diff --git a/docs/api/utility/index.rst b/docs/api/utility/index.rst index 917902f..619164a 100644 --- a/docs/api/utility/index.rst +++ b/docs/api/utility/index.rst @@ -10,17 +10,17 @@ Utility Functions * - Function Name - Description - * - :func:`~s2wav.utils.shapes.L0_j` + * - :func:`~s2wav.samples.L0_j` - Computes the minimum harmonic index supported by the given wavelet scale :math:`j`. - * - :func:`~s2wav.utils.shapes.n_wav_scales` + * - :func:`~s2wav.samples.n_wav_scales` - Evalutes the total number of wavelet scales. - * - :func:`~s2wav.utils.shapes.j_max` + * - :func:`~s2wav.samples.j_max` - Computes maximum wavelet scale required to ensure exact reconstruction. - * - :func:`~s2wav.utils.shapes.LN_j` + * - :func:`~s2wav.samples.LN_j` - Computes the harmonic bandlimit and directionality for scale :math:`j`. - * - :func:`~s2wav.utils.shapes.scal_bandlimit` + * - :func:`~s2wav.samples.scal_bandlimit` - Returns the harmominc bandlimit of the scaling coefficients. - * - :func:`~s2wav.utils.shapes.wav_j_bandlimit` + * - :func:`~s2wav.samples.wav_j_bandlimit` - Returns the harmominc bandlimit of the scaling coefficients. .. list-table:: Shape functions. @@ -29,11 +29,11 @@ Utility Functions * - Function Name - Description - * - :func:`~s2wav.utils.shapes.f_scal` + * - :func:`~s2wav.samples.f_scal` - Computes the shape of scaling coefficients in pixel-space. - * - :func:`~s2wav.utils.shapes.f_wav_j` + * - :func:`~s2wav.samples.f_wav_j` - Computes the shape of wavelet coefficients :math:`f^j` in pixel-space. - * - :func:`~s2wav.utils.shapes.flmn_wav_j` + * - :func:`~s2wav.samples.flmn_wav_j` - Returns the shape of wavelet coefficients :math:`f^j_{\ell m n}` in Wigner space. .. list-table:: Array constructing and shape checking functions. @@ -42,13 +42,13 @@ Utility Functions * - Function Name - Description - * - :func:`~s2wav.utils.shapes.construct_f` + * - :func:`~s2wav.samples.construct_f` - Defines a list of arrays corresponding to f_wav. - * - :func:`~s2wav.utils.shapes.construct_flm` + * - :func:`~s2wav.samples.construct_flm` - Returns the shape of scaling coefficients in harmonic space. - * - :func:`~s2wav.utils.shapes.construct_flmn` + * - :func:`~s2wav.samples.construct_flmn` - Defines a list of arrays corresponding to flmn. - * - :func:`~s2wav.utils.shapes.wavelet_shape_check` + * - :func:`~s2wav.samples.wavelet_shape_check` - Checks the shape of wavelet coefficients are correct. .. toctree:: @@ -56,5 +56,4 @@ Utility Functions :maxdepth: 3 :caption: Utilities - math_functions shapes \ No newline at end of file diff --git a/docs/api/utility/shapes.rst b/docs/api/utility/samples.rst similarity index 78% rename from docs/api/utility/shapes.rst rename to docs/api/utility/samples.rst index c708425..f888a4e 100644 --- a/docs/api/utility/shapes.rst +++ b/docs/api/utility/samples.rst @@ -3,5 +3,5 @@ ************************** Array Shape Functions ************************** -.. automodule:: s2wav.utils.shapes +.. automodule:: s2wav.samples :members: \ No newline at end of file diff --git a/notebooks/latent_emulation.ipynb b/notebooks/latent_emulation.ipynb deleted file mode 100644 index 14bca39..0000000 --- a/notebooks/latent_emulation.ipynb +++ /dev/null @@ -1,289 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Config for 64bit precision\n", - "from jax.config import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "\n", - "# Check we're running on GPU\n", - "from jax.lib import xla_bridge\n", - "print(xla_bridge.get_backend().platform)\n", - "\n", - "import numpy as np \n", - "import jax.numpy as jnp\n", - "from jax import device_put, local_device_count\n", - "from s2wav.transforms import jax_scattering, jax_wavelets\n", - "from s2wav.filter_factory import filters as filter_generator\n", - "import s2fft\n", - "import pyssht as ssht" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "L = 16\n", - "N = 1\n", - "nlayers = 2\n", - "J_min = 0 \n", - "lam = 2.0\n", - "reality = True\n", - "sampling = \"mw\"\n", - "multiresolution = True\n", - "spmd=False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "f = ssht.inverse(ssht.forward(np.random.randn(L, 2*L-1), L, Reality=reality), L, Reality=reality)\n", - "f -= np.nanmean(f)\n", - "f /= np.nanmax(abs(f))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "filters = filter_generator.filters_directional_vectorised(L, N, J_min, lam)\n", - "precomps = jax_wavelets.generate_wigner_precomputes(L, N, J_min, lam, sampling, None, False, reality, multiresolution)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "coeffs = jax_scattering.scatter(\n", - " f=jnp.array(f),\n", - " L=L,\n", - " N=N,\n", - " J_min=J_min,\n", - " lam=lam,\n", - " nlayers=nlayers,\n", - " reality=reality,\n", - " multiresolution=multiresolution,\n", - " filters=filters,\n", - " spmd=spmd,\n", - " )\n", - "\n", - "print(coeffs[:,0,0])\n", - "print(coeffs.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from jax import grad\n", - "def mse_loss(y):\n", - " return (1./y.size)*jnp.sum((jnp.abs(y-coeffs))**2)\n", - "\n", - "def power_spectrum(flm, L):\n", - " ps = np.zeros(L, dtype=np.float64)\n", - " flm = np.abs(flm)**2\n", - " return np.sum(flm, axis=-1)\n", - "\n", - "# ps_true = power_spectrum(s2fft.forward(np.array(f), L, 0, reality=reality), L)\n", - "# def ps_loss(x):\n", - "# z = s2fft.forward_jax(x, L, 0,reality=reality)\n", - "# ps = jnp.sum(jnp.abs(z)**2, axis=-1)\n", - "# return (1./ps.size)*jnp.sum((ps-ps_true)**2)\n", - " \n", - "def scattering_func(x):\n", - " y = jax_scattering.scatter(\n", - " jnp.array(x),\n", - " L=L,\n", - " N=N,\n", - " J_min=J_min,\n", - " lam=lam,\n", - " nlayers=nlayers,\n", - " reality=reality,\n", - " multiresolution=multiresolution,\n", - " filters=filters,\n", - " spmd=spmd,\n", - " )\n", - " return mse_loss(y)\n", - " # return mse_loss(y) + ps_loss(x)\n", - "\n", - "grad_func = grad(scattering_func)\n", - "f_temp = np.random.randn(L, 2*L-1)\n", - "print(grad_func(f_temp))\n", - "f_start = np.copy(f_temp)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "momentum = 100\n", - "E0 = scattering_func(f_start)\n", - "for i in range(1000000):\n", - " f_temp -= momentum*grad_func(f_temp)\n", - " if i % 10 == 0: \n", - " print(f\"Iteration {i}: Energy/E0 = {scattering_func(f_temp)}/{E0}, Momentum = {momentum}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "start_coeffs = jax_scattering.scatter(\n", - " f=f_start,\n", - " L=L,\n", - " N=N,\n", - " J_min=J_min,\n", - " lam=lam,\n", - " nlayers=nlayers,\n", - " reality=reality,\n", - " multiresolution=multiresolution,\n", - " filters=filters,\n", - " spmd=spmd,\n", - " )\n", - "optimised_coeffs = jax_scattering.scatter(\n", - " f=f_temp,\n", - " L=L,\n", - " N=N,\n", - " J_min=J_min,\n", - " lam=lam,\n", - " nlayers=nlayers,\n", - " reality=reality,\n", - " multiresolution=multiresolution,\n", - " filters=filters,\n", - " spmd=spmd,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "c1 = coeffs[:,0,0]\n", - "c2 = start_coeffs[:,0,0]\n", - "c3 = optimised_coeffs[:,0,0]\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in range(len(c1)):\n", - " print(c1[i], c2[i], c3[i])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from matplotlib import pyplot as plt \n", - "f_temp = ssht.inverse(ssht.forward(np.array(f_temp), L, Reality=reality), L, Reality=reality)\n", - "f_start2 = ssht.inverse(ssht.forward(np.array(f_start), L, Reality=reality), L, Reality=reality)\n", - "mx, mn = np.nanmax(f), np.nanmin(f)\n", - "fig, (ax1,ax2, ax3) = plt.subplots(1,3, figsize=(20,10))\n", - "ax1.imshow(f, vmax=mx, vmin=mn, cmap='magma')\n", - "ax2.imshow(f_start2, vmax=mx, vmin=mn, cmap='magma')\n", - "ax3.imshow(f_temp, vmax=mx, vmin=mn, cmap='magma')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "flm = s2fft.sampling.s2_samples.flm_1d_to_2d(ssht.forward(f, L, Reality=reality), L)\n", - "flm_temp = s2fft.sampling.s2_samples.flm_1d_to_2d(ssht.forward(np.array(f_temp), L, Reality=reality), L)\n", - "flm_start = s2fft.sampling.s2_samples.flm_1d_to_2d(ssht.forward(np.array(f_start2), L, Reality=reality), L)\n", - "\n", - "flm=np.real(flm)\n", - "flm_temp=np.real(flm_temp)\n", - "flm_start=np.real(flm_start)\n", - "\n", - "from matplotlib import pyplot as plt \n", - "mx, mn = np.nanmax(flm), np.nanmin(flm)\n", - "fig, (ax1,ax2, ax3) = plt.subplots(1,3, figsize=(20,10))\n", - "ax1.imshow(flm, vmax=mx, vmin=mn, cmap='magma')\n", - "ax2.imshow(flm_start, vmax=mx, vmin=mn, cmap='magma')\n", - "ax3.imshow(flm_temp, vmax=mx, vmin=mn, cmap='magma')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ps = power_spectrum(flm, L)\n", - "ps_temp = power_spectrum(flm_temp, L)\n", - "ps_start = power_spectrum(flm_start, L)\n", - "\n", - "plt.plot(ps, label=\"input\")\n", - "plt.plot(ps_temp, label=\"converged\")\n", - "plt.plot(ps_start, label=\"initial\")\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.8.15 ('s2wav')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "2eaa51c34c6264c479aef01ba42a63404a2d0b54fbb558b3097eeea4996caab5" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/template_notebook.ipynb b/notebooks/template_notebook.ipynb deleted file mode 100644 index cb36cb5..0000000 --- a/notebooks/template_notebook.ipynb +++ /dev/null @@ -1,69 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Basically just do whatever you like in a notebook such as this, just try and keep it neat (-ish).\n", - "\n", - "Once you are happy with the notebook, copy it into " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cp my_new_notebook.ipynb ../docs/assets/static_notebooks/" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then add an `nblink' file to docs/tutorials that looks like " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "{\n", - " \"path\": \"../assets/static_notebooks/my_new_notebook.ipynb\"\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "and then add this to the main documentation by editing the top level index.rst by adding the following" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - ".. toctree::\n", - " :hidden:\n", - " :maxdepth: 2\n", - " :caption: Tutorials\n", - " \n", - " tutorials/my_new_notebook.nblink" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 54c21823b8fd695dcc2a4d1afd45fd08eed2226b Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Fri, 8 Mar 2024 16:24:43 +0000 Subject: [PATCH 3/8] add pytorch precompute support --- s2wav/filters.py | 64 ++++--- s2wav/samples.py | 116 +++++++++++- s2wav/transforms/base.py | 37 +--- s2wav/transforms/construct.py | 35 +++- s2wav/transforms/pre_wav_torch.py | 294 ++++++++++++++++++++++++++++++ tests/conftest.py | 34 +++- tests/test_gradients.py | 80 +++----- tests/test_wavelets.py | 125 +++++++------ tests/test_wavelets_base.py | 166 +++-------------- 9 files changed, 623 insertions(+), 328 deletions(-) create mode 100644 s2wav/transforms/pre_wav_torch.py diff --git a/s2wav/filters.py b/s2wav/filters.py index 1462cc9..ba81f49 100644 --- a/s2wav/filters.py +++ b/s2wav/filters.py @@ -1,5 +1,6 @@ from jax import jit import jax.numpy as jnp +import torch import numpy as np from typing import Tuple from functools import partial @@ -9,26 +10,26 @@ def filters_axisym( L: int, J_min: int = 0, lam: float = 2.0 ) -> Tuple[np.ndarray, np.ndarray]: - r"""Computes wavelet kernels :math:`\Psi^j_{\ell m}` and scaling kernel + r"""Computes wavelet kernels :math:`\Psi^j_{\ell m}` and scaling kernel :math:`\Phi_{\ell m}` in harmonic space. - Specifically, these kernels are derived in `[1] `_, + Specifically, these kernels are derived in `[1] `_, where the wavelet kernels are defined (15) for scale :math:`j` to be .. math:: \Psi^j_{\ell m} \equiv \sqrt{\frac{2\ell+1}{4\pi}} \kappa_{\lambda}(\frac{\ell}{\lambda^j})\delta_{m0}, - where :math:`\kappa_{\lambda} = \sqrt{k_{\lambda}(t/\lambda) - k_{\lambda}(t)}` for :math:`k_{\lambda}` + where :math:`\kappa_{\lambda} = \sqrt{k_{\lambda}(t/\lambda) - k_{\lambda}(t)}` for :math:`k_{\lambda}` given in :func:`~k_lam`. Similarly, the scaling kernel is defined (16) as .. math:: \Phi_{\ell m} \equiv \sqrt{\frac{2\ell+1}{4\pi}} \nu_{\lambda} (\frac{\ell}{\lambda^{J_0}})\delta_{m0}, - where :math:`\nu_{\lambda} = \sqrt{k_{\lambda}(t)}` for :math:`k_{\lambda}` given in :func:`~k_lam`. - Notice that :math:`\delta_{m0}` enforces that these kernels are axisymmetric, i.e. coefficients - for :math:`m \not = \ell` are zero. In this implementation the normalisation constant has been + where :math:`\nu_{\lambda} = \sqrt{k_{\lambda}(t)}` for :math:`k_{\lambda}` given in :func:`~k_lam`. + Notice that :math:`\delta_{m0}` enforces that these kernels are axisymmetric, i.e. coefficients + for :math:`m \not = \ell` are zero. In this implementation the normalisation constant has been omitted as it is nulled in subsequent functions. Args: @@ -36,16 +37,16 @@ def filters_axisym( J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - lam (float, optional): Wavelet parameter which determines the scale factor between - consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic + lam (float, optional): Wavelet parameter which determines the scale factor between + consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. Raises: ValueError: J_min is negative or greater than J. Returns: - Tuple[np.ndarray, np.ndarray]: Unnormalised wavelet kernels :math:`\Psi^j_{\ell m}` - with shape :math:`[(J+1)L]`, and scaling kernel :math:`\Phi_{\el m}` with shape + Tuple[np.ndarray, np.ndarray]: Unnormalised wavelet kernels :math:`\Psi^j_{\ell m}` + with shape :math:`[(J+1)L]`, and scaling kernel :math:`\Phi_{\el m}` with shape :math:`[L]` in harmonic space. Note: @@ -87,10 +88,11 @@ def filters_directional( lam: float = 2.0, spin: int = 0, spin0: int = 0, + using_torch: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: r"""Generates the harmonic coefficients for the directional tiling wavelets. - This implementation is based on equation 36 in the wavelet computation paper + This implementation is based on equation 36 in the wavelet computation paper `[1] `_. Args: @@ -100,16 +102,18 @@ def filters_directional( J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - lam (float, optional): Wavelet parameter which determines the scale factor between - consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic + lam (float, optional): Wavelet parameter which determines the scale factor between + consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. spin (int, optional): Spin (integer) to perform the transform. Defaults to 0. spin0 (int, optional): Spin number the wavelet was lowered from. Defaults to 0. + using_torch (bool, optional): Desired frontend functionality. Defaults to False. + Returns: - Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels + Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels (:math:`\Psi^j_{\ell n}`, :math:`\Phi_{\ell m}`) Notes: @@ -144,6 +148,9 @@ def filters_directional( psi[j, el, L - 1 + m] *= ( spin_normalization(el, spin0) * (-1) ** spin0 ) + if using_torch: + psi = torch.from_numpy(psi) + phi = torch.from_numpy(phi) return psi, phi @@ -192,6 +199,7 @@ def filters_directional_vectorised( lam: float = 2.0, spin: int = 0, spin0: int = 0, + using_torch: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: r"""Vectorised version of :func:`~filters_directional`. @@ -210,8 +218,10 @@ def filters_directional_vectorised( spin0 (int, optional): Spin number the wavelet was lowered from. Defaults to 0. + using_torch (bool, optional): Desired frontend functionality. Defaults to False. + Returns: - Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels + Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels (:math:`\Psi^j_{\ell n}`, :math:`\Phi_{\ell m}`). """ el_min = max(abs(spin), abs(spin0)) @@ -234,6 +244,11 @@ def filters_directional_vectorised( kappa0[:el_min] = 0 kappa[:, :el_min, :] = 0 + + if using_torch: + kappa0 = torch.from_numpy(kappa0) + kappa = torch.from_numpy(kappa) + return kappa, kappa0 @@ -302,15 +317,13 @@ def filters_directional_jax( spin0 (int, optional): Spin number the wavelet was lowered from. Defaults to 0. Returns: - Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels + Tuple[np.ndarray, np.ndarray]: Tuple of wavelet and scaling kernels (:math:`\Psi^j_{\ell n}`, :math:`\Phi_{\ell m}`). """ el_min = max(abs(spin), abs(spin0)) spin_norms = ( - (-1) ** spin0 * spin_normalization_jax(np.arange(L), spin0) - if spin0 != 0 - else 1 + (-1) ** spin0 * spin_normalization_jax(np.arange(L), spin0) if spin0 != 0 else 1 ) kappa, kappa0 = filters_axisym_jax(L, J_min, lam) @@ -332,6 +345,7 @@ def filters_directional_jax( return kappa, kappa0 + def tiling_integrand(t: float, lam: float = 2.0) -> float: r"""Tiling integrand for scale-discretised wavelets `[1] `_. @@ -463,7 +477,7 @@ def k_lam(L: int, lam: float = 2.0, quad_iters: int = 300) -> float: @partial(jit, static_argnums=(2, 3)) # not sure def part_scaling_fn_jax(a: float, b: float, n: int, lam: float = 2.0) -> float: - r"""JAX version of part_scaling_fn. Computes integral used to calculate smoothly + r"""JAX version of part_scaling_fn. Computes integral used to calculate smoothly decreasing function :math:`k_{\lambda}`. Intermediate step used to compute the wavelet and scaling function generating @@ -503,7 +517,7 @@ def part_scaling_fn_jax(a: float, b: float, n: int, lam: float = 2.0) -> float: @partial(jit, static_argnums=(0, 1, 2)) def k_lam_jax(L: int, lam: float = 2.0, quad_iters: int = 300) -> float: - r"""JAX version of k_lam. Compute function :math:`k_{\lambda}` used as a wavelet + r"""JAX version of k_lam. Compute function :math:`k_{\lambda}` used as a wavelet generating function. Specifically, this function is derived in [1] and is given by @@ -561,6 +575,7 @@ def k_lam_jax(L: int, lam: float = 2.0, quad_iters: int = 300) -> float: return k + def tiling_direction(L: int, N: int = 1) -> np.ndarray: r"""Generates the harmonic coefficients for the directionality component of the tiling functions. @@ -603,7 +618,8 @@ def tiling_direction(L: int, N: int = 1) -> np.ndarray: for m in range(-el, el + 1): if abs(m) < N and (N + m) % 2: s_elm[el, L - 1 + m] = nu * np.sqrt( - (samples.binomial_coefficient(gamma, ((gamma - m) / 2))) / (2**gamma) + (samples.binomial_coefficient(gamma, ((gamma - m) / 2))) + / (2**gamma) ) else: s_elm[el, L - 1 + m] = 0.0 @@ -649,7 +665,7 @@ def spin_normalization_vectorised(el: np.ndarray, spin: int = 0) -> float: @partial(jit, static_argnums=(0, 1)) def tiling_direction_jax(L: int, N: int = 1) -> np.ndarray: - r"""JAX version of tiling_direction. Generates the harmonic coefficients for the + r"""JAX version of tiling_direction. Generates the harmonic coefficients for the directionality component of the tiling functions. Formally, this function implements the follow equation @@ -708,4 +724,4 @@ def spin_normalization_jax(el: np.ndarray, spin: int = 0) -> float: """ factor = jnp.arange(-abs(spin) + 1, abs(spin) + 1).reshape(1, 2 * abs(spin) + 1) factor = el.reshape(len(el), 1).dot(factor) - return jnp.sqrt(jnp.prod(factor, axis=1) ** (jnp.sign(spin))) \ No newline at end of file + return jnp.sqrt(jnp.prod(factor, axis=1) ** (jnp.sign(spin))) diff --git a/s2wav/samples.py b/s2wav/samples.py index 0c55264..d6548fe 100644 --- a/s2wav/samples.py +++ b/s2wav/samples.py @@ -1,6 +1,7 @@ from jax import jit import jax.numpy as jnp import numpy as np +import torch import math from functools import partial from typing import Tuple @@ -8,6 +9,7 @@ from scipy.special import loggamma from jax.scipy.special import gammaln as jax_gammaln + def f_scal( L: int, J_min: int = 0, @@ -210,7 +212,7 @@ def construct_f( optimise for memory. Defaults to False. Returns: - Tuple[int, int, int, int]: Wavelet coefficients shape :math:`[n_{J}, L, 2L-1, n_{N}]`. + np.ndarray: Empty array (or list of empty arrays) in which to write data. """ J = j_max(L, lam) if scattering: @@ -267,7 +269,7 @@ def construct_f_jax( optimise for memory. Defaults to False. Returns: - Tuple[int, int, int, int]: Wavelet coefficients shape :math:`[n_{J}, L, 2L-1, n_{N}]`. + jnp.ndarray: Empty array (or list of empty arrays) in which to write data. """ J = j_max(L, lam) if scattering: @@ -287,6 +289,62 @@ def construct_f_jax( return f +def construct_f_torch( + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + sampling: str = "mw", + nside: int = None, + multiresolution: bool = False, + scattering: bool = False, +) -> torch.tensor: + """Defines a list of tensors corresponding to f_wav. + + Args: + L (int): Harmonic bandlimit. + + N (int, optional): Upper orientational band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between + consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic + wavelets. Defaults to 2. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. + Defaults to "mw". + + nside (int, optional): HEALPix Nside resolution parameter. Only required if + sampling="healpix". Defaults to None. + + multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` + resolution or its own resolution. Defaults to False. + + scattering (bool, optional): Whether to create minimal arrays for scattering transform to + optimise for memory. Defaults to False. + + Returns: + torch.tensor: Empty tensor (or list of empty tensors) in which to write data. + """ + J = j_max(L, lam) + if scattering: + f = torch.zeros( + f_wav_j(L, J - 1, N, lam, sampling, nside, multiresolution), + dtype=torch.complex128, + ) + else: + f = [] + for j in range(J_min, J + 1): + f.append( + torch.zeros( + f_wav_j(L, j, N, lam, sampling, nside, multiresolution), + dtype=torch.complex128, + ) + ) + return f + + def construct_flm( L: int, J_min: int = 0, lam: float = 2.0, multiresolution: bool = False ) -> Tuple[int, int]: @@ -455,7 +513,7 @@ def construct_flmn( optimise for memory. Defaults to False. Returns: - Tuple[int, int, int, int]: Wavelet coefficients shape :math:`[n_{J}, L, 2L-1, n_{N}]`. + np.ndarray: Empty array (or list of empty arrays) in which to write data. """ J = j_max(L, lam) if scattering: @@ -503,7 +561,7 @@ def construct_flmn_jax( optimise for memory. Defaults to False. Returns: - Tuple[int, int, int, int]: Wavelet coefficients shape :math:`[n_{J}, L, 2L-1, n_{N}]`. + jnp.ndarray: Empty array (or list of empty arrays) in which to write data. """ J = j_max(L, lam) if scattering: @@ -522,6 +580,53 @@ def construct_flmn_jax( return flmn +def construct_flmn_torch( + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + multiresolution: bool = False, + scattering: bool = False, +) -> torch.tensor: + """Defines a list of tensors corresponding to flmn. + + Args: + L (int): Harmonic bandlimit. + + N (int, optional): Upper orientational band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between + consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic + wavelets. Defaults to 2. + + multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` + resolution or its own resolution. Defaults to False. + + scattering (bool, optional): Whether to create minimal arrays for scattering transform to + optimise for memory. Defaults to False. + + Returns: + torch.tensor: Empty tensor (or list of empty tensors) in which to write data. + """ + J = j_max(L, lam) + if scattering: + flmn = torch.zeros( + flmn_wav_j(L, J - 1, N, lam, multiresolution), dtype=torch.complex128 + ) + else: + flmn = [] + for j in range(J_min, J + 1): + flmn.append( + torch.zeros( + flmn_wav_j(L, j, N, lam, multiresolution), + dtype=torch.complex128, + ) + ) + return flmn + + def j_max(L: int, lam: float = 2.0) -> int: r"""Computes needlet maximum level required to ensure exact reconstruction. @@ -588,6 +693,7 @@ def wavelet_shape_check( L, j, N, lam, sampling, nside, multiresolution ) + def binomial_coefficient(n: int, k: int) -> int: r"""Computes the binomial coefficient :math:`\binom{n}{k}`. @@ -617,4 +723,4 @@ def binomial_coefficient_jax(n: int, k: int) -> int: """ return jnp.floor( 0.5 + jnp.exp(jax_gammaln(n + 1) - jax_gammaln(k + 1) - jax_gammaln(n - k + 1)) - ) \ No newline at end of file + ) diff --git a/s2wav/transforms/base.py b/s2wav/transforms/base.py index d8e4849..931c101 100644 --- a/s2wav/transforms/base.py +++ b/s2wav/transforms/base.py @@ -16,7 +16,7 @@ def synthesis_looped( sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, + multiresolution: bool = True, ) -> np.ndarray: r"""Computes the synthesis directional wavelet transform [1,2]. Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` by summing the contributions from wavelet and scaling coefficients in harmonic space, see equation 27 from `[2] `_. @@ -38,7 +38,7 @@ def synthesis_looped( reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. + resolution or its own resolution. Defaults to True. Raises: AssertionError: Shape of wavelet/scaling coefficients incorrect. Returns: @@ -99,7 +99,7 @@ def synthesis( sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, + multiresolution: bool = True, ) -> np.ndarray: r"""Computes the synthesis directional wavelet transform [1,2]. Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` by summing the contributions from wavelet and scaling coefficients in harmonic space, see equation 27 from `[2] `_. @@ -121,7 +121,7 @@ def synthesis( reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. + resolution or its own resolution. Defaults to True. Raises: AssertionError: Shape of wavelet/scaling coefficients incorrect. Returns: @@ -176,36 +176,26 @@ def analysis_looped( sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, + multiresolution: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. Args: f (np.ndarray): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. - L (int): Harmonic bandlimit. - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - spin (int, optional): Spin (integer) of input signal. Defaults to 0. - spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". - nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. - reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. + resolution or its own resolution. Defaults to True. Returns: f_wav (np.ndarray): Array of wavelet pixel-space coefficients @@ -280,38 +270,27 @@ def analysis( sampling: str = "mw", nside: int = None, reality: bool = False, - multiresolution: bool = False, + multiresolution: bool = True, scattering: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. Args: f (np.ndarray): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. - L (int): Harmonic bandlimit. - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - spin (int, optional): Spin (integer) of input signal. Defaults to 0. - spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". - nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. - reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - + resolution or its own resolution. Defaults to True. scattering (bool, optional): If using for scattering transform return absolute value of scattering coefficients. diff --git a/s2wav/transforms/construct.py b/s2wav/transforms/construct.py index 5a5ad01..42a2497 100644 --- a/s2wav/transforms/construct.py +++ b/s2wav/transforms/construct.py @@ -1,5 +1,7 @@ +import torch +import numpy as np import jax.numpy as jnp -from typing import List +from typing import List import s2fft from s2fft.precompute_transforms.construct import ( wigner_kernel_jax, @@ -7,6 +9,7 @@ ) from s2wav import samples + def generate_full_precomputes( L: int, N: int, @@ -17,6 +20,7 @@ def generate_full_precomputes( forward: bool = False, reality: bool = False, nospherical: bool = False, + using_torch: bool = False, ) -> List[jnp.ndarray]: r"""Generates a list of precompute arrays associated with the underlying Wigner transforms. @@ -45,6 +49,8 @@ def generate_full_precomputes( nospherical (bool, optional): Whether to only compute Wigner precomputes. Defaults to False. + using_torch (bool, optional): Desired frontend functionality. Defaults to False. + Returns: List[jnp.ndarray]: Precomputed recursion arrays for underlying Wigner transforms. """ @@ -52,7 +58,10 @@ def generate_full_precomputes( J = samples.j_max(L, lam) for j in range(J_min, J): Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) - precomps.append(wigner_kernel_jax(Lj, Nj, reality, sampling, nside, forward)) + kernel = wigner_kernel_jax(Lj, Nj, reality, sampling, nside, forward) + precomps.append( + torch.from_numpy(np.array(kernel)) + ) if using_torch else precomps.append(kernel) Ls = samples.scal_bandlimit(L, J_min, lam, True) if nospherical: return [], [], precomps @@ -62,8 +71,14 @@ def generate_full_precomputes( precompute_full = spin_spherical_kernel_jax( L, 0, reality, sampling, nside, not forward ) + + if using_torch: + precompute_full = torch.from_numpy(np.array(precompute_full)) + precompute_scaling = torch.from_numpy(np.array(precompute_scaling)) + return precompute_full, precompute_scaling, precomps + def generate_wigner_precomputes( L: int, N: int, @@ -72,7 +87,8 @@ def generate_wigner_precomputes( sampling: str = "mw", nside: int = None, forward: bool = False, - reality: bool = False + reality: bool = False, + using_torch: bool = False, ) -> List[jnp.ndarray]: r"""Generates a list of precompute arrays associated with the underlying Wigner transforms. @@ -98,6 +114,8 @@ def generate_wigner_precomputes( reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits conjugate symmetry of harmonic coefficients. Defaults to False. + using_torch (bool, optional): Desired frontend functionality. Defaults to False. + Returns: List[jnp.ndarray]: Precomputed recursion arrays for underlying Wigner transforms. """ @@ -105,9 +123,10 @@ def generate_wigner_precomputes( J = samples.j_max(L, lam) for j in range(J_min, J + 1): Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) - precomps.append( - s2fft.generate_precomputes_wigner_jax( - Lj, Nj, sampling, nside, forward, reality, L0j - ) + kernel = s2fft.generate_precomputes_wigner_jax( + Lj, Nj, sampling, nside, forward, reality, L0j ) - return precomps \ No newline at end of file + precomps.append( + torch.from_numpy(np.array(kernel)) + ) if using_torch else precomps.append(kernel) + return precomps diff --git a/s2wav/transforms/pre_wav_torch.py b/s2wav/transforms/pre_wav_torch.py new file mode 100644 index 0000000..53f8ae2 --- /dev/null +++ b/s2wav/transforms/pre_wav_torch.py @@ -0,0 +1,294 @@ +import torch +from typing import Tuple, List +from s2fft.precompute_transforms import wigner, spherical +from s2wav import samples + + +def synthesis( + f_wav: torch.tensor, + f_scal: torch.tensor, + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + spin: int = 0, + sampling: str = "mw", + nside: int = None, + reality: bool = False, + filters: Tuple[torch.tensor] = None, + precomps: List[List[torch.tensor]] = None, +) -> torch.tensor: + r"""Computes the synthesis directional wavelet transform [1,2]. + Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` + by summing the contributions from wavelet and scaling coefficients in harmonic space, + see equation 27 from `[2] `_. + Args: + f_wav (torch.tensor): Array of wavelet pixel-space coefficients + with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. + + f_scal (torch.tensor): Array of scaling pixel-space coefficients + with shape :math:`[n_{\theta}, n_{\phi}]`. + + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. + + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates + dyadic wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) of input signal. Defaults to 0. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", + "healpix"}. Defaults to "mw". + + nside (int, optional): HEALPix Nside resolution parameter. Only required if + sampling="healpix". Defaults to None. + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + filters (Tuple[torch.tensor], optional): Precomputed wavelet filters. Defaults to None. + + precomps (List[torch.tensor]): Precomputed list of recursion coefficients. At most + of length :math:`L^2`, which is a minimal memory overhead. + + Raises: + AssertionError: Shape of wavelet/scaling coefficients incorrect. + + Returns: + torch.tensor: Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. + + Notes: + [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the sphere", A&A, vol. 558, p. A128, 2013. + [2] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). + """ + if precomps == None: + raise ValueError("Must provide precomputed kernels for this transform!") + + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, True) + flm = torch.zeros((L, 2 * L - 1), dtype=torch.complex128) + f_scal_lm = spherical.forward_transform_torch( + f_scal, precomps[1], Ls, sampling, reality, spin, nside + ) + + # Sum the all wavelet wigner coefficients for each lmn + # Note that almost the entire compute is concentrated at the highest two scales. + for j in range(J_min, J + 1): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + shift = 0 if j < J else -1 + temp = wigner.forward_transform_torch( + f_wav[j - J_min], + precomps[2][j - J_min + shift], + Lj, + Nj, + sampling, + reality, + nside, + ) + flm[L0j:Lj, L - Lj : L - 1 + Lj] += torch.einsum( + "ln,nlm->lm", + filters[0][j, L0j:Lj, L - Nj : L - 1 + Nj : 2], + temp[::2, L0j:, :], + ) + + # Sum the all scaling harmonic coefficients for each lm + phi = filters[1][:Ls] * torch.sqrt( + 4 * torch.pi / (2 * torch.arange(Ls, dtype=torch.float64) + 1) + ) + flm[:Ls, L - Ls : L - 1 + Ls] += torch.einsum("lm,l->lm", f_scal_lm, phi) + return spherical.inverse_transform_torch( + flm, precomps[0], L, sampling, reality, spin, nside + ) + + +def analysis( + f: torch.tensor, + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + spin: int = 0, + sampling: str = "mw", + nside: int = None, + reality: bool = False, + filters: Tuple[torch.tensor] = None, + precomps: List[List[torch.tensor]] = None, +) -> Tuple[torch.tensor]: + r"""Wavelet analysis from pixel space to wavelet space for complex signals. + + Args: + f (torch.tensor): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. + + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) of input signal. Defaults to 0. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". + + nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults + to None. + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + filters (Tuple[torch.tensor], optional): Precomputed wavelet filters. Defaults to None. + + precomps (List[torch.tensor]): Precomputed list of recursion coefficients. At most + of length :math:`L^2`, which is a minimal memory overhead. + + Returns: + f_wav (torch.tensor): Array of wavelet pixel-space coefficients + with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. + + f_scal (torch.tensor): Array of scaling pixel-space coefficients + with shape :math:`[n_{\theta}, n_{\phi}]`. + """ + if precomps == None: + raise ValueError("Must provide precomputed kernels for this transform!") + + J = samples.j_max(L, lam) + Ls = samples.scal_bandlimit(L, J_min, lam, True) + + f_wav_lmn = samples.construct_flmn_torch(L, N, J_min, lam, True) + f_wav = samples.construct_f_torch(L, N, J_min, lam, sampling, nside, True) + + wav_lm = torch.einsum( + "jln, l->jln", + torch.conj(filters[0]), + 8 * torch.pi**2 / (2 * torch.arange(L, dtype=torch.float64) + 1), + ) + + flm = spherical.forward_transform_torch( + f, precomps[0], L, sampling, reality, spin, nside + ) + # Project all wigner coefficients for each lmn onto wavelet coefficients + # Note that almost the entire compute is concentrated at the highest J + for j in range(J_min, J + 1): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + f_wav_lmn[j - J_min][::2, L0j:] += torch.einsum( + "lm,ln->nlm", + flm[L0j:Lj, L - Lj : L - 1 + Lj], + wav_lm[j, L0j:Lj, L - Nj : L - 1 + Nj : 2], + ) + + shift = 0 if j < J else -1 + f_wav[j - J_min] = wigner.inverse_transform_torch( + f_wav_lmn[j - J_min], + precomps[2][j - J_min + shift], + Lj, + Nj, + sampling, + reality, + nside, + ) + + # Project all harmonic coefficients for each lm onto scaling coefficients + phi = filters[1][:Ls] * torch.sqrt( + 4 * torch.pi / (2 * torch.arange(Ls, dtype=torch.float64) + 1) + ) + temp = torch.einsum("lm,l->lm", flm[:Ls, L - Ls : L - 1 + Ls], phi) + # Handle edge case + if Ls == 1: + f_scal = temp * torch.sqrt(1 / (4 * torch.pi)) + else: + f_scal = spherical.inverse_transform_torch( + temp, precomps[1], Ls, sampling, reality, spin, nside + ) + return f_wav, f_scal + + +def flm_to_analysis( + flm: torch.tensor, + L: int, + N: int = 1, + J_min: int = 0, + J_max: int = None, + lam: float = 2.0, + sampling: str = "mw", + nside: int = None, + reality: bool = False, + filters: Tuple[torch.tensor] = None, + precomps: List[List[torch.tensor]] = None, +) -> Tuple[torch.tensor]: + r"""Wavelet analysis from pixel space to wavelet space for complex signals. + + Args: + f (torch.tensor): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. + + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) of input signal. Defaults to 0. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". + + nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults + to None. + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + filters (Tuple[torch.tensor], optional): Precomputed wavelet filters. Defaults to None. + + precomps (List[torch.tensor]): Precomputed list of recursion coefficients. At most + of length :math:`L^2`, which is a minimal memory overhead. + + Returns: + f_wav (torch.tensor): Array of wavelet pixel-space coefficients + with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. + + f_scal (torch.tensor): Array of scaling pixel-space coefficients + with shape :math:`[n_{\theta}, n_{\phi}]`. + """ + if precomps == None: + raise ValueError("Must provide precomputed kernels for this transform!") + + J = J_max if J_max is not None else samples.j_max(L, lam) + + f_wav_lmn = samples.construct_flmn_torch(L, N, J_min, lam, True) + f_wav = samples.construct_f_torch(L, N, J_min, lam, sampling, nside, True) + + wav_lm = torch.einsum( + "jln, l->jln", + torch.conj(filters), + 8 * torch.pi**2 / (2 * torch.arange(L, dtype=torch.float64) + 1), + ) + + # Project all wigner coefficients for each lmn onto wavelet coefficients + # Note that almost the entire compute is concentrated at the highest J + for j in range(J_min, J + 1): + Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) + f_wav_lmn[j - J_min][::2, L0j:] += torch.einsum( + "lm,ln->nlm", + flm[L0j:Lj, L - Lj : L - 1 + Lj], + wav_lm[j, L0j:Lj, L - Nj : L - 1 + Nj : 2], + ) + shift = 0 if j < J else -1 + f_wav[j - J_min] = wigner.inverse_transform_torch( + f_wav_lmn[j - J_min], + precomps[2][j - J_min + shift], + Lj, + Nj, + sampling, + reality, + nside, + ) + + return f_wav diff --git a/tests/conftest.py b/tests/conftest.py index 6d194a2..75de5bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from functools import partial from typing import Tuple import numpy as np +import torch import pytest import s2fft from s2fft import base_transforms as base @@ -38,9 +39,12 @@ def generate_f_wav_scal( lam: float, sampling: str = "mw", reality: bool = False, + using_torch: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: J = samples.j_max(L, lam) flmn = samples.construct_flmn(L, N, J_min, lam, True) + f_wav_s2let = np.zeros(n_wav(L, N, J_min, lam, True), dtype=np.complex128) + offset = 0 f_wav = [] for j in range(J_min, J + 1): @@ -52,7 +56,17 @@ def generate_f_wav_scal( flmn[j - J_min][Nj - 1 + n, el, Lj - 1 + m] = ( rng.uniform() + 1j * rng.uniform() ) - f_wav.append(base.wigner.inverse(flmn[j - J_min], Lj, Nj, 0, sampling, reality)) + temp = base.wigner.inverse(flmn[j - J_min], Lj, Nj, 0, sampling, reality) + + # Pys2let data entries + entries = temp.flatten("C") + f_wav_s2let[offset : offset + len(entries)] = entries + offset += len(entries) + + # S2wav data entries + if using_torch: + temp = torch.from_numpy(temp) + f_wav.append(temp) L_s = samples.scal_bandlimit(L, J_min, lam, True) flm = np.zeros((L_s, 2 * L_s - 1), dtype=np.complex128) @@ -64,8 +78,8 @@ def generate_f_wav_scal( return ( f_wav, - f_scal, - s2wav_to_s2let(f_wav, L, N, J_min, lam, True), + torch.from_numpy(f_scal) if using_torch else f_scal, + f_wav_s2let, f_scal.flatten("C"), ) @@ -76,16 +90,18 @@ def s2wav_to_s2let( N: int = 1, J_min: int = 0, lam: float = 2.0, - multiresolution: bool = False, -) -> int: + using_torch: bool = False, +) -> np.ndarray: J = samples.j_max(L, lam) - f_wav_s2let = np.zeros( - n_wav(L, N, J_min, lam, multiresolution), dtype=np.complex128 - ) + f_wav_s2let = np.zeros(n_wav(L, N, J_min, lam, True), dtype=np.complex128) offset = 0 for j in range(J_min, J + 1): - entries = f_wav[j - J_min].flatten("C") + entries = ( + f_wav[j - J_min].numpy().flatten("C") + if using_torch + else f_wav[j - J_min].flatten("C") + ) f_wav_s2let[offset : offset + len(entries)] = entries offset += len(entries) return f_wav_s2let diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 97aa88e..6783fdf 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -8,60 +8,45 @@ L_to_test = [8] N_to_test = [3] J_min_to_test = [2] -multiresolution = [False, True] reality = [False, True] sampling_to_test = ["mw", "mwss", "dh"] recursive_transform = [False, True] + @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) def test_jax_synthesis_gradients( - flm_generator, + wavelet_generator, L: int, N: int, J_min: int, - multiresolution: bool, reality: bool, - recursive: bool + recursive: bool, ): J = samples.j_max(L) + + # Exceptions if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") # Generate wavelet filters filter = filters.filters_directional_vectorised(L, N, J_min) - generator = construct.generate_wigner_precomputes if recursive else construct.generate_full_precomputes - synthesis = rec_wav_jax.synthesis if recursive else pre_wav_jax.synthesis - precomps = generator( - L, - N, - J_min, - forward=True, - reality=reality, - multiresolution=multiresolution, + generator = ( + construct.generate_wigner_precomputes + if recursive + else construct.generate_full_precomputes ) + synthesis = rec_wav_jax.synthesis if recursive else pre_wav_jax.synthesis + precomps = generator(L, N, J_min, forward=True, reality=reality) # Generate random signal - flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f = s2fft.inverse_jax(flm, L) - f_wav, f_scal = rec_wav_jax.analysis( - f, - L, - N, - J_min, - multiresolution=multiresolution, - reality=reality, - filters=filter, + f_wav, f_scal, _, _ = wavelet_generator( + L=L, N=N, J_min=J_min, lam=2.0, reality=reality ) - # Generate target signal - flm_target = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f_target = s2fft.inverse_jax(flm_target, L) - def func(f_wav, f_scal): f = synthesis( f_wav, @@ -69,12 +54,11 @@ def func(f_wav, f_scal): L, N, J_min, - multiresolution=multiresolution, reality=reality, filters=filter, precomps=precomps, ) - return jnp.sum(jnp.abs(f - f_target) ** 2) + return jnp.sum(jnp.abs(f) ** 2) check_grads( func, @@ -90,7 +74,6 @@ def func(f_wav, f_scal): @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) def test_jax_analysis_gradients( @@ -98,9 +81,8 @@ def test_jax_analysis_gradients( L: int, N: int, J_min: int, - multiresolution: bool, reality: bool, - recursive: bool + recursive: bool, ): J = samples.j_max(L) if J_min >= J: @@ -108,16 +90,13 @@ def test_jax_analysis_gradients( # Generate wavelet filters filter = filters.filters_directional_vectorised(L, N, J_min) - generator = construct.generate_wigner_precomputes if recursive else construct.generate_full_precomputes - analysis = rec_wav_jax.analysis if recursive else pre_wav_jax.analysis - precomps = generator( - L, - N, - J_min, - forward=False, - reality=reality, - multiresolution=multiresolution, + generator = ( + construct.generate_wigner_precomputes + if recursive + else construct.generate_full_precomputes ) + analysis = rec_wav_jax.analysis if recursive else pre_wav_jax.analysis + precomps = generator(L, N, J_min, forward=False, reality=reality) # Generate random signal flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) @@ -127,25 +106,12 @@ def test_jax_analysis_gradients( flm_target = flm_generator(L=L, L_lower=0, spin=0, reality=reality) f_target = s2fft.inverse_jax(flm_target, L) f_wav_target, f_scal_target = rec_wav_jax.analysis( - f_target, - L, - N, - J_min, - multiresolution=multiresolution, - reality=reality, - filters=filter, + f_target, L, N, J_min, reality=reality, filters=filter ) def func(f): f_wav, f_scal = analysis( - f, - L, - N, - J_min, - multiresolution=multiresolution, - reality=reality, - filters=filter, - precomps=precomps + f, L, N, J_min, reality=reality, filters=filter, precomps=precomps ) loss = jnp.sum(jnp.abs(f_scal - f_scal_target) ** 2) for j in range(J - J_min): diff --git a/tests/test_wavelets.py b/tests/test_wavelets.py index 31eb574..2be1759 100644 --- a/tests/test_wavelets.py +++ b/tests/test_wavelets.py @@ -1,8 +1,9 @@ import pytest import numpy as np +import torch import pys2let as s2let from s2fft import base_transforms as sht_base -from s2wav.transforms import rec_wav_jax, pre_wav_jax, construct +from s2wav.transforms import rec_wav_jax, pre_wav_jax, pre_wav_torch, construct from s2wav import filters, samples L_to_test = [8] @@ -12,6 +13,8 @@ reality = [False, True] sampling_to_test = ["mw", "mwss", "dh"] recursive_transform = [False, True] +using_torch_frontend = [False, True] + @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("N", N_to_test) @@ -19,6 +22,7 @@ @pytest.mark.parametrize("lam", lam_to_test) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) +@pytest.mark.parametrize("using_torch", using_torch_frontend) def test_jax_synthesis( wavelet_generator, L: int, @@ -26,18 +30,19 @@ def test_jax_synthesis( J_min: int, lam: int, reality: bool, - recursive: bool + recursive: bool, + using_torch: bool, ): J = samples.j_max(L, lam) + + # Exceptions if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") + if recursive and using_torch: + pytest.skip("Recursive transform not yet available for torch frontend") f_wav, f_scal, f_wav_s2let, f_scal_s2let = wavelet_generator( - L=L, - N=N, - J_min=J_min, - lam=lam, - reality=reality, + L=L, N=N, J_min=J_min, lam=lam, reality=reality, using_torch=using_torch ) f = s2let.synthesis_wav2px( @@ -51,18 +56,24 @@ def test_jax_synthesis( upsample=False, ) - filter = filters.filters_directional_vectorised(L, N, J_min, lam) - generator = construct.generate_wigner_precomputes if recursive else construct.generate_full_precomputes - synthesis = rec_wav_jax.synthesis if recursive else pre_wav_jax.synthesis + filter = filters.filters_directional_vectorised( + L, N, J_min, lam, using_torch=using_torch + ) + generator = ( + construct.generate_wigner_precomputes + if recursive + else construct.generate_full_precomputes + ) + synthesis = ( + rec_wav_jax.synthesis + if recursive + else (pre_wav_torch.synthesis if using_torch else pre_wav_jax.synthesis) + ) precomps = generator( - L, - N, - J_min, - lam, - forward=True, - reality=reality - ) + L, N, J_min, lam, forward=True, reality=reality, using_torch=using_torch + ) + f_check = synthesis( f_wav, f_scal, @@ -72,8 +83,12 @@ def test_jax_synthesis( lam, reality=reality, filters=filter, - precomps=precomps + precomps=precomps, ) + + if using_torch: + f_check = f_check.resolve_conj().numpy() + f = np.real(f) if reality else f np.testing.assert_allclose(f, f_check.flatten("C"), atol=1e-14) @@ -84,6 +99,7 @@ def test_jax_synthesis( @pytest.mark.parametrize("lam", lam_to_test) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) +@pytest.mark.parametrize("using_torch", using_torch_frontend) def test_jax_analysis( flm_generator, f_wav_converter, @@ -92,50 +108,60 @@ def test_jax_analysis( J_min: int, lam: int, reality: bool, - recursive: bool + recursive: bool, + using_torch: bool, ): J = samples.j_max(L, lam) + + # Exceptions if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") + if recursive and using_torch: + pytest.skip("Recursive transform not yet available for torch frontend") flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) f = sht_base.spherical.inverse(flm, L, reality=reality) f_wav, f_scal = s2let.analysis_px2wav( - f.flatten("C").astype(np.complex128), - lam, - L, - J_min, - N, - spin=0, - upsample=False + f.flatten("C").astype(np.complex128), lam, L, J_min, N, spin=0, upsample=False + ) + filter = filters.filters_directional_vectorised( + L, N, J_min, lam, using_torch=using_torch + ) + generator = ( + construct.generate_wigner_precomputes + if recursive + else construct.generate_full_precomputes + ) + analysis = ( + rec_wav_jax.analysis + if recursive + else (pre_wav_torch.analysis if using_torch else pre_wav_jax.analysis) ) - filter = filters.filters_directional_vectorised(L, N, J_min, lam) - generator = construct.generate_wigner_precomputes if recursive else construct.generate_full_precomputes - analysis = rec_wav_jax.analysis if recursive else pre_wav_jax.analysis precomps = generator( - L, - N, - J_min, - lam, - forward=False, - reality=reality + L, N, J_min, lam, forward=False, reality=reality, using_torch=using_torch ) f_wav_check, f_scal_check = analysis( - f, + torch.from_numpy(f) if using_torch else f, L, N, J_min, lam, reality=reality, filters=filter, - precomps=precomps + precomps=precomps, ) - f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, True) + f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, using_torch) np.testing.assert_allclose(f_wav, f_wav_check, atol=1e-14) - np.testing.assert_allclose(f_scal, f_scal_check.flatten("C"), atol=1e-14) + np.testing.assert_allclose( + f_scal, + f_scal_check.resolve_conj().numpy().flatten("C") + if using_torch + else f_scal_check.flatten("C"), + atol=1e-14, + ) @pytest.mark.parametrize("L", L_to_test) @@ -145,15 +171,11 @@ def test_jax_analysis( @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("sampling", sampling_to_test) def test_jax_round_trip( - flm_generator, - L: int, - N: int, - J_min: int, - lam: int, - reality: bool, - sampling: str + flm_generator, L: int, N: int, J_min: int, lam: int, reality: bool, sampling: str ): J = samples.j_max(L, lam) + + # Exceptions if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") @@ -162,14 +184,7 @@ def test_jax_round_trip( filter = filters.filters_directional_vectorised(L, N, J_min, lam) f_wav, f_scal = rec_wav_jax.analysis( - f, - L, - N, - J_min, - lam, - reality=reality, - sampling=sampling, - filters=filter + f, L, N, J_min, lam, reality=reality, sampling=sampling, filters=filter ) f_check = rec_wav_jax.synthesis( f_wav, @@ -180,7 +195,7 @@ def test_jax_round_trip( lam, sampling=sampling, reality=reality, - filters=filter + filters=filter, ) np.testing.assert_allclose(f, f_check, atol=1e-14) diff --git a/tests/test_wavelets_base.py b/tests/test_wavelets_base.py index 32896cd..4300caf 100644 --- a/tests/test_wavelets_base.py +++ b/tests/test_wavelets_base.py @@ -9,7 +9,6 @@ N_to_test = [2, 3] J_min_to_test = [2] lam_to_test = [2, 3] -multiresolution = [False, True] reality = [False, True] sampling_to_test = ["mw", "mwss", "dh"] @@ -18,45 +17,26 @@ @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) def test_synthesis_looped( - wavelet_generator, - L: int, - N: int, - J_min: int, - lam: int, - multiresolution: bool, - reality: bool, + wavelet_generator, L: int, N: int, J_min: int, lam: int, reality: bool ): J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") f_wav, f_scal, f_wav_s2let, f_scal_s2let = wavelet_generator( - L=L, - N=N, - J_min=J_min, - lam=lam, - multiresolution=multiresolution, - reality=reality, + L=L, N=N, J_min=J_min, lam=lam, reality=reality ) f = s2let.synthesis_wav2px( - f_wav_s2let, - f_scal_s2let, - lam, - L, - J_min, - N, - spin=0, - upsample=not multiresolution, + f_wav_s2let, f_scal_s2let, lam, L, J_min, N, spin=0, upsample=False ) f_check = wav_base.synthesis_looped( - f_wav, f_scal, L, N, J_min, lam, multiresolution=multiresolution + f_wav, f_scal, L, N, J_min, lam, reality=reality ) - + f = np.real(f) if reality else f np.testing.assert_allclose(f, f_check.flatten("C"), atol=1e-14) @@ -64,51 +44,23 @@ def test_synthesis_looped( @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) def test_synthesis_vectorised( - wavelet_generator, - L: int, - N: int, - J_min: int, - lam: int, - multiresolution: bool, - reality: bool, + wavelet_generator, L: int, N: int, J_min: int, lam: int, reality: bool ): J = samples.j_max(L, lam) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") f_wav, f_scal, f_wav_s2let, f_scal_s2let = wavelet_generator( - L=L, - N=N, - J_min=J_min, - lam=lam, - multiresolution=multiresolution, - reality=reality, + L=L, N=N, J_min=J_min, lam=lam, reality=reality ) f = s2let.synthesis_wav2px( - f_wav_s2let, - f_scal_s2let, - lam, - L, - J_min, - N, - spin=0, - upsample=not multiresolution, + f_wav_s2let, f_scal_s2let, lam, L, J_min, N, spin=0, upsample=False ) - f_check = wav_base.synthesis( - f_wav, - f_scal, - L, - N, - J_min, - lam, - multiresolution=multiresolution, - reality=reality, - ) + f_check = wav_base.synthesis(f_wav, f_scal, L, N, J_min, lam, reality=reality) f = np.real(f) if reality else f np.testing.assert_allclose(f, f_check.flatten("C"), atol=1e-14) @@ -117,7 +69,6 @@ def test_synthesis_vectorised( @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) def test_analysis_looped( flm_generator, @@ -126,7 +77,6 @@ def test_analysis_looped( N: int, J_min: int, lam: int, - multiresolution: bool, reality: bool, ): J = samples.j_max(L, lam) @@ -137,18 +87,12 @@ def test_analysis_looped( f = sht_base.spherical.inverse(flm, L, reality=reality) f_wav, f_scal = s2let.analysis_px2wav( - f.flatten("C").astype(np.complex128), - lam, - L, - J_min, - N, - spin=0, - upsample=not multiresolution, + f.flatten("C").astype(np.complex128), lam, L, J_min, N, spin=0, upsample=False ) f_wav_check, f_scal_check = wav_base.analysis_looped( - f, L, N, J_min, lam, reality=reality, multiresolution=multiresolution + f, L, N, J_min, lam, reality=reality ) - f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, multiresolution) + f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam) np.testing.assert_allclose(f_wav, f_wav_check, atol=1e-14) np.testing.assert_allclose(f_scal, f_scal_check.flatten("C"), atol=1e-14) @@ -157,17 +101,9 @@ def test_analysis_looped( @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) def test_analysis_vectorised( - flm_generator, - f_wav_converter, - L: int, - N: int, - J_min: int, - lam: int, - multiresolution: bool, - reality: bool, + flm_generator, f_wav_converter, L: int, N: int, J_min: int, lam: int, reality: bool ): J = samples.j_max(L, lam) if J_min >= J: @@ -177,19 +113,11 @@ def test_analysis_vectorised( f = sht_base.spherical.inverse(flm, L, reality=reality) f_wav, f_scal = s2let.analysis_px2wav( - f.flatten("C").astype(np.complex128), - lam, - L, - J_min, - N, - spin=0, - upsample=not multiresolution, - ) - f_wav_check, f_scal_check = wav_base.analysis( - f, L, N, J_min, lam, multiresolution=multiresolution, reality=reality + f.flatten("C").astype(np.complex128), lam, L, J_min, N, spin=0, upsample=False ) + f_wav_check, f_scal_check = wav_base.analysis(f, L, N, J_min, lam, reality=reality) - f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, multiresolution) + f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam) np.testing.assert_allclose(f_wav, f_wav_check.flatten("C"), atol=1e-14) np.testing.assert_allclose(f_scal, f_scal_check.flatten("C"), atol=1e-14) @@ -198,18 +126,10 @@ def test_analysis_vectorised( @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("sampling", sampling_to_test) def test_looped_round_trip( - flm_generator, - L: int, - N: int, - J_min: int, - lam: int, - multiresolution: bool, - reality: bool, - sampling: str, + flm_generator, L: int, N: int, J_min: int, lam: int, reality: bool, sampling: str ): J = samples.j_max(L, lam) if J_min >= J: @@ -218,30 +138,16 @@ def test_looped_round_trip( nside = int(L / 2) flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f = sht_base.spherical.inverse(flm, L, reality=reality, sampling=sampling, nside=nside) + f = sht_base.spherical.inverse( + flm, L, reality=reality, sampling=sampling, nside=nside + ) f_wav, f_scal = wav_base.analysis_looped( - f, - L, - N, - J_min, - lam, - multiresolution=multiresolution, - reality=reality, - sampling=sampling, - nside=nside, + f, L, N, J_min, lam, reality=reality, sampling=sampling, nside=nside ) f_check = wav_base.synthesis_looped( - f_wav, - f_scal, - L, - N, - J_min, - lam, - multiresolution=multiresolution, - sampling=sampling, - nside=nside, + f_wav, f_scal, L, N, J_min, lam, sampling=sampling, reality=reality, nside=nside ) np.testing.assert_allclose(f, f_check, atol=1e-14) @@ -251,18 +157,10 @@ def test_looped_round_trip( @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) -@pytest.mark.parametrize("multiresolution", multiresolution) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("sampling", sampling_to_test) def test_vectorised_round_trip( - flm_generator, - L: int, - N: int, - J_min: int, - lam: int, - multiresolution: bool, - reality: bool, - sampling: str, + flm_generator, L: int, N: int, J_min: int, lam: int, reality: bool, sampling: str ): J = samples.j_max(L, lam) if J_min >= J: @@ -272,25 +170,11 @@ def test_vectorised_round_trip( f = sht_base.spherical.inverse(flm, L, reality=reality, sampling=sampling) f_wav, f_scal = wav_base.analysis( - f, - L, - N, - J_min, - lam, - multiresolution=multiresolution, - reality=reality, - sampling=sampling, + f, L, N, J_min, lam, reality=reality, sampling=sampling ) f_check = wav_base.synthesis( - f_wav, - f_scal, - L, - N, - J_min, - lam, - multiresolution=multiresolution, - sampling=sampling, + f_wav, f_scal, L, N, J_min, lam, reality=reality, sampling=sampling ) np.testing.assert_allclose(f, f_check, atol=1e-14) From 479b451ce2486b1e32e8771db42532ad6798848d Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Fri, 8 Mar 2024 16:37:10 +0000 Subject: [PATCH 4/8] add GL sampling --- s2wav/samples.py | 28 ++++++++++++++-------------- tests/test_gradients.py | 1 - tests/test_wavelets.py | 2 +- tests/test_wavelets_base.py | 2 +- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/s2wav/samples.py b/s2wav/samples.py index d6548fe..d3971f2 100644 --- a/s2wav/samples.py +++ b/s2wav/samples.py @@ -29,8 +29,8 @@ def f_scal( consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. - Defaults to "mw". + sampling (str, optional): Spherical sampling scheme from + {"mw","mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -62,7 +62,7 @@ def n_wav_scales(L: int, N: int = 1, J_min: int = 0, lam: float = 2.0) -> int: consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. + sampling (str, optional): Spherical sampling scheme from {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". Returns: @@ -110,7 +110,7 @@ def LN_j( consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. + sampling (str, optional): Spherical sampling scheme from {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` @@ -152,8 +152,8 @@ def f_wav_j( consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. - Defaults to "mw". + sampling (str, optional): Spherical sampling scheme from + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -199,8 +199,8 @@ def construct_f( consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. - Defaults to "mw". + sampling (str, optional): Spherical sampling scheme from + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -256,8 +256,8 @@ def construct_f_jax( consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. - Defaults to "mw". + sampling (str, optional): Spherical sampling scheme from + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -312,8 +312,8 @@ def construct_f_torch( consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. - Defaults to "mw". + sampling (str, optional): Spherical sampling scheme from + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -675,8 +675,8 @@ def wavelet_shape_check( consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. - Defaults to "mw". + sampling (str, optional): Spherical sampling scheme from + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 6783fdf..ac1192a 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -9,7 +9,6 @@ N_to_test = [3] J_min_to_test = [2] reality = [False, True] -sampling_to_test = ["mw", "mwss", "dh"] recursive_transform = [False, True] diff --git a/tests/test_wavelets.py b/tests/test_wavelets.py index 2be1759..3c99ff5 100644 --- a/tests/test_wavelets.py +++ b/tests/test_wavelets.py @@ -11,7 +11,7 @@ J_min_to_test = [2] lam_to_test = [2, 3] reality = [False, True] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] recursive_transform = [False, True] using_torch_frontend = [False, True] diff --git a/tests/test_wavelets_base.py b/tests/test_wavelets_base.py index 4300caf..00807a6 100644 --- a/tests/test_wavelets_base.py +++ b/tests/test_wavelets_base.py @@ -10,7 +10,7 @@ J_min_to_test = [2] lam_to_test = [2, 3] reality = [False, True] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] @pytest.mark.parametrize("L", L_to_test) From e48424864e9c9b1631ce8832c0793597bdc3b778 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Fri, 8 Mar 2024 16:48:36 +0000 Subject: [PATCH 5/8] update readme pip_readme and docs top level index --- .pip_readme.rst | 71 ++++++++++++++++++++++++++++++------------------- README.md | 36 +++++++++++++++++++++---- docs/index.rst | 24 ++++++++++------- 3 files changed, 88 insertions(+), 43 deletions(-) diff --git a/.pip_readme.rst b/.pip_readme.rst index 1035e08..8613cc9 100644 --- a/.pip_readme.rst +++ b/.pip_readme.rst @@ -1,34 +1,38 @@ -.. image:: https://img.shields.io/badge/GitHub-s2wav-brightgreen.svg?style=flat - :target: https://github.com/astro-informatics/s2wav .. image:: https://github.com/astro-informatics/s2wav/actions/workflows/tests.yml/badge.svg?branch=main :target: https://github.com/astro-informatics/s2wav/actions/workflows/tests.yml -.. image:: https://readthedocs.org/projects/ansicolortags/badge/?version=latest - :target: https://astro-informatics.github.io/s2wav -.. image:: https://codecov.io/gh/astro-informatics/s2wav/branch/main/graph/badge.svg?token=ZES6J4K3KZ +.. image:: https://codecov.io/gh/astro-informatics/s2wav/branch/main/graph/badge.svg?token=ZES6J4K3KZ :target: https://codecov.io/gh/astro-informatics/s2wav -.. image:: https://img.shields.io/badge/License-GPL-blue.svg - :target: http://perso.crans.org/besson/LICENSE.html +.. image:: https://img.shields.io/badge/License-MIT-yellow.svg + :target: https://opensource.org/licenses/MIT .. image:: http://img.shields.io/badge/arXiv-2402.01282-orange.svg?style=flat :target: https://arxiv.org/abs/2402.01282 +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/psf/black +.. image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/drive/15E64EAQ7TIp2a3cCoXtnNgf7Ud9MYjVq?usp=sharing -s2wav +Differentiable and accelerated wavelet transform on the sphere ================================================================================================================= -Add some basic discussion about ``s2wav`` here. +`S2WAV` is a python package for computing wavelet transforms on the sphere +and rotation group, both in JAX and PyTorch. It leverages autodiff to provide differentiable +transforms, which are also deployable on modern hardware accelerators +(e.g. GPUs and TPUs), and can be mapped across multiple accelerators. -Installation -============ +More specifically, `S2WAV` provides support for scale-discretised +wavelet transforms on the sphere and rotation group (for both real and +complex signals), with support for adjoints where needed, and comes with +a variety of different optimisations (e.g. precompute or not, +multi-resolution algorithms) that one may select depending on available +resources and desired angular resolution $L$. `S2WAV` is a sister package of +`_, both of which are part of the `SAX` +project, which aims to provide comprehensive support for differentiable transforms on the +sphere and rotation group. -Add some basic installation instructions here. - Documentation ============= -Link to the full documentation (when deployed). - -Contributors -============ -Author names & Contributors +Read the full documentation `_. Attribution =========== @@ -43,18 +47,29 @@ A BibTeX entry for s2wav is: eprint = "arXiv:2402.01282" } -License -======= +we also request that you cite the following paper -s2wav is released under the GPL-3 license (see `LICENSE.txt `_), +``` +@article{price:s2fft, + author = "Matthew A. Price and Jason D. McEwen", + title = "Differentiable and accelerated spherical harmonic and Wigner transforms", + journal = "Journal of Computational Physics, submitted", + year = "2023", + eprint = "arXiv:2311.14670" +} +``` -.. code-block:: +in which the core underlying algorithms for the spherical harmonic and Wigner transforms +are developed. + +License +======= - s2wav - Copyright (C) 2022 Author names & contributors +We provide this code under an MIT open-source licence with the hope that +it will be of use to a wider community. - This program is released under the GPL-3 license (see LICENSE.txt). +Copyright 2024 Matthew Price, Jessica Whtiney, Alicja Polanska, Jason +McEwen and contributors. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +`S2WAV` is free software made available under the MIT License. For +details see the LICENSE file. diff --git a/README.md b/README.md index b31a142..a23716d 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,8 @@ # Differentiable and accelerated wavelet transform on the sphere -`S2WAV` is a JAX package for computing wavelet transforms on the sphere -and rotation group. It leverages autodiff to provide differentiable +`S2WAV` is a python package for computing wavelet transforms on the sphere +and rotation group, both in JAX and PyTorch. It leverages autodiff to provide differentiable transforms, which are also deployable on modern hardware accelerators (e.g. GPUs and TPUs), and can be mapped across multiple accelerators. @@ -20,10 +20,20 @@ wavelet transforms on the sphere and rotation group (for both real and complex signals), with support for adjoints where needed, and comes with a variety of different optimisations (e.g. precompute or not, multi-resolution algorithms) that one may select depending on available -resources and desired angular resolution $L$. `S2WAV` is a sister package of [`S2FFT`](https://github.com/astro-informatics/s2fft), both of which are part of the `SAX` project, which aims to provide comprehensive support for differentiable transforms on the sphere and rotation group. +resources and desired angular resolution $L$. `S2WAV` is a sister package of +[`S2FFT`](https://github.com/astro-informatics/s2fft), both of which are part of the `SAX` +project, which aims to provide comprehensive support for differentiable transforms on the +sphere and rotation group. ## Wavelet Transform :zap: -`S2WAV` is an updated implementation of the scale-discretised wavelet transform on the sphere, which builds upon the papers of [Leistedt et al 2013](https://arxiv.org/abs/1211.1680) and [McEwen et al 2017](https://arxiv.org/abs/1509.06749). This wavelet transform is designed to have excellent localisation and uncorrelation properties, and has been successfully adopted for various applications e.g. scattering transforms on the sphere [McEwen et al 2022](https://arxiv.org/pdf/2102.02828.pdf). The wavelet dictionary is constructed by tiling the harmonic line with infinitely differentiable Cauchy-Schwartz functions, which can straightforwardly be performed in an efficient multiresolution manner, as in the Euclidean case. This is what the directional wavelet filters look like in pixel space. +`S2WAV` is an updated implementation of the scale-discretised wavelet transform on the +sphere, which builds upon the papers of [Leistedt et al 2013](https://arxiv.org/abs/1211.1680) +and [McEwen et al 2017](https://arxiv.org/abs/1509.06749). This wavelet transform is designed to +have excellent localisation and uncorrelation properties, and has been successfully adopted for +various applications e.g. scattering transforms on the sphere [McEwen et al 2022](https://arxiv.org/pdf/2102.02828.pdf). +The wavelet dictionary is constructed by tiling the harmonic line with infinitely differentiable +Cauchy-Schwartz functions, which can straightforwardly be performed in an efficient multiresolution +manner, as in the Euclidean case. This is what the directional wavelet filters look like in pixel space.

@@ -64,7 +74,8 @@ f_wav, f_scal = s2wav.analysis(f, L, N) # Map back to signal on the sphere f = s2wav.synthesis(f_wav, f_scal, L, N) ``` -however we strongly recommend that the multiresolution argument is set to true, as this will accelerate the transform by a factor of the total number of wavelet scales, which can be around an order of magnitude. +> [!NOTE] +> However we strongly recommend that the multiresolution argument is set to true, as this will accelerate the transform by a factor of the total number of wavelet scales, which can be around an order of magnitude. ## Contributors ✨ We strongly encourage contributions from any interested developers; a @@ -105,6 +116,21 @@ A BibTeX entry for `S2WAV` is: } ``` +we also request that you cite the following paper + +``` +@article{price:s2fft, + author = "Matthew A. Price and Jason D. McEwen", + title = "Differentiable and accelerated spherical harmonic and Wigner transforms", + journal = "Journal of Computational Physics, submitted", + year = "2023", + eprint = "arXiv:2311.14670" +} +``` + +in which the core underlying algorithms for the spherical harmonic and Wigner transforms +are developed. + ## License :memo: Copyright 2024 Matthew Price, Jessica Whtiney, Alicja Polanska, Jason diff --git a/docs/index.rst b/docs/index.rst index 1598320..b3f3700 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,16 +1,20 @@ Differentiable and accelerated spherical wavelets =================================================== -``S2WAV`` is a JAX package for computing wavelet transforms on the sphere and rotation -group. It leverages autodiff to provide differentiable transforms, which are also -deployable on modern hardware accelerators (e.g. GPUs and TPUs), and can be mapped -across multiple accelerators. - -More specifically, ``S2WAV`` provides support for scale-discretised wavelet transforms -on the sphere and rotation group (for both real and complex signals), with support for -adjoints where needed, and comes with a variety of different optimisations (e.g. precompute -or not, multi-resolution algorithms) that one may select depending on available resources -and desired angular resolution :math:`L`. +`S2WAV` is a python package for computing wavelet transforms on the sphere +and rotation group, both in JAX and PyTorch. It leverages autodiff to provide differentiable +transforms, which are also deployable on modern hardware accelerators +(e.g. GPUs and TPUs), and can be mapped across multiple accelerators. + +More specifically, `S2WAV` provides support for scale-discretised +wavelet transforms on the sphere and rotation group (for both real and +complex signals), with support for adjoints where needed, and comes with +a variety of different optimisations (e.g. precompute or not, +multi-resolution algorithms) that one may select depending on available +resources and desired angular resolution $L$. `S2WAV` is a sister package of +`_, both of which are part of the `SAX` +project, which aims to provide comprehensive support for differentiable transforms on the +sphere and rotation group. Wavelet Transform |:zap:| -------------------------- From 4baa256e8478bb0a079c4c645a38ebda5881b628 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Fri, 8 Mar 2024 18:00:28 +0000 Subject: [PATCH 6/8] add torch to docs and update notebooks --- .pip_readme.rst | 20 +++++++++---------- docs/api/transforms/index.rst | 16 +++++++++++++-- docs/api/transforms/pre_wav_torch.rst | 7 +++++++ docs/index.rst | 17 +++++++++++++++- docs/tutorials/index.rst | 9 +++------ .../jax_transform/jax_transforms.nblink | 3 +++ .../numpy_transform/numpy_transforms.nblink | 3 +++ .../wavelet_transforms_precompute.nblink | 3 --- .../torch_transform/torch_transforms.nblink | 3 +++ .../transforms/wavelet_transforms.nblink | 3 --- s2wav/__init__.py | 20 +++++++++++++++++++ s2wav/transforms/pre_wav_torch.py | 5 ++++- tests/test_wavelets.py | 6 +++--- 13 files changed, 86 insertions(+), 29 deletions(-) create mode 100644 docs/api/transforms/pre_wav_torch.rst create mode 100644 docs/tutorials/jax_transform/jax_transforms.nblink create mode 100644 docs/tutorials/numpy_transform/numpy_transforms.nblink delete mode 100644 docs/tutorials/precompute_transforms/wavelet_transforms_precompute.nblink create mode 100644 docs/tutorials/torch_transform/torch_transforms.nblink delete mode 100644 docs/tutorials/transforms/wavelet_transforms.nblink diff --git a/.pip_readme.rst b/.pip_readme.rst index 8613cc9..3da1f9c 100644 --- a/.pip_readme.rst +++ b/.pip_readme.rst @@ -25,7 +25,7 @@ complex signals), with support for adjoints where needed, and comes with a variety of different optimisations (e.g. precompute or not, multi-resolution algorithms) that one may select depending on available resources and desired angular resolution $L$. `S2WAV` is a sister package of -`_, both of which are part of the `SAX` +`S2FFT `_, both of which are part of the `SAX` project, which aims to provide comprehensive support for differentiable transforms on the sphere and rotation group. @@ -49,15 +49,15 @@ A BibTeX entry for s2wav is: we also request that you cite the following paper -``` -@article{price:s2fft, - author = "Matthew A. Price and Jason D. McEwen", - title = "Differentiable and accelerated spherical harmonic and Wigner transforms", - journal = "Journal of Computational Physics, submitted", - year = "2023", - eprint = "arXiv:2311.14670" -} -``` +.. code-block:: + + @article{price:s2fft, + author = "Matthew A. Price and Jason D. McEwen", + title = "Differentiable and accelerated spherical harmonic and Wigner transforms", + journal = "Journal of Computational Physics, submitted", + year = "2023", + eprint = "arXiv:2311.14670" + } in which the core underlying algorithms for the spherical harmonic and Wigner transforms are developed. diff --git a/docs/api/transforms/index.rst b/docs/api/transforms/index.rst index 269d8b7..15ccc2c 100644 --- a/docs/api/transforms/index.rst +++ b/docs/api/transforms/index.rst @@ -36,15 +36,26 @@ Wavelet Transforms - JAX implementation of mapping from pixel to wavelet space (fully precompute). * - :func:`~s2wav.transforms.pre_wav_jax.flm_to_analysis` - JAX implementation of mapping from harmonic to wavelet space (fully precompute). + + .. list-table:: PyTorch transforms + :widths: 25 25 + :header-rows: 1 + + * - :func:`~s2wav.transforms.pre_wav_torch.synthesis` + - PyTorch implementation of mapping from wavelet to pixel space (fully precompute). + * - :func:`~s2wav.transforms.pre_wav_torch.analysis` + - PyTorch implementation of mapping from pixel to wavelet space (fully precompute). + * - :func:`~s2wav.transforms.pre_wav_torch.flm_to_analysis` + - PyTorch implementation of mapping from harmonic to wavelet space (fully precompute). .. list-table:: Matrices precomputations :widths: 25 25 :header-rows: 1 * - :func:`~s2wav.transforms.construct.generate_wigner_precomputes` - - JAX function to generate precompute arrays for underlying Wigner transforms. + - JAX/PyTorch function to generate precompute arrays for underlying Wigner transforms. * - :func:`~s2wav.transforms.construct.generate_full_precomputes` - - JAX function to generate precompute arrays for fully precompute transforms. + - JAX/PyTorch function to generate precompute arrays for fully precompute transforms. .. toctree:: :hidden: @@ -55,4 +66,5 @@ Wavelet Transforms construct rec_wav_jax pre_wav_jax + pre_wav_torch \ No newline at end of file diff --git a/docs/api/transforms/pre_wav_torch.rst b/docs/api/transforms/pre_wav_torch.rst new file mode 100644 index 0000000..bbcd73c --- /dev/null +++ b/docs/api/transforms/pre_wav_torch.rst @@ -0,0 +1,7 @@ +:html_theme.sidebar_secondary.remove: + +************************** +PyTorch Transforms (Precompute) +************************** +.. automodule:: s2wav.transforms.pre_wav_torch + :members: \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index b3f3700..8db712f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,7 +12,7 @@ complex signals), with support for adjoints where needed, and comes with a variety of different optimisations (e.g. precompute or not, multi-resolution algorithms) that one may select depending on available resources and desired angular resolution $L$. `S2WAV` is a sister package of -`_, both of which are part of the `SAX` +`S2FFT `_, both of which are part of the `SAX` project, which aims to provide comprehensive support for differentiable transforms on the sphere and rotation group. @@ -71,6 +71,21 @@ A BibTeX entry for ``S2WAV`` is: eprint = "arXiv:2402.01282" } +we also request that you cite the following paper + +.. code-block:: + + @article{price:s2fft, + author = "Matthew A. Price and Jason D. McEwen", + title = "Differentiable and accelerated spherical harmonic and Wigner transforms", + journal = "Journal of Computational Physics, submitted", + year = "2023", + eprint = "arXiv:2311.14670" + } + +in which the core underlying algorithms for the spherical harmonic and Wigner transforms +are developed. + License |:memo:| ----------------- diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 724b32c..4cc8cfa 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -23,14 +23,11 @@ To import and use ``S2WAV`` is as simple follows: # Map back to signal on the sphere f = s2wav.synthesis(f_wav, f_scal, L, N) -Benchmarking |:hourglass_flowing_sand:| -------------------------------------- -TODO: Add table here when results available - .. toctree:: :hidden: :maxdepth: 3 :caption: Jupyter Notebooks - transforms/wavelet_transforms.nblink - precompute_transforms/wavelet_transforms_precompute.nblink + numpy_transform/numpy_transforms.nblink + jax_transform/jax_transforms.nblink + torch_transform/torch_transforms.nblink diff --git a/docs/tutorials/jax_transform/jax_transforms.nblink b/docs/tutorials/jax_transform/jax_transforms.nblink new file mode 100644 index 0000000..879e792 --- /dev/null +++ b/docs/tutorials/jax_transform/jax_transforms.nblink @@ -0,0 +1,3 @@ +{ + "path": "../../../notebooks/jax_transform.ipynb" +} \ No newline at end of file diff --git a/docs/tutorials/numpy_transform/numpy_transforms.nblink b/docs/tutorials/numpy_transform/numpy_transforms.nblink new file mode 100644 index 0000000..c7e583c --- /dev/null +++ b/docs/tutorials/numpy_transform/numpy_transforms.nblink @@ -0,0 +1,3 @@ +{ + "path": "../../../notebooks/numpy_transform.ipynb" +} \ No newline at end of file diff --git a/docs/tutorials/precompute_transforms/wavelet_transforms_precompute.nblink b/docs/tutorials/precompute_transforms/wavelet_transforms_precompute.nblink deleted file mode 100644 index 2415cd6..0000000 --- a/docs/tutorials/precompute_transforms/wavelet_transforms_precompute.nblink +++ /dev/null @@ -1,3 +0,0 @@ -{ - "path": "../../assets/static_notebooks/example_notebook.ipynb" -} \ No newline at end of file diff --git a/docs/tutorials/torch_transform/torch_transforms.nblink b/docs/tutorials/torch_transform/torch_transforms.nblink new file mode 100644 index 0000000..4f42b3e --- /dev/null +++ b/docs/tutorials/torch_transform/torch_transforms.nblink @@ -0,0 +1,3 @@ +{ + "path": "../../../notebooks/torch_transform.ipynb" +} \ No newline at end of file diff --git a/docs/tutorials/transforms/wavelet_transforms.nblink b/docs/tutorials/transforms/wavelet_transforms.nblink deleted file mode 100644 index 2415cd6..0000000 --- a/docs/tutorials/transforms/wavelet_transforms.nblink +++ /dev/null @@ -1,3 +0,0 @@ -{ - "path": "../../assets/static_notebooks/example_notebook.ipynb" -} \ No newline at end of file diff --git a/s2wav/__init__.py b/s2wav/__init__.py index 58d40fc..442123c 100644 --- a/s2wav/__init__.py +++ b/s2wav/__init__.py @@ -1,3 +1,23 @@ +# ~~ Core ~~ from . import filters from . import samples + +# ~~ Aliases ~~ + +# JAX recursive transforms from .transforms.rec_wav_jax import analysis, synthesis, flm_to_analysis + +# Base transforms +from .transforms.base import analysis as analysis_base +from .transforms.base import synthesis as synthesis_base + +# JAX precompute transforms +from .transforms.pre_wav_jax import analysis as analysis_precomp_jax +from .transforms.pre_wav_jax import synthesis as synthesis_precomp_jax + +# PyTorch precompute transforms +from .transforms.pre_wav_torch import analysis as analysis_precomp_torch +from .transforms.pre_wav_torch import synthesis as synthesis_precomp_torch + +# Martix precompute functions +from .transforms import construct diff --git a/s2wav/transforms/pre_wav_torch.py b/s2wav/transforms/pre_wav_torch.py index 53f8ae2..ac6c67b 100644 --- a/s2wav/transforms/pre_wav_torch.py +++ b/s2wav/transforms/pre_wav_torch.py @@ -198,9 +198,12 @@ def analysis( 4 * torch.pi / (2 * torch.arange(Ls, dtype=torch.float64) + 1) ) temp = torch.einsum("lm,l->lm", flm[:Ls, L - Ls : L - 1 + Ls], phi) + # Handle edge case if Ls == 1: - f_scal = temp * torch.sqrt(1 / (4 * torch.pi)) + f_scal = temp * torch.sqrt( + torch.tensor(1 / (4 * torch.pi), dtype=torch.float64) + ) else: f_scal = spherical.inverse_transform_torch( temp, precomps[1], Ls, sampling, reality, spin, nside diff --git a/tests/test_wavelets.py b/tests/test_wavelets.py index 3c99ff5..68d60ff 100644 --- a/tests/test_wavelets.py +++ b/tests/test_wavelets.py @@ -23,7 +23,7 @@ @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) @pytest.mark.parametrize("using_torch", using_torch_frontend) -def test_jax_synthesis( +def test_synthesis( wavelet_generator, L: int, N: int, @@ -100,7 +100,7 @@ def test_jax_synthesis( @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) @pytest.mark.parametrize("using_torch", using_torch_frontend) -def test_jax_analysis( +def test_analysis( flm_generator, f_wav_converter, L: int, @@ -170,7 +170,7 @@ def test_jax_analysis( @pytest.mark.parametrize("lam", lam_to_test) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("sampling", sampling_to_test) -def test_jax_round_trip( +def test_round_trip( flm_generator, L: int, N: int, J_min: int, lam: int, reality: bool, sampling: str ): J = samples.j_max(L, lam) From d7f3633f502b3d377d3fa58968bdbe0a1ecaaa8c Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Tue, 9 Apr 2024 13:48:54 +0100 Subject: [PATCH 7/8] overhaul code and add torch + C backend support --- .pip_readme.rst | 6 + README.md | 68 +++++- docs/api/index.rst | 2 +- docs/api/transforms/index.rst | 45 ++-- .../{rec_wav_jax.rst => wavelet.rst} | 2 +- ...pre_wav_jax.rst => wavelet_precompute.rst} | 2 +- ...torch.rst => wavelet_precompute_torch.rst} | 2 +- docs/api/utility/index.rst | 2 +- docs/conf.py | 2 +- docs/index.rst | 9 +- docs/tutorials/index.rst | 20 +- .../jax_ssht_transform/jax_transforms.nblink | 3 + docs/user_guide/install.rst | 3 +- notebooks/jax_ssht_transform.ipynb | 153 +++++++++++++ notebooks/jax_transform.ipynb | 152 +++++++++++++ notebooks/numpy_transform.ipynb | 146 ++++++++++++ notebooks/torch_transform.ipynb | 183 ++++++++++++++++ requirements/requirements-core.txt | 2 +- requirements/requirements-tests.txt | 4 - s2wav/__init__.py | 28 +-- s2wav/transforms/__init__.py | 4 +- .../transforms/{rec_wav_jax.py => wavelet.py} | 207 ++++++++++++++---- .../{pre_wav_jax.py => wavelet_precompute.py} | 17 +- ...v_torch.py => wavelet_precompute_torch.py} | 0 setup.py | 2 +- tests/test_gradients.py | 109 ++++++--- tests/test_wavelets.py | 123 +++++++++-- 27 files changed, 1133 insertions(+), 163 deletions(-) rename docs/api/transforms/{rec_wav_jax.rst => wavelet.rst} (74%) rename docs/api/transforms/{pre_wav_jax.rst => wavelet_precompute.rst} (72%) rename docs/api/transforms/{pre_wav_torch.rst => wavelet_precompute_torch.rst} (70%) create mode 100644 docs/tutorials/jax_ssht_transform/jax_transforms.nblink create mode 100644 notebooks/jax_ssht_transform.ipynb create mode 100644 notebooks/jax_transform.ipynb create mode 100644 notebooks/numpy_transform.ipynb create mode 100644 notebooks/torch_transform.ipynb rename s2wav/transforms/{rec_wav_jax.py => wavelet.py} (67%) rename s2wav/transforms/{pre_wav_jax.py => wavelet_precompute.py} (97%) rename s2wav/transforms/{pre_wav_torch.py => wavelet_precompute_torch.py} (100%) diff --git a/.pip_readme.rst b/.pip_readme.rst index 3da1f9c..602a25e 100644 --- a/.pip_readme.rst +++ b/.pip_readme.rst @@ -29,6 +29,12 @@ resources and desired angular resolution $L$. `S2WAV` is a sister package of project, which aims to provide comprehensive support for differentiable transforms on the sphere and rotation group. +As of version 1.0.0 `S2WAV` also provides partial frontend support for PyTorch. In future +this will be expanded to full support. Also note that this release also provides JAX support +for existing C spherical harmonic libraries, specifically `SSHT`. This works be wrapping +python bindings with custom JAX frontends. Note that currently this C to JAX interoperability +is limited to CPU. + Documentation ============= diff --git a/README.md b/README.md index a23716d..67d7560 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,13 @@ resources and desired angular resolution $L$. `S2WAV` is a sister package of project, which aims to provide comprehensive support for differentiable transforms on the sphere and rotation group. +> [!TIP] +> As of version 1.0.0 `S2WAV` also provides partial frontend support for PyTorch. In future +> this will be expanded to full support. Also note that this release also provides JAX support +> for existing C spherical harmonic libraries, specifically `SSHT`. This works be wrapping +> python bindings with custom JAX frontends. Note that currently this C to JAX interoperability +> is limited to CPU. + ## Wavelet Transform :zap: `S2WAV` is an updated implementation of the scale-discretised wavelet transform on the sphere, which builds upon the papers of [Leistedt et al 2013](https://arxiv.org/abs/1211.1680) @@ -46,28 +53,42 @@ The Python dependencies for the `S2WAV` package are listed in the file into the active python environment by [pip](https://pypi.org) when running ``` bash -pip install . +pip install s2wav ``` +This will install the core functionality which includes JAX support (including PyTorch support). -from the root directory of the repository. Unit tests can then be -executed to ensure the installation was successful by running +Alternatively, the `S2WAV` package may be installed directly from GitHub by cloning this +repository and then running ``` bash -pytest tests/ +pip install . ``` -In the near future one will be able to install `S2WAV` directly from -[PyPi](https://pypi.org) by `pip install s2wav` but this is not yet supported. -Note that to run `JAX` on NVIDIA GPUs you will need to follow the -[guide](https://github.com/google/jax#installation) outlined by Google. +from the root directory. + +Unit tests can then be executed to ensure the installation was successful by first +installing the test requirements and then running pytest + +``` bash +pip install -r requirements/requirements-tests.txt +pytest tests/ +``` + +Documentation for the released version is available [here](https://astro-informatics.github.io/s2wav/). +To build the documentation locally run + +``` bash +pip install -r requirements/requirements-docs.txt +cd docs +make html +open _build/html/index.html +``` ## Usage :rocket: To import and use `S2WAV` is as simple follows: ``` python -import s2wav - # Compute wavelet coefficients f_wav, f_scal = s2wav.analysis(f, L, N) @@ -75,7 +96,32 @@ f_wav, f_scal = s2wav.analysis(f, L, N) f = s2wav.synthesis(f_wav, f_scal, L, N) ``` > [!NOTE] -> However we strongly recommend that the multiresolution argument is set to true, as this will accelerate the transform by a factor of the total number of wavelet scales, which can be around an order of magnitude. +> However we strongly recommend that the multiresolution argument is set to true, as this +> will accelerate the transform by a factor of the total number of wavelet scales, which +> can be around an order of magnitude. + +## C JAX Frontends for SSHT :bulb: + +`S2WAV` also provides JAX support for SSHT, which is a highly optimised C library which +implements the underlying spherical harmonic transforms. This works by wrapping python +bindings with custom JAX frontends. Note that this C to JAX interoperability is currently +limited to CPU. + +For example, one may call these alternate backends for the spherical wavelet transform by: + +``` python +# Compute wavelet coefficients using SSHT C library backend +f_wav, f_scal = s2wav.analysis(f, L, N, use_c_backend=True) + +# Map back to signal on the sphere using SSHT C library backend +f = s2wav.synthesis(f_wav, f_scal, L, N, use_c_backend=True) +``` +These JAX frontends supports out of the box reverse mode automatic differentiation, +and under the hood is simply linking to the C packages you are familiar with. In this +way S2fft enhances existing packages with gradient functionality for modern scientific +computing or machine learning applications! + +For further details on usage see the associated [notebooks](https://astro-informatics.github.io/s2wav/tutorials/index.html). ## Contributors ✨ We strongly encourage contributions from any interested developers; a diff --git a/docs/api/index.rst b/docs/api/index.rst index f9cd230..17bf748 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -9,7 +9,7 @@ directory structure for the software. .. toctree:: :hidden: - :maxdepth: 1 + :maxdepth: 2 :caption: Namespaces transforms/index diff --git a/docs/api/transforms/index.rst b/docs/api/transforms/index.rst index 15ccc2c..7a1fce5 100644 --- a/docs/api/transforms/index.rst +++ b/docs/api/transforms/index.rst @@ -23,35 +23,40 @@ Wavelet Transforms :widths: 25 25 :header-rows: 1 - * - :func:`~s2wav.transforms.rec_wav_jax.synthesis` + * - Function Name + - Description + * - :func:`~s2wav.transforms.wavelet.synthesis` - JAX implementation of mapping from wavelet to pixel space (Recursive). - * - :func:`~s2wav.transforms.rec_wav_jax.analysis` + * - :func:`~s2wav.transforms.wavelet.analysis` - JAX implementation of mapping from pixel to wavelet space (Recursive). - * - :func:`~s2wav.transforms.rec_wav_jax.flm_to_analysis` - - JAX implementation of mapping from harmonic to wavelet space (Recursive). - - * - :func:`~s2wav.transforms.pre_wav_jax.synthesis` + * - :func:`~s2wav.transforms.wavelet.flm_to_analysis` + - JAX implementation of mapping from harmonic to wavelet coefficients only (Recursive). + * - :func:`~s2wav.transforms.wavelet_precompute.synthesis` - JAX implementation of mapping from wavelet to pixel space (fully precompute). - * - :func:`~s2wav.transforms.pre_wav_jax.analysis` + * - :func:`~s2wav.transforms.wavelet_precompute.analysis` - JAX implementation of mapping from pixel to wavelet space (fully precompute). - * - :func:`~s2wav.transforms.pre_wav_jax.flm_to_analysis` - - JAX implementation of mapping from harmonic to wavelet space (fully precompute). - - .. list-table:: PyTorch transforms + * - :func:`~s2wav.transforms.wavelet_precompute.flm_to_analysis` + - JAX implementation of mapping from harmonic to wavelet coefficients only (fully precompute). + +.. list-table:: PyTorch transforms :widths: 25 25 :header-rows: 1 - * - :func:`~s2wav.transforms.pre_wav_torch.synthesis` + * - Function Name + - Description + * - :func:`~s2wav.transforms.wavelet_precompute_torch.synthesis` - PyTorch implementation of mapping from wavelet to pixel space (fully precompute). - * - :func:`~s2wav.transforms.pre_wav_torch.analysis` + * - :func:`~s2wav.transforms.wavelet_precompute_torch.analysis` - PyTorch implementation of mapping from pixel to wavelet space (fully precompute). - * - :func:`~s2wav.transforms.pre_wav_torch.flm_to_analysis` - - PyTorch implementation of mapping from harmonic to wavelet space (fully precompute). - - .. list-table:: Matrices precomputations + * - :func:`~s2wav.transforms.wavelet_precompute_torch.flm_to_analysis` + - PyTorch implementation of mapping from harmonic to wavelet coefficients only (fully precompute). + +.. list-table:: Matrices precomputations :widths: 25 25 :header-rows: 1 + * - Function Name + - Description * - :func:`~s2wav.transforms.construct.generate_wigner_precomputes` - JAX/PyTorch function to generate precompute arrays for underlying Wigner transforms. * - :func:`~s2wav.transforms.construct.generate_full_precomputes` @@ -64,7 +69,7 @@ Wavelet Transforms base construct - rec_wav_jax - pre_wav_jax - pre_wav_torch + wavelet + wavelet_precompute + wavelet_precompute_torch \ No newline at end of file diff --git a/docs/api/transforms/rec_wav_jax.rst b/docs/api/transforms/wavelet.rst similarity index 74% rename from docs/api/transforms/rec_wav_jax.rst rename to docs/api/transforms/wavelet.rst index ad496bb..50c2f16 100644 --- a/docs/api/transforms/rec_wav_jax.rst +++ b/docs/api/transforms/wavelet.rst @@ -3,5 +3,5 @@ ************************** JAX Transforms (Recursive) ************************** -.. automodule:: s2wav.transforms.rec_wav_jax +.. automodule:: s2wav.transforms.wavelet :members: \ No newline at end of file diff --git a/docs/api/transforms/pre_wav_jax.rst b/docs/api/transforms/wavelet_precompute.rst similarity index 72% rename from docs/api/transforms/pre_wav_jax.rst rename to docs/api/transforms/wavelet_precompute.rst index 8ba6fd9..a2cf5f0 100644 --- a/docs/api/transforms/pre_wav_jax.rst +++ b/docs/api/transforms/wavelet_precompute.rst @@ -3,5 +3,5 @@ ************************** JAX Transforms (Precompute) ************************** -.. automodule:: s2wav.transforms.pre_wav_jax +.. automodule:: s2wav.transforms.wavelet_precompute :members: \ No newline at end of file diff --git a/docs/api/transforms/pre_wav_torch.rst b/docs/api/transforms/wavelet_precompute_torch.rst similarity index 70% rename from docs/api/transforms/pre_wav_torch.rst rename to docs/api/transforms/wavelet_precompute_torch.rst index bbcd73c..c05d32f 100644 --- a/docs/api/transforms/pre_wav_torch.rst +++ b/docs/api/transforms/wavelet_precompute_torch.rst @@ -3,5 +3,5 @@ ************************** PyTorch Transforms (Precompute) ************************** -.. automodule:: s2wav.transforms.pre_wav_torch +.. automodule:: s2wav.transforms.wavelet_precompute_torch :members: \ No newline at end of file diff --git a/docs/api/utility/index.rst b/docs/api/utility/index.rst index 619164a..61969a4 100644 --- a/docs/api/utility/index.rst +++ b/docs/api/utility/index.rst @@ -56,4 +56,4 @@ Utility Functions :maxdepth: 3 :caption: Utilities - shapes \ No newline at end of file + samples \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 85f9c48..cc3b662 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -94,7 +94,7 @@ "logo_only": True, "display_version": False, "navbar_align": "left", - "announcement": "s2wav is currently in an open alpha, please provide feedback on GitHub", + "announcement": "s2wav is currently in an open beta, please provide feedback on GitHub", "show_toc_level": 2, "show_nav_level": 1, "header_links_before_dropdown": 5, diff --git a/docs/index.rst b/docs/index.rst index 8db712f..c06627a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,11 +11,18 @@ wavelet transforms on the sphere and rotation group (for both real and complex signals), with support for adjoints where needed, and comes with a variety of different optimisations (e.g. precompute or not, multi-resolution algorithms) that one may select depending on available -resources and desired angular resolution $L$. `S2WAV` is a sister package of +resources and desired angular resolution :math:`L`. `S2WAV` is a sister package of `S2FFT `_, both of which are part of the `SAX` project, which aims to provide comprehensive support for differentiable transforms on the sphere and rotation group. +.. tip:: + As of version 1.0.0 `S2WAV` also provides partial frontend support for PyTorch. In future + this will be expanded to full support. Also note that this release also provides JAX support + for existing C spherical harmonic libraries, specifically `SSHT`. This works be wrapping + python bindings with custom JAX frontends. Note that currently this C to JAX interoperability + is limited to CPU. + Wavelet Transform |:zap:| -------------------------- ``S2WAV`` is an updated implementation of the scale-discretised wavelet transform on the diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 4cc8cfa..1f545de 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -9,7 +9,7 @@ in the time being feel free to contact contributors for advice! At a high-level ``S2WAV`` package is structured such that the 2 primary transforms, the analysis and synthesis directional wavelet transforms, can easily be accessed. -Usage |:rocket:| +Core usage |:rocket:| ----------------- To import and use ``S2WAV`` is as simple follows: @@ -23,6 +23,23 @@ To import and use ``S2WAV`` is as simple follows: # Map back to signal on the sphere f = s2wav.synthesis(f_wav, f_scal, L, N) + +C backend library support |:bulb:| +---------------------------------- +``S2WAV`` also supports JAX frontend wrappers for the existing `SSHT `_ +spherical harmonic and Wigner transform C libraries which, though limited to CPU compute, are nevertheless very fast +and memory efficient when e.g. GPU compute is not available. To call this operating mode simply run + +.. code-block:: Python + + import s2wav + + # Compute wavelet coefficients + f_wav, f_scal = s2wav.analysis(f, L, N, use_c_backend=True) + + # Map back to signal on the sphere + f = s2wav.synthesis(f_wav, f_scal, L, N, use_c_backend=True) + .. toctree:: :hidden: :maxdepth: 3 @@ -30,4 +47,5 @@ To import and use ``S2WAV`` is as simple follows: numpy_transform/numpy_transforms.nblink jax_transform/jax_transforms.nblink + jax_ssht_transform/jax_transforms.nblink torch_transform/torch_transforms.nblink diff --git a/docs/tutorials/jax_ssht_transform/jax_transforms.nblink b/docs/tutorials/jax_ssht_transform/jax_transforms.nblink new file mode 100644 index 0000000..87f0817 --- /dev/null +++ b/docs/tutorials/jax_ssht_transform/jax_transforms.nblink @@ -0,0 +1,3 @@ +{ + "path": "../../../notebooks/jax_ssht_transform.ipynb" +} \ No newline at end of file diff --git a/docs/user_guide/install.rst b/docs/user_guide/install.rst index 8cb3647..df63dd7 100644 --- a/docs/user_guide/install.rst +++ b/docs/user_guide/install.rst @@ -42,7 +42,8 @@ from the root directory of the repository. Unit tests can then be executed to en installation was successful by running .. code-block:: bash - + + pip install -r requirements/requirements-tests.txt pytest tests/ Installing JAX for NVIDIA GPUs diff --git a/notebooks/jax_ssht_transform.ipynb b/notebooks/jax_ssht_transform.ipynb new file mode 100644 index 0000000..f4d25ae --- /dev/null +++ b/notebooks/jax_ssht_transform.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Wavelet transform (JAX-SSHT)\n", + "[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2wav/blob/main/notebooks/jax_ssht_transform.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install s2wav\n", + "!pip install s2wav &> /dev/null" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets start by importing some packages which we'll be using in this notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Make sure we configure 64 bit precision. \n", + "# 32 bit can be faster but you will be (potentially much) less precise.\n", + "import jax\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "import s2wav # Wavelet transforms on the sphere and rotation group\n", + "import s2fft # Spherical harmonic and Wigner transforms\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we'll define the constraints of the problem and generated some random data just for this example" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "L = 16 # Spherical harmonic bandlimit\n", + "N = 3 # Azimuthal (directional) bandlimit\n", + "sampling = \"mw\" # Sampling scheme\n", + "use_c_backend = True # Switches backend JAX harmonic and Wigner transforms to call underlying SSHT C libraries.\n", + "\n", + "# Generate a random bandlimited signal to work with\n", + "rng = np.random.default_rng(12346161)\n", + "flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n", + "f = s2fft.inverse(flm, L)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can calculate the wavelet and scaling coefficients by first building a bank of wavelet filters and the running the analysis transform" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "filter_bank = s2wav.filters.filters_directional_vectorised(L, N)\n", + "wavelet_coeffs, scaling_coeffs = s2wav.analysis(f, L, N, filters=filter_bank, use_c_backend=use_c_backend)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You'll notice that this first pass is very slow. That's because it is JIT compiling the function, so future calls to `s2wav.analysis` will be much fater! When an exact sampling theorem is chosen we can recover the original signal to machine precision by running" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "f_check = s2wav.synthesis(wavelet_coeffs, scaling_coeffs, L, N, filters=filter_bank, use_c_backend=use_c_backend)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again this first call is quite slow, but subsequent calls should be much faster. Lets double check that we actually got machine precision!" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean absolute error = 1.1009715781712507e-14\n" + ] + } + ], + "source": [ + "print(f\"Mean absolute error = {np.nanmean(np.abs(f_check - f))}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.16 ('s2wav')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "2eaa51c34c6264c479aef01ba42a63404a2d0b54fbb558b3097eeea4996caab5" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/jax_transform.ipynb b/notebooks/jax_transform.ipynb new file mode 100644 index 0000000..056963e --- /dev/null +++ b/notebooks/jax_transform.ipynb @@ -0,0 +1,152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Wavelet transform (JAX)\n", + "[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2wav/blob/main/notebooks/jax_transform.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install s2wav\n", + "!pip install s2wav &> /dev/null" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets start by importing some packages which we'll be using in this notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Make sure we configure 64 bit precision. \n", + "# 32 bit can be faster but you will be (potentially much) less precise.\n", + "import jax\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "import s2wav # Wavelet transforms on the sphere and rotation group\n", + "import s2fft # Spherical harmonic and Wigner transforms\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we'll define the constraints of the problem and generated some random data just for this example" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "L = 16 # Spherical harmonic bandlimit\n", + "N = 3 # Azimuthal (directional) bandlimit\n", + "sampling = \"mw\" # Sampling scheme\n", + "\n", + "# Generate a random bandlimited signal to work with\n", + "rng = np.random.default_rng(12346161)\n", + "flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n", + "f = s2fft.inverse(flm, L)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can calculate the wavelet and scaling coefficients by first building a bank of wavelet filters and the running the analysis transform" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "filter_bank = s2wav.filters.filters_directional_vectorised(L, N)\n", + "wavelet_coeffs, scaling_coeffs = s2wav.analysis(f, L, N, filters=filter_bank)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You'll notice that this first pass is very slow. That's because it is JIT compiling the function, so future calls to `s2wav.analysis` will be much fater! When an exact sampling theorem is chosen we can recover the original signal to machine precision by running" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "f_check = s2wav.synthesis(wavelet_coeffs, scaling_coeffs, L, N, filters=filter_bank)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again this first call is quite slow, but subsequent calls should be much faster. Lets double check that we actually got machine precision!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean absolute error = 2.068390707329961e-14\n" + ] + } + ], + "source": [ + "print(f\"Mean absolute error = {np.nanmean(np.abs(f_check - f))}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.16 ('s2wav')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "2eaa51c34c6264c479aef01ba42a63404a2d0b54fbb558b3097eeea4996caab5" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/numpy_transform.ipynb b/notebooks/numpy_transform.ipynb new file mode 100644 index 0000000..e155134 --- /dev/null +++ b/notebooks/numpy_transform.ipynb @@ -0,0 +1,146 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Wavelet transform (Numpy)\n", + "[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2wav/blob/main/notebooks/numpy_transform.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install s2wav\n", + "!pip install s2wav &> /dev/null" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets start by importing some packages which we'll be using in this notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import s2wav # Wavelet transforms on the sphere and rotation group\n", + "import s2fft # Spherical harmonic and Wigner transforms\n", + "import numpy as np " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we'll define the constraints of the problem and generated some random data just for this example" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "L = 16 # Spherical harmonic bandlimit\n", + "N = 3 # Azimuthal (directional) bandlimit\n", + "sampling = \"mw\" # Sampling scheme\n", + "\n", + "# Generate a random bandlimited signal to work with\n", + "rng = np.random.default_rng(12346161)\n", + "flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n", + "f = s2fft.inverse(flm, L)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can calculate the wavelet and scaling coefficients by running" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "wavelet_coeffs, scaling_coeffs = s2wav.analysis_base(f, L, N)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "when an exact sampling theorem is chosen we can recover the original signal to machine precision by running" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "f_check = s2wav.synthesis_base(wavelet_coeffs, scaling_coeffs, L, N)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets double check that we actually got machine precision!" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean absolute error = 2.056856753687673e-14\n" + ] + } + ], + "source": [ + "print(f\"Mean absolute error = {np.nanmean(np.abs(f_check - f))}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.16 ('s2wav')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "2eaa51c34c6264c479aef01ba42a63404a2d0b54fbb558b3097eeea4996caab5" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/torch_transform.ipynb b/notebooks/torch_transform.ipynb new file mode 100644 index 0000000..096e5dc --- /dev/null +++ b/notebooks/torch_transform.ipynb @@ -0,0 +1,183 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Wavelet transform (PyTorch)\n", + "[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2wav/blob/main/notebooks/torch_transform.ipynb)\n", + "\n", + "Note that currently we only provide precompute support for PyTorch, so these transforms will only work up until around a bandlimit of $L\\sim1024$. Support for recursive, or so called *on-the-fly*, algorithms is already provided in JAX and should reach PyTorch soon.\n", + "\n", + "Lets start by importing some packages which we'll be using in this notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install s2wav\n", + "!pip install s2wav &> /dev/null" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets start by importing some packages which we'll be using in this notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.\n" + ] + } + ], + "source": [ + "import torch # Differentiable programming ecosystem\n", + "import s2wav # Wavelet transforms on the sphere and rotation group\n", + "import s2fft # Spherical harmonic and Wigner transforms\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we'll define the constraints of the problem and generated some random data just for this example" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "L = 16 # Spherical harmonic bandlimit\n", + "N = 3 # Azimuthal (directional) bandlimit\n", + "\n", + "# Generate a random bandlimited signal to work with\n", + "rng = np.random.default_rng(12346161)\n", + "flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n", + "f = s2fft.inverse(flm, L)\n", + "\n", + "# We'll need to convert this numpy array into a torch.tensor\n", + "f_torch = torch.from_numpy(f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can calculate the wavelet and scaling coefficients by first building a bank of wavelet filters and precomputing and caching all matrices involved in the core transforms" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "filter_bank = s2wav.filters.filters_directional_vectorised(L, N, using_torch=True)\n", + "analysis_matrices = s2wav.construct.generate_full_precomputes(L, N, using_torch=True, forward=False)\n", + "synthesis_matrices = s2wav.construct.generate_full_precomputes(L, N, using_torch=True, forward=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can run the transforms, which are straightforwared linear algebra, by running" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "wavelet_coeffs, scaling_coeffs = s2wav.analysis_precomp_torch(\n", + " f_torch, L, N, filters=filter_bank, precomps=analysis_matrices\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When an exact sampling theorem is chosen we can recover the original signal to machine precision by running" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "f_check = s2wav.synthesis_precomp_torch(\n", + " wavelet_coeffs, scaling_coeffs, L, N, filters=filter_bank, precomps=synthesis_matrices\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again this first call is quite slow, but subsequent calls should be much faster. Lets double check that we actually got machine precision!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean absolute error = 2.0514116979479282e-14\n" + ] + } + ], + "source": [ + "print(f\"Mean absolute error = {np.nanmean(np.abs(f_check.resolve_conj().numpy() - f))}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.16 ('s2wav')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "2eaa51c34c6264c479aef01ba42a63404a2d0b54fbb558b3097eeea4996caab5" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements/requirements-core.txt b/requirements/requirements-core.txt index 16485d2..7fe0b81 100644 --- a/requirements/requirements-core.txt +++ b/requirements/requirements-core.txt @@ -5,4 +5,4 @@ pyyaml==6.0 scipy # For spherical transforms -s2fft >= 1.0.2 +s2fft >= 1.1.0 diff --git a/requirements/requirements-tests.txt b/requirements/requirements-tests.txt index c3b7e64..67937f9 100644 --- a/requirements/requirements-tests.txt +++ b/requirements/requirements-tests.txt @@ -13,8 +13,4 @@ jupyter==1.0.0 # For code regression pys2let -pyssht so3 - -#JAX dependencies -jaxlib diff --git a/s2wav/__init__.py b/s2wav/__init__.py index 442123c..a5737cc 100644 --- a/s2wav/__init__.py +++ b/s2wav/__init__.py @@ -2,22 +2,22 @@ from . import filters from . import samples -# ~~ Aliases ~~ +# # ~~ Aliases ~~ -# JAX recursive transforms -from .transforms.rec_wav_jax import analysis, synthesis, flm_to_analysis +# # JAX recursive transforms +# from .transforms.wavelet import analysis, synthesis, flm_to_analysis -# Base transforms -from .transforms.base import analysis as analysis_base -from .transforms.base import synthesis as synthesis_base +# # Base transforms +# from .transforms.base import analysis as analysis_base +# from .transforms.base import synthesis as synthesis_base -# JAX precompute transforms -from .transforms.pre_wav_jax import analysis as analysis_precomp_jax -from .transforms.pre_wav_jax import synthesis as synthesis_precomp_jax +# # JAX precompute transforms +# from .transforms.wavelet_precompute import analysis as analysis_precomp_jax +# from .transforms.wavelet_precompute import synthesis as synthesis_precomp_jax -# PyTorch precompute transforms -from .transforms.pre_wav_torch import analysis as analysis_precomp_torch -from .transforms.pre_wav_torch import synthesis as synthesis_precomp_torch +# # PyTorch precompute transforms +# from .transforms.wavelet_precompute_torch import analysis as analysis_precomp_torch +# from .transforms.wavelet_precompute_torch import synthesis as synthesis_precomp_torch -# Martix precompute functions -from .transforms import construct +# # Martix precompute functions +# from .transforms import construct diff --git a/s2wav/transforms/__init__.py b/s2wav/transforms/__init__.py index f109544..c6708cf 100644 --- a/s2wav/transforms/__init__.py +++ b/s2wav/transforms/__init__.py @@ -1,3 +1,5 @@ from . import base from . import construct -from . import rec_wav_jax, pre_wav_jax +from . import wavelet +from . import wavelet_precompute +from . import wavelet_precompute_torch diff --git a/s2wav/transforms/rec_wav_jax.py b/s2wav/transforms/wavelet.py similarity index 67% rename from s2wav/transforms/rec_wav_jax.py rename to s2wav/transforms/wavelet.py index fdcd920..0bae776 100644 --- a/s2wav/transforms/rec_wav_jax.py +++ b/s2wav/transforms/wavelet.py @@ -1,12 +1,13 @@ from jax import jit import jax.numpy as jnp +import numpy as np from functools import partial from typing import Tuple, List import s2fft from s2wav import samples from s2wav.transforms import construct -@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9)) + def synthesis( f_wav: jnp.ndarray, f_scal: jnp.ndarray, @@ -20,9 +21,14 @@ def synthesis( reality: bool = False, filters: Tuple[jnp.ndarray] = None, precomps: List[List[jnp.ndarray]] = None, + use_c_backend: bool = False, + _ssht_backend: int = 1, ) -> jnp.ndarray: r"""Computes the synthesis directional wavelet transform [1,2]. - Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` by summing the contributions from wavelet and scaling coefficients in harmonic space, see equation 27 from `[2] `_. + Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` + by summing the contributions from wavelet and scaling coefficients in harmonic space, + see equation 27 from `[2] `_. + Args: f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. @@ -55,8 +61,16 @@ def synthesis( precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. + use_c_backend (bool, optional): Execution mode in {"jax" = False, "jax_ssht" = True}. + Defaults to False. + + _ssht_backend (int, optional, experimental): Whether to default to SSHT core + (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental + backend. Use with caution. + Raises: AssertionError: Shape of wavelet/scaling coefficients incorrect. + ValueError: If healpix sampling is provided to SSHT C backend. Returns: jnp.ndarray: Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. @@ -65,28 +79,58 @@ def synthesis( [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the sphere", A&A, vol. 558, p. A128, 2013. [2] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). """ - if precomps == None: + if precomps == None and not use_c_backend: precomps = construct.generate_wigner_precomputes( L, N, J_min, lam, sampling, nside, True, reality ) + if use_c_backend and sampling.lower() == "healpix": + raise ValueError("SSHT C backend does not support healpix sampling.") + J = samples.j_max(L, lam) Ls = samples.scal_bandlimit(L, J_min, lam, True) flm = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) - f_scal_lm = s2fft.forward_jax(f_scal, Ls, spin, nside, sampling, reality) + + f_scal_lm = ( + s2fft.forward( + f_scal.real if reality else f_scal, + Ls, + spin, + nside, + sampling, + "jax_ssht", + reality, + _ssht_backend=_ssht_backend, + ) + if use_c_backend + else s2fft.forward_jax(f_scal, Ls, spin, nside, sampling, reality) + ) # Sum the all wavelet wigner coefficients for each lmn # Note that almost the entire compute is concentrated at the highest J for j in range(J_min, J + 1): Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) - temp = s2fft.wigner.forward_jax( - f_wav[j - J_min], - Lj, - Nj, - nside, - sampling, - reality, - precomps[j - J_min], - L_lower=L0j, + temp = ( + s2fft.wigner.forward( + f_wav[j - J_min], + Lj, + Nj, + nside, + sampling, + "jax_ssht", + reality, + _ssht_backend=_ssht_backend, + ) + if use_c_backend + else s2fft.wigner.forward_jax( + f_wav[j - J_min], + Lj, + Nj, + nside, + sampling, + reality, + precomps[j - J_min], + L_lower=L0j, + ) ) flm = flm.at[L0j:Lj, L - Lj : L - 1 + Lj].add( jnp.einsum( @@ -102,11 +146,22 @@ def synthesis( flm = flm.at[:Ls, L - Ls : L - 1 + Ls].add( jnp.einsum("lm,l->lm", f_scal_lm, phi, optimize=True) ) - - return s2fft.inverse_jax(flm, L, spin, nside, sampling, reality) + return ( + s2fft.inverse( + flm, + L, + spin, + nside, + sampling, + "jax_ssht", + reality, + _ssht_backend=_ssht_backend, + ) + if use_c_backend + else s2fft.inverse_jax(flm, L, spin, nside, sampling, reality) + ) -@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8)) def analysis( f: jnp.ndarray, L: int, @@ -119,6 +174,8 @@ def analysis( reality: bool = False, filters: Tuple[jnp.ndarray] = None, precomps: List[List[jnp.ndarray]] = None, + use_c_backend: bool = False, + _ssht_backend: int = 1, ) -> Tuple[jnp.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. @@ -149,6 +206,13 @@ def analysis( precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. + use_c_backend (bool, optional): Execution mode in {"jax" = False, "jax_ssht" = True}. + Defaults to False. + + _ssht_backend (int, optional, experimental): Whether to default to SSHT core + (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental + backend. Use with caution. + Returns: f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. @@ -156,7 +220,7 @@ def analysis( f_scal (jnp.ndarray): Array of scaling pixel-space coefficients with shape :math:`[n_{\theta}, n_{\phi}]`. """ - if precomps == None: + if precomps == None and not use_c_backend: precomps = construct.generate_wigner_precomputes( L, N, J_min, lam, sampling, nside, False, reality ) @@ -173,7 +237,20 @@ def analysis( optimize=True, ) - flm = s2fft.forward_jax(f, L, spin, nside, sampling, reality) + flm = ( + s2fft.forward( + f, + L, + spin, + nside, + sampling, + "jax_ssht", + reality, + _ssht_backend=_ssht_backend, + ) + if use_c_backend + else s2fft.forward_jax(f, L, spin, nside, sampling, reality) + ) # Project all wigner coefficients for each lmn onto wavelet coefficients # Note that almost the entire compute is concentrated at the highest J @@ -192,30 +269,56 @@ def analysis( ) ) - f_wav[j - J_min] = s2fft.wigner.inverse_jax( - f_wav_lmn[j - J_min], - Lj, - Nj, - nside, - sampling, - reality, - precomps[j - J_min], - False, - L0j, + f_wav[j - J_min] = ( + s2fft.wigner.inverse( + f_wav_lmn[j - J_min], + Lj, + Nj, + nside, + sampling, + "jax_ssht", + reality, + _ssht_backend=_ssht_backend, + ) + if use_c_backend + else s2fft.wigner.inverse_jax( + f_wav_lmn[j - J_min], + Lj, + Nj, + nside, + sampling, + reality, + precomps[j - J_min], + False, + L0j, + ) ) # Project all harmonic coefficients for each lm onto scaling coefficients phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1)) temp = jnp.einsum("lm,l->lm", flm[:Ls, L - Ls : L - 1 + Ls], phi, optimize=True) + # Handle edge case if Ls == 1: f_scal = temp * jnp.sqrt(1 / (4 * jnp.pi)) else: - f_scal = s2fft.inverse_jax(temp, Ls, spin, nside, sampling, reality) + f_scal = ( + s2fft.inverse( + temp, + Ls, + spin, + nside, + sampling, + "jax_ssht", + reality, + _ssht_backend=_ssht_backend, + ) + if use_c_backend + else s2fft.inverse_jax(temp, Ls, spin, nside, sampling, reality) + ) return f_wav, f_scal -@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8)) def flm_to_analysis( flm: jnp.ndarray, L: int, @@ -228,6 +331,8 @@ def flm_to_analysis( reality: bool = False, filters: Tuple[jnp.ndarray] = None, precomps: List[List[jnp.ndarray]] = None, + use_c_backend: bool = False, + _ssht_backend: int = 1, ) -> Tuple[jnp.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. @@ -256,11 +361,18 @@ def flm_to_analysis( precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. + use_c_backend (bool, optional): Execution mode in {"jax" = False, "jax_ssht" = True}. + Defaults to False. + + _ssht_backend (int, optional, experimental): Whether to default to SSHT core + (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental + backend. Use with caution. + Returns: f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. """ - if precomps == None: + if precomps == None and not use_c_backend: precomps = construct.generate_wigner_precomputes( L, N, J_min, lam, sampling, nside, False, reality ) @@ -294,16 +406,29 @@ def flm_to_analysis( ) ) - f_wav[j - J_min] = s2fft.wigner.inverse_jax( - f_wav_lmn[j - J_min], - Lj, - Nj, - nside, - sampling, - reality, - precomps[j - J_min], - False, - L0j, + f_wav[j - J_min] = ( + s2fft.wigner.inverse( + f_wav_lmn[j - J_min], + Lj, + Nj, + nside, + sampling, + "jax_ssht", + reality, + _ssht_backend=_ssht_backend, + ) + if use_c_backend + else s2fft.wigner.inverse_jax( + f_wav_lmn[j - J_min], + Lj, + Nj, + nside, + sampling, + reality, + precomps[j - J_min], + False, + L0j, + ) ) return f_wav diff --git a/s2wav/transforms/pre_wav_jax.py b/s2wav/transforms/wavelet_precompute.py similarity index 97% rename from s2wav/transforms/pre_wav_jax.py rename to s2wav/transforms/wavelet_precompute.py index 7750b51..ad85be7 100644 --- a/s2wav/transforms/pre_wav_jax.py +++ b/s2wav/transforms/wavelet_precompute.py @@ -5,6 +5,7 @@ from s2fft.precompute_transforms import wigner, spherical from s2wav import samples + @partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9)) def synthesis( f_wav: jnp.ndarray, @@ -21,8 +22,8 @@ def synthesis( precomps: List[List[jnp.ndarray]] = None, ) -> jnp.ndarray: r"""Computes the synthesis directional wavelet transform [1,2]. - Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` - by summing the contributions from wavelet and scaling coefficients in harmonic space, + Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` + by summing the contributions from wavelet and scaling coefficients in harmonic space, see equation 27 from `[2] `_. Args: f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients @@ -37,8 +38,8 @@ def synthesis( J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. - lam (float, optional): Wavelet parameter which determines the scale factor - between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates + lam (float, optional): Wavelet parameter which determines the scale factor + between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. spin (int, optional): Spin (integer) of input signal. Defaults to 0. @@ -83,7 +84,13 @@ def synthesis( Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True) shift = 0 if j < J else -1 temp = wigner.forward_transform_jax( - f_wav[j - J_min], precomps[2][j-J_min+shift], Lj, Nj, sampling, reality, nside + f_wav[j - J_min], + precomps[2][j - J_min + shift], + Lj, + Nj, + sampling, + reality, + nside, ) flm = flm.at[L0j:Lj, L - Lj : L - 1 + Lj].add( jnp.einsum( diff --git a/s2wav/transforms/pre_wav_torch.py b/s2wav/transforms/wavelet_precompute_torch.py similarity index 100% rename from s2wav/transforms/pre_wav_torch.py rename to s2wav/transforms/wavelet_precompute_torch.py diff --git a/setup.py b/setup.py index 6c84ddb..e57ca91 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ "Intended Audience :: Science/Research", ], name="s2wav", - version="0.0.2", + version="1.0.0", url="https://github.com/astro-informatics/s2wav", author="Authors & Contributors", license="GNU General Public License v3 (GPLv3)", diff --git a/tests/test_gradients.py b/tests/test_gradients.py index ac1192a..ac78a19 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -2,7 +2,7 @@ import jax.numpy as jnp from jax.test_util import check_grads import s2fft -from s2wav.transforms import rec_wav_jax, pre_wav_jax, construct +from s2wav.transforms import wavelet, wavelet_precompute, construct from s2wav import filters, samples L_to_test = [8] @@ -10,6 +10,8 @@ J_min_to_test = [2] reality = [False, True] recursive_transform = [False, True] +using_c_backend = [False, True] +_ssht_backends = [0, 1] @pytest.mark.parametrize("L", L_to_test) @@ -17,6 +19,8 @@ @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) +@pytest.mark.parametrize("using_c_backend", using_c_backend) +@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_jax_synthesis_gradients( wavelet_generator, L: int, @@ -24,26 +28,47 @@ def test_jax_synthesis_gradients( J_min: int, reality: bool, recursive: bool, + using_c_backend: bool, + _ssht_backend: int, ): J = samples.j_max(L) # Exceptions if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") + if not recursive and using_c_backend: + pytest.skip("Precompute transform not supported from C backend libraries.") + if reality and using_c_backend: + pytest.skip("Hermitian symmetry for C backend gradients currently conflicts.") + + # Generate random signal + f_wav, f_scal, _, _ = wavelet_generator( + L=L, N=N, J_min=J_min, lam=2, reality=reality + ) # Generate wavelet filters - filter = filters.filters_directional_vectorised(L, N, J_min) + filter = filters.filters_directional_vectorised(L, N, J_min, 2) generator = ( - construct.generate_wigner_precomputes - if recursive - else construct.generate_full_precomputes + None + if using_c_backend + else ( + construct.generate_wigner_precomputes + if recursive + else construct.generate_full_precomputes + ) ) - synthesis = rec_wav_jax.synthesis if recursive else pre_wav_jax.synthesis - precomps = generator(L, N, J_min, forward=True, reality=reality) + synthesis = wavelet.synthesis if recursive else wavelet_precompute.synthesis - # Generate random signal - f_wav, f_scal, _, _ = wavelet_generator( - L=L, N=N, J_min=J_min, lam=2.0, reality=reality + precomps = ( + None + if using_c_backend + else generator(L, N, J_min, 2, forward=True, reality=reality) + ) + + args = ( + {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} + if using_c_backend + else {} ) def func(f_wav, f_scal): @@ -56,18 +81,11 @@ def func(f_wav, f_scal): reality=reality, filters=filter, precomps=precomps, + **args, ) return jnp.sum(jnp.abs(f) ** 2) - check_grads( - func, - ( - f_wav, - f_scal, - ), - order=1, - modes=("rev"), - ) + check_grads(func, (f_wav, f_scal), order=1, modes=("rev")) @pytest.mark.parametrize("L", L_to_test) @@ -75,46 +93,67 @@ def func(f_wav, f_scal): @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) +@pytest.mark.parametrize("using_c_backend", using_c_backend) +@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_jax_analysis_gradients( flm_generator, + wavelet_generator, L: int, N: int, J_min: int, reality: bool, recursive: bool, + using_c_backend: bool, + _ssht_backend: int, ): J = samples.j_max(L) if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") - - # Generate wavelet filters - filter = filters.filters_directional_vectorised(L, N, J_min) - generator = ( - construct.generate_wigner_precomputes - if recursive - else construct.generate_full_precomputes - ) - analysis = rec_wav_jax.analysis if recursive else pre_wav_jax.analysis - precomps = generator(L, N, J_min, forward=False, reality=reality) + if not recursive and using_c_backend: + pytest.skip("Precompute transform not supported from C backend libraries.") + if reality and using_c_backend: + pytest.skip("Hermitian symmetry for C backend gradients currently conflicts.") # Generate random signal flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) f = s2fft.inverse_jax(flm, L) # Generate target signal - flm_target = flm_generator(L=L, L_lower=0, spin=0, reality=reality) - f_target = s2fft.inverse_jax(flm_target, L) - f_wav_target, f_scal_target = rec_wav_jax.analysis( - f_target, L, N, J_min, reality=reality, filters=filter + f_wav_target, f_scal_target, _, _ = wavelet_generator( + L=L, N=N, J_min=J_min, lam=2, reality=reality + ) + + # Generate wavelet filters + filter = filters.filters_directional_vectorised(L, N, J_min) + generator = ( + None + if using_c_backend + else ( + construct.generate_wigner_precomputes + if recursive + else construct.generate_full_precomputes + ) + ) + analysis = wavelet.analysis if recursive else wavelet_precompute.analysis + precomps = ( + None + if using_c_backend + else generator(L, N, J_min, forward=False, reality=reality) + ) + + args = ( + {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} + if using_c_backend + else {} ) def func(f): f_wav, f_scal = analysis( - f, L, N, J_min, reality=reality, filters=filter, precomps=precomps + f, L, N, J_min, reality=reality, filters=filter, precomps=precomps, **args ) loss = jnp.sum(jnp.abs(f_scal - f_scal_target) ** 2) for j in range(J - J_min): loss += jnp.sum(jnp.abs(f_wav[j - J_min] - f_wav_target[j - J_min]) ** 2) return loss - check_grads(func, (f,), order=1, modes=("rev")) + check_grads(func, (f.real if reality else f,), order=1, modes=("rev")) diff --git a/tests/test_wavelets.py b/tests/test_wavelets.py index 68d60ff..9f3f822 100644 --- a/tests/test_wavelets.py +++ b/tests/test_wavelets.py @@ -3,7 +3,12 @@ import torch import pys2let as s2let from s2fft import base_transforms as sht_base -from s2wav.transforms import rec_wav_jax, pre_wav_jax, pre_wav_torch, construct +from s2wav.transforms import ( + wavelet, + wavelet_precompute, + wavelet_precompute_torch, + construct, +) from s2wav import filters, samples L_to_test = [8] @@ -14,6 +19,8 @@ sampling_to_test = ["mw", "mwss", "dh", "gl"] recursive_transform = [False, True] using_torch_frontend = [False, True] +using_c_backend = [False, True] +_ssht_backends = [0, 1] @pytest.mark.parametrize("L", L_to_test) @@ -23,6 +30,8 @@ @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) @pytest.mark.parametrize("using_torch", using_torch_frontend) +@pytest.mark.parametrize("using_c_backend", using_c_backend) +@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_synthesis( wavelet_generator, L: int, @@ -32,6 +41,8 @@ def test_synthesis( reality: bool, recursive: bool, using_torch: bool, + using_c_backend: bool, + _ssht_backend: int, ): J = samples.j_max(L, lam) @@ -39,7 +50,9 @@ def test_synthesis( if J_min >= J: pytest.skip("J_min larger than J which isn't a valid test case.") if recursive and using_torch: - pytest.skip("Recursive transform not yet available for torch frontend") + pytest.skip("Recursive transform not yet available for torch frontend.") + if not recursive and using_c_backend: + pytest.skip("Precompute transform not supported from C backend libraries.") f_wav, f_scal, f_wav_s2let, f_scal_s2let = wavelet_generator( L=L, N=N, J_min=J_min, lam=lam, reality=reality, using_torch=using_torch @@ -60,18 +73,34 @@ def test_synthesis( L, N, J_min, lam, using_torch=using_torch ) generator = ( - construct.generate_wigner_precomputes - if recursive - else construct.generate_full_precomputes + None + if using_c_backend + else ( + construct.generate_wigner_precomputes + if recursive + else construct.generate_full_precomputes + ) ) synthesis = ( - rec_wav_jax.synthesis + wavelet.synthesis if recursive - else (pre_wav_torch.synthesis if using_torch else pre_wav_jax.synthesis) + else ( + wavelet_precompute_torch.synthesis + if using_torch + else wavelet_precompute.synthesis + ) ) - - precomps = generator( - L, N, J_min, lam, forward=True, reality=reality, using_torch=using_torch + precomps = ( + None + if using_c_backend + else generator( + L, N, J_min, lam, forward=True, reality=reality, using_torch=using_torch + ) + ) + args = ( + {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} + if using_c_backend + else {} ) f_check = synthesis( @@ -84,6 +113,7 @@ def test_synthesis( reality=reality, filters=filter, precomps=precomps, + **args, ) if using_torch: @@ -100,6 +130,8 @@ def test_synthesis( @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("recursive", recursive_transform) @pytest.mark.parametrize("using_torch", using_torch_frontend) +@pytest.mark.parametrize("using_c_backend", using_c_backend) +@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_analysis( flm_generator, f_wav_converter, @@ -110,6 +142,8 @@ def test_analysis( reality: bool, recursive: bool, using_torch: bool, + using_c_backend: bool, + _ssht_backend: int, ): J = samples.j_max(L, lam) @@ -118,6 +152,8 @@ def test_analysis( pytest.skip("J_min larger than J which isn't a valid test case.") if recursive and using_torch: pytest.skip("Recursive transform not yet available for torch frontend") + if not recursive and using_c_backend: + pytest.skip("Precompute transform not supported from C backend libraries.") flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) f = sht_base.spherical.inverse(flm, L, reality=reality) @@ -129,18 +165,37 @@ def test_analysis( L, N, J_min, lam, using_torch=using_torch ) generator = ( - construct.generate_wigner_precomputes - if recursive - else construct.generate_full_precomputes + None + if using_c_backend + else ( + construct.generate_wigner_precomputes + if recursive + else construct.generate_full_precomputes + ) ) analysis = ( - rec_wav_jax.analysis + wavelet.analysis if recursive - else (pre_wav_torch.analysis if using_torch else pre_wav_jax.analysis) + else ( + wavelet_precompute_torch.analysis + if using_torch + else wavelet_precompute.analysis + ) + ) + precomps = ( + None + if using_c_backend + else generator( + L, N, J_min, lam, forward=False, reality=reality, using_torch=using_torch + ) ) - precomps = generator( - L, N, J_min, lam, forward=False, reality=reality, using_torch=using_torch + + args = ( + {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} + if using_c_backend + else {} ) + f_wav_check, f_scal_check = analysis( torch.from_numpy(f) if using_torch else f, L, @@ -150,6 +205,7 @@ def test_analysis( reality=reality, filters=filter, precomps=precomps, + **args, ) f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, using_torch) @@ -170,8 +226,18 @@ def test_analysis( @pytest.mark.parametrize("lam", lam_to_test) @pytest.mark.parametrize("reality", reality) @pytest.mark.parametrize("sampling", sampling_to_test) +@pytest.mark.parametrize("using_c_backend", using_c_backend) +@pytest.mark.parametrize("_ssht_backend", _ssht_backends) def test_round_trip( - flm_generator, L: int, N: int, J_min: int, lam: int, reality: bool, sampling: str + flm_generator, + L: int, + N: int, + J_min: int, + lam: int, + reality: bool, + sampling: str, + using_c_backend: bool, + _ssht_backend: int, ): J = samples.j_max(L, lam) @@ -183,10 +249,24 @@ def test_round_trip( f = sht_base.spherical.inverse(flm, L, reality=reality, sampling=sampling) filter = filters.filters_directional_vectorised(L, N, J_min, lam) - f_wav, f_scal = rec_wav_jax.analysis( - f, L, N, J_min, lam, reality=reality, sampling=sampling, filters=filter + args = ( + {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} + if using_c_backend + else {} + ) + + f_wav, f_scal = wavelet.analysis( + f, + L, + N, + J_min, + lam, + reality=reality, + sampling=sampling, + filters=filter, + **args, ) - f_check = rec_wav_jax.synthesis( + f_check = wavelet.synthesis( f_wav, f_scal, L, @@ -196,6 +276,7 @@ def test_round_trip( sampling=sampling, reality=reality, filters=filter, + **args, ) np.testing.assert_allclose(f, f_check, atol=1e-14) From 6394602306a68100659aff7936db15b4c85530a3 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Tue, 9 Apr 2024 14:29:02 +0100 Subject: [PATCH 8/8] increase codecov by catching edge cases --- s2wav/filters.py | 22 +++++---- tests/test_filters.py | 27 +++++++++++ tests/test_gradients.py | 3 ++ tests/test_wavelets.py | 95 +++++++++++++++++++++++++++++++++++++ tests/test_wavelets_base.py | 3 ++ 5 files changed, 140 insertions(+), 10 deletions(-) diff --git a/s2wav/filters.py b/s2wav/filters.py index ba81f49..fe8c457 100644 --- a/s2wav/filters.py +++ b/s2wav/filters.py @@ -132,7 +132,7 @@ def filters_directional( if kappa0[el] != 0: phi[el] = np.sqrt((2 * el + 1) / (4.0 * np.pi)) * kappa0[el] if spin0 != 0: - phi[el] *= spin_normalization(el, spin0) * (-1) ** spin0 + phi[el] *= _spin_normalization(el, spin0) * (-1) ** spin0 for j in range(J_min, J + 1): for el in range(el_min, L): @@ -146,7 +146,7 @@ def filters_directional( ) if spin0 != 0: psi[j, el, L - 1 + m] *= ( - spin_normalization(el, spin0) * (-1) ** spin0 + _spin_normalization(el, spin0) * (-1) ** spin0 ) if using_torch: psi = torch.from_numpy(psi) @@ -227,7 +227,7 @@ def filters_directional_vectorised( el_min = max(abs(spin), abs(spin0)) spin_norms = ( - (-1) ** spin0 * spin_normalization_vectorised(np.arange(L), spin0) + (-1) ** spin0 * _spin_normalization_vectorised(np.arange(L), spin0) if spin0 != 0 else 1 ) @@ -323,7 +323,9 @@ def filters_directional_jax( el_min = max(abs(spin), abs(spin0)) spin_norms = ( - (-1) ** spin0 * spin_normalization_jax(np.arange(L), spin0) if spin0 != 0 else 1 + (-1) ** spin0 * _spin_normalization_jax(np.arange(L), spin0) + if spin0 != 0 + else 1 ) kappa, kappa0 = filters_axisym_jax(L, J_min, lam) @@ -476,7 +478,7 @@ def k_lam(L: int, lam: float = 2.0, quad_iters: int = 300) -> float: @partial(jit, static_argnums=(2, 3)) # not sure -def part_scaling_fn_jax(a: float, b: float, n: int, lam: float = 2.0) -> float: +def _part_scaling_fn_jax(a: float, b: float, n: int, lam: float = 2.0) -> float: r"""JAX version of part_scaling_fn. Computes integral used to calculate smoothly decreasing function :math:`k_{\lambda}`. @@ -627,7 +629,7 @@ def tiling_direction(L: int, N: int = 1) -> np.ndarray: return s_elm -def spin_normalization(el: int, spin: int = 0) -> float: +def _spin_normalization(el: int, spin: int = 0) -> float: r"""Computes the normalization factor for spin-lowered wavelets, which is :math:`\sqrt{\frac{(\ell+s)!}{(\ell-s)!}}`. @@ -650,8 +652,8 @@ def spin_normalization(el: int, spin: int = 0) -> float: return np.sqrt(1.0 / factor) -def spin_normalization_vectorised(el: np.ndarray, spin: int = 0) -> float: - r"""Vectorised version of :func:`~spin_normalization`. +def _spin_normalization_vectorised(el: np.ndarray, spin: int = 0) -> float: + r"""Vectorised version of :func:`~_spin_normalization`. Args: el (int): Harmonic index :math:`\ell`. spin (int): Spin of field over which to perform the transform. Defaults to 0. @@ -714,8 +716,8 @@ def tiling_direction_jax(L: int, N: int = 1) -> np.ndarray: @partial(jit, static_argnums=(1)) -def spin_normalization_jax(el: np.ndarray, spin: int = 0) -> float: - r"""JAX version of :func:`~spin_normalization`. +def _spin_normalization_jax(el: np.ndarray, spin: int = 0) -> float: + r"""JAX version of :func:`~_spin_normalization`. Args: el (int): Harmonic index :math:`\ell`. spin (int): Spin of field over which to perform the transform. Defaults to 0. diff --git a/tests/test_filters.py b/tests/test_filters.py index 283d801..13e8e59 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,3 +1,6 @@ +import jax + +jax.config.update("jax_enable_x64", True) import pytest import numpy as np from s2wav import filters, samples @@ -94,6 +97,18 @@ def test_directional_vectorised(L: int, N: int, J_min: int, lam: int): np.testing.assert_allclose(f[i], f_vect[i], rtol=1e-14) +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("J_min", J_min_to_test) +@pytest.mark.parametrize("lam", lam_to_test) +def test_directional_torch(L: int, N: int, J_min: int, lam: int): + f = filters.filters_directional(L, N, J_min, lam) + f_vect = filters.filters_directional_vectorised(L, N, J_min, lam, using_torch=True) + + for i in range(2): + np.testing.assert_allclose(f[i], f_vect[i], rtol=1e-14) + + @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("J_min", J_min_to_test) @pytest.mark.parametrize("lam", lam_to_test) @@ -115,3 +130,15 @@ def test_directional_jax(L: int, N: int, J_min: int, lam: int): for i in range(2): np.testing.assert_allclose(f[i], f_jax[i], rtol=1e-13, atol=1e-13) + + +def test_filter_exceptions(): + L = 8 + with pytest.raises(ValueError) as e: + filters.filters_axisym(L, 10) + + with pytest.raises(ValueError) as e: + filters.filters_axisym_vectorised(L, 10) + + with pytest.raises(ValueError) as e: + filters.filters_axisym_jax(L, 10) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index ac78a19..7fd860b 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,3 +1,6 @@ +import jax + +jax.config.update("jax_enable_x64", True) import pytest import jax.numpy as jnp from jax.test_util import check_grads diff --git a/tests/test_wavelets.py b/tests/test_wavelets.py index 9f3f822..fe7121f 100644 --- a/tests/test_wavelets.py +++ b/tests/test_wavelets.py @@ -1,3 +1,6 @@ +import jax + +jax.config.update("jax_enable_x64", True) import pytest import numpy as np import torch @@ -280,3 +283,95 @@ def test_round_trip( ) np.testing.assert_allclose(f, f_check, atol=1e-14) + + +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("J_min", J_min_to_test) +@pytest.mark.parametrize("lam", lam_to_test) +@pytest.mark.parametrize("reality", reality) +@pytest.mark.parametrize("recursive", recursive_transform) +@pytest.mark.parametrize("using_torch", using_torch_frontend) +@pytest.mark.parametrize("using_c_backend", using_c_backend) +@pytest.mark.parametrize("_ssht_backend", _ssht_backends) +def test_flm_to_analysis( + flm_generator, + f_wav_converter, + L: int, + N: int, + J_min: int, + lam: int, + reality: bool, + recursive: bool, + using_torch: bool, + using_c_backend: bool, + _ssht_backend: int, +): + J = samples.j_max(L, lam) + + # Exceptions + if J_min >= J: + pytest.skip("J_min larger than J which isn't a valid test case.") + if recursive and using_torch: + pytest.skip("Recursive transform not yet available for torch frontend") + if not recursive and using_c_backend: + pytest.skip("Precompute transform not supported from C backend libraries.") + + flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) + f = sht_base.spherical.inverse(flm, L, reality=reality) + + f_wav, _ = s2let.analysis_px2wav( + f.flatten("C").astype(np.complex128), lam, L, J_min, N, spin=0, upsample=False + ) + filter = filters.filters_directional_vectorised( + L, N, J_min, lam, using_torch=using_torch + )[0] + + generator = ( + None + if using_c_backend + else ( + construct.generate_wigner_precomputes + if recursive + else construct.generate_full_precomputes + ) + ) + analysis = ( + wavelet.flm_to_analysis + if recursive + else ( + wavelet_precompute_torch.flm_to_analysis + if using_torch + else wavelet_precompute.flm_to_analysis + ) + ) + precomps = ( + None + if using_c_backend + else generator( + L, N, J_min, lam, forward=False, reality=reality, using_torch=using_torch + ) + ) + + args = ( + {"use_c_backend": using_c_backend, "_ssht_backend": _ssht_backend} + if using_c_backend + else {} + ) + + f_wav_check = analysis( + torch.from_numpy(flm) if using_torch else flm, + L, + N, + J_min, + None, + lam, + reality=reality, + filters=filter, + precomps=precomps, + **args, + ) + + f_wav_check = f_wav_converter(f_wav_check, L, N, J_min, lam, using_torch) + + np.testing.assert_allclose(f_wav, f_wav_check, atol=1e-14) diff --git a/tests/test_wavelets_base.py b/tests/test_wavelets_base.py index 00807a6..381cbaa 100644 --- a/tests/test_wavelets_base.py +++ b/tests/test_wavelets_base.py @@ -1,3 +1,6 @@ +import jax + +jax.config.update("jax_enable_x64", True) import pytest import numpy as np import pys2let as s2let