Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change field line integration to use diffrax #610

Merged
merged 32 commits into from
Sep 18, 2024

Conversation

dpanici
Copy link
Collaborator

@dpanici dpanici commented Aug 3, 2023

Updated comment by @f0uriest

  • field_line_integrate now uses diffrax instead of jax.experimental.ode.odeint which has been soft deprecated.
  • Allows better control of bounding box (hard stop of field lines rather than ad-hoc expoential decay)
  • In practice seems to be a bit faster. test_plot_poincare which calls field_line_integrate under the hood went from ~65s to ~55s with these changes.

Resolves #609

@dpanici dpanici added the test_jax Run tests against different versions of JAX label Aug 3, 2023
requirements.txt Outdated Show resolved Hide resolved
desc/magnetic_fields.py Outdated Show resolved Hide resolved
desc/magnetic_fields.py Outdated Show resolved Hide resolved
@dpanici
Copy link
Collaborator Author

dpanici commented Aug 10, 2023

The RuntimeWarning for numpy header size being different is still an issue, I think because the warning is coming from C level code (@f0uriest @unalmis any ideas on how to ignore those? seems like a filterwarnings problem), right now I have in a blanket ignore RuntimeWarning but that is not really what we want, I opened an issue to see if pytest knows how to resolve it pytest-dev/pytest#11304

@codecov
Copy link

codecov bot commented Aug 11, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 95.45%. Comparing base (a4220a4) to head (acf1c98).

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #610      +/-   ##
==========================================
+ Coverage   92.78%   95.45%   +2.66%     
==========================================
  Files          95       95              
  Lines       23419    23429      +10     
==========================================
+ Hits        21729    22363     +634     
+ Misses       1690     1066     -624     
Files with missing lines Coverage Δ
desc/magnetic_fields/_core.py 96.61% <100.00%> (+0.04%) ⬆️

... and 26 files with indirect coverage changes

@f0uriest
Copy link
Member

f0uriest commented Oct 6, 2023

it looks like there are older versions of diffrax that work with older versions of jax. They might not have the discrete terminating event stuff but might still be useful for being able to select different integration schemes etc.

@dpanici
Copy link
Collaborator Author

dpanici commented Oct 18, 2023

exp. decrease RHS outside the bounding box (exp(-r)*B)
and try older version

Copy link
Contributor

github-actions bot commented Jan 21, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     -3.98 +/- 5.49     | -2.57e-02 +/- 3.55e-02 |  6.21e-01 +/- 3.1e-02  |  6.46e-01 +/- 1.7e-02  |
 test_build_transform_fft_highres        |     -5.12 +/- 3.82     | -5.40e-02 +/- 4.03e-02 |  1.00e+00 +/- 3.8e-02  |  1.06e+00 +/- 1.4e-02  |
 test_equilibrium_init_lowres            |     -6.77 +/- 2.85     | -2.81e-01 +/- 1.19e-01 |  3.87e+00 +/- 8.7e-02  |  4.15e+00 +/- 8.0e-02  |
 test_objective_compile_atf              |     -0.44 +/- 3.12     | -3.49e-02 +/- 2.48e-01 |  7.93e+00 +/- 2.1e-01  |  7.96e+00 +/- 1.3e-01  |
 test_objective_compute_atf              |     -0.53 +/- 2.12     | -5.40e-05 +/- 2.17e-04 |  1.02e-02 +/- 1.1e-04  |  1.02e-02 +/- 1.9e-04  |
 test_objective_jac_atf                  |     +0.26 +/- 1.38     | +4.96e-03 +/- 2.65e-02 |  1.92e+00 +/- 1.7e-02  |  1.92e+00 +/- 2.1e-02  |
 test_perturb_1                          |     +2.56 +/- 5.52     | +3.16e-01 +/- 6.80e-01 |  1.26e+01 +/- 5.8e-01  |  1.23e+01 +/- 3.5e-01  |
 test_proximal_jac_atf                   |     +0.34 +/- 1.47     | +2.78e-02 +/- 1.19e-01 |  8.12e+00 +/- 6.9e-02  |  8.09e+00 +/- 9.7e-02  |
 test_proximal_freeb_compute             |     +2.33 +/- 1.08     | +4.26e-03 +/- 1.98e-03 |  1.87e-01 +/- 1.5e-03  |  1.83e-01 +/- 1.2e-03  |
 test_build_transform_fft_lowres         |     +0.30 +/- 6.13     | +1.60e-03 +/- 3.26e-02 |  5.33e-01 +/- 2.4e-02  |  5.31e-01 +/- 2.1e-02  |
 test_equilibrium_init_medres            |     +1.40 +/- 5.57     | +5.74e-02 +/- 2.29e-01 |  4.17e+00 +/- 2.3e-01  |  4.11e+00 +/- 3.5e-02  |
 test_equilibrium_init_highres           |     +1.59 +/- 2.35     | +8.64e-02 +/- 1.28e-01 |  5.53e+00 +/- 1.2e-01  |  5.45e+00 +/- 4.5e-02  |
 test_objective_compile_dshape_current   |     +0.40 +/- 1.18     | +1.54e-02 +/- 4.49e-02 |  3.82e+00 +/- 8.8e-03  |  3.81e+00 +/- 4.4e-02  |
 test_objective_compute_dshape_current   |     +0.63 +/- 1.60     | +2.19e-05 +/- 5.52e-05 |  3.48e-03 +/- 3.9e-05  |  3.46e-03 +/- 3.9e-05  |
 test_objective_jac_dshape_current       |     -1.70 +/- 5.33     | -6.86e-04 +/- 2.14e-03 |  3.96e-02 +/- 1.4e-03  |  4.03e-02 +/- 1.6e-03  |
 test_perturb_2                          |     +0.48 +/- 1.96     | +8.30e-02 +/- 3.38e-01 |  1.73e+01 +/- 1.9e-01  |  1.72e+01 +/- 2.8e-01  |
 test_proximal_freeb_jac                 |     -0.42 +/- 0.77     | -3.18e-02 +/- 5.76e-02 |  7.48e+00 +/- 3.2e-02  |  7.51e+00 +/- 4.8e-02  |
 test_solve_fixed_iter                   |     -0.31 +/- 61.51    | -1.52e-02 +/- 3.05e+00 |  4.95e+00 +/- 2.2e+00  |  4.96e+00 +/- 2.1e+00  |

desc/magnetic_fields.py Outdated Show resolved Hide resolved
desc/magnetic_fields.py Outdated Show resolved Hide resolved
@f0uriest f0uriest requested review from a team, rahulgaur104, f0uriest, ddudt, kianorr, sinaatalay, unalmis and YigitElma and removed request for a team August 27, 2024 17:13
@f0uriest f0uriest marked this pull request as ready for review August 27, 2024 17:13
ddudt
ddudt previously approved these changes Sep 9, 2024
desc/magnetic_fields/_core.py Show resolved Hide resolved
@f0uriest
Copy link
Member

@dpanici I know you can't formally approve this but lmk if you have any comments

@f0uriest f0uriest requested a review from ddudt September 10, 2024 01:57
r, z = field_line_integrate(r0, z0, phis, field, bounds_Z=(-np.inf, 0.05))
np.testing.assert_allclose(z[-1], 0.05, atol=3e-3)
phis = [0, 2 * np.pi * 25]
r, z = field_line_integrate(r0, z0, phis, field, solver=Dopri5())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we use another integrator for this test (an implicit one?)? Maybe we should specify some integrator options for the user to choose depending on the case. Or basically reference https://docs.kidger.site/diffrax/api/solvers/ode_solvers/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dopri5 and Tsit5 are both rk45 methods, just with slightly different coefficients. There's not any real reason to prefer one over the other, other than to just test that they both work.

)
return jnp.array(
[r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)]
).squeeze()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
).squeeze()
# step along the field line
# Torodial component of the field is used to normalize the step size
return jnp.array(
jnp.sign(bp) * [r * br / bp, 1, r * bz / bp]
).squeeze()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that important if you don't change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not that we're normalizing the step or anything, its just because we're using the toroidal angle as our "time" coordinate so that modifies the ODE slightly.

desc/magnetic_fields/_core.py Show resolved Hide resolved
@YigitElma YigitElma self-requested a review September 10, 2024 11:26
@f0uriest f0uriest requested a review from ddudt September 16, 2024 22:46
@ddudt ddudt merged commit 54becdb into master Sep 18, 2024
24 checks passed
@ddudt ddudt deleted the dp/field-line-integrate-diffrax branch September 18, 2024 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
hackathon Stuff to work on during hackathon test_jax Run tests against different versions of JAX
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use Diffrax for field line integration
4 participants