Skip to content

Commit

Permalink
Address review comments and fix regression in batch argument from rec…
Browse files Browse the repository at this point in the history
…ent commit

Ensure batch=False bounce integration is done in test to catch future regressions
  • Loading branch information
unalmis committed Aug 30, 2024
1 parent 75c13fd commit 446c0b7
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 42 deletions.
5 changes: 3 additions & 2 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
imap = jax.lax.map
from jax.experimental.ode import odeint
from jax.lax import cond, fori_loop, scan, switch, while_loop
from jax.nn import softmax
from jax.nn import softmax as softargmax
from jax.numpy import bincount, flatnonzero, repeat, take
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
Expand Down Expand Up @@ -422,7 +422,8 @@ def tangent_solve(g, y):
qr,
solve_triangular,
)
from scipy.special import gammaln, logsumexp, softmax # noqa: F401
from scipy.special import gammaln, logsumexp # noqa: F401
from scipy.special import softmax as softargmax # noqa: F401

trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz

Expand Down
18 changes: 9 additions & 9 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def __init__(
source=(0, 1),
destination=(-1, -2),
)
assert self.B.shape == (grid.num_alpha, grid.num_rho, grid.num_zeta - 1, 4)
self._dB_dz = polyder_vec(self.B)
assert self.B.shape == (grid.num_alpha, grid.num_rho, grid.num_zeta - 1, 4)
assert self._dB_dz.shape == (grid.num_alpha, grid.num_rho, grid.num_zeta - 1, 3)

@staticmethod
Expand Down Expand Up @@ -210,9 +210,9 @@ def points(self, pitch_inv, num_well=None):
----------
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to compute the bounce points at each field line. 1/λ(ρ,α) is
1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
are interpreted as the indices that correspond to that field line.
num_well : int or None
Specify to return the first ``num_well`` pairs of bounce points for each
pitch along each field line. This is useful if ``num_well`` tightly
Expand All @@ -232,7 +232,7 @@ def points(self, pitch_inv, num_well=None):
that the straight line path between ``z1`` and ``z2`` resides in the
epigraph of |B|.
If there were less than ``num_wells`` wells detected along a field line,
If there were less than ``num_well`` wells detected along a field line,
then the last axis, which enumerates bounce points for a particular field
line and pitch, is padded with zero.
Expand All @@ -251,9 +251,9 @@ def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs):
epigraph of |B|.
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to compute the bounce points at each field line. 1/λ(ρ,α) is
1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
are interpreted as the indices that correspond to that field line.
plot : bool
Whether to plot stuff.
kwargs
Expand Down Expand Up @@ -298,9 +298,9 @@ def integrate(
----------
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to compute the bounce integrals of each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by
``pitch_inv[...,α,ρ]`` where in the latter the labels are interpreted
as the indices that correspond to that field line.
integrand : callable
The composition operator on the set of functions in ``f`` that maps the
functions in ``f`` to the integrand f(ℓ) in ∫ f(ℓ) dℓ. It should accept the
Expand Down
53 changes: 28 additions & 25 deletions desc/integrals/bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from interpax import PPoly
from matplotlib import pyplot as plt

from desc.backend import imap, jnp
from desc.backend import softmax as softargmax
from desc.backend import imap, jnp, softargmax
from desc.integrals.basis import _add2legend, _in_epigraph_and, _plot_intersect
from desc.integrals.interp_utils import (
interp1d_Hermite_vec,
Expand Down Expand Up @@ -85,9 +84,9 @@ def _check_spline_shape(knots, g, dg_dz, pitch_inv=None):
last axis enumerates the polynomials that compose a particular spline.
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to compute the bounce integrals of each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
1/λ values. 1/λ(α,ρ) is specified by ``pitch_inv[...,α,ρ]`` where in
the latter the labels are interpreted as the indices that correspond
to that field line.
"""
errorif(knots.ndim != 1, msg=f"knots should be 1d; got shape {knots.shape}.")
Expand Down Expand Up @@ -127,9 +126,9 @@ def bounce_points(
----------
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to compute the bounce points at each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
1/λ values to compute the bounce points. 1/λ(α,ρ) is specified by
``pitch_inv[...,α,ρ]`` where in the latter the labels are interpreted
as the indices that correspond to that field line.
knots : jnp.ndarray
Shape (N, ).
ζ coordinates of spline knots. Must be strictly increasing.
Expand Down Expand Up @@ -168,7 +167,7 @@ def bounce_points(
that the straight line path between ``z1`` and ``z2`` resides in the
epigraph of |B|.
If there were less than ``num_wells`` wells detected along a field line,
If there were less than ``num_well`` wells detected along a field line,
then the last axis, which enumerates bounce points for a particular field
line and pitch, is padded with zero.
Expand Down Expand Up @@ -321,9 +320,9 @@ def bounce_quadrature(
epigraph of |B|.
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to evaluate the bounce integrals of each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by
``pitch_inv[...,α,ρ]`` where in the latter the labels are interpreted
as the indices that correspond to that field line.
integrand : callable
The composition operator on the set of functions in ``f`` that maps the
functions in ``f`` to the integrand f(ℓ) in ∫ f(ℓ) dℓ. It should accept the
Expand Down Expand Up @@ -357,16 +356,15 @@ def bounce_quadrature(
-------
result : jnp.ndarray
Shape (P, M, L, num_well).
First axis enumerates pitch values. Second axis enumerates the field lines.
Third axis enumerates the flux surfaces. Last axis enumerates the bounce
integrals.
Last axis enumerates the bounce integrals for a given pitch, field line,
and flux surface.
"""
errorif(x.ndim != 1 or x.shape != w.shape)
errorif(z1.ndim != 4 or z1.shape != z2.shape)
errorif(pitch_inv.ndim != 3)
if not isinstance(f, (list, tuple)):
f = list(f)
f = [f] if isinstance(f, (jnp.ndarray, np.ndarray)) else list(f)

Check warning on line 367 in desc/integrals/bounce_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_utils.py#L367

Added line #L367 was not covered by tests

# Integrate and complete the change of variable.
if batch:
Expand Down Expand Up @@ -441,17 +439,15 @@ def _interpolate_and_integrate(
-------
result : jnp.ndarray
Shape Q.shape[:-1].
Quadrature for every pitch.
Quadrature result.
"""
assert w.ndim == 1
assert 3 < Q.ndim < 6 and Q.shape[0] == pitch_inv.shape[0] and Q.shape[-1] == w.size
assert data["|B|"].shape[-1] == knots.size

if Q.ndim == 5:
pitch_inv = pitch_inv[..., jnp.newaxis]
shape = Q.shape
Q = flatten_matrix(Q)
Q = Q.reshape(*Q.shape[:3], -1)
b_sup_z = interp1d_Hermite_vec(
Q,
knots,
Expand All @@ -464,7 +460,14 @@ def _interpolate_and_integrate(
# that do not preserve smoothness can be captured.
f = [interp1d_vec(Q, knots, f_i, method=method) for f_i in f]
result = jnp.dot(
(integrand(*f, B=B, pitch=1 / pitch_inv) / b_sup_z).reshape(shape),
(
integrand(
*f,
B=B,
pitch=1 / pitch_inv[..., jnp.newaxis],
)
/ b_sup_z
).reshape(shape),
w,
)
if check:
Expand Down Expand Up @@ -529,7 +532,7 @@ def _plot_check_interp(Q, V, name=""):
doing debugging, so we don't include an option to plot these
in the public API of Bounce1D.
"""
for idx in np.ndindex(Q.shape[:-2]):
for idx in np.ndindex(Q.shape[:3]):
marked = jnp.nonzero(jnp.any(Q[idx] != 0.0, axis=-1))[0]
if marked.size == 0:
continue
Expand Down Expand Up @@ -663,9 +666,9 @@ def interp_to_argmin(
z1 = atleast_nd(4, z1)
z2 = atleast_nd(4, z2)
ext, g_ext = _get_extrema(knots, g, dg_dz, sentinel=0)
# JAX softmax(x) does the proper shift to compute softmax(x - max(x)), but it's
# still not a good idea to compute over a large length scale, so we warn in
# docstring to choose upper sentinel properly.
# Our softargmax(x) does the proper shift to compute softargmax(x - max(x)),
# but it's still not a good idea to compute over a large length scale, so we
# warn in docstring to choose upper sentinel properly.
argmin = softargmax(
beta * _where_for_argmin(z1, z2, ext, g_ext, upper_sentinel),
axis=-1,
Expand Down
8 changes: 4 additions & 4 deletions desc/integrals/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Notes
-----
These polynomial utilities are chosen for performance on gpu when among
methods that have the best (asymptotic) algorithmic complexity. For example,
we prefer not to use Horner's method.
These polynomial utilities are chosen for performance on gpu among
methods that have the best (asymptotic) algorithmic complexity.
For example, we prefer to not use Horner's method.
"""

from functools import partial
Expand Down Expand Up @@ -159,7 +159,7 @@ def polyroot_vec(
-------
r : jnp.ndarray
Shape (..., *c.shape[:-1], c.shape[-1] - 1).
The roots of the polynomial, iterated over the last axis.First
The roots of the polynomial, iterated over the last axis.
"""
get_only_real_roots = not (a_min is None and a_max is None)
Expand Down
1 change: 0 additions & 1 deletion desc/integrals/quad_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def get_quadrature(quad, automorphism):
x, w = quad
assert x.ndim == w.ndim == 1
if automorphism is not None:
# Apply automorphisms to supress singularities.
auto, grad_auto = automorphism
w = w * grad_auto(x)
# Recall bijection_from_disc(auto(x), ζ₁, ζ₂) = ζ.
Expand Down
3 changes: 2 additions & 1 deletion tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def test_z1_before_extrema(self):
@pytest.mark.unit
def test_z2_before_extrema(self):
"""Case where local minimum is the shared intersect between two wells."""
# To make sure both regions in hypgraph left and right of extrema are not
# To make sure both regions in hypograph left and right of extrema are not
# integrated over.
start = -1.2 * np.pi
end = -2 * start
Expand Down Expand Up @@ -1097,6 +1097,7 @@ def test_bounce1d_checks(self):
pitch_inv,
integrand=TestBounce1D._example_denominator,
check=True,
batch=False,
)
avg = safediv(num, den)
assert np.isfinite(avg).all() and np.count_nonzero(avg)
Expand Down

0 comments on commit 446c0b7

Please sign in to comment.