Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2ead6db
refactor: remove temporary field arrays and reuse CPU buffers for saving
henry2004y Feb 3, 2026
7e8522b
feat: support time-dependent field interpolation in kernel.
henry2004y Feb 3, 2026
f1d394b
revert interpolation changes
henry2004y Feb 3, 2026
9581436
Specialize the kernel in CPU version
henry2004y Feb 4, 2026
1b1ddad
Check CPU for save_end
henry2004y Feb 4, 2026
eedd0ee
Specialized CPU methods
henry2004y Feb 4, 2026
35cb33b
refactor: use svector velocity update methods
henry2004y Feb 4, 2026
6310f84
fix: compatibility with CUDA
henry2004y Feb 4, 2026
31cfc83
refactor kernel Boris internals
henry2004y Feb 5, 2026
cde6497
refactor: use FieldInterpolator instead of function capture
henry2004y Feb 5, 2026
d8092de
Cleanup imports
henry2004y Feb 5, 2026
779f954
Cleanup unused SphericalVectorFieldInterpolator
henry2004y Feb 5, 2026
734a764
Add the kernel solver to precompilation
henry2004y Feb 5, 2026
f3b14b2
fix synchronize import
henry2004y Feb 5, 2026
c7e692d
feat: support multithreading kernel Boris solver
henry2004y Feb 5, 2026
58b9de0
Use threading as the default kernel solver option based on benchmarks
henry2004y Feb 5, 2026
38b3965
Switch back to EnsembleSerial for default
henry2004y Feb 6, 2026
5b719e7
redesign the multithreading kernel boris solver
henry2004y Feb 6, 2026
56d4902
Fix the race condition bug
henry2004y Feb 6, 2026
a993367
refactor: inline Field call
henry2004y Feb 6, 2026
bf6647b
refactor: make isoutofdomain and velocity_updater type stable
henry2004y Feb 6, 2026
17496b2
fix: less restrictive SVector type for Unitful integration
henry2004y Feb 6, 2026
21c3b04
Remove redundant SVector wrapper
henry2004y Feb 6, 2026
4c9dd2f
test: add kernel boris threading test
henry2004y Feb 6, 2026
e01136a
Remove unnecessary SVector wrappers.
henry2004y Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/TestParticle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ import ForwardDiff
using ChunkSplitters: index_chunks
using PrecompileTools: @setup_workload, @compile_workload
using MuladdMacro: @muladd
using KernelAbstractions: @kernel, @index, @Const, synchronize, Backend, CPU

import KernelAbstractions as KA
import Adapt
import Tensors
import Base: +, -, *, /, setindex!, getindex
import LinearAlgebra: ×
Expand Down
141 changes: 77 additions & 64 deletions src/boris.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,39 @@ end
@inline ODE_DEFAULT_ISOUTOFDOMAIN(u, p, t) = false

"""
update_velocity(v, r, param, dt, t)
boris_velocity_update(v, E, B, qdt_2m)

Update velocity using the Boris method, returning the new velocity as an SVector.
This is the core logic shared between the standard solver and the kernel solver.
"""
@muladd function update_velocity(v, r, param, dt, t)
q2m, _, Efunc, Bfunc, _ = param
E = Efunc(r, t)
B = Bfunc(r, t)

t_rotate = q2m * B * 0.5 * dt
@inline @muladd function boris_velocity_update(v, E, B, qdt_2m)
t_rotate = qdt_2m * B
t_mag2 = sum(abs2, t_rotate)
s_rotate = 2 * t_rotate / (1 + t_mag2)

v_minus = v + q2m * E * 0.5 * dt
v_minus = v + qdt_2m * E
v_prime = v_minus + (v_minus × t_rotate)
v_plus = v_minus + (v_prime × s_rotate)

v_new = v_plus + q2m * E * 0.5 * dt
v_new = v_plus + qdt_2m * E

return v_new
end

"""
update_velocity(v, r, param, dt, t)

Update velocity using the Boris method, returning the new velocity as an SVector.
"""
@inline @muladd function update_velocity(v, r, param::P, dt, t) where {P}
q2m, _, Efunc, Bfunc, _ = param
E = Efunc(r, t)
B = Bfunc(r, t)
qdt_2m = q2m * 0.5 * dt

return boris_velocity_update(v, E, B, qdt_2m)
end

"""
update_velocity!(xv, paramBoris, param, dt, t)

Expand Down Expand Up @@ -184,22 +195,22 @@ Trace particles using the Boris method with specified `prob`.
- `save_work::Bool`: save the work done by the electric field. Default is `false`.
"""
@inline function solve(
prob::TraceProblem, ensemblealg::BasicEnsembleAlgorithm = EnsembleSerial();
prob::TraceProblem, ensemblealg::EA = EnsembleSerial();
trajectories::Int = 1, savestepinterval::Int = 1, dt::AbstractFloat,
isoutofdomain::Function = ODE_DEFAULT_ISOUTOFDOMAIN, n::Int = 1,
isoutofdomain::F = ODE_DEFAULT_ISOUTOFDOMAIN, n::Int = 1,
save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true,
save_fields::Bool = false, save_work::Bool = false
)
) where {EA <: BasicEnsembleAlgorithm, F}
return _solve(
ensemblealg, prob, trajectories, dt, savestepinterval, isoutofdomain, n,
save_start, save_end, save_everystep, Val(save_fields), Val(save_work)
)
end

function _dispatch_boris!(
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, n,
sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F, n,
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
) where {SaveFields, SaveWork}
) where {SaveFields, SaveWork, F}
return if n == 1
_boris!(
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
Expand All @@ -213,10 +224,10 @@ function _dispatch_boris!(
end
end

function _solve(
::EnsembleSerial, prob, trajectories, dt, savestepinterval, isoutofdomain, n,
@inline function _solve(
::EnsembleSerial, prob::TraceProblem, trajectories, dt, savestepinterval, isoutofdomain::F, n,
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
) where {SaveFields, SaveWork}
) where {SaveFields, SaveWork, F}
sols, nt,
nout = _prepare(
prob, trajectories, dt, savestepinterval,
Expand All @@ -231,10 +242,10 @@ function _solve(
return sols
end

function _solve(
::EnsembleThreads, prob, trajectories, dt, savestepinterval, isoutofdomain, n,
@inline function _solve(
::EnsembleThreads, prob::TraceProblem, trajectories, dt, savestepinterval, isoutofdomain::F, n,
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
) where {SaveFields, SaveWork}
) where {SaveFields, SaveWork, F}
sols, nt,
nout = _prepare(
prob, trajectories, dt, savestepinterval,
Expand Down Expand Up @@ -351,13 +362,41 @@ end
"""
Apply Boris method for particles with index in `irange`.
"""
@muladd function _generic_boris!(
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
@inline function _boris_loop!(
traj, tsave, iout, r, v, p, dt, nt, tspan,
savestepinterval, save_everystep, isoutofdomain::F1, velocity_updater::F2,
::Val{SaveFields}, ::Val{SaveWork}
) where {F1, F2, SaveFields, SaveWork}
it = 1
while it <= nt
v_prev = v
t = (it - 0.5) * dt
v = velocity_updater(v, r, p, dt, t)

if save_everystep && (it - 1) > 0 && (it - 1) % savestepinterval == 0
iout += 1
if iout <= length(traj)
t_current = tspan[1] + (it - 1) * dt
v_save = velocity_updater(v_prev, r, p, 0.5 * dt, t_current)
data = vcat(r, v_save)
traj[iout] = _prepare_saved_data(data, p, t_current, Val(SaveFields), Val(SaveWork))
tsave[iout] = t_current
end
end

r += v * dt
isoutofdomain(vcat(r, v), p, it * dt) && break
it += 1
end
return it, iout, r, v
end

@inline @muladd function _generic_boris!(
sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F1,
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork},
velocity_updater, alg_name
) where {SaveFields, SaveWork}
velocity_updater::F2, alg_name
) where {SaveFields, SaveWork, F1, F2}
(; tspan, p, u0) = prob
q2m, m, Efunc, Bfunc, _ = p
T = eltype(u0)

vars_dim = 6
Expand All @@ -368,57 +407,31 @@ Apply Boris method for particles with index in `irange`.
vars_dim += 4
end

@fastmath @inbounds for i in irange
@inbounds for i in irange
traj = Vector{SVector{vars_dim, T}}(undef, nout)
tsave = Vector{typeof(tspan[1] + dt)}(undef, nout)

# set initial conditions for each trajectory i
iout = 0
new_prob = prob.prob_func(prob, i, false)
# Load independent r and v SVector from u0
u0_i = SVector{6, T}(new_prob.u0)
r = u0_i[SVector(1, 2, 3)]
v = u0_i[SVector(4, 5, 6)]

if save_start
iout += 1
traj[iout] = _prepare_saved_data(
u0_i, p, tspan[1], Val(SaveFields), Val(SaveWork)
)
traj[iout] = _prepare_saved_data(u0_i, p, tspan[1], Val(SaveFields), Val(SaveWork))
tsave[iout] = tspan[1]
end

# push velocity back in time by 1/2 dt
v = velocity_updater(v, r, p, -0.5 * dt, tspan[1])

it = 1
while it <= nt
v_prev = v
t = (it - 0.5) * dt
v = velocity_updater(v, r, p, dt, t)

if save_everystep && (it - 1) > 0 && (it - 1) % savestepinterval == 0
iout += 1
if iout <= nout
t_current = tspan[1] + (it - 1) * dt
# Approximate v_n from v_{n-1/2} (v_prev)
v_save = velocity_updater(
v_prev, r, p, 0.5 * dt,
t_current
)

data = vcat(r, v_save)
traj[iout] = _prepare_saved_data(
data, p, t_current, Val(SaveFields), Val(SaveWork)
)
tsave[iout] = t_current
end
end

r += v * dt
isoutofdomain(vcat(r, v), p, it * dt) && break
it += 1
end
it, iout, r, v = _boris_loop!(
traj, tsave, iout, r, v, p, dt, nt, tspan,
savestepinterval, save_everystep, isoutofdomain, velocity_updater,
Val(SaveFields), Val(SaveWork)
)

final_step = min(it, nt)
should_save_final = false
Expand Down Expand Up @@ -464,10 +477,10 @@ end
"""
Apply Boris method for particles with index in `irange`.
"""
@muladd function _boris!(
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
@inline @muladd function _boris!(
sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F,
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
) where {SaveFields, SaveWork}
) where {SaveFields, SaveWork, F}

_generic_boris!(
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain,
Expand Down Expand Up @@ -541,10 +554,10 @@ Reference: [Zenitani & Kato 2025](https://arxiv.org/abs/2505.02270)
return v_new
end

@muladd function _multistep_boris!(
sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, n_steps::Int,
@inline @muladd function _multistep_boris!(
sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F, n_steps::Int,
save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}
) where {SaveFields, SaveWork}
) where {SaveFields, SaveWork, F}

velocity_updater = (v, r, p, dt, t) ->
update_velocity_multistep(v, r, p, dt, t, n_steps)
Expand Down
Loading
Loading