Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change field line integration to use diffrax #610

Merged
merged 32 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a15ad3e
change field line integration to use diffrax
dpanici Aug 3, 2023
6896171
limit diffrax version as 0.4.1 requires jax > 0.4.13
dpanici Aug 3, 2023
c3fde2c
add to docstring
dpanici Aug 3, 2023
6ca2471
increase JAX version, restrict ml_dtypes to avoid DeprecationWarning …
dpanici Aug 10, 2023
0bc5340
add NaN check (not working currently) and ignore RuntimeWarnings (fro…
dpanici Aug 10, 2023
8b70958
Merge branch 'master' into dp/field-line-integrate-diffrax
dpanici Aug 10, 2023
2900dfa
change ignore to only catch the bening ndarray size change warning (i…
dpanici Aug 11, 2023
ab4b1d4
add missing pytest markers
dpanici Aug 11, 2023
98512da
Merge branch 'master' into dp/field-line-integrate-diffrax
dpanici Aug 14, 2023
140633c
add min stepsize (NaN event works now but compile time takes ages on …
dpanici Aug 14, 2023
c6118c1
Merge branch 'master' into dp/field-line-integrate-diffrax
dpanici Aug 15, 2023
8107318
fix test by updating diffrax version
dpanici Aug 15, 2023
ee285b4
change default terminating event to end integration if exit domain of…
dpanici Aug 16, 2023
1705508
remove 3.8 testing as is not compatible with diffrax
dpanici Aug 16, 2023
b9cee23
Merge branch 'master' into dp/field-line-integrate-diffrax
dpanici Jan 21, 2024
9709281
fix params arg
dpanici Jan 21, 2024
60cde73
Merge branch 'master' into dp/field-line-integrate-diffrax
dpanici Jan 26, 2024
062fc63
Merge branch 'master' into dp/field-line-integrate-diffrax
dpanici Jan 26, 2024
e81e557
Merge branch 'master' into dp/field-line-integrate-diffrax
dpanici Feb 28, 2024
5a38899
Merge branch 'master' into dp/field-line-integrate-diffrax
f0uriest Aug 27, 2024
ee6fa85
Simplify API a bit, get tests working
f0uriest Aug 27, 2024
59f9175
Return NaN for points that leave box
f0uriest Aug 27, 2024
2dab0a3
Ignore diffrax deprecation warning for now
f0uriest Aug 27, 2024
4871332
Merge branch 'master' into dp/field-line-integrate-diffrax
f0uriest Aug 27, 2024
534576a
Merge branch 'master' into dp/field-line-integrate-diffrax
f0uriest Aug 27, 2024
f286045
Merge branch 'master' into dp/field-line-integrate-diffrax
dpanici Aug 28, 2024
b9eb9ec
Merge branch 'master' into dp/field-line-integrate-diffrax
dpanici Aug 28, 2024
7fe6771
Merge branch 'master' into dp/field-line-integrate-diffrax
f0uriest Aug 29, 2024
3c2a82d
Merge branch 'master' into dp/field-line-integrate-diffrax
f0uriest Sep 10, 2024
129e08d
Merge branch 'master' into dp/field-line-integrate-diffrax
YigitElma Sep 10, 2024
b2adabe
Merge branch 'master' into dp/field-line-integrate-diffrax
f0uriest Sep 12, 2024
acf1c98
Merge branch 'master' into dp/field-line-integrate-diffrax
f0uriest Sep 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 73 additions & 63 deletions desc/magnetic_fields/_core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
"""Classes for magnetic fields."""

import warnings
from abc import ABC, abstractmethod
from collections.abc import MutableSequence

import numpy as np
import scipy.linalg
from diffrax import (
DiscreteTerminatingEvent,
ODETerm,
PIDController,
SaveAt,
Tsit5,
diffeqsolve,
)
from interpax import approx_df, interp1d, interp2d, interp3d
from netCDF4 import Dataset, chartostring, stringtochar

from desc.backend import fori_loop, jit, jnp, odeint, sign
from desc.backend import fori_loop, jit, jnp, sign
from desc.basis import (
ChebyshevDoubleFourierBasis,
ChebyshevPolynomial,
Expand Down Expand Up @@ -2175,11 +2184,13 @@
rtol=1e-8,
atol=1e-8,
maxstep=1000,
min_step_size=1e-8,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a few different values for this and it didn't seem to have a major effect on speed, and accuracy is fine as long as its reasonably small. We could do something fancy using the data from Albanese 2015 but I think probably not worth it, and if the user wants that they can still give a particular value for this or even give a different step size controller.

solver=Tsit5(),
bounds_R=(0, np.inf),
bounds_Z=(-np.inf, np.inf),
decay_accel=1e6,
**kwargs,
):
"""Trace field lines by integration.
"""Trace field lines by integration, using diffrax package.

Parameters
----------
Expand All @@ -2191,34 +2202,29 @@
and the negative toroidal angle for negative Bphi
field : MagneticField
source of magnetic field to integrate
params: dict
params: dict, optional
parameters passed to field
source_grid : Grid, optional
Collocation points used to discretize source field.
rtol, atol : float
relative and absolute tolerances for ode integration
maxstep : int
maximum number of steps between different phis
min_step_size: float
minimum step size (in phi) that the integration can take. default is 1e-8
solver: diffrax.Solver
diffrax Solver object to use in integration,
defaults to Tsit5(), a RK45 explicit solver
bounds_R : tuple of (float,float), optional
R bounds for field line integration bounding box.
If supplied, the RHS of the field line equations will be
multiplied by exp(-r) where r is the distance to the bounding box,
this is meant to prevent the field lines which escape to infinity from
slowing the integration down by being traced to infinity.
defaults to (0,np.inf)
R bounds for field line integration bounding box. Trajectories that leave this
box will be stopped, and NaN returned for points outside the box.
Defaults to (0,np.inf)
bounds_Z : tuple of (float,float), optional
Z bounds for field line integration bounding box.
If supplied, the RHS of the field line equations will be
multiplied by exp(-r) where r is the distance to the bounding box,
this is meant to prevent the field lines which escape to infinity from
slowing the integration down by being traced to infinity.
Z bounds for field line integration bounding box. Trajectories that leave this
box will be stopped, and NaN returned for points outside the box.
Defaults to (-np.inf,np.inf)
decay_accel : float, optional
An extra factor to the exponential that decays the RHS, i.e.
the RHS is multiplied by exp(-r * decay_accel), this is to
accelerate the decay of the RHS and stop the integration sooner
after exiting the bounds. Defaults to 1e6

kwargs: dict
keyword arguments to be passed into the ``diffrax.diffeqsolve``

Returns
-------
Expand All @@ -2228,60 +2234,64 @@
"""
r0, z0, phis = map(jnp.asarray, (r0, z0, phis))
assert r0.shape == z0.shape, "r0 and z0 must have the same shape"
assert decay_accel > 0, "decay_accel must be positive"
rshape = r0.shape
r0 = r0.flatten()
z0 = z0.flatten()
x0 = jnp.array([r0, phis[0] * jnp.ones_like(r0), z0]).T

@jit
def odefun(rpz, s):
def odefun(s, rpz, args):

Check warning on line 2243 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2243

Added line #L2243 was not covered by tests
rpz = rpz.reshape((3, -1)).T
r = rpz[:, 0]
z = rpz[:, 2]
# if bounds are given, will decay the magnetic field line eqn
# RHS if the trajectory is outside of bounds to avoid
# integrating the field line to infinity, which is costly
# and not useful in most cases
decay_factor = jnp.where(
jnp.array(
[
jnp.less(r, bounds_R[0]),
jnp.greater(r, bounds_R[1]),
jnp.less(z, bounds_Z[0]),
jnp.greater(z, bounds_Z[1]),
]
),
jnp.array(
[
# we multiply by decay_accel to accelerate the decay so that the
# integration is stopped soon after the bounds are exited.
jnp.exp(-(decay_accel * (r - bounds_R[0]) ** 2)),
jnp.exp(-(decay_accel * (r - bounds_R[1]) ** 2)),
jnp.exp(-(decay_accel * (z - bounds_Z[0]) ** 2)),
jnp.exp(-(decay_accel * (z - bounds_Z[1]) ** 2)),
]
),
1.0,
)
# multiply all together, the conditions that are not violated
# are just one while the violated ones are continuous decaying exponentials
decay_factor = jnp.prod(decay_factor, axis=0)

br, bp, bz = field.compute_magnetic_field(
rpz, params, basis="rpz", source_grid=source_grid
).T
return (
decay_factor
* jnp.array(
[r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)]
).squeeze()
)
return jnp.array(

Check warning on line 2249 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2249

Added line #L2249 was not covered by tests
[r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)]
).squeeze()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
).squeeze()
# step along the field line
# Torodial component of the field is used to normalize the step size
return jnp.array(
jnp.sign(bp) * [r * br / bp, 1, r * bz / bp]
).squeeze()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that important if you don't change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not that we're normalizing the step or anything, its just because we're using the toroidal angle as our "time" coordinate so that modifies the ODE slightly.


# diffrax parameters

def default_terminating_event_fxn(state, **kwargs):
R_out = jnp.any(jnp.array([state.y[0] < bounds_R[0], state.y[0] > bounds_R[1]]))
Z_out = jnp.any(jnp.array([state.y[2] < bounds_Z[0], state.y[2] > bounds_Z[1]]))
return jnp.any(jnp.array([R_out, Z_out]))

Check warning on line 2258 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2255-L2258

Added lines #L2255 - L2258 were not covered by tests

kwargs.setdefault(

Check warning on line 2260 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2260

Added line #L2260 was not covered by tests
"stepsize_controller", PIDController(rtol=rtol, atol=atol, dtmin=min_step_size)
ddudt marked this conversation as resolved.
Show resolved Hide resolved
YigitElma marked this conversation as resolved.
Show resolved Hide resolved
)
kwargs.setdefault(

Check warning on line 2263 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2263

Added line #L2263 was not covered by tests
"discrete_terminating_event",
DiscreteTerminatingEvent(default_terminating_event_fxn),
)

term = ODETerm(odefun)
saveat = SaveAt(ts=phis)

Check warning on line 2269 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2268-L2269

Added lines #L2268 - L2269 were not covered by tests

intfun = lambda x: diffeqsolve(

Check warning on line 2271 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2271

Added line #L2271 was not covered by tests
term,
solver,
y0=x,
t0=phis[0],
t1=phis[-1],
saveat=saveat,
max_steps=maxstep * len(phis),
dt0=min_step_size,
**kwargs,
).ys

# suppress warnings till its fixed upstream:
# https://github.com/patrick-kidger/diffrax/issues/445
# also ignore deprecation warning for now until we actually need to deal with it
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="unhashable type")
warnings.filterwarnings("ignore", message="`diffrax.*discrete_terminating")
x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0)

Check warning on line 2289 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2286-L2289

Added lines #L2286 - L2289 were not covered by tests

x = jnp.where(jnp.isinf(x), jnp.nan, x)
r = x[:, :, 0].squeeze().T.reshape((len(phis), *rshape))
z = x[:, :, 2].squeeze().T.reshape((len(phis), *rshape))

Check warning on line 2293 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2291-L2293

Added lines #L2291 - L2293 were not covered by tests

intfun = lambda x: odeint(odefun, x, phis, rtol=rtol, atol=atol, mxstep=maxstep)
x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0)
r = x[:, :, 0].T.reshape((len(phis), *rshape))
z = x[:, :, 2].T.reshape((len(phis), *rshape))
return r, z


Expand Down
1 change: 1 addition & 0 deletions devtools/dev-requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pip:
# Conda only parses a single list of pip requirements.
# If two pip lists are given, all but the last list is skipped.
- diffrax >= 0.4.1
- interpax >= 0.3.3
- jax[cpu] >= 0.3.2, < 0.5.0
- nvgpu
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
colorama
diffrax >= 0.4.1
h5py >= 3.0.0, < 4.0
interpax >= 0.3.3
jax[cpu] >= 0.3.2, < 0.5.0
Expand Down
1 change: 1 addition & 0 deletions requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: desc-env
dependencies:
# standard install
- colorama
- diffrax >= 0.4.1
- h5py >= 3.0.0, < 4.0
- matplotlib >= 3.5.0, < 4.0.0
- mpmath >= 1.0.0, < 2.0
Expand Down
56 changes: 42 additions & 14 deletions tests/test_magnetic_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
from diffrax import Dopri5
from scipy.constants import mu_0

from desc.backend import jit, jnp
Expand Down Expand Up @@ -1038,24 +1039,51 @@ def test_field_line_integrate(self):
np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6)

@pytest.mark.unit
def test_field_line_integrate_bounds(self):
"""Test field line integration with bounding box."""
def test_field_line_integrate_long(self):
"""Test field line integration for long distance along line."""
# q=4, field line should rotate 1/4 turn after 1 toroidal transit
# from outboard midplane to top center
field = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25)
# test that bounds work correctly, and stop integration when trajectory
# hits the bounds
r0 = [10.1]
r0 = [10.001]
z0 = [0.0]
phis = [0, 2 * np.pi]
# this will hit the R bound
# (there is no Z bound, and R would go to 10.0 if not bounded)
r, z = field_line_integrate(r0, z0, phis, field, bounds_R=(10.05, np.inf))
np.testing.assert_allclose(r[-1], 10.05, rtol=3e-4)
# this will hit the Z bound
# (there is no R bound, and Z would go to 0.1 if not bounded)
r, z = field_line_integrate(r0, z0, phis, field, bounds_Z=(-np.inf, 0.05))
np.testing.assert_allclose(z[-1], 0.05, atol=3e-3)
phis = [0, 2 * np.pi * 25]
r, z = field_line_integrate(r0, z0, phis, field, solver=Dopri5())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we use another integrator for this test (an implicit one?)? Maybe we should specify some integrator options for the user to choose depending on the case. Or basically reference https://docs.kidger.site/diffrax/api/solvers/ode_solvers/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dopri5 and Tsit5 are both rk45 methods, just with slightly different coefficients. There's not any real reason to prefer one over the other, other than to just test that they both work.

np.testing.assert_allclose(r[-1], 10, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6)

@pytest.mark.unit
def test_field_line_integrate_early_terminate_default(self):
"""Test field line integration with default early termination criterion."""
# q=4, field line should rotate 1/4 turn after 1 toroidal transit
# from outboard midplane to top center
# early terminate when it crosses towards the inboard side (R=10),
field1 = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25)
# make a SplineMagneticField only defined in a tiny region around initial point
field = SplineMagneticField.from_field(
field=field1,
R=np.linspace(10.0, 10.005, 40),
phi=np.linspace(0, 2 * np.pi, 40),
Z=np.linspace(-5e-3, 5e-3, 40),
extrap=True,
)
r0 = [10.001]
z0 = [0.0]
phis = [0, 2 * np.pi, 2 * np.pi * 2]

r, z = field_line_integrate(
r0,
z0,
phis,
field,
bounds_R=(np.min(field._R), np.max(field._R)),
bounds_Z=(np.min(field._Z), np.max(field._Z)),
min_step_size=1e-2,
)
np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(z[1], 0.001, rtol=1e-6, atol=1e-6)
# if early terinated, the values at the un-integrated phi points are inf
assert np.isnan(r[-1])
assert np.isnan(z[-1])

@pytest.mark.unit
def test_Bnormal_calculation(self):
Expand Down
Loading