From 1a24a43adfd711e4234645cb43bbb2b6c34a3d49 Mon Sep 17 00:00:00 2001 From: unalmis Date: Mon, 26 Aug 2024 17:41:18 -0400 Subject: [PATCH] Fix bug in Fourier bounce with interpolation of b_sup_z --- desc/compute/_bootstrap.py | 2 +- desc/compute/_equil.py | 2 +- desc/compute/_field.py | 2 +- desc/compute/_metric.py | 2 +- desc/compute/_profiles.py | 2 +- desc/compute/_stability.py | 2 +- desc/integrals/basis.py | 168 +++++++++--------- desc/integrals/bounce_integral.py | 52 ++++-- desc/integrals/bounce_utils.py | 4 +- desc/integrals/interp_utils.py | 27 +-- desc/integrals/quad_utils.py | 33 ++-- tests/test_fourier_bounce.py | 206 --------------------- tests/test_integrals.py | 286 ++++++++++++++++++++++++++---- tests/test_interp_utils.py | 127 ++++++------- tests/test_quad_utils.py | 26 ++- 15 files changed, 486 insertions(+), 455 deletions(-) delete mode 100644 tests/test_fourier_bounce.py diff --git a/desc/compute/_bootstrap.py b/desc/compute/_bootstrap.py index 48af83b4e..2329682c0 100644 --- a/desc/compute/_bootstrap.py +++ b/desc/compute/_bootstrap.py @@ -13,7 +13,7 @@ from scipy.special import roots_legendre from ..backend import fori_loop, jnp -from ..integrals import surface_averages_map +from ..integrals.surface_integral import surface_averages_map from .data_index import register_compute_fun diff --git a/desc/compute/_equil.py b/desc/compute/_equil.py index 7cd01491e..2cb7a9360 100644 --- a/desc/compute/_equil.py +++ b/desc/compute/_equil.py @@ -14,7 +14,7 @@ from desc.backend import jnp -from ..integrals import surface_averages +from ..integrals.surface_integral import surface_averages from .data_index import register_compute_fun from .utils import cross, dot, safediv, safenorm diff --git a/desc/compute/_field.py b/desc/compute/_field.py index 31a9d58a1..8af2e8368 100644 --- a/desc/compute/_field.py +++ b/desc/compute/_field.py @@ -13,7 +13,7 @@ from desc.backend import jnp -from ..integrals import ( +from ..integrals.surface_integral import ( surface_averages, surface_integrals_map, surface_max, diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index 536bd05bb..ceb670338 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -13,7 +13,7 @@ from desc.backend import jnp -from ..integrals import surface_averages +from ..integrals.surface_integral import surface_averages from .data_index import register_compute_fun from .utils import cross, dot, safediv, safenorm diff --git a/desc/compute/_profiles.py b/desc/compute/_profiles.py index 84de48e57..4a647fdfa 100644 --- a/desc/compute/_profiles.py +++ b/desc/compute/_profiles.py @@ -13,7 +13,7 @@ from desc.backend import cond, jnp -from ..integrals import surface_averages, surface_integrals +from ..integrals.surface_integral import surface_averages, surface_integrals from .data_index import register_compute_fun from .utils import cumtrapz, dot, safediv diff --git a/desc/compute/_stability.py b/desc/compute/_stability.py index 4a985a4dc..3b820f83b 100644 --- a/desc/compute/_stability.py +++ b/desc/compute/_stability.py @@ -13,7 +13,7 @@ from desc.backend import jnp -from ..integrals import surface_integrals_map +from ..integrals.surface_integral import surface_integrals_map from .data_index import register_compute_fun from .utils import dot diff --git a/desc/integrals/basis.py b/desc/integrals/basis.py index 68422ea68..0baa6ae80 100644 --- a/desc/integrals/basis.py +++ b/desc/integrals/basis.py @@ -319,6 +319,86 @@ def N(self): """Chebyshev spectral resolution.""" return self.cheb.shape[-1] + def isomorphism_to_C1(self, y): + """Return coordinates z ∈ ℂ isomorphic to (x, y) ∈ ℂ². + + Maps row x of y to z = y + f(x) where f(x) = x * |domain|. + + Parameters + ---------- + y : jnp.ndarray + Shape (..., y.shape[-2], y.shape[-1]). + Second to last axis iterates the rows. + + Returns + ------- + z : jnp.ndarray + Shape y.shape. + Isomorphic coordinates. + + """ + assert y.ndim >= 2 + z_shift = jnp.arange(y.shape[-2]) * (self.domain[-1] - self.domain[0]) + z = y + z_shift[:, jnp.newaxis] + return z + + def isomorphism_to_C2(self, z): + """Return coordinates (x, y) ∈ ℂ² isomorphic to z ∈ ℂ. + + Returns index x and minimum value y such that + z = f(x) + y where f(x) = x * |domain|. + + Parameters + ---------- + z : jnp.ndarray + Shape z.shape. + + Returns + ------- + x_idx, y_val : (jnp.ndarray, jnp.ndarray) + Shape z.shape. + Isomorphic coordinates. + + """ + x_idx, y_val = jnp.divmod(z - self.domain[0], self.domain[-1] - self.domain[0]) + x_idx = x_idx.astype(int) + y_val += self.domain[0] + return x_idx, y_val + + def eval1d(self, z, cheb=None): + """Evaluate piecewise Chebyshev series at coordinates z. + + Parameters + ---------- + z : jnp.ndarray + Shape (..., *cheb.shape[:-2], z.shape[-1]). + Coordinates in [sef.domain[0], ∞). + The coordinates z ∈ ℝ are assumed isomorphic to (x, y) ∈ ℝ² where + ``z // domain`` yields the index into the proper Chebyshev series + along the second to last axis of ``cheb`` and ``z % domain`` is + the coordinate value on the domain of that Chebyshev series. + cheb : jnp.ndarray + Shape (..., M, N). + Chebyshev coefficients to use. If not given, uses ``self.cheb``. + + Returns + ------- + f : jnp.ndarray + Shape z.shape. + Chebyshev basis evaluated at z. + + """ + cheb = _chebcast(setdefault(cheb, self.cheb), z) + N = cheb.shape[-1] + x_idx, y = self.isomorphism_to_C2(z) + y = bijection_to_disc(y, self.domain[0], self.domain[1]) + # Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z]) + # are held in cheb with shape (..., num cheb series, N). + cheb = jnp.take_along_axis(cheb, x_idx[..., jnp.newaxis], axis=-2) + f = idct_non_uniform(y, cheb, N) + assert f.shape == z.shape + return f + def intersect2d(self, k=0.0, eps=_eps): """Coordinates yᵢ such that f(x, yᵢ) = k(x). @@ -434,86 +514,6 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0): z2 = jnp.where(mask, z2, pad_value) return z1, z2 - def eval1d(self, z, cheb=None): - """Evaluate piecewise Chebyshev series at coordinates z. - - Parameters - ---------- - z : jnp.ndarray - Shape (..., *cheb.shape[:-2], z.shape[-1]). - Coordinates in [sef.domain[0], ∞). - The coordinates z ∈ ℝ are assumed isomorphic to (x, y) ∈ ℝ² where - ``z // domain`` yields the index into the proper Chebyshev series - along the second to last axis of ``cheb`` and ``z % domain`` is - the coordinate value on the domain of that Chebyshev series. - cheb : jnp.ndarray - Shape (..., M, N). - Chebyshev coefficients to use. If not given, uses ``self.cheb``. - - Returns - ------- - f : jnp.ndarray - Shape z.shape. - Chebyshev basis evaluated at z. - - """ - cheb = _chebcast(setdefault(cheb, self.cheb), z) - N = cheb.shape[-1] - x_idx, y = self.isomorphism_to_C2(z) - y = bijection_to_disc(y, self.domain[0], self.domain[1]) - # Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z]) - # are held in cheb with shape (..., num cheb series, N). - cheb = jnp.take_along_axis(cheb, x_idx[..., jnp.newaxis], axis=-2) - f = idct_non_uniform(y, cheb, N) - assert f.shape == z.shape - return f - - def isomorphism_to_C1(self, y): - """Return coordinates z ∈ ℂ isomorphic to (x, y) ∈ ℂ². - - Maps row x of y to z = y + f(x) where f(x) = x * |domain|. - - Parameters - ---------- - y : jnp.ndarray - Shape (..., y.shape[-2], y.shape[-1]). - Second to last axis iterates the rows. - - Returns - ------- - z : jnp.ndarray - Shape y.shape. - Isomorphic coordinates. - - """ - assert y.ndim >= 2 - z_shift = jnp.arange(y.shape[-2]) * (self.domain[-1] - self.domain[0]) - z = y + z_shift[:, jnp.newaxis] - return z - - def isomorphism_to_C2(self, z): - """Return coordinates (x, y) ∈ ℂ² isomorphic to z ∈ ℂ. - - Returns index x and minimum value y such that - z = f(x) + y where f(x) = x * |domain|. - - Parameters - ---------- - z : jnp.ndarray - Shape z.shape. - - Returns - ------- - x_idx, y_val : (jnp.ndarray, jnp.ndarray) - Shape z.shape. - Isomorphic coordinates. - - """ - x_idx, y_val = jnp.divmod(z - self.domain[0], self.domain[-1] - self.domain[0]) - x_idx = x_idx.astype(int) - y_val += self.domain[0] - return x_idx, y_val - def _check_shape(self, z1, z2, k): """Return shapes that broadcast with (k.shape[0], *self.cheb.shape[:-2], W).""" # Ensure pitch batch dim exists and add back dim to broadcast with wells. @@ -533,7 +533,9 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs): z1, z2 : jnp.ndarray Shape must broadcast with (*self.cheb.shape[:-2], W). ``z1``, ``z2`` holds intersects satisfying ∂f/∂y <= 0, ∂f/∂y >= 0, - respectively. + respectively. The points are grouped and ordered such that the + straight line path between the intersects in ``z1`` and ``z2`` + resides in the epigraph of f. k : jnp.ndarray Shape must broadcast with *self.cheb.shape[:-2]. k such that fₓ(yᵢ) = k. @@ -582,8 +584,8 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs): assert not err_1[idx], "Intersects have an inversion.\n" assert not err_2[idx], "Detected discontinuity.\n" assert not err_3[idx], ( - "Detected f > k in well, implying a path between z1 and z2 " - "is in hypograph(f). Increase Chebyshev resolution.\n" + "Detected f > k in well, implying the straight line path between " + "z1 and z2 is in hypograph(f). Increase spectral resolution.\n" f"{f_midpoint[idx][mask[idx]]} > {k[idx] + self._eps}" ) idx = (slice(None), *l) diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index b78af9496..8f7df5024 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -3,7 +3,7 @@ from interpax import CubicHermiteSpline from orthax.legendre import leggauss -from desc.backend import jnp +from desc.backend import jnp, rfft2 from desc.integrals.basis import FourierChebyshevBasis from desc.integrals.bounce_utils import ( _check_bounce_points, @@ -13,12 +13,7 @@ interp_to_argmin_B_soft, plot_ppoly, ) -from desc.integrals.interp_utils import ( - interp_rfft2, - irfft2_non_uniform, - polyder_vec, - transform_to_desc, -) +from desc.integrals.interp_utils import interp_rfft2, irfft2_non_uniform, polyder_vec from desc.integrals.quad_utils import ( automorphism_sin, bijection_from_disc, @@ -70,6 +65,7 @@ def _transform_to_clebsch(grid, desc_from_clebsch, M, N, B): # which is not a tensor product node set in DESC space. xq=desc_from_clebsch[:, 1:].reshape(grid.num_rho, -1, 2), f=grid.meshgrid_reshape(B, order="rtz")[:, jnp.newaxis], + # Real fft over poloidal since usually num theta > num zeta. axes=(-1, -2), ).reshape(grid.num_rho, M, N), domain=Bounce2D.domain, @@ -77,6 +73,30 @@ def _transform_to_clebsch(grid, desc_from_clebsch, M, N, B): return T, B +def _transform_to_desc(grid, f): + """Transform to DESC spectral domain. + + Parameters + ---------- + grid : Grid + Tensor-product grid in (θ, ζ) with uniformly spaced nodes in + (2π × 2π) poloidal and toroidal coordinates. + f : jnp.ndarray + Function evaluated on ``grid``. + + Returns + ------- + a : jnp.ndarray + Shape (grid.num_rho, grid.num_theta // 2 + 1, grid.num_zeta) + Complex coefficients of 2D real FFT. + + """ + f = grid.meshgrid_reshape(f, order="rtz") + a = rfft2(f, axes=(-1, -2), norm="forward") + assert a.shape == (grid.num_rho, grid.num_theta // 2 + 1, grid.num_zeta) + return a + + # TODO: # After GitHub issue #1034 is resolved, we should pass in the previous # θ(α) coordinates as an initial guess for the next coordinate mapping. @@ -292,14 +312,20 @@ def __init__( ) self._m = grid.num_theta self._n = grid.num_zeta - self._b_sup_z = jnp.expand_dims( - transform_to_desc(grid, jnp.abs(data["B^zeta"]) / data["|B|"] * Lref), - axis=1, - ) self._x, self._w = get_quadrature(quad, automorphism) # Compute global splines. - T, B = _transform_to_clebsch(grid, desc_from_clebsch, M, N, data["|B|"] / Bref) + self._b_sup_z = _transform_to_desc( + grid, + jnp.abs(data["B^zeta"]) / data["|B|"] * Lref, + )[:, jnp.newaxis] + T, B = _transform_to_clebsch( + grid, + desc_from_clebsch, + M, + N, + data["|B|"] / Bref, + ) # peel off field lines alphas = get_alpha( alpha_0, @@ -535,7 +561,7 @@ def _integrate(self, z1, z2, pitch, integrand, f): pitch=pitch[..., jnp.newaxis, jnp.newaxis], ) / irfft2_non_uniform( - Q, self._b_sup_z, self._m, self._n, axes=(-1, -2) + xq=Q, a=self._b_sup_z, M=self._n, N=self._m, axes=(-1, -2) ).reshape(shape), self._w, ) diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py index c7349b7ba..d500f1b33 100644 --- a/desc/integrals/bounce_utils.py +++ b/desc/integrals/bounce_utils.py @@ -278,8 +278,8 @@ def _check_bounce_points(z1, z2, pitch, knots, B, plot=True, **kwargs): assert not err_2[p, s], "Detected discontinuity.\n" assert not err_3, ( f"Detected |B| = {Bs_midpoint[mask[p, s]]} > {1 / pitch[p, s] + eps} " - "= 1/λ in well, implying the straight line path between bounce points " - "is in hypograph(|B|). Use more knots.\n" + "= 1/λ in well, implying the straight line path between " + "bounce points is in hypograph(|B|). Use more knots.\n" ) if plot: plot_ppoly( diff --git a/desc/integrals/interp_utils.py b/desc/integrals/interp_utils.py index 114d5faf0..dd4b5c646 100644 --- a/desc/integrals/interp_utils.py +++ b/desc/integrals/interp_utils.py @@ -259,7 +259,7 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)): M : int Spectral resolution of ``a`` along ``axes[0]``. N : int - Spectral resolution of ``a`` along ``axes[-1]``. + Spectral resolution of ``a`` along ``axes[1]``. axes : tuple[int, int] Axes along which to transform. @@ -295,31 +295,6 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)): return fq -def transform_to_desc(grid, f): - """Transform to DESC spectral domain. - - Parameters - ---------- - grid : Grid - Tensor-product grid in (θ, ζ) with uniformly spaced nodes in - (2π × 2π) poloidal and toroidal coordinates. - f : jnp.ndarray - Function evaluated on ``grid``. - - Returns - ------- - a : jnp.ndarray - Shape (grid.num_rho, grid.num_theta // 2 + 1, grid.num_zeta) - Complex coefficients of 2D real FFT. - - """ - f = grid.meshgrid_reshape(f, order="rtz") - # Real fft done over poloidal since num_theta > num_zeta usually. - a = rfft2(f, axes=(-1, -2), norm="forward") - assert a.shape == (grid.num_rho, grid.num_theta // 2 + 1, grid.num_zeta) - return a - - def cheb_from_dct(a, axis=-1): """Get discrete Chebyshev transform from discrete cosine transform. diff --git a/desc/integrals/quad_utils.py b/desc/integrals/quad_utils.py index b14f691b0..89ad99d82 100644 --- a/desc/integrals/quad_utils.py +++ b/desc/integrals/quad_utils.py @@ -138,7 +138,7 @@ def tanh_sinh(deg, m=10): return x, w -def leggauss_lobatto(deg): +def leggauss_lob(deg, interior_only=False): """Lobatto-Gauss-Legendre quadrature. Returns quadrature points xₖ and weights wₖ for the approximate evaluation of the @@ -147,26 +147,30 @@ def leggauss_lobatto(deg): Parameters ---------- deg : int - Number of (interior) quadrature points to return. + Number of quadrature points. + interior_only : bool + Whether to exclude the points and weights at -1 and 1; + useful if f(-1) = f(1) = 0. If ``True``, then ``deg`` points are still + returned; these are the interior points for lobatto quadrature of ``deg+2``. Returns ------- x, w : (jnp.ndarray, jnp.ndarray) - Quadrature points in (-1, 1) and associated weights. - Excludes points and weights at -1 and 1. + Shape (deg, ). + Quadrature points and weights. """ - # Designate two degrees for endpoints. - deg = int(deg) + 2 + N = deg + 2 * bool(interior_only) + errorif(N < 2) - # Golub-Welsh algorithm for eigenvalues of orthogonal polynomials - n = jnp.arange(2, deg - 1) + # Golub-Welsh algorithm + n = jnp.arange(2, N - 1) x = eigh_tridiagonal( - jnp.zeros(deg - 2), + jnp.zeros(N - 2), jnp.sqrt((n**2 - 1) / (4 * n**2 - 1)), eigvals_only=True, ) - c0 = put(jnp.zeros(deg), -1, 1) + c0 = put(jnp.zeros(N), -1, 1) # improve (single multiplicity) roots by one application of Newton c = legder(c0) @@ -174,7 +178,14 @@ def leggauss_lobatto(deg): df = legval(x=x, c=legder(c)) x -= dy / df - w = 2 / (deg * (deg - 1) * legval(x=x, c=c0) ** 2) + w = 2 / (N * (N - 1) * legval(x=x, c=c0) ** 2) + + if not interior_only: + x = jnp.hstack([-1.0, x, 1.0]) + w_end = 2 / (deg * (deg - 1)) + w = jnp.hstack([w_end, w, w_end]) + + assert x.size == w.size == deg return x, w diff --git a/tests/test_fourier_bounce.py b/tests/test_fourier_bounce.py deleted file mode 100644 index 714afd453..000000000 --- a/tests/test_fourier_bounce.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Test interpolation to Clebsch coordinates and Fourier bounce integration.""" - -import numpy as np -import pytest -from matplotlib import pyplot as plt -from numpy.polynomial.chebyshev import chebinterpolate, chebroots -from tests.test_integrals import TestBounce1D -from tests.test_plotting import tol_1d - -from desc.backend import jnp -from desc.equilibrium import Equilibrium -from desc.equilibrium.coords import get_rtz_grid, map_coordinates -from desc.examples import get -from desc.grid import LinearGrid -from desc.integrals import Bounce2D -from desc.integrals.basis import FourierChebyshevBasis -from desc.integrals.bounce_utils import get_alpha, get_pitch -from desc.integrals.interp_utils import fourier_pts - - -@pytest.mark.unit -@pytest.mark.parametrize( - "alpha_0, iota, num_period, period", - [(0, np.sqrt(2), 1, 2 * np.pi), (0, np.arange(1, 3) * np.sqrt(2), 5, 2 * np.pi)], -) -def test_alpha_sequence(alpha_0, iota, num_period, period): - """Test field line poloidal label tracking.""" - iota = np.atleast_1d(iota) - alphas = get_alpha(alpha_0, iota, num_period, period) - assert alphas.shape == (iota.size, num_period) - for i in range(iota.size): - assert np.unique(alphas[i]).size == num_period, f"{iota} is irrational" - print(alphas) - - -class TestBouncePoints: - """Test that bounce points are computed correctly.""" - - @staticmethod - def _cheb_intersect(cheb, k): - cheb = cheb.copy() - cheb[0] = cheb[0] - k - roots = chebroots(cheb) - intersect = roots[ - np.logical_and(np.isreal(roots), np.abs(roots.real) <= 1) - ].real - return intersect - - @staticmethod - def _periodic_fun(nodes, M, N): - alpha, zeta = nodes.T - f = -2 * np.cos(1 / (0.1 + zeta**2)) + 2 - return f.reshape(M, N) - - @pytest.mark.unit - def test_bp1_first(self): - """Test that bounce points are computed correctly.""" - M, N = 1, 10 - domain = (-1, 1) - nodes = FourierChebyshevBasis.nodes(M, N, domain=domain) - f = self._periodic_fun(nodes, M, N) - fcb = FourierChebyshevBasis(f, domain=domain) - pcb = fcb.compute_cheb(fourier_pts(M)) - pitch = 1 / np.linspace(1, 4, 20) - bp1, bp2 = pcb.intersect1d(pitch) - pcb.check_intersect1d(bp1, bp2, pitch) - bp1, bp2 = TestBouncePoints.filter(bp1, bp2) - - def f(z): - return -2 * np.cos(1 / (0.1 + z**2)) + 2 - - r = self._cheb_intersect(chebinterpolate(f, N), 1 / pitch) - np.testing.assert_allclose(bp1, r[::2], rtol=1e-3) - np.testing.assert_allclose(bp2, r[1::2], rtol=1e-3) - - -@pytest.mark.unit -def test_fourier_chebyshev(rho=1, M=8, N=32, f=lambda B, pitch: B * pitch): - """Test bounce points...""" - eq = get("W7-X") - clebsch = FourierChebyshevBasis.nodes(M, N, L=rho) - desc_from_clebsch = map_coordinates( - eq, - clebsch, - inbasis=("rho", "alpha", "zeta"), - period=(np.inf, 2 * np.pi, np.inf), - ) - grid = LinearGrid( - rho=rho, M=eq.M_grid, N=eq.N_grid, sym=False, NFP=eq.NFP - ) # check if NFP!=1 works - data = eq.compute( - names=Bounce2D.required_names() + ["min_tz |B|", "max_tz |B|"], grid=grid - ) - fb = Bounce2D( - grid, data, M, N, desc_from_clebsch, check=True, warn=False - ) # TODO check true - pitch = get_pitch( - grid.compress(data["min_tz |B|"]), grid.compress(data["max_tz |B|"]), 10 - ) - result = fb.integrate(f, [], pitch) # noqa: F841 - - -@pytest.mark.unit -@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d) -def test_drift(): - """Test bounce-averaged drift with analytical expressions.""" - eq = Equilibrium.load(".//tests//inputs//low-beta-shifted-circle.h5") - psi_boundary = eq.Psi / (2 * np.pi) - psi = 0.25 * psi_boundary - rho = np.sqrt(psi / psi_boundary) - np.testing.assert_allclose(rho, 0.5) - - # Make a set of nodes along a single fieldline. - grid_fsa = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) - data = eq.compute(["iota"], grid=grid_fsa) - iota = grid_fsa.compress(data["iota"]).item() - alpha = 0 - zeta = np.linspace(-np.pi / iota, np.pi / iota, (2 * eq.M_grid) * 4 + 1) - grid = get_rtz_grid( - eq, - rho, - alpha, - zeta, - coordinates="raz", - period=(np.inf, 2 * np.pi, np.inf), - iota=np.array([iota]), - ) - data = eq.compute( - Bounce2D.required_names() - + [ - "cvdrift", - "gbdrift", - "grad(psi)", - "grad(alpha)", - "shear", - "iota", - "psi", - "a", - ], - grid=grid, - ) - np.testing.assert_allclose(data["psi"], psi) - np.testing.assert_allclose(data["iota"], iota) - assert np.all(data["B^zeta"] > 0) - data["Bref"] = 2 * np.abs(psi_boundary) / data["a"] ** 2 - data["rho"] = rho - data["alpha"] = alpha - data["zeta"] = zeta - data["psi"] = grid.compress(data["psi"]) - data["iota"] = grid.compress(data["iota"]) - data["shear"] = grid.compress(data["shear"]) - - # Compute analytic approximation. - drift_analytic, cvdrift, gbdrift, pitch = TestBounce1D.drift_analytic(data) - # Compute numerical result. - grid = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP) - data_2 = eq.compute( - names=Bounce2D.required_names() + ["cvdrift", "gbdrift"], grid=grid - ) - M, N = eq.M_grid, 20 - bounce = Bounce2D( - grid=grid, - data=data_2, - desc_from_clebsch=Bounce2D.desc_from_clebsch(eq, rho, M, N), - M=M, - N=N, - alpha_0=data["alpha"], - num_transit=1, - Bref=data["Bref"], - Lref=data["a"], - check=True, - plot=True, - ) - - def integrand_num(cvdrift, gbdrift, B, pitch): - g = jnp.sqrt(1 - pitch * B) - return (cvdrift * g) - (0.5 * g * gbdrift) + (0.5 * gbdrift / g) - - def integrand_den(B, pitch): - return 1 / jnp.sqrt(1 - pitch * B) - - normalization = -np.sign(data["psi"]) * data["Bref"] * data["a"] ** 2 - drift_numerical_num = bounce.integrate( - pitch=pitch[:, np.newaxis], - integrand=integrand_num, - f=Bounce2D.reshape_data( - grid, data_2["cvdrift"] * normalization, data_2["gbdrift"] * normalization - ), - num_well=1, - ) - drift_numerical_den = bounce.integrate( - pitch=pitch[:, np.newaxis], - integrand=integrand_den, - f=[], - num_well=1, - ) - drift_numerical = np.squeeze(drift_numerical_num / drift_numerical_den) - msg = "There should be one bounce integral per pitch in this example." - assert drift_numerical.size == drift_analytic.size, msg - np.testing.assert_allclose(drift_numerical, drift_analytic, atol=5e-3, rtol=5e-2) - - fig, ax = plt.subplots() - ax.plot(1 / pitch, drift_analytic) - ax.plot(1 / pitch, drift_numerical) - plt.show() - return fig diff --git a/tests/test_integrals.py b/tests/test_integrals.py index 037984432..e5d600f91 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -6,7 +6,7 @@ import pytest from jax import grad from matplotlib import pyplot as plt -from numpy.polynomial.chebyshev import chebgauss, chebweight +from numpy.polynomial.chebyshev import chebgauss, chebinterpolate, chebroots, chebweight from numpy.polynomial.legendre import leggauss from scipy import integrate from scipy.interpolate import CubicHermiteSpline @@ -17,11 +17,12 @@ from desc.basis import FourierZernikeBasis from desc.compute.utils import dot from desc.equilibrium import Equilibrium -from desc.equilibrium.coords import get_rtz_grid +from desc.equilibrium.coords import get_rtz_grid, map_coordinates from desc.examples import get from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid from desc.integrals import ( Bounce1D, + Bounce2D, DFTInterpolator, FFTInterpolator, line_integrals, @@ -34,20 +35,23 @@ surface_variance, virtual_casing_biot_savart, ) +from desc.integrals.basis import FourierChebyshevBasis from desc.integrals.bounce_utils import ( _get_extrema, bounce_points, + get_alpha, get_pitch, interp_to_argmin_B_hard, interp_to_argmin_B_soft, plot_ppoly, ) +from desc.integrals.interp_utils import fourier_pts from desc.integrals.quad_utils import ( automorphism_sin, bijection_from_disc, grad_automorphism_sin, grad_bijection_from_disc, - leggauss_lobatto, + leggauss_lob, tanh_sinh, ) from desc.integrals.singularities import _get_quadrature_nodes @@ -720,7 +724,7 @@ def test_biest_interpolators(self): np.testing.assert_allclose(g1, ff) -class TestBouncePoints: +class TestBounce1DPoints: """Test that bounce points are computed correctly.""" @staticmethod @@ -739,7 +743,7 @@ def test_z1_first(self): pitch = 2.0 intersect = B.solve(1 / pitch, extrapolate=False) z1, z2 = bounce_points(pitch, knots, B.c, B.derivative().c, check=True) - z1, z2 = TestBouncePoints.filter(z1, z2) + z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size np.testing.assert_allclose(z1, intersect[0::2]) np.testing.assert_allclose(z2, intersect[1::2]) @@ -754,7 +758,7 @@ def test_z2_first(self): pitch = 2.0 intersect = B.solve(1 / pitch, extrapolate=False) z1, z2 = bounce_points(pitch, k, B.c, B.derivative().c, check=True) - z1, z2 = TestBouncePoints.filter(z1, z2) + z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size np.testing.assert_allclose(z1, intersect[1:-1:2]) np.testing.assert_allclose(z2, intersect[0::2][1:]) @@ -771,7 +775,7 @@ def test_z1_before_extrema(self): dB_dz = B.derivative() pitch = 1 / B(dB_dz.roots(extrapolate=False))[3] + 1e-13 z1, z2 = bounce_points(pitch, k, B.c, dB_dz.c, check=True) - z1, z2 = TestBouncePoints.filter(z1, z2) + z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size intersect = B.solve(1 / pitch, extrapolate=False) np.testing.assert_allclose(z1[1], 1.982767, rtol=1e-6) @@ -794,7 +798,7 @@ def test_z2_before_extrema(self): dB_dz = B.derivative() pitch = 1 / B(dB_dz.roots(extrapolate=False))[2] z1, z2 = bounce_points(pitch, k, B.c, dB_dz.c, check=True) - z1, z2 = TestBouncePoints.filter(z1, z2) + z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size intersect = B.solve(1 / pitch, extrapolate=False) np.testing.assert_allclose(z1, intersect[[0, -2]]) @@ -817,7 +821,7 @@ def test_extrema_first_and_before_z1(self): pitch, k[2:], B.c[:, 2:], dB_dz.c[:, 2:], check=True, plot=False ) plot_ppoly(B, z1=z1, z2=z2, k=1 / pitch, start=k[2]) - z1, z2 = TestBouncePoints.filter(z1, z2) + z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size intersect = B.solve(1 / pitch, extrapolate=False) np.testing.assert_allclose(z1[0], 0.835319, rtol=1e-6) @@ -839,7 +843,7 @@ def test_extrema_first_and_before_z2(self): dB_dz = B.derivative() pitch = 1 / B(dB_dz.roots(extrapolate=False))[1] + 1e-13 z1, z2 = bounce_points(pitch, k, B.c, dB_dz.c, check=True) - z1, z2 = TestBouncePoints.filter(z1, z2) + z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size # Our routine correctly detects intersection, while scipy, jnp.root fails. intersect = B.solve(1 / pitch, extrapolate=False) @@ -871,20 +875,20 @@ def test_get_extrema(self): np.testing.assert_allclose(B_ext[idx], B_ext_scipy) -class TestBounceQuadrature: - """Test bounce quadrature accuracy.""" +def _mod_cheb_gauss(deg): + x, w = chebgauss(deg) + w /= chebweight(x) + return x, w - @staticmethod - def _mod_cheb_gauss(deg): - x, w = chebgauss(deg) - w /= chebweight(x) - return x, w - @staticmethod - def _mod_chebu_gauss(deg): - x, w = roots_chebyu(deg) - w *= chebweight(x) - return x, w +def _mod_chebu_gauss(deg): + x, w = roots_chebyu(deg) + w *= chebweight(x) + return x, w + + +class TestBounce1DQuadrature: + """Test bounce quadrature accuracy.""" @pytest.mark.unit @pytest.mark.parametrize( @@ -893,7 +897,7 @@ def _mod_chebu_gauss(deg): (True, tanh_sinh(40), None), (True, leggauss(25), "default"), (False, tanh_sinh(20), None), - (False, leggauss_lobatto(10), "default"), + (False, leggauss_lob(10), "default"), # sin automorphism still helps out chebyshev quadrature (True, _mod_cheb_gauss(30), "default"), (False, _mod_chebu_gauss(10), "default"), @@ -930,7 +934,7 @@ def test_bounce_quadrature(self, is_strong, quad, automorphism): data, quad, check=True, - **kwargs + **kwargs, ) result = bounce.integrate(pitch, integrand, [], check=True) assert np.count_nonzero(result) == 1 @@ -964,14 +968,14 @@ def elliptic_incomplete(k2): # Scipy's elliptic integrals are broken. # https://github.com/scipy/scipy/issues/20525. k = np.sqrt(k2) - K = TestBounceQuadrature._adaptive_elliptic(K_integrand, k) - E = TestBounceQuadrature._adaptive_elliptic(E_integrand, k) + K = TestBounce1DQuadrature._adaptive_elliptic(K_integrand, k) + E = TestBounce1DQuadrature._adaptive_elliptic(E_integrand, k) # Make sure scipy's adaptive quadrature is not broken. np.testing.assert_allclose( - K, TestBounceQuadrature._fixed_elliptic(K_integrand, k, 10) + K, TestBounce1DQuadrature._fixed_elliptic(K_integrand, k, 10) ) np.testing.assert_allclose( - E, TestBounceQuadrature._fixed_elliptic(E_integrand, k, 10) + E, TestBounce1DQuadrature._fixed_elliptic(E_integrand, k, 10) ) I_0 = 4 / k * K @@ -985,32 +989,32 @@ def elliptic_incomplete(k2): # Check for math mistakes. np.testing.assert_allclose( I_2, - TestBounceQuadrature._adaptive_elliptic( + TestBounce1DQuadrature._adaptive_elliptic( lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * Z * np.sin(Z), k ), ) np.testing.assert_allclose( I_3, - TestBounceQuadrature._adaptive_elliptic( + TestBounce1DQuadrature._adaptive_elliptic( lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) * Z * np.sin(Z), k ), ) np.testing.assert_allclose( I_4, - TestBounceQuadrature._adaptive_elliptic( + TestBounce1DQuadrature._adaptive_elliptic( lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.sin(Z) ** 2, k ), ) np.testing.assert_allclose( I_5, - TestBounceQuadrature._adaptive_elliptic( + TestBounce1DQuadrature._adaptive_elliptic( lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.sin(Z) ** 2, k ), ) # scipy fails np.testing.assert_allclose( I_6, - TestBounceQuadrature._fixed_elliptic( + TestBounce1DQuadrature._fixed_elliptic( lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.cos(Z), k, deg=10, @@ -1018,7 +1022,7 @@ def elliptic_incomplete(k2): ) np.testing.assert_allclose( I_7, - TestBounceQuadrature._adaptive_elliptic( + TestBounce1DQuadrature._adaptive_elliptic( lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.cos(Z), k ), ) @@ -1026,7 +1030,7 @@ def elliptic_incomplete(k2): class TestBounce1D: - """Test bounce integral methods that use one-dimensional local splines.""" + """Test bounce integration with one-dimensional local spline methods.""" @pytest.mark.unit def test_integrate_checks(self): @@ -1136,7 +1140,19 @@ def dB_dz(z): @staticmethod def drift_analytic(data): - """Compute analytic approximation for bounce-averaged binormal drift.""" + """Compute analytic approximation for bounce-averaged binormal drift. + + Returns + ------- + drift_analytic : jnp.ndarray + Analytic approximation for the true result that the numerical computation + should attempt to match. + cvdrift, gbdrift : jnp.ndarray + Numerically computed ``data["cvdrift"]` and ``data["gbdrift"]`` normalized + by some scale factors for this unit test. These should be fed to the bounce + integration as input. + + """ B = data["|B|"] / data["Bref"] B0 = np.mean(B) # epsilon should be changed to dimensionless, and computed in a way that @@ -1201,7 +1217,7 @@ def drift_analytic(data): pitch = get_pitch(np.min(B), np.max(B), 100)[1:] k2 = 0.5 * ((1 - pitch * B0) / (epsilon * pitch * B0) + 1) I_0, I_1, I_2, I_3, I_4, I_5, I_6, I_7 = ( - TestBounceQuadrature.elliptic_incomplete(k2) + TestBounce1DQuadrature.elliptic_incomplete(k2) ) y = np.sqrt(2 * epsilon * pitch * B0) I_0, I_2, I_4, I_6 = map(lambda I: I / y, (I_0, I_2, I_4, I_6)) @@ -1348,3 +1364,199 @@ def integrand_grad(*args, **kwargs2): # Make sure bounce points get differentiated too. result = fun2(pitch) assert np.isfinite(result) and not np.isclose(result, truth, rtol=1e-1) + + +class TestBounce2DPoints: + """Test that bounce points are computed correctly.""" + + @staticmethod + def _cheb_intersect(cheb, k): + cheb = cheb.copy() + cheb[0] = cheb[0] - k + roots = chebroots(cheb) + intersect = roots[ + np.logical_and(np.isreal(roots), np.abs(roots.real) <= 1) + ].real + return intersect + + @staticmethod + def _periodic_fun(nodes, M, N): + alpha, zeta = nodes.T + f = -2 * np.cos(1 / (0.1 + zeta**2)) + 2 + return f.reshape(M, N) + + @pytest.mark.unit + def test_bp1_first(self): + """Test that bounce points are computed correctly.""" + M, N = 1, 10 + domain = (-1, 1) + nodes = FourierChebyshevBasis.nodes(M, N, domain=domain) + f = self._periodic_fun(nodes, M, N) + fcb = FourierChebyshevBasis(f, domain=domain) + pcb = fcb.compute_cheb(fourier_pts(M)) + pitch = 1 / np.linspace(1, 4, 20) + bp1, bp2 = pcb.intersect1d(pitch) + pcb.check_intersect1d(bp1, bp2, pitch) + bp1, bp2 = TestBounce1DPoints.filter(bp1, bp2) + + def f(z): + return -2 * np.cos(1 / (0.1 + z**2)) + 2 + + r = self._cheb_intersect(chebinterpolate(f, N), 1 / pitch) + np.testing.assert_allclose(bp1, r[::2], rtol=1e-3) + np.testing.assert_allclose(bp2, r[1::2], rtol=1e-3) + + +class TestBounce2D: + """Test bounce integration with two-dimensional pseudo-spectral methods.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "alpha_0, iota, num_period, period", + [ + (0, np.sqrt(2), 1, 2 * np.pi), + (0, np.arange(1, 3) * np.sqrt(2), 5, 2 * np.pi), + ], + ) + def test_alpha_sequence(self, alpha_0, iota, num_period, period): + """Test field line poloidal label tracking.""" + iota = np.atleast_1d(iota) + alphas = get_alpha(alpha_0, iota, num_period, period) + assert alphas.shape == (iota.size, num_period) + for i in range(iota.size): + assert np.unique(alphas[i]).size == num_period, f"{iota} is irrational" + print(alphas) + + @pytest.mark.unit + def test_fourier_chebyshev(self, rho=1, M=8, N=32, f=lambda B, pitch: B * pitch): + """Test bounce points...""" + eq = get("W7-X") + clebsch = FourierChebyshevBasis.nodes(M, N, L=rho) + desc_from_clebsch = map_coordinates( + eq, + clebsch, + inbasis=("rho", "alpha", "zeta"), + period=(np.inf, 2 * np.pi, np.inf), + ) + grid = LinearGrid( + rho=rho, M=eq.M_grid, N=eq.N_grid, sym=False, NFP=eq.NFP + ) # check if NFP!=1 works + data = eq.compute( + names=Bounce2D.required_names() + ["min_tz |B|", "max_tz |B|"], grid=grid + ) + fb = Bounce2D( + grid, data, M, N, desc_from_clebsch, check=True, warn=False + ) # TODO check true + pitch = get_pitch( + grid.compress(data["min_tz |B|"]), grid.compress(data["max_tz |B|"]), 10 + ) + result = fb.integrate(f, [], pitch) # noqa: F841 + + @pytest.mark.unit + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d) + def test_drift(self): + """Test bounce-averaged drift with analytical expressions.""" + eq = Equilibrium.load(".//tests//inputs//low-beta-shifted-circle.h5") + psi_boundary = eq.Psi / (2 * np.pi) + psi = 0.25 * psi_boundary + rho = np.sqrt(psi / psi_boundary) + np.testing.assert_allclose(rho, 0.5) + + # Make a set of nodes along a single fieldline. + grid_fsa = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) + data = eq.compute(["iota"], grid=grid_fsa) + iota = grid_fsa.compress(data["iota"]).item() + alpha = 0 + zeta = np.linspace(-np.pi / iota, np.pi / iota, (2 * eq.M_grid) * 4 + 1) + grid = get_rtz_grid( + eq, + rho, + alpha, + zeta, + coordinates="raz", + period=(np.inf, 2 * np.pi, np.inf), + iota=np.array([iota]), + ) + data = eq.compute( + Bounce2D.required_names() + + [ + "cvdrift", + "gbdrift", + "grad(psi)", + "grad(alpha)", + "shear", + "iota", + "psi", + "a", + ], + grid=grid, + ) + np.testing.assert_allclose(data["psi"], psi) + np.testing.assert_allclose(data["iota"], iota) + assert np.all(data["B^zeta"] > 0) + data["Bref"] = 2 * np.abs(psi_boundary) / data["a"] ** 2 + data["rho"] = rho + data["alpha"] = alpha + data["zeta"] = zeta + data["psi"] = grid.compress(data["psi"]) + data["iota"] = grid.compress(data["iota"]) + data["shear"] = grid.compress(data["shear"]) + + # Compute analytic approximation. + drift_analytic, cvdrift, gbdrift, pitch = TestBounce1D.drift_analytic(data) + # Compute numerical result. + grid = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP) + data_2 = eq.compute( + names=Bounce2D.required_names() + ["cvdrift", "gbdrift"], grid=grid + ) + M, N = eq.M_grid, 20 + bounce = Bounce2D( + grid=grid, + data=data_2, + desc_from_clebsch=Bounce2D.desc_from_clebsch(eq, rho, M, N), + M=M, + N=N, + alpha_0=data["alpha"], + num_transit=1, + Bref=data["Bref"], + Lref=data["a"], + check=True, + plot=True, + ) + + def integrand_num(cvdrift, gbdrift, B, pitch): + g = jnp.sqrt(1 - pitch * B) + return (cvdrift * g) - (0.5 * g * gbdrift) + (0.5 * gbdrift / g) + + def integrand_den(B, pitch): + return 1 / jnp.sqrt(1 - pitch * B) + + normalization = -np.sign(data["psi"]) * data["Bref"] * data["a"] ** 2 + drift_numerical_num = bounce.integrate( + pitch=pitch[:, np.newaxis], + integrand=integrand_num, + f=Bounce2D.reshape_data( + grid, + data_2["cvdrift"] * normalization, + data_2["gbdrift"] * normalization, + ), + num_well=1, + ) + drift_numerical_den = bounce.integrate( + pitch=pitch[:, np.newaxis], + integrand=integrand_den, + f=[], + num_well=1, + ) + drift_numerical = np.squeeze(drift_numerical_num / drift_numerical_den) + msg = "There should be one bounce integral per pitch in this example." + assert drift_numerical.size == drift_analytic.size, msg + np.testing.assert_allclose( + drift_numerical, drift_analytic, atol=5e-3, rtol=5e-2 + ) + + fig, ax = plt.subplots() + ax.plot(1 / pitch, drift_analytic) + ax.plot(1 / pitch, drift_numerical) + plt.show() + return fig diff --git a/tests/test_interp_utils.py b/tests/test_interp_utils.py index 78b25f599..f2225c033 100644 --- a/tests/test_interp_utils.py +++ b/tests/test_interp_utils.py @@ -13,7 +13,7 @@ from scipy.fft import dct as sdct from scipy.fft import idct as sidct -from desc.backend import dct, idct, jnp, rfft +from desc.backend import dct, idct, rfft from desc.integrals.interp_utils import ( cheb_from_dct, cheb_pts, @@ -29,13 +29,6 @@ from desc.integrals.quad_utils import bijection_to_disc -def filter_not_nan(a): - """Filter out nan from ``a`` while asserting nan is padded at right.""" - is_nan = jnp.isnan(a) - assert jnp.array_equal(is_nan, jnp.sort(is_nan, axis=-1)) - return a[~is_nan] - - @pytest.mark.unit def test_poly_root(): """Test vectorized computation of cubic polynomial exact roots.""" @@ -70,15 +63,12 @@ def test_poly_root(): root = poly_root(c.T, sort=True, distinct=True) for j in range(c.shape[0]): unique_roots = np.unique(np.roots(c[j])) - root_filter = filter_not_nan(root[j]) - assert root_filter.size == unique_roots.size, j np.testing.assert_allclose( - actual=root_filter, - desired=unique_roots, - err_msg=str(j), + actual=root[j][~np.isnan(root[j])], desired=unique_roots, err_msg=str(j) ) c = np.array([0, 1, -1, -8, 12]) - root = filter_not_nan(poly_root(c, sort=True, distinct=True)) + root = poly_root(c, sort=True, distinct=True) + root = root[~np.isnan(root)] unique_root = np.unique(np.roots(c)) assert root.size == unique_root.size np.testing.assert_allclose(root, unique_root) @@ -102,11 +92,11 @@ def test_polyval_vec(): def test(x, c): val = polyval_vec(x=x, c=c) + c = np.moveaxis(c, 0, -1) + x = x[..., np.newaxis] np.testing.assert_allclose( val, - np.vectorize(np.polyval, signature="(m),(n)->(n)")( - np.moveaxis(c, 0, -1), x[..., np.newaxis] - ).squeeze(axis=-1), + np.vectorize(np.polyval, signature="(m),(n)->(n)")(c, x).squeeze(axis=-1), ) quartic = 5 @@ -125,6 +115,47 @@ def test(x, c): test(x, c) +def _f_1d(x): + """Test function for 1D FFT.""" + return np.cos(7 * x) + np.sin(x) - 33.2 + + +def _f_1d_nyquist_freq(): + return 7 + + +def _f_2d(x, y): + """Test function for 2D FFT.""" + x_freq, y_freq = 3, 5 + return ( + # something that's not separable + np.cos(x_freq * x) * np.sin(2 * x + y) + + np.sin(y_freq * y) * np.cos(x + 3 * y) + # DC terms + - 33.2 + + np.cos(x) + + np.cos(y) + ) + + +def _f_2d_nyquist_freq(): + x_freq_nyquist = 3 + 2 + y_freq_nyquist = 5 + 3 + return x_freq_nyquist, y_freq_nyquist + + +def _identity(x): + return x + + +def _f_non_periodic(z): + return np.sin(np.sqrt(2) * z) * np.cos(1 / (2 + z)) * np.cos(z**2) * z + + +def _f_algebraic(z): + return z**3 - 10 * z**6 - z - np.e + z**4 + + class TestFastInterp: """Test fast interpolation.""" @@ -145,23 +176,6 @@ def test_rfftfreq(self, M): """Make sure numpy uses Nyquist interpolant frequencies.""" np.testing.assert_allclose(np.fft.rfftfreq(M, d=1 / M), np.arange(M // 2 + 1)) - @staticmethod - def _interp_rfft_harmonic(xq, f): - M = f.shape[-1] - fq = jnp.linalg.vecdot( - harmonic_vander(xq, M), harmonic(rfft(f, norm="forward"), M) - ) - return fq - - @staticmethod - def _f_1d(x): - """Test function for 1D FFT.""" - return np.cos(7 * x) + np.sin(x) - 33.2 - - @staticmethod - def _f_1d_nyquist_freq(): - return 7 - @pytest.mark.unit @pytest.mark.parametrize( "func, n", @@ -176,29 +190,15 @@ def test_interp_rfft(self, func, n): x = np.linspace(0, 2 * np.pi, n, endpoint=False) assert not np.any(np.isclose(xq[..., np.newaxis], x)) f, fq = func(x), func(xq) - np.testing.assert_allclose(self._interp_rfft_harmonic(xq, f), fq) np.testing.assert_allclose(interp_rfft(xq, f), fq) - - @staticmethod - def _f_2d(x, y): - """Test function for 2D FFT.""" - x_freq, y_freq = 3, 5 - return ( - # something that's not separable - np.cos(x_freq * x) * np.sin(2 * x + y) - + np.sin(y_freq * y) * np.cos(x + 3 * y) - # DC terms - - 33.2 - + np.cos(x) - + np.cos(y) + M = f.shape[-1] + np.testing.assert_allclose( + np.sum( + harmonic_vander(xq, M) * harmonic(rfft(f, norm="forward"), M), axis=-1 + ), + fq, ) - @staticmethod - def _f_2d_nyquist_freq(): - x_freq_nyquist = 3 + 2 - y_freq_nyquist = 5 + 3 - return x_freq_nyquist, y_freq_nyquist - @pytest.mark.xfail( reason="Numpy, jax, and scipy need to fix bug with 2D FFT (fft2)." ) @@ -228,25 +228,12 @@ def test_interp_rfft2(self, func, m, n): truth, ) - @staticmethod - def _identity(x): - # Identity map known for bad Gibbs; - # only if distribution of spectral coefficients is correct will DCT - # recover Chebyshev interpolation, avoiding Gibbs and Runge. - return x - - @staticmethod - def _f_non_periodic(z): - return np.sin(np.sqrt(2) * z) * np.cos(1 / (2 + z)) * np.cos(z**2) * z - - @staticmethod - def _f_algebraic(z): - return z**3 - 10 * z**6 - z - np.e + z**4 - @pytest.mark.unit @pytest.mark.parametrize( "f, M, lobatto", [ + # Identity map known for bad Gibbs; if discrete Chebyshev transform + # implemented correctly then won't see Gibbs. (_identity, 2, False), (_identity, 3, False), (_identity, 3, True), @@ -316,7 +303,7 @@ def test_interp_dct(self, f, M): z = cheb_pts(M) fz = f(z) np.testing.assert_allclose(c0, cheb_from_dct(dct(fz, 2) / M), atol=1e-13) - if np.allclose(self._f_algebraic(z), fz): + if np.allclose(_f_algebraic(z), fz): np.testing.assert_allclose( cheb2poly(c0), np.array([-np.e, -1, 0, 1, 1, 0, -10]), atol=1e-13 ) diff --git a/tests/test_quad_utils.py b/tests/test_quad_utils.py index 662e9fcef..a23b81c8d 100644 --- a/tests/test_quad_utils.py +++ b/tests/test_quad_utils.py @@ -12,6 +12,7 @@ grad_automorphism_arcsin, grad_automorphism_sin, grad_bijection_from_disc, + leggauss_lob, tanh_sinh, ) from desc.utils import only1 @@ -66,4 +67,27 @@ def test_automorphism(): @pytest.mark.unit def test_leggauss_lobatto(): - """Test that quadrature points and weights are correct.""" + """Test quadrature points and weights against known values.""" + with pytest.raises(ValueError): + x, w = leggauss_lob(1) + x, w = leggauss_lob(0, True) + assert x.size == w.size == 0 + + x, w = leggauss_lob(2) + np.testing.assert_allclose(x, [-1, 1]) + np.testing.assert_allclose(w, [1, 1]) + + x, w = leggauss_lob(3) + np.testing.assert_allclose(x, [-1, 0, 1]) + np.testing.assert_allclose(w, [1 / 3, 4 / 3, 1 / 3]) + np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + x, w = leggauss_lob(4) + np.testing.assert_allclose(x, [-1, -np.sqrt(1 / 5), np.sqrt(1 / 5), 1]) + np.testing.assert_allclose(w, [1 / 6, 5 / 6, 5 / 6, 1 / 6]) + np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + x, w = leggauss_lob(5) + np.testing.assert_allclose(x, [-1, -np.sqrt(3 / 7), 0, np.sqrt(3 / 7), 1]) + np.testing.assert_allclose(w, [1 / 10, 49 / 90, 32 / 45, 49 / 90, 1 / 10]) + np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1]))