-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change field line integration to use diffrax #610
Changes from 29 commits
a15ad3e
6896171
c3fde2c
6ca2471
0bc5340
8b70958
2900dfa
ab4b1d4
98512da
140633c
c6118c1
8107318
ee285b4
1705508
b9cee23
9709281
60cde73
062fc63
e81e557
5a38899
ee6fa85
59f9175
2dab0a3
4871332
534576a
f286045
b9eb9ec
7fe6771
3c2a82d
129e08d
b2adabe
acf1c98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,14 +1,23 @@ | ||||||||||||||
"""Classes for magnetic fields.""" | ||||||||||||||
|
||||||||||||||
import warnings | ||||||||||||||
from abc import ABC, abstractmethod | ||||||||||||||
from collections.abc import MutableSequence | ||||||||||||||
|
||||||||||||||
import numpy as np | ||||||||||||||
import scipy.linalg | ||||||||||||||
from diffrax import ( | ||||||||||||||
DiscreteTerminatingEvent, | ||||||||||||||
ODETerm, | ||||||||||||||
PIDController, | ||||||||||||||
SaveAt, | ||||||||||||||
Tsit5, | ||||||||||||||
diffeqsolve, | ||||||||||||||
) | ||||||||||||||
from interpax import approx_df, interp1d, interp2d, interp3d | ||||||||||||||
from netCDF4 import Dataset, chartostring, stringtochar | ||||||||||||||
|
||||||||||||||
from desc.backend import fori_loop, jit, jnp, odeint, sign | ||||||||||||||
from desc.backend import fori_loop, jit, jnp, sign | ||||||||||||||
from desc.basis import ( | ||||||||||||||
ChebyshevDoubleFourierBasis, | ||||||||||||||
ChebyshevPolynomial, | ||||||||||||||
|
@@ -2175,11 +2184,13 @@ | |||||||||||||
rtol=1e-8, | ||||||||||||||
atol=1e-8, | ||||||||||||||
maxstep=1000, | ||||||||||||||
min_step_size=1e-8, | ||||||||||||||
solver=Tsit5(), | ||||||||||||||
bounds_R=(0, np.inf), | ||||||||||||||
bounds_Z=(-np.inf, np.inf), | ||||||||||||||
decay_accel=1e6, | ||||||||||||||
**kwargs, | ||||||||||||||
): | ||||||||||||||
"""Trace field lines by integration. | ||||||||||||||
"""Trace field lines by integration, using diffrax package. | ||||||||||||||
|
||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
|
@@ -2191,34 +2202,29 @@ | |||||||||||||
and the negative toroidal angle for negative Bphi | ||||||||||||||
field : MagneticField | ||||||||||||||
source of magnetic field to integrate | ||||||||||||||
params: dict | ||||||||||||||
params: dict, optional | ||||||||||||||
parameters passed to field | ||||||||||||||
source_grid : Grid, optional | ||||||||||||||
Collocation points used to discretize source field. | ||||||||||||||
rtol, atol : float | ||||||||||||||
relative and absolute tolerances for ode integration | ||||||||||||||
maxstep : int | ||||||||||||||
maximum number of steps between different phis | ||||||||||||||
min_step_size: float | ||||||||||||||
minimum step size (in phi) that the integration can take. default is 1e-8 | ||||||||||||||
solver: diffrax.Solver | ||||||||||||||
diffrax Solver object to use in integration, | ||||||||||||||
defaults to Tsit5(), a RK45 explicit solver | ||||||||||||||
bounds_R : tuple of (float,float), optional | ||||||||||||||
R bounds for field line integration bounding box. | ||||||||||||||
If supplied, the RHS of the field line equations will be | ||||||||||||||
multiplied by exp(-r) where r is the distance to the bounding box, | ||||||||||||||
this is meant to prevent the field lines which escape to infinity from | ||||||||||||||
slowing the integration down by being traced to infinity. | ||||||||||||||
defaults to (0,np.inf) | ||||||||||||||
R bounds for field line integration bounding box. Trajectories that leave this | ||||||||||||||
box will be stopped, and NaN returned for points outside the box. | ||||||||||||||
Defaults to (0,np.inf) | ||||||||||||||
bounds_Z : tuple of (float,float), optional | ||||||||||||||
Z bounds for field line integration bounding box. | ||||||||||||||
If supplied, the RHS of the field line equations will be | ||||||||||||||
multiplied by exp(-r) where r is the distance to the bounding box, | ||||||||||||||
this is meant to prevent the field lines which escape to infinity from | ||||||||||||||
slowing the integration down by being traced to infinity. | ||||||||||||||
Z bounds for field line integration bounding box. Trajectories that leave this | ||||||||||||||
box will be stopped, and NaN returned for points outside the box. | ||||||||||||||
Defaults to (-np.inf,np.inf) | ||||||||||||||
decay_accel : float, optional | ||||||||||||||
An extra factor to the exponential that decays the RHS, i.e. | ||||||||||||||
the RHS is multiplied by exp(-r * decay_accel), this is to | ||||||||||||||
accelerate the decay of the RHS and stop the integration sooner | ||||||||||||||
after exiting the bounds. Defaults to 1e6 | ||||||||||||||
|
||||||||||||||
kwargs: dict | ||||||||||||||
keyword arguments to be passed into the ``diffrax.diffeqsolve`` | ||||||||||||||
|
||||||||||||||
Returns | ||||||||||||||
------- | ||||||||||||||
|
@@ -2228,60 +2234,64 @@ | |||||||||||||
""" | ||||||||||||||
r0, z0, phis = map(jnp.asarray, (r0, z0, phis)) | ||||||||||||||
assert r0.shape == z0.shape, "r0 and z0 must have the same shape" | ||||||||||||||
assert decay_accel > 0, "decay_accel must be positive" | ||||||||||||||
rshape = r0.shape | ||||||||||||||
r0 = r0.flatten() | ||||||||||||||
z0 = z0.flatten() | ||||||||||||||
x0 = jnp.array([r0, phis[0] * jnp.ones_like(r0), z0]).T | ||||||||||||||
|
||||||||||||||
@jit | ||||||||||||||
def odefun(rpz, s): | ||||||||||||||
def odefun(s, rpz, args): | ||||||||||||||
rpz = rpz.reshape((3, -1)).T | ||||||||||||||
r = rpz[:, 0] | ||||||||||||||
z = rpz[:, 2] | ||||||||||||||
# if bounds are given, will decay the magnetic field line eqn | ||||||||||||||
# RHS if the trajectory is outside of bounds to avoid | ||||||||||||||
# integrating the field line to infinity, which is costly | ||||||||||||||
# and not useful in most cases | ||||||||||||||
decay_factor = jnp.where( | ||||||||||||||
jnp.array( | ||||||||||||||
[ | ||||||||||||||
jnp.less(r, bounds_R[0]), | ||||||||||||||
jnp.greater(r, bounds_R[1]), | ||||||||||||||
jnp.less(z, bounds_Z[0]), | ||||||||||||||
jnp.greater(z, bounds_Z[1]), | ||||||||||||||
] | ||||||||||||||
), | ||||||||||||||
jnp.array( | ||||||||||||||
[ | ||||||||||||||
# we multiply by decay_accel to accelerate the decay so that the | ||||||||||||||
# integration is stopped soon after the bounds are exited. | ||||||||||||||
jnp.exp(-(decay_accel * (r - bounds_R[0]) ** 2)), | ||||||||||||||
jnp.exp(-(decay_accel * (r - bounds_R[1]) ** 2)), | ||||||||||||||
jnp.exp(-(decay_accel * (z - bounds_Z[0]) ** 2)), | ||||||||||||||
jnp.exp(-(decay_accel * (z - bounds_Z[1]) ** 2)), | ||||||||||||||
] | ||||||||||||||
), | ||||||||||||||
1.0, | ||||||||||||||
) | ||||||||||||||
# multiply all together, the conditions that are not violated | ||||||||||||||
# are just one while the violated ones are continuous decaying exponentials | ||||||||||||||
decay_factor = jnp.prod(decay_factor, axis=0) | ||||||||||||||
|
||||||||||||||
br, bp, bz = field.compute_magnetic_field( | ||||||||||||||
rpz, params, basis="rpz", source_grid=source_grid | ||||||||||||||
).T | ||||||||||||||
return ( | ||||||||||||||
decay_factor | ||||||||||||||
* jnp.array( | ||||||||||||||
[r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)] | ||||||||||||||
).squeeze() | ||||||||||||||
) | ||||||||||||||
return jnp.array( | ||||||||||||||
[r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)] | ||||||||||||||
).squeeze() | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not that important if you don't change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not that we're normalizing the step or anything, its just because we're using the toroidal angle as our "time" coordinate so that modifies the ODE slightly. |
||||||||||||||
|
||||||||||||||
# diffrax parameters | ||||||||||||||
|
||||||||||||||
def default_terminating_event_fxn(state, **kwargs): | ||||||||||||||
R_out = jnp.any(jnp.array([state.y[0] < bounds_R[0], state.y[0] > bounds_R[1]])) | ||||||||||||||
Z_out = jnp.any(jnp.array([state.y[2] < bounds_Z[0], state.y[2] > bounds_Z[1]])) | ||||||||||||||
return jnp.any(jnp.array([R_out, Z_out])) | ||||||||||||||
|
||||||||||||||
kwargs.setdefault( | ||||||||||||||
"stepsize_controller", PIDController(rtol=rtol, atol=atol, dtmin=min_step_size) | ||||||||||||||
ddudt marked this conversation as resolved.
Show resolved
Hide resolved
YigitElma marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
) | ||||||||||||||
kwargs.setdefault( | ||||||||||||||
"discrete_terminating_event", | ||||||||||||||
DiscreteTerminatingEvent(default_terminating_event_fxn), | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
term = ODETerm(odefun) | ||||||||||||||
saveat = SaveAt(ts=phis) | ||||||||||||||
|
||||||||||||||
intfun = lambda x: diffeqsolve( | ||||||||||||||
term, | ||||||||||||||
solver, | ||||||||||||||
y0=x, | ||||||||||||||
t0=phis[0], | ||||||||||||||
t1=phis[-1], | ||||||||||||||
saveat=saveat, | ||||||||||||||
max_steps=maxstep * len(phis), | ||||||||||||||
dt0=min_step_size, | ||||||||||||||
**kwargs, | ||||||||||||||
).ys | ||||||||||||||
|
||||||||||||||
# suppress warnings till its fixed upstream: | ||||||||||||||
# https://github.com/patrick-kidger/diffrax/issues/445 | ||||||||||||||
# also ignore deprecation warning for now until we actually need to deal with it | ||||||||||||||
with warnings.catch_warnings(): | ||||||||||||||
warnings.filterwarnings("ignore", message="unhashable type") | ||||||||||||||
warnings.filterwarnings("ignore", message="`diffrax.*discrete_terminating") | ||||||||||||||
x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) | ||||||||||||||
|
||||||||||||||
x = jnp.where(jnp.isinf(x), jnp.nan, x) | ||||||||||||||
r = x[:, :, 0].squeeze().T.reshape((len(phis), *rshape)) | ||||||||||||||
z = x[:, :, 2].squeeze().T.reshape((len(phis), *rshape)) | ||||||||||||||
|
||||||||||||||
intfun = lambda x: odeint(odefun, x, phis, rtol=rtol, atol=atol, mxstep=maxstep) | ||||||||||||||
x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) | ||||||||||||||
r = x[:, :, 0].T.reshape((len(phis), *rshape)) | ||||||||||||||
z = x[:, :, 2].T.reshape((len(phis), *rshape)) | ||||||||||||||
return r, z | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
colorama | ||
diffrax >= 0.4.1 | ||
h5py >= 3.0.0, < 4.0 | ||
interpax >= 0.3.3 | ||
jax[cpu] >= 0.3.2, < 0.5.0 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
import numpy as np | ||
import pytest | ||
from diffrax import Dopri5 | ||
from scipy.constants import mu_0 | ||
|
||
from desc.backend import jit, jnp | ||
|
@@ -1038,24 +1039,51 @@ def test_field_line_integrate(self): | |
np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6) | ||
|
||
@pytest.mark.unit | ||
def test_field_line_integrate_bounds(self): | ||
"""Test field line integration with bounding box.""" | ||
def test_field_line_integrate_long(self): | ||
"""Test field line integration for long distance along line.""" | ||
# q=4, field line should rotate 1/4 turn after 1 toroidal transit | ||
# from outboard midplane to top center | ||
field = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) | ||
# test that bounds work correctly, and stop integration when trajectory | ||
# hits the bounds | ||
r0 = [10.1] | ||
r0 = [10.001] | ||
z0 = [0.0] | ||
phis = [0, 2 * np.pi] | ||
# this will hit the R bound | ||
# (there is no Z bound, and R would go to 10.0 if not bounded) | ||
r, z = field_line_integrate(r0, z0, phis, field, bounds_R=(10.05, np.inf)) | ||
np.testing.assert_allclose(r[-1], 10.05, rtol=3e-4) | ||
# this will hit the Z bound | ||
# (there is no R bound, and Z would go to 0.1 if not bounded) | ||
r, z = field_line_integrate(r0, z0, phis, field, bounds_Z=(-np.inf, 0.05)) | ||
np.testing.assert_allclose(z[-1], 0.05, atol=3e-3) | ||
phis = [0, 2 * np.pi * 25] | ||
r, z = field_line_integrate(r0, z0, phis, field, solver=Dopri5()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we use another integrator for this test (an implicit one?)? Maybe we should specify some integrator options for the user to choose depending on the case. Or basically reference https://docs.kidger.site/diffrax/api/solvers/ode_solvers/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dopri5 and Tsit5 are both rk45 methods, just with slightly different coefficients. There's not any real reason to prefer one over the other, other than to just test that they both work. |
||
np.testing.assert_allclose(r[-1], 10, rtol=1e-6, atol=1e-6) | ||
np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6) | ||
|
||
@pytest.mark.unit | ||
def test_field_line_integrate_early_terminate_default(self): | ||
"""Test field line integration with default early termination criterion.""" | ||
# q=4, field line should rotate 1/4 turn after 1 toroidal transit | ||
# from outboard midplane to top center | ||
# early terminate when it crosses towards the inboard side (R=10), | ||
field1 = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) | ||
# make a SplineMagneticField only defined in a tiny region around initial point | ||
field = SplineMagneticField.from_field( | ||
field=field1, | ||
R=np.linspace(10.0, 10.005, 40), | ||
phi=np.linspace(0, 2 * np.pi, 40), | ||
Z=np.linspace(-5e-3, 5e-3, 40), | ||
extrap=True, | ||
) | ||
r0 = [10.001] | ||
z0 = [0.0] | ||
phis = [0, 2 * np.pi, 2 * np.pi * 2] | ||
|
||
r, z = field_line_integrate( | ||
r0, | ||
z0, | ||
phis, | ||
field, | ||
bounds_R=(np.min(field._R), np.max(field._R)), | ||
bounds_Z=(np.min(field._Z), np.max(field._Z)), | ||
min_step_size=1e-2, | ||
) | ||
np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6) | ||
np.testing.assert_allclose(z[1], 0.001, rtol=1e-6, atol=1e-6) | ||
# if early terinated, the values at the un-integrated phi points are inf | ||
assert np.isnan(r[-1]) | ||
assert np.isnan(z[-1]) | ||
|
||
@pytest.mark.unit | ||
def test_Bnormal_calculation(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried a few different values for this and it didn't seem to have a major effect on speed, and accuracy is fine as long as its reasonably small. We could do something fancy using the data from Albanese 2015 but I think probably not worth it, and if the user wants that they can still give a particular value for this or even give a different step size controller.