Skip to content

Commit

Permalink
Merging changes from downstream branch #1290
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Oct 20, 2024
1 parent ed7833f commit 927ab54
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 26 deletions.
61 changes: 39 additions & 22 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def check_points(self, points, pitch_inv, *, plot=True, **kwargs):
"""Check that bounce points are computed correctly."""

@abstractmethod
def integrate(self, integrand, pitch_inv, f=None, weight=None, points=None):
def integrate(
self, integrand, pitch_inv, f=None, weight=None, points=None, *, quad=None
):
"""Bounce integrate ∫ f(λ, ℓ) dℓ."""


Expand Down Expand Up @@ -82,8 +84,9 @@ class Bounce2D(Bounce):
the particle's guiding center trajectory traveling in the direction of increasing
field-line-following coordinate ζ.
Brief description of algorithm
------------------------------
Overview
--------
Magnetic field line with label α, defined by B = ∇ρ × ∇α, is determined from
α : ρ, θ, ζ ↦ θ + λ(ρ,θ,ζ) − ι(ρ) [ζ + ω(ρ,θ,ζ)]
Interpolate Fourier-Chebyshev series to DESC poloidal coordinate.
Expand All @@ -102,8 +105,8 @@ class Bounce2D(Bounce):
In that case, supply the single valued parts, which will be interpolated
with FFTs, and use the provided coordinates θ,ζ ∈ ℝ to compose G.
Notes for developers
--------------------
Notes
-----
For applications which reduce to computing a nonlinear function of distance
along field lines between bounce points, it is required to identify these
points with field-line-following coordinates. (In the special case of a linear
Expand Down Expand Up @@ -234,8 +237,8 @@ class Bounce2D(Bounce):
* 2D interpolation enables tracing the field line for many toroidal transits.
* The drawback is that evaluating a Fourier series with resolution F at Q
non-uniform quadrature points takes 𝒪([F+Q] log[F] log[1/ε]) time
whereas cubic splines take 𝒪(C Q) time. Still, F decreases as
NFP increases whereas C increases, and Q >> F and C.
whereas cubic splines take 𝒪(C Q) time. However, as NFP increases,
F decreases whereas C increases. Also, Q >> F and Q >> C.
Attributes
----------
Expand Down Expand Up @@ -288,7 +291,7 @@ def __init__(
theta : jnp.ndarray
Shape (num rho, X, Y).
DESC coordinates θ sourced from the Clebsch coordinates
``FourierChebyshevSeries.nodes(M,N,L,domain=(0,2*jnp.pi))``.
``FourierChebyshevSeries.nodes(X,Y,rho,domain=(0,2*jnp.pi))``.
Y_B : int
Desired Chebyshev spectral resolution for |B|.
Default is to double the resolution of ``theta``.
Expand Down Expand Up @@ -325,8 +328,11 @@ def __init__(
Flag for debugging. Must be false for JAX transformations.
spline : bool
Whether to use cubic splines to compute bounce points.
Default is true. This is useful since the efficient root-finding
on Chebyshev series algorithm is not yet implemented.
Default is true, because the algorithm for efficient root-finding on
Chebyshev series algorithm is not yet implemented.
When using splines, it is recommended to reduce the ``num_well``
parameter in the ``points`` method from ``3*Y_B*num_transit`` to
``Y_B*num_transit``.
"""
errorif(grid.sym, NotImplementedError, msg="Need grid that works with FFTs.")
Expand Down Expand Up @@ -413,9 +419,8 @@ def compute_theta(eq, X=16, Y=32, rho=1.0, clebsch=None, **kwargs):
rho : float or jnp.ndarray
Flux surfaces labels in [0, 1] on which to compute.
clebsch : jnp.ndarray
Optional, Clebsch coordinate tensor-product grid (ρ, α, ζ).
``FourierChebyshevSeries.nodes(M,N,L,domain=(0,2*jnp.pi))``.
If given, ``rho`` is ignored.
Optional, precomputed Clebsch coordinate tensor-product grid (ρ, α, ζ).
``FourierChebyshevSeries.nodes(X,Y,rho,domain=(0,2*jnp.pi))``.
kwargs
Additional parameters to supply to the coordinate mapping function.
See ``desc.equilibrium.Equilibrium.map_coordinates``.
Expand Down Expand Up @@ -559,6 +564,7 @@ def integrate(
*,
check=False,
plot=False,
quad=None,
):
"""Bounce integrate ∫ f(λ, ℓ) dℓ.
Expand Down Expand Up @@ -607,6 +613,9 @@ def integrate(
plot : bool
Whether to plot the quantities in the integrand interpolated to the
quadrature points of each integral. Ignored if ``check`` is false.
quad : tuple[jnp.ndarray]
Optional quadrature points and weights. If given this overrides
the quadrature chosen when this object was made.
Returns
-------
Expand All @@ -626,7 +635,15 @@ def integrate(
pitch_inv = atleast_nd(self._c["T(z)"].cheb.ndim - 1, pitch_inv).T

result = self._integrate(
integrand, pitch_inv, setdefault(f, []), z1, z2, check, plot
self._x if quad is None else quad[0],
self._w if quad is None else quad[1],
integrand,
pitch_inv,
setdefault(f, []),
z1,
z2,
check,
plot,
)
if weight is not None:
errorif(

Check warning on line 649 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L649

Added line #L649 was not covered by tests
Expand All @@ -645,7 +662,7 @@ def integrate(
)
return _swap_pl(result)

def _integrate(self, integrand, pitch_inv, f, z1, z2, check, plot):
def _integrate(self, x, w, integrand, pitch_inv, f, z1, z2, check, plot):
"""Bounce integrate ∫ f(λ, ℓ) dℓ.
Parameters
Expand All @@ -665,10 +682,10 @@ def _integrate(self, integrand, pitch_inv, f, z1, z2, check, plot):
"""
if not isinstance(f, (list, tuple)):
f = [f]
shape = [*z1.shape, self._x.size] # num pitch, num rho, num well, num quad
shape = [*z1.shape, x.size] # num pitch, num rho, num well, num quad
# ζ ∈ ℝ and θ ∈ ℝ coordinates of quadrature points
zeta = flatten_matrix(
bijection_from_disc(self._x, z1[..., jnp.newaxis], z2[..., jnp.newaxis])
bijection_from_disc(x, z1[..., jnp.newaxis], z2[..., jnp.newaxis])
)
theta = self._c["T(z)"].eval1d(zeta)

Expand Down Expand Up @@ -708,7 +725,7 @@ def _integrate(self, integrand, pitch_inv, f, z1, z2, check, plot):
integrand(*f, B=B, pitch=1 / pitch_inv[..., jnp.newaxis], zeta=zeta)
* B
/ B_sup_z
).reshape(shape).dot(self._w) * grad_bijection_from_disc(z1, z2)
).reshape(shape).dot(w) * grad_bijection_from_disc(z1, z2)

if check:
shape[-3], shape[0] = shape[0], shape[-3]
Expand All @@ -722,15 +739,15 @@ def _integrate(self, integrand, pitch_inv, f, z1, z2, check, plot):

return result

def compute_length(self, quad=None):
def compute_fieldline_length(self, quad=None):
"""Compute the proper length of the field line ∫ dℓ / |B|.
Parameters
----------
quad : tuple[jnp.ndarray]
Quadrature points xₖ and weights wₖ for the approximate evaluation
of the integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ).
Number of points equal to half the Chebyshev resolution of |B| works well.
Resolution equal to half the Chebyshev resolution of |B| works well.
Returns
-------
Expand Down Expand Up @@ -876,8 +893,8 @@ class Bounce1D(Bounce):
the particle's guiding center trajectory traveling in the direction of increasing
field-line-following coordinate ζ.
Notes for developers
--------------------
Notes
-----
For applications which reduce to computing a nonlinear function of distance
along field lines between bounce points, it is required to identify these
points with field-line-following coordinates. (In the special case of a linear
Expand Down
6 changes: 2 additions & 4 deletions tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,6 @@ def test_quad_compare(self, is_strong, B):
"""Compare quadratures in W-shaped wells."""
x = np.linspace(-1, 1, 1000)
plt.plot(x, B(x))
plt.show()

def func(x):
w1 = jnp.sqrt(jnp.clip(2 - B(x), 0, jnp.inf))
Expand All @@ -1069,7 +1068,6 @@ def func(x):
return w1

plt.plot(x, func(x))
plt.show()

truth, info = quadax.quadts(func, interval=(-1, 1))
print("\n" + 50 * "---" + f"\nTrue value: {truth}, neval: {info[1]}")
Expand Down Expand Up @@ -1760,8 +1758,8 @@ def test_bounce2d_checks(self):
print("ρ:", rho[l])

np.testing.assert_allclose(
bounce.compute_length(),
# Computed data below through <L|r,a> with Simpson's rule at 800 nodes.
bounce.compute_fieldline_length(),
# Computed below through "fieldline length" with Simpson's rule 800 points.
# The difference is likely due to interpolation and floating point error.
# (On the version of JAX on which rtol was set, there is a bug with DCT
# and FFT that limit the accuracy to something comparable to 32 bit).
Expand Down

0 comments on commit 927ab54

Please sign in to comment.