From 2ead6db474789a1bca6cf2547be6b235dc73d6d4 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Tue, 3 Feb 2026 12:42:07 -0500 Subject: [PATCH 01/25] refactor: remove temporary field arrays and reuse CPU buffers for saving --- src/boris_kernel.jl | 207 ++++++++------------------------------------ 1 file changed, 36 insertions(+), 171 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index d19a0c5e0..58b586c53 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -31,9 +31,13 @@ function adapt_field_to_gpu(field::Field, backend::KA.Backend) # Adapt interpolation object to GPU adapted_func = Adapt.adapt(backend, itp) - return Field{is_time_dependent(field), typeof(adapted_func)}(adapted_func) + + # Re-wrap in FieldInterpolator to maintain calling convention f(r) + adapted_wrapper = FieldInterpolator(adapted_func) + + return Field{is_time_dependent(field), typeof(adapted_wrapper)}(adapted_wrapper) end - # Analytic fields don't need adaptation + # Analytic fields don't need adaptation (assuming they are GPU compatible functions) return field end @@ -73,8 +77,7 @@ end @kernel function boris_push_kernel!( @Const(xv_in), xv_out, @Const(q2m), @Const(dt), - @Const(Ex_arr), @Const(Ey_arr), @Const(Ez_arr), - @Const(Bx_arr), @Const(By_arr), @Const(Bz_arr) + Efunc, Bfunc, @Const(t) ) i = @index(Global) @@ -85,12 +88,17 @@ end vy = xv_in[5, i] vz = xv_in[6, i] - Ex = Ex_arr[i] - Ey = Ey_arr[i] - Ez = Ez_arr[i] - Bx = Bx_arr[i] - By = By_arr[i] - Bz = Bz_arr[i] + # Evaluate fields directly + r_vec = SVector{3}(x, y, z) + E_val = Efunc(r_vec, t) + B_val = Bfunc(r_vec, t) + + Ex = E_val[1] + Ey = E_val[2] + Ez = E_val[3] + Bx = B_val[1] + By = B_val[2] + Bz = B_val[3] qdt_2m = q2m * 0.5 * dt @@ -108,17 +116,22 @@ end @kernel function velocity_back_kernel!( xv_out, @Const(xv_in), @Const(q2m_val), @Const(dt_val), - @Const(Ex), @Const(Ey), @Const(Ez), @Const(Bx), @Const(By), @Const(Bz) + Efunc, Bfunc, @Const(t) ) i = @index(Global) vx = xv_in[4, i] vy = xv_in[5, i] vz = xv_in[6, i] + # Evaluate fields at current position + r_vec = SVector{3}(xv_in[1, i], xv_in[2, i], xv_in[3, i]) + E_val = Efunc(r_vec, t) + B_val = Bfunc(r_vec, t) + qdt_2m = q2m_val * 0.5 * dt_val vx_new, vy_new, vz_new = boris_velocity_update( - vx, vy, vz, Ex[i], Ey[i], Ez[i], Bx[i], By[i], Bz[i], qdt_2m + vx, vy, vz, E_val[1], E_val[2], E_val[3], B_val[1], B_val[2], B_val[3], qdt_2m ) xv_out[4, i] = vx_new @@ -127,98 +140,6 @@ end end -""" -GPU kernel to evaluate interpolated fields directly on GPU. -This eliminates CPU-GPU data transfers for numerical fields. -""" -@kernel function evaluate_interp_fields_kernel!( - Ex_arr, Ey_arr, Ez_arr, Bx_arr, By_arr, Bz_arr, - @Const(xv), E_interp, B_interp, t - ) - i = @index(Global) - - x = xv[1, i] - y = xv[2, i] - z = xv[3, i] - - # Evaluate interpolation directly on GPU - # For 3D interpolation, call with (x, y, z) - E_val = E_interp(x, y, z) - B_val = B_interp(x, y, z) - - Ex_arr[i] = E_val[1] - Ey_arr[i] = E_val[2] - Ez_arr[i] = E_val[3] - Bx_arr[i] = B_val[1] - By_arr[i] = B_val[2] - Bz_arr[i] = B_val[3] -end - -function evaluate_fields_on_particles!( - Ex_cpu, Ey_cpu, Ez_cpu, Bx_cpu, By_cpu, Bz_cpu, xv_cpu, Efunc, Bfunc, t - ) - n_particles = size(xv_cpu, 2) - - for i in 1:n_particles - r = SVector(xv_cpu[1, i], xv_cpu[2, i], xv_cpu[3, i]) - E_val = Efunc(r, t) - B_val = Bfunc(r, t) - - Ex_cpu[i] = E_val[1] - Ey_cpu[i] = E_val[2] - Ez_cpu[i] = E_val[3] - Bx_cpu[i] = B_val[1] - By_cpu[i] = B_val[2] - Bz_cpu[i] = B_val[3] - end - - return -end - -function gpu_field_evaluation!( - backend, use_gpu_interp, eval_kernel!, - Ex_arr, Ey_arr, Ez_arr, Bx_arr, By_arr, Bz_arr, - xv_current, - Efunc, Bfunc, # Adapted fields - Ex_cpu, Ey_cpu, Ez_cpu, Bx_cpu, By_cpu, Bz_cpu, # CPU buffers - xv_cpu, # CPU buffer for positions - t, n_particles - ) - return if use_gpu_interp - # Use GPU kernel for interpolation - eval_kernel!( - Ex_arr, Ey_arr, Ez_arr, Bx_arr, By_arr, Bz_arr, - xv_current, Efunc.field_function, Bfunc.field_function, t; - ndrange = n_particles - ) - KA.synchronize(backend) - else - # CPU evaluation for analytic fields - # If xv_current is not on CPU, copy to xv_cpu buffer - if xv_current isa Array - xv_target = xv_current - else - copyto!(xv_cpu, xv_current) - xv_target = xv_cpu - end - - evaluate_fields_on_particles!( - Ex_cpu, Ey_cpu, Ez_cpu, Bx_cpu, By_cpu, Bz_cpu, - xv_target, Efunc, Bfunc, t - ) - - # Copy back if needed (if Ex_arr is not aliased to Ex_cpu) - if Ex_arr !== Ex_cpu - copyto!(Ex_arr, Ex_cpu) - copyto!(Ey_arr, Ey_cpu) - copyto!(Ez_arr, Ez_cpu) - copyto!(Bx_arr, Bx_cpu) - copyto!(By_arr, By_cpu) - copyto!(Bz_arr, Bz_cpu) - end - end -end - function _leapfrog_to_output(xv, Efunc, Bfunc, t, qdt_2m_half) T = eltype(xv) # Extract position and velocity (v^{n-1/2}) @@ -256,10 +177,6 @@ function solve( q2m, m, Efunc, Bfunc, _ = p T = eltype(u0) - # Check if fields are interpolation objects and adapt to GPU - use_gpu_interp = is_interpolation_field(Efunc.field_function) || - is_interpolation_field(Bfunc.field_function) - # Adapt interpolation fields to GPU memory Efunc_gpu = adapt_field_to_gpu(Efunc, backend) Bfunc_gpu = adapt_field_to_gpu(Bfunc, backend) @@ -289,13 +206,7 @@ function solve( xv_current = KA.zeros(backend, T, 6, n_particles) xv_next = KA.zeros(backend, T, 6, n_particles) - Ex_arr = KA.zeros(backend, T, n_particles) - Ey_arr = KA.zeros(backend, T, n_particles) - Ez_arr = KA.zeros(backend, T, n_particles) - Bx_arr = KA.zeros(backend, T, n_particles) - By_arr = KA.zeros(backend, T, n_particles) - Bz_arr = KA.zeros(backend, T, n_particles) - + # xv_init on CPU xv_init = zeros(T, 6, n_particles) for i in 1:n_particles new_prob = prob.prob_func(prob, i, false) @@ -304,44 +215,8 @@ function solve( end copyto!(xv_current, xv_init) - args_cpu = (T, n_particles) - if Ex_arr isa Array - Ex_cpu = Ex_arr - Ey_cpu = Ey_arr - Ez_cpu = Ez_arr - Bx_cpu = Bx_arr - By_cpu = By_arr - Bz_cpu = Bz_arr - else - Ex_cpu = zeros(args_cpu...) - Ey_cpu = zeros(args_cpu...) - Ez_cpu = zeros(args_cpu...) - Bx_cpu = zeros(args_cpu...) - By_cpu = zeros(args_cpu...) - Bz_cpu = zeros(args_cpu...) - end - - # Buffer for particle positions on CPU (if needed) - xv_cpu_buffer = if xv_current isa Array - xv_current # Placeholder, will use actual xv_current dynamically - else - zeros(T, 6, n_particles) - end - - # Initial field evaluation - # Determine integration strategy and prepare kernel if needed - eval_kernel! = use_gpu_interp ? evaluate_interp_fields_kernel!(backend, workgroup_size) : nothing - - # Initial field evaluation - gpu_field_evaluation!( - backend, use_gpu_interp, eval_kernel!, - Ex_arr, Ey_arr, Ez_arr, Bx_arr, By_arr, Bz_arr, - xv_current, - Efunc_gpu, Bfunc_gpu, - Ex_cpu, Ey_cpu, Ez_cpu, Bx_cpu, By_cpu, Bz_cpu, - xv_cpu_buffer, - tspan[1], n_particles - ) + # Buffer for particle positions on CPU (used for saving data) + xv_cpu_buffer = zeros(T, 6, n_particles) kernel! = boris_push_kernel!(backend, workgroup_size) @@ -354,10 +229,10 @@ function solve( iout_counters = zeros(Int, trajectories) if save_start - xv_cpu = Array(xv_current) + copyto!(xv_cpu_buffer, xv_current) for i in 1:n_particles iout_counters[i] += 1 - saved_data[i][iout_counters[i]] = SVector{6, T}(xv_cpu[:, i]) + saved_data[i][iout_counters[i]] = SVector{6, T}(xv_cpu_buffer[:, i]) saved_times[i][iout_counters[i]] = tspan[1] end end @@ -365,26 +240,16 @@ function solve( vback_kernel! = velocity_back_kernel!(backend, workgroup_size) vback_kernel!( xv_current, xv_current, q2m, -0.5 * dt, - Ex_arr, Ey_arr, Ez_arr, Bx_arr, By_arr, Bz_arr; ndrange = n_particles + Efunc_gpu, Bfunc_gpu, tspan[1]; ndrange = n_particles ) KA.synchronize(backend) for it in 1:nt t = tspan[1] + (it - 0.5) * dt - gpu_field_evaluation!( - backend, use_gpu_interp, eval_kernel!, - Ex_arr, Ey_arr, Ez_arr, Bx_arr, By_arr, Bz_arr, - xv_current, - Efunc_gpu, Bfunc_gpu, - Ex_cpu, Ey_cpu, Ez_cpu, Bx_cpu, By_cpu, Bz_cpu, - xv_cpu_buffer, - t, n_particles - ) - kernel!( xv_current, xv_next, q2m, dt, - Ex_arr, Ey_arr, Ez_arr, Bx_arr, By_arr, Bz_arr; + Efunc_gpu, Bfunc_gpu, t; ndrange = n_particles ) KA.synchronize(backend) @@ -392,7 +257,7 @@ function solve( xv_current, xv_next = xv_next, xv_current if save_everystep && it % savestepinterval == 0 - xv_cpu = Array(xv_current) + copyto!(xv_cpu_buffer, xv_current) t_current = tspan[1] + it * dt qdt_2m_half = q2m * 0.5 * (0.5 * dt) @@ -400,7 +265,7 @@ function solve( if iout_counters[i] < nout iout_counters[i] += 1 saved_data[i][iout_counters[i]] = _leapfrog_to_output( - @view(xv_cpu[:, i]), Efunc, Bfunc, t_current, qdt_2m_half + @view(xv_cpu_buffer[:, i]), Efunc, Bfunc, t_current, qdt_2m_half ) saved_times[i][iout_counters[i]] = t_current end @@ -409,7 +274,7 @@ function solve( end if save_end - xv_cpu = Array(xv_current) + copyto!(xv_cpu_buffer, xv_current) t_current = tspan[2] qdt_2m_half = q2m * 0.5 * (0.5 * dt) @@ -417,7 +282,7 @@ function solve( if iout_counters[i] < nout iout_counters[i] += 1 saved_data[i][iout_counters[i]] = _leapfrog_to_output( - @view(xv_cpu[:, i]), Efunc, Bfunc, t_current, qdt_2m_half + @view(xv_cpu_buffer[:, i]), Efunc, Bfunc, t_current, qdt_2m_half ) saved_times[i][iout_counters[i]] = t_current end From 7e8522bf6fbf0abb766928d3459fa2c26512896c Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Tue, 3 Feb 2026 13:59:20 -0500 Subject: [PATCH 02/25] feat: support time-dependent field interpolation in kernel. --- src/utility/interpolation.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 665994785..31fdc14ed 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -16,6 +16,10 @@ function (fi::FieldInterpolator)(xu) return fi.itp(xu[1], xu[2], xu[3]) end +function (fi::FieldInterpolator)(xu, t) + return fi.itp(xu[1], xu[2], xu[3]) +end + function getinterp_scalar(A, grid1, grid2, grid3, args...) return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) end From f1d394b2abba5e9508e492c5a8398373bb8c2309 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Tue, 3 Feb 2026 18:39:26 -0500 Subject: [PATCH 03/25] revert interpolation changes --- src/boris_kernel.jl | 2 +- src/utility/interpolation.jl | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index 58b586c53..69936e16d 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -34,7 +34,7 @@ function adapt_field_to_gpu(field::Field, backend::KA.Backend) # Re-wrap in FieldInterpolator to maintain calling convention f(r) adapted_wrapper = FieldInterpolator(adapted_func) - + #TODO: time interpolation support needs to be checked return Field{is_time_dependent(field), typeof(adapted_wrapper)}(adapted_wrapper) end # Analytic fields don't need adaptation (assuming they are GPU compatible functions) diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 31fdc14ed..665994785 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -16,10 +16,6 @@ function (fi::FieldInterpolator)(xu) return fi.itp(xu[1], xu[2], xu[3]) end -function (fi::FieldInterpolator)(xu, t) - return fi.itp(xu[1], xu[2], xu[3]) -end - function getinterp_scalar(A, grid1, grid2, grid3, args...) return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) end From 958143600bffbe2f2a5e52ec5b1e2317b6019ce6 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Feb 2026 00:07:19 -0500 Subject: [PATCH 04/25] Specialize the kernel in CPU version --- src/boris_kernel.jl | 39 +++++++++++++++++++++++++++++------- src/prepare.jl | 1 + src/utility/interpolation.jl | 4 ++++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index 69936e16d..cd69ecf36 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -24,6 +24,9 @@ Adapt interpolation fields to GPU memory using Adapt.jl. Analytic functions are returned unchanged. """ function adapt_field_to_gpu(field::Field, backend::KA.Backend) + if backend isa KA.CPU + return field + end f = field.field_function if is_interpolation_field(f) # Unwrap FieldInterpolator to get the inner interpolation object @@ -34,7 +37,6 @@ function adapt_field_to_gpu(field::Field, backend::KA.Backend) # Re-wrap in FieldInterpolator to maintain calling convention f(r) adapted_wrapper = FieldInterpolator(adapted_func) - #TODO: time interpolation support needs to be checked return Field{is_time_dependent(field), typeof(adapted_wrapper)}(adapted_wrapper) end # Analytic fields don't need adaptation (assuming they are GPU compatible functions) @@ -206,17 +208,33 @@ function solve( xv_current = KA.zeros(backend, T, 6, n_particles) xv_next = KA.zeros(backend, T, 6, n_particles) - # xv_init on CPU - xv_init = zeros(T, 6, n_particles) + # Optimization for CPU backend: alias buffers to avoid allocations + # On CPU, KA.zeros returns an Array, so we can check this to detect CPU backend availability + is_cpu_accessible = xv_current isa Array + + if is_cpu_accessible + xv_init = xv_current + else + xv_init = zeros(T, 6, n_particles) + end + for i in 1:n_particles new_prob = prob.prob_func(prob, i, false) u0_i = new_prob.u0 xv_init[:, i] .= u0_i end - copyto!(xv_current, xv_init) + + if !is_cpu_accessible + copyto!(xv_current, xv_init) + end # Buffer for particle positions on CPU (used for saving data) - xv_cpu_buffer = zeros(T, 6, n_particles) + # If xv_current is already on CPU, we don't need a separate buffer allocation + if is_cpu_accessible + xv_cpu_buffer = xv_current + else + xv_cpu_buffer = zeros(T, 6, n_particles) + end kernel! = boris_push_kernel!(backend, workgroup_size) @@ -229,7 +247,11 @@ function solve( iout_counters = zeros(Int, trajectories) if save_start - copyto!(xv_cpu_buffer, xv_current) + if !is_cpu_accessible + copyto!(xv_cpu_buffer, xv_current) + end + # If is_cpu_accessible, xv_cpu_buffer aliases xv_current, so it's already up to date + for i in 1:n_particles iout_counters[i] += 1 saved_data[i][iout_counters[i]] = SVector{6, T}(xv_cpu_buffer[:, i]) @@ -257,7 +279,10 @@ function solve( xv_current, xv_next = xv_next, xv_current if save_everystep && it % savestepinterval == 0 - copyto!(xv_cpu_buffer, xv_current) + if !is_cpu_accessible + copyto!(xv_cpu_buffer, xv_current) + end + t_current = tspan[1] + it * dt qdt_2m_half = q2m * 0.5 * (0.5 * dt) diff --git a/src/prepare.jl b/src/prepare.jl index 2cb194c45..56514a403 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -38,6 +38,7 @@ end Field(f::Function) = Field{is_time_dependent(f), typeof(f)}(f) is_time_dependent(::AbstractField{itd}) where {itd} = itd +is_time_dependent(::FieldInterpolator) = false # Always treat as static by default (f::AbstractField{true})(xu, t) = f.field_function(xu, t) function (f::AbstractField{true})(xu) diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 665994785..0c37ef1c7 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -16,6 +16,10 @@ function (fi::FieldInterpolator)(xu) return fi.itp(xu[1], xu[2], xu[3]) end +function (fi::FieldInterpolator)(xu, t) + return fi(xu) +end + function getinterp_scalar(A, grid1, grid2, grid3, args...) return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) end From 1b1ddadc078929e58adc1ec1ea4337b58b3ac624 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Feb 2026 13:21:55 -0500 Subject: [PATCH 05/25] Check CPU for save_end --- src/boris_kernel.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index cd69ecf36..9a0251b81 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -209,7 +209,6 @@ function solve( xv_next = KA.zeros(backend, T, 6, n_particles) # Optimization for CPU backend: alias buffers to avoid allocations - # On CPU, KA.zeros returns an Array, so we can check this to detect CPU backend availability is_cpu_accessible = xv_current isa Array if is_cpu_accessible @@ -229,7 +228,6 @@ function solve( end # Buffer for particle positions on CPU (used for saving data) - # If xv_current is already on CPU, we don't need a separate buffer allocation if is_cpu_accessible xv_cpu_buffer = xv_current else @@ -299,7 +297,9 @@ function solve( end if save_end - copyto!(xv_cpu_buffer, xv_current) + if !is_cpu_accessible + copyto!(xv_cpu_buffer, xv_current) + end t_current = tspan[2] qdt_2m_half = q2m * 0.5 * (0.5 * dt) From eedd0eef379d53fb62959ab1a078e60bca84eb41 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Feb 2026 14:07:50 -0500 Subject: [PATCH 06/25] Specialized CPU methods --- src/boris_kernel.jl | 101 ++++++++++++++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 32 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index 9a0251b81..937935c56 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -77,12 +77,7 @@ adapt_field_to_gpu(field::ZeroField, backend::KA.Backend) = field return vx_new, vy_new, vz_new end -@kernel function boris_push_kernel!( - @Const(xv_in), xv_out, @Const(q2m), @Const(dt), - Efunc, Bfunc, @Const(t) - ) - i = @index(Global) - +@inline function boris_push_node!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) x = xv_in[1, i] y = xv_in[2, i] z = xv_in[3, i] @@ -113,20 +108,23 @@ end xv_out[3, i] = z + vz_new * dt xv_out[4, i] = vx_new xv_out[5, i] = vy_new - xv_out[6, i] = vz_new + return xv_out[6, i] = vz_new end -@kernel function velocity_back_kernel!( - xv_out, @Const(xv_in), @Const(q2m_val), @Const(dt_val), +@kernel function boris_push_kernel!( + @Const(xv_in), xv_out, @Const(q2m), @Const(dt), Efunc, Bfunc, @Const(t) ) i = @index(Global) - vx = xv_in[4, i] - vy = xv_in[5, i] - vz = xv_in[6, i] + boris_push_node!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) +end + +@inline function velocity_back_node!(i, xv_out, xv_in, q2m_val, dt_val, Efunc, Bfunc, t) + x, y, z = xv_in[1, i], xv_in[2, i], xv_in[3, i] + vx, vy, vz = xv_in[4, i], xv_in[5, i], xv_in[6, i] # Evaluate fields at current position - r_vec = SVector{3}(xv_in[1, i], xv_in[2, i], xv_in[3, i]) + r_vec = SVector{3}(x, y, z) E_val = Efunc(r_vec, t) B_val = Bfunc(r_vec, t) @@ -138,19 +136,63 @@ end xv_out[4, i] = vx_new xv_out[5, i] = vy_new - xv_out[6, i] = vz_new + return xv_out[6, i] = vz_new +end + +@kernel function velocity_back_kernel!( + xv_out, @Const(xv_in), @Const(q2m_val), @Const(dt_val), + Efunc, Bfunc, @Const(t) + ) + i = @index(Global) + velocity_back_node!(i, xv_out, xv_in, q2m_val, dt_val, Efunc, Bfunc, t) +end + +function boris_step!( + backend::KA.Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, + n_particles, workgroup_size + ) + kernel! = boris_push_kernel!(backend, workgroup_size) + kernel!(xv_in, xv_out, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) + KA.synchronize(backend) + return +end + +function boris_step!( + ::KA.CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, + workgroup_size + ) + @inbounds for i in 1:n_particles + boris_push_node!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) + end + return +end + +function velocity_back_step!( + backend::KA.Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, + t, n_particles, workgroup_size + ) + kernel! = velocity_back_kernel!(backend, workgroup_size) + kernel!(xv_out, xv_in, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) + KA.synchronize(backend) + return +end + +function velocity_back_step!( + ::KA.CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, + n_particles, workgroup_size + ) + @inbounds for i in 1:n_particles + velocity_back_node!(i, xv_out, xv_in, q2m, dt, Efunc, Bfunc, t) + end + return end function _leapfrog_to_output(xv, Efunc, Bfunc, t, qdt_2m_half) T = eltype(xv) # Extract position and velocity (v^{n-1/2}) - x_p = xv[1] - y_p = xv[2] - z_p = xv[3] - vx = xv[4] - vy = xv[5] - vz = xv[6] + x_p, y_p, z_p = xv[1], xv[2], xv[3] + vx, vy, vz = xv[4], xv[5], xv[6] # Evaluate fields at current position and time r_vec = SVector(x_p, y_p, z_p) @@ -234,7 +276,6 @@ function solve( xv_cpu_buffer = zeros(T, 6, n_particles) end - kernel! = boris_push_kernel!(backend, workgroup_size) sols = Vector{ typeof(build_solution(prob, :boris, [tspan[1]], [SVector{6, T}(u0)])), @@ -249,7 +290,6 @@ function solve( copyto!(xv_cpu_buffer, xv_current) end # If is_cpu_accessible, xv_cpu_buffer aliases xv_current, so it's already up to date - for i in 1:n_particles iout_counters[i] += 1 saved_data[i][iout_counters[i]] = SVector{6, T}(xv_cpu_buffer[:, i]) @@ -257,22 +297,19 @@ function solve( end end - vback_kernel! = velocity_back_kernel!(backend, workgroup_size) - vback_kernel!( - xv_current, xv_current, q2m, -0.5 * dt, - Efunc_gpu, Bfunc_gpu, tspan[1]; ndrange = n_particles + # Initial backward half-step + velocity_back_step!( + backend, xv_current, xv_current, q2m, -0.5 * dt, + Efunc_gpu, Bfunc_gpu, tspan[1], n_particles, workgroup_size ) - KA.synchronize(backend) for it in 1:nt t = tspan[1] + (it - 0.5) * dt - kernel!( - xv_current, xv_next, q2m, dt, - Efunc_gpu, Bfunc_gpu, t; - ndrange = n_particles + boris_step!( + backend, xv_current, xv_next, q2m, dt, + Efunc_gpu, Bfunc_gpu, t, n_particles, workgroup_size ) - KA.synchronize(backend) xv_current, xv_next = xv_next, xv_current From 35cb33b9a7ac9f7401966cc12ac6beb339d180f6 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Feb 2026 16:27:00 -0500 Subject: [PATCH 07/25] refactor: use svector velocity update methods --- src/boris.jl | 29 +++++++++----- src/boris_kernel.jl | 98 ++++++++++----------------------------------- 2 files changed, 41 insertions(+), 86 deletions(-) diff --git a/src/boris.jl b/src/boris.jl index a43527a06..51205f739 100644 --- a/src/boris.jl +++ b/src/boris.jl @@ -76,28 +76,39 @@ end @inline ODE_DEFAULT_ISOUTOFDOMAIN(u, p, t) = false """ - update_velocity(v, r, param, dt, t) + boris_velocity_update(v, E, B, qdt_2m) Update velocity using the Boris method, returning the new velocity as an SVector. +This is the core logic shared between the standard solver and the kernel solver. """ -@muladd function update_velocity(v, r, param, dt, t) - q2m, _, Efunc, Bfunc, _ = param - E = Efunc(r, t) - B = Bfunc(r, t) - - t_rotate = q2m * B * 0.5 * dt +@inline @muladd function boris_velocity_update(v, E, B, qdt_2m) + t_rotate = qdt_2m * B t_mag2 = sum(abs2, t_rotate) s_rotate = 2 * t_rotate / (1 + t_mag2) - v_minus = v + q2m * E * 0.5 * dt + v_minus = v + qdt_2m * E v_prime = v_minus + (v_minus × t_rotate) v_plus = v_minus + (v_prime × s_rotate) - v_new = v_plus + q2m * E * 0.5 * dt + v_new = v_plus + qdt_2m * E return v_new end +""" + update_velocity(v, r, param, dt, t) + +Update velocity using the Boris method, returning the new velocity as an SVector. +""" +@muladd function update_velocity(v, r, param, dt, t) + q2m, _, Efunc, Bfunc, _ = param + E = Efunc(r, t) + B = Bfunc(r, t) + qdt_2m = q2m * 0.5 * dt + + return boris_velocity_update(v, E, B, qdt_2m) +end + """ update_velocity!(xv, paramBoris, param, dt, t) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index 937935c56..e66171e50 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -47,68 +47,25 @@ end adapt_field_to_gpu(field::ZeroField, backend::KA.Backend) = field -@inline function boris_velocity_update(vx, vy, vz, Ex, Ey, Ez, Bx, By, Bz, qdt_2m) - vx_minus = vx + qdt_2m * Ex - vy_minus = vy + qdt_2m * Ey - vz_minus = vz + qdt_2m * Ez - - tx = qdt_2m * Bx - ty = qdt_2m * By - tz = qdt_2m * Bz - - t_mag2 = tx * tx + ty * ty + tz * tz - factor = 2 / (1 + t_mag2) - sx = factor * tx - sy = factor * ty - sz = factor * tz - - vpx = vx_minus + (vy_minus * tz - vz_minus * ty) - vpy = vy_minus + (vz_minus * tx - vx_minus * tz) - vpz = vz_minus + (vx_minus * ty - vy_minus * tx) - - vx_plus = vx_minus + (vpy * sz - vpz * sy) - vy_plus = vy_minus + (vpz * sx - vpx * sz) - vz_plus = vz_minus + (vpx * sy - vpy * sx) - - vx_new = vx_plus + qdt_2m * Ex - vy_new = vy_plus + qdt_2m * Ey - vz_new = vz_plus + qdt_2m * Ez - - return vx_new, vy_new, vz_new +@inline function get_particle(xv, i) + return SVector(xv[1, i], xv[2, i], xv[3, i]), SVector(xv[4, i], xv[5, i], xv[6, i]) end @inline function boris_push_node!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) - x = xv_in[1, i] - y = xv_in[2, i] - z = xv_in[3, i] - vx = xv_in[4, i] - vy = xv_in[5, i] - vz = xv_in[6, i] + r_vec, v_vec = get_particle(xv_in, i) # Evaluate fields directly - r_vec = SVector{3}(x, y, z) E_val = Efunc(r_vec, t) B_val = Bfunc(r_vec, t) - Ex = E_val[1] - Ey = E_val[2] - Ez = E_val[3] - Bx = B_val[1] - By = B_val[2] - Bz = B_val[3] - qdt_2m = q2m * 0.5 * dt - vx_new, vy_new, vz_new = boris_velocity_update( - vx, vy, vz, Ex, Ey, Ez, Bx, By, Bz, qdt_2m - ) + v_new = boris_velocity_update(v_vec, E_val, B_val, qdt_2m) - xv_out[1, i] = x + vx_new * dt - xv_out[2, i] = y + vy_new * dt - xv_out[3, i] = z + vz_new * dt - xv_out[4, i] = vx_new - xv_out[5, i] = vy_new - return xv_out[6, i] = vz_new + xv_out[1:3, i] = r_vec + v_new * dt + xv_out[4:6, i] = v_new + + return end @kernel function boris_push_kernel!( @@ -120,23 +77,17 @@ end end @inline function velocity_back_node!(i, xv_out, xv_in, q2m_val, dt_val, Efunc, Bfunc, t) - x, y, z = xv_in[1, i], xv_in[2, i], xv_in[3, i] - vx, vy, vz = xv_in[4, i], xv_in[5, i], xv_in[6, i] + r_vec, v_vec = get_particle(xv_in, i) # Evaluate fields at current position - r_vec = SVector{3}(x, y, z) E_val = Efunc(r_vec, t) B_val = Bfunc(r_vec, t) qdt_2m = q2m_val * 0.5 * dt_val - vx_new, vy_new, vz_new = boris_velocity_update( - vx, vy, vz, E_val[1], E_val[2], E_val[3], B_val[1], B_val[2], B_val[3], qdt_2m - ) + xv_out[4:6, i] = boris_velocity_update(v_vec, E_val, B_val, qdt_2m) - xv_out[4, i] = vx_new - xv_out[5, i] = vy_new - return xv_out[6, i] = vz_new + return end @kernel function velocity_back_kernel!( @@ -147,7 +98,7 @@ end velocity_back_node!(i, xv_out, xv_in, q2m_val, dt_val, Efunc, Bfunc, t) end -function boris_step!( +@inline function boris_step!( backend::KA.Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) @@ -157,7 +108,7 @@ function boris_step!( return end -function boris_step!( +@inline function boris_step!( ::KA.CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) @@ -167,7 +118,7 @@ function boris_step!( return end -function velocity_back_step!( +@inline function velocity_back_step!( backend::KA.Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) @@ -177,7 +128,7 @@ function velocity_back_step!( return end -function velocity_back_step!( +@inline function velocity_back_step!( ::KA.CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) @@ -191,34 +142,28 @@ end function _leapfrog_to_output(xv, Efunc, Bfunc, t, qdt_2m_half) T = eltype(xv) # Extract position and velocity (v^{n-1/2}) - x_p, y_p, z_p = xv[1], xv[2], xv[3] - vx, vy, vz = xv[4], xv[5], xv[6] + r_vec = SVector(xv[1], xv[2], xv[3]) + v_vec = SVector{3}(xv[4], xv[5], xv[6]) # Evaluate fields at current position and time - r_vec = SVector(x_p, y_p, z_p) E_val = Efunc(r_vec, t) B_val = Bfunc(r_vec, t) # Correct velocity to v^n using half-step push - vx_n, vy_n, vz_n = boris_velocity_update( - vx, vy, vz, - E_val[1], E_val[2], E_val[3], - B_val[1], B_val[2], B_val[3], - qdt_2m_half - ) + v_n = boris_velocity_update(v_vec, E_val, B_val, qdt_2m_half) - return SVector{6, T}(x_p, y_p, z_p, vx_n, vy_n, vz_n) + return vcat(r_vec, v_n) end -function solve( +@inbounds function solve( prob::TraceProblem, backend::KA.Backend; dt::AbstractFloat, trajectories::Int = 1, savestepinterval::Int = 1, save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true, workgroup_size::Int = 256 ) (; tspan, p, u0) = prob - q2m, m, Efunc, Bfunc, _ = p + q2m, _, Efunc, Bfunc, _ = p T = eltype(u0) # Adapt interpolation fields to GPU memory @@ -276,7 +221,6 @@ function solve( xv_cpu_buffer = zeros(T, 6, n_particles) end - sols = Vector{ typeof(build_solution(prob, :boris, [tspan[1]], [SVector{6, T}(u0)])), }(undef, trajectories) From 6310f849019cddf2cc1df4c8ca70cd892a00aacf Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Feb 2026 17:31:32 -0500 Subject: [PATCH 08/25] fix: compatibility with CUDA --- src/boris_kernel.jl | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index e66171e50..91f965010 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -61,9 +61,13 @@ end qdt_2m = q2m * 0.5 * dt v_new = boris_velocity_update(v_vec, E_val, B_val, qdt_2m) - - xv_out[1:3, i] = r_vec + v_new * dt - xv_out[4:6, i] = v_new + # Use scalar indexing for GPU compilation + xv_out[1, i] = r_vec[1] + v_new[1] * dt + xv_out[2, i] = r_vec[2] + v_new[2] * dt + xv_out[3, i] = r_vec[3] + v_new[3] * dt + xv_out[4, i] = v_new[1] + xv_out[5, i] = v_new[2] + xv_out[6, i] = v_new[3] return end @@ -85,7 +89,11 @@ end qdt_2m = q2m_val * 0.5 * dt_val - xv_out[4:6, i] = boris_velocity_update(v_vec, E_val, B_val, qdt_2m) + v_new = boris_velocity_update(v_vec, E_val, B_val, qdt_2m) + # Use scalar indexing for GPU compilation + xv_out[4, i] = v_new[1] + xv_out[5, i] = v_new[2] + xv_out[6, i] = v_new[3] return end From 31cfc83123c935d3b8006cb0a61ce932458345b9 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Wed, 4 Feb 2026 21:58:56 -0500 Subject: [PATCH 09/25] refactor kernel Boris internals --- src/boris_kernel.jl | 72 ++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 40 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index 91f965010..543c82296 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -47,12 +47,9 @@ end adapt_field_to_gpu(field::ZeroField, backend::KA.Backend) = field -@inline function get_particle(xv, i) - return SVector(xv[1, i], xv[2, i], xv[3, i]), SVector(xv[4, i], xv[5, i], xv[6, i]) -end - -@inline function boris_push_node!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) - r_vec, v_vec = get_particle(xv_in, i) +@inline function get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) + r_vec = SVector(xv_in[1, i], xv_in[2, i], xv_in[3, i]) + v_vec = SVector(xv_in[4, i], xv_in[5, i], xv_in[6, i]) # Evaluate fields directly E_val = Efunc(r_vec, t) @@ -61,10 +58,16 @@ end qdt_2m = q2m * 0.5 * dt v_new = boris_velocity_update(v_vec, E_val, B_val, qdt_2m) - # Use scalar indexing for GPU compilation - xv_out[1, i] = r_vec[1] + v_new[1] * dt - xv_out[2, i] = r_vec[2] + v_new[2] * dt - xv_out[3, i] = r_vec[3] + v_new[3] * dt + + return v_new +end + +@inline @muladd function boris_update!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) + v_new = get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) + # Scalar write for GPU compatibility + xv_out[1, i] = xv_in[1, i] + v_new[1] * dt + xv_out[2, i] = xv_in[2, i] + v_new[2] * dt + xv_out[3, i] = xv_in[3, i] + v_new[3] * dt xv_out[4, i] = v_new[1] xv_out[5, i] = v_new[2] xv_out[6, i] = v_new[3] @@ -72,45 +75,31 @@ end return end -@kernel function boris_push_kernel!( - @Const(xv_in), xv_out, @Const(q2m), @Const(dt), +@kernel function boris_velocity_kernel!( + xv_out, @Const(xv_in), @Const(q2m), @Const(dt), Efunc, Bfunc, @Const(t) ) i = @index(Global) - boris_push_node!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) -end - -@inline function velocity_back_node!(i, xv_out, xv_in, q2m_val, dt_val, Efunc, Bfunc, t) - r_vec, v_vec = get_particle(xv_in, i) - - # Evaluate fields at current position - E_val = Efunc(r_vec, t) - B_val = Bfunc(r_vec, t) - - qdt_2m = q2m_val * 0.5 * dt_val - - v_new = boris_velocity_update(v_vec, E_val, B_val, qdt_2m) - # Use scalar indexing for GPU compilation + v_new = get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) + # Scalar write for GPU compatibility xv_out[4, i] = v_new[1] xv_out[5, i] = v_new[2] xv_out[6, i] = v_new[3] - - return end -@kernel function velocity_back_kernel!( - xv_out, @Const(xv_in), @Const(q2m_val), @Const(dt_val), +@kernel function boris_update_kernel!( + @Const(xv_in), xv_out, @Const(q2m), @Const(dt), Efunc, Bfunc, @Const(t) ) i = @index(Global) - velocity_back_node!(i, xv_out, xv_in, q2m_val, dt_val, Efunc, Bfunc, t) + boris_update!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) end @inline function boris_step!( backend::KA.Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) - kernel! = boris_push_kernel!(backend, workgroup_size) + kernel! = boris_update_kernel!(backend, workgroup_size) kernel!(xv_in, xv_out, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) KA.synchronize(backend) return @@ -121,27 +110,30 @@ end workgroup_size ) @inbounds for i in 1:n_particles - boris_push_node!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) + boris_update!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) end return end -@inline function velocity_back_step!( +@inline function boris_velocity_step!( backend::KA.Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) - kernel! = velocity_back_kernel!(backend, workgroup_size) + kernel! = boris_velocity_kernel!(backend, workgroup_size) kernel!(xv_out, xv_in, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) KA.synchronize(backend) return end -@inline function velocity_back_step!( +@inline function boris_velocity_step!( ::KA.CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) @inbounds for i in 1:n_particles - velocity_back_node!(i, xv_out, xv_in, q2m, dt, Efunc, Bfunc, t) + v_new = get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) + xv_out[4, i] = v_new[1] + xv_out[5, i] = v_new[2] + xv_out[6, i] = v_new[3] end return end @@ -150,8 +142,8 @@ end function _leapfrog_to_output(xv, Efunc, Bfunc, t, qdt_2m_half) T = eltype(xv) # Extract position and velocity (v^{n-1/2}) - r_vec = SVector(xv[1], xv[2], xv[3]) - v_vec = SVector{3}(xv[4], xv[5], xv[6]) + r_vec = SVector{3, T}(xv[1], xv[2], xv[3]) + v_vec = SVector{3, T}(xv[4], xv[5], xv[6]) # Evaluate fields at current position and time E_val = Efunc(r_vec, t) @@ -250,7 +242,7 @@ end end # Initial backward half-step - velocity_back_step!( + boris_velocity_step!( backend, xv_current, xv_current, q2m, -0.5 * dt, Efunc_gpu, Bfunc_gpu, tspan[1], n_particles, workgroup_size ) From cde649753a98396b69a58bd1f73c82d6e05f2a0a Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Feb 2026 11:14:25 -0500 Subject: [PATCH 10/25] refactor: use FieldInterpolator instead of function capture --- src/TestParticle.jl | 1 + src/boris_kernel.jl | 29 +----- src/prepare.jl | 2 +- src/utility/interpolation.jl | 184 ++++++++++++++++++++++------------- test/test_utility.jl | 1 - 5 files changed, 120 insertions(+), 97 deletions(-) diff --git a/src/TestParticle.jl b/src/TestParticle.jl index 150283072..ece5a1fc2 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -19,6 +19,7 @@ using MuladdMacro: @muladd import Tensors import Base: +, -, *, /, setindex!, getindex import LinearAlgebra: × +import Adapt export prepare, prepare_gc, get_gc, get_gc_func export trace!, trace_relativistic!, trace_normalized!, trace_relativistic_normalized!, diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index 543c82296..0900b1e37 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -7,16 +7,6 @@ using Interpolations: AbstractInterpolation, AbstractExtrapolation # Helper functions for GPU interpolation support -""" - is_interpolation_field(f) - -Check if a field function is an interpolation object that can be adapted to GPU. -""" -function is_interpolation_field(f) - return f isa AbstractInterpolation || f isa AbstractExtrapolation || - f isa FieldInterpolator -end - """ adapt_field_to_gpu(field::Field, backend::KA.Backend) @@ -24,23 +14,12 @@ Adapt interpolation fields to GPU memory using Adapt.jl. Analytic functions are returned unchanged. """ function adapt_field_to_gpu(field::Field, backend::KA.Backend) - if backend isa KA.CPU - return field - end - f = field.field_function - if is_interpolation_field(f) - # Unwrap FieldInterpolator to get the inner interpolation object - itp = f isa FieldInterpolator ? f.itp : f + backend isa KA.CPU && return field - # Adapt interpolation object to GPU - adapted_func = Adapt.adapt(backend, itp) + # Adapt the inner function (FieldInterpolator or analytic) + adapted_func = Adapt.adapt(backend, field.field_function) - # Re-wrap in FieldInterpolator to maintain calling convention f(r) - adapted_wrapper = FieldInterpolator(adapted_func) - return Field{is_time_dependent(field), typeof(adapted_wrapper)}(adapted_wrapper) - end - # Analytic fields don't need adaptation (assuming they are GPU compatible functions) - return field + return Field{is_time_dependent(field), typeof(adapted_func)}(adapted_func) end # Fallback for ZeroField diff --git a/src/prepare.jl b/src/prepare.jl index 56514a403..6d88585fd 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -38,7 +38,7 @@ end Field(f::Function) = Field{is_time_dependent(f), typeof(f)}(f) is_time_dependent(::AbstractField{itd}) where {itd} = itd -is_time_dependent(::FieldInterpolator) = false # Always treat as static by default +is_time_dependent(::AbstractFieldInterpolator) = false # Always treat as static by default (f::AbstractField{true})(xu, t) = f.field_function(xu, t) function (f::AbstractField{true})(xu) diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 0c37ef1c7..0db234919 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -2,13 +2,19 @@ @inline getinterp(A, grid1, args...) = getinterp(CartesianGrid, A, grid1, args...) +""" + AbstractFieldInterpolator + +Abstract type for all field interpolators. +""" +abstract type AbstractFieldInterpolator <: Function end + """ FieldInterpolator{T} -A callable struct that wraps an interpolation object. -It enables compatibility with `boris_kernel` by exposing the inner `itp` object for GPU adaptation. +A callable struct that wraps a 3D interpolation object. """ -struct FieldInterpolator{T} <: Function +struct FieldInterpolator{T} <: AbstractFieldInterpolator itp::T end @@ -20,6 +26,107 @@ function (fi::FieldInterpolator)(xu, t) return fi(xu) end +Adapt.adapt_structure(to, fi::FieldInterpolator) = FieldInterpolator(Adapt.adapt(to, fi.itp)) + +""" + FieldInterpolator2D{T} + +A callable struct that wraps a 2D interpolation object. +""" +struct FieldInterpolator2D{T} <: AbstractFieldInterpolator + itp::T +end + +function (fi::FieldInterpolator2D)(xu) + # 2D interpolation usually involves x and y + return fi.itp(xu[1], xu[2]) +end + +function (fi::FieldInterpolator2D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator2D) = FieldInterpolator2D(Adapt.adapt(to, fi.itp)) + +""" + FieldInterpolator1D{T} + +A callable struct that wraps a 1D interpolation object. +""" +struct FieldInterpolator1D{T} <: AbstractFieldInterpolator + itp::T + dir::Int +end + +function (fi::FieldInterpolator1D)(xu) + return fi.itp(xu[fi.dir]) +end + +function (fi::FieldInterpolator1D)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::FieldInterpolator1D) = FieldInterpolator1D(Adapt.adapt(to, fi.itp), fi.dir) + +""" + SphericalFieldInterpolator{T} + +A callable struct for spherical grid interpolation (scalar or combined vector). +""" +struct SphericalFieldInterpolator{T} <: AbstractFieldInterpolator + itp::T +end + +function (fi::SphericalFieldInterpolator)(xu) + r_val, θ_val, ϕ_val = cart2sph(xu) + res = fi.itp(r_val, θ_val, ϕ_val) + if length(res) > 1 + # Convert vector result from spherical to cartesian basis + Br, Bθ, Bϕ = res + return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) + else + return res + end +end + +function (fi::SphericalFieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) + +""" + SphericalVectorFieldInterpolator{Tr, Tth, Tph} + +A callable struct for spherical vector field interpolation where components are interpolated separately. +""" +struct SphericalVectorFieldInterpolator{Tr, Tth, Tph} <: AbstractFieldInterpolator + itpr::Tr + itpθ::Tth + itpϕ::Tph +end + +function (fi::SphericalVectorFieldInterpolator)(xu) + r_val, θ_val, ϕ_val = cart2sph(xu) + + Br = fi.itpr(r_val, θ_val, ϕ_val) + Bθ = fi.itpθ(r_val, θ_val, ϕ_val) + Bϕ = fi.itpϕ(r_val, θ_val, ϕ_val) + + return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) +end + +function (fi::SphericalVectorFieldInterpolator)(xu, t) + return fi(xu) +end + +Adapt.adapt_structure(to, fi::SphericalVectorFieldInterpolator) = SphericalVectorFieldInterpolator( + Adapt.adapt(to, fi.itpr), + Adapt.adapt(to, fi.itpθ), + Adapt.adapt(to, fi.itpϕ) +) + + function getinterp_scalar(A, grid1, grid2, grid3, args...) return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) end @@ -82,17 +189,9 @@ function getinterp(::Type{<:CartesianGrid}, A, gridx, gridy, order::Int = 1, bc: end itp = _get_interp_object(As, order, bc) - interp = scale(itp, gridx, gridy) - # Return field value at a given location. - function get_field(xu) - r = @view xu[1:2] - - return interp(r...) - end - - return get_field + return FieldInterpolator2D(interp) end function get_interpolator( @@ -125,12 +224,7 @@ function get_interpolator( itp = extrapolate(interpolate!((gridx, gridy, gridz), A, Gridded(Linear())), bctype) - function get_field(xu) - r = @view xu[1:3] - return itp(r...) - end - - return get_field + return FieldInterpolator(itp) end function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = 3; dir = 1) @@ -143,17 +237,9 @@ function getinterp(::Type{<:CartesianGrid}, A, gridx, order::Int = 1, bc::Int = end itp = _get_interp_object(As, order, bc) - interp = scale(itp, gridx) - # Return field value at a given location. - function get_field(xu) - r = xu[dir] - - return interp(r) - end - - return get_field + return FieldInterpolator1D(interp, dir) end function _get_bspline(order::Int, periodic::Bool) @@ -256,7 +342,6 @@ function get_interpolator( gridx, gridy, gridz, order::Int = 1, bc::Int = 1 ) where {T} itp = _get_interp_object(A, order, bc) - interp = scale(itp, gridx, gridy, gridz) # Return field value at a given location. @@ -314,35 +399,6 @@ function get_interpolator( return get_interpolator(StructuredGrid, As, gridr, gridθ, gridϕ, order, bc) end -function _create_spherical_vector_field_interpolator(interpr, interpθ, interpϕ) - function get_field(xu) - r_val, θ_val, ϕ_val = cart2sph(xu) - - Br = interpr(r_val, θ_val, ϕ_val) - Bθ = interpθ(r_val, θ_val, ϕ_val) - Bϕ = interpϕ(r_val, θ_val, ϕ_val) - - Bvec = sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) - - return Bvec - end - return get_field -end - -function _create_spherical_vector_field_interpolator(itp) - function get_field(xu) - r_val, θ_val, ϕ_val = cart2sph(xu) - - B_local = itp(r_val, θ_val, ϕ_val) - Br, Bθ, Bϕ = B_local - - Bvec = sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) - - return Bvec - end - return get_field -end - function get_interpolator( ::Type{<:StructuredGrid}, A::AbstractArray{T, 3}, gridr, gridθ, gridϕ, order::Int = 1, bc::Int = 1 @@ -361,19 +417,7 @@ function get_interpolator( itp = extrapolate(interpolate!((gridr, gridθ, gridϕ), A, Gridded(Linear())), bctype) end - if T <: SVector - return _create_spherical_vector_field_interpolator(itp) - else - return _create_spherical_scalar_field_interpolator(itp) - end -end - -function _create_spherical_scalar_field_interpolator(interp) - function get_field(xu) - r_val, θ_val, ϕ_val = cart2sph(xu[1], xu[2], xu[3]) - return interp(r_val, θ_val, ϕ_val) - end - return get_field + return SphericalFieldInterpolator(itp) end function _get_interp_object(A, order::Int, bc::Int) diff --git a/test/test_utility.jl b/test/test_utility.jl index eae7b9ea4..2353a7da4 100644 --- a/test/test_utility.jl +++ b/test/test_utility.jl @@ -638,7 +638,6 @@ import TestParticle as TP struct MockArray{T, N} <: AbstractArray{T, N} data::Array{T, N} end - MockArray(A::Array{T, N}) where {T, N} = MockArray{T, N}(A) Base.size(M::MockArray) = size(M.data) Base.getindex(M::MockArray, I...) = getindex(M.data, I...) Base.setindex!(M::MockArray, v::Number, I...) = setindex!(M.data, v, I...) From d8092de92ae7be0ab1ac7e133a977b8fc49bb438 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Feb 2026 11:14:41 -0500 Subject: [PATCH 11/25] Cleanup imports --- src/TestParticle.jl | 4 +++- src/boris_kernel.jl | 27 ++++++++++----------------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/TestParticle.jl b/src/TestParticle.jl index ece5a1fc2..4492b27fa 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -15,11 +15,13 @@ import ForwardDiff using ChunkSplitters: index_chunks using PrecompileTools: @setup_workload, @compile_workload using MuladdMacro: @muladd +using KernelAbstractions: @kernel, @index, @Const, @synchronize, Backend, CPU +import KernelAbstractions as KA +import Adapt import Tensors import Base: +, -, *, /, setindex!, getindex import LinearAlgebra: × -import Adapt export prepare, prepare_gc, get_gc, get_gc_func export trace!, trace_relativistic!, trace_normalized!, trace_relativistic_normalized!, diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index 0900b1e37..a2ceb97ee 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -1,20 +1,13 @@ # GPU Boris solver using KernelAbstractions.jl -using KernelAbstractions -const KA = KernelAbstractions -using Adapt -using Interpolations: AbstractInterpolation, AbstractExtrapolation - -# Helper functions for GPU interpolation support - """ adapt_field_to_gpu(field::Field, backend::KA.Backend) Adapt interpolation fields to GPU memory using Adapt.jl. Analytic functions are returned unchanged. """ -function adapt_field_to_gpu(field::Field, backend::KA.Backend) - backend isa KA.CPU && return field +function adapt_field_to_gpu(field::Field, backend::Backend) + backend isa CPU && return field # Adapt the inner function (FieldInterpolator or analytic) adapted_func = Adapt.adapt(backend, field.field_function) @@ -23,7 +16,7 @@ function adapt_field_to_gpu(field::Field, backend::KA.Backend) end # Fallback for ZeroField -adapt_field_to_gpu(field::ZeroField, backend::KA.Backend) = field +adapt_field_to_gpu(field::ZeroField, backend::Backend) = field @inline function get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) @@ -75,17 +68,17 @@ end end @inline function boris_step!( - backend::KA.Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, + backend::Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) kernel! = boris_update_kernel!(backend, workgroup_size) kernel!(xv_in, xv_out, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) - KA.synchronize(backend) + synchronize(backend) return end @inline function boris_step!( - ::KA.CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, + ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) @inbounds for i in 1:n_particles @@ -95,17 +88,17 @@ end end @inline function boris_velocity_step!( - backend::KA.Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, + backend::Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) kernel! = boris_velocity_kernel!(backend, workgroup_size) kernel!(xv_out, xv_in, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) - KA.synchronize(backend) + synchronize(backend) return end @inline function boris_velocity_step!( - ::KA.CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, + ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, workgroup_size ) @inbounds for i in 1:n_particles @@ -136,7 +129,7 @@ end @inbounds function solve( - prob::TraceProblem, backend::KA.Backend; + prob::TraceProblem, backend::Backend; dt::AbstractFloat, trajectories::Int = 1, savestepinterval::Int = 1, save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true, workgroup_size::Int = 256 From 779f9540f994c68e39328a43eb5cd8b6c50ed75e Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Feb 2026 11:52:51 -0500 Subject: [PATCH 12/25] Cleanup unused SphericalVectorFieldInterpolator --- src/prepare.jl | 4 ++++ src/utility/interpolation.jl | 32 -------------------------------- 2 files changed, 4 insertions(+), 32 deletions(-) diff --git a/src/prepare.jl b/src/prepare.jl index 6d88585fd..d36c17cee 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -38,6 +38,10 @@ end Field(f::Function) = Field{is_time_dependent(f), typeof(f)}(f) is_time_dependent(::AbstractField{itd}) where {itd} = itd +# Note: Without it, AbstractFieldInterpolators are treated as time-dependent (due to their (x,t) method). +# This causes TestParticle to wrap them in Field{true}, which forbids calling f(x) (without time). +# Several components/tests rely on f(x) for static fields, leading to ArgumentError. +#TODO: Have a proper treatment for the time dependency with LazyTimeInterpolator. is_time_dependent(::AbstractFieldInterpolator) = false # Always treat as static by default (f::AbstractField{true})(xu, t) = f.field_function(xu, t) diff --git a/src/utility/interpolation.jl b/src/utility/interpolation.jl index 0db234919..8663a8e9d 100644 --- a/src/utility/interpolation.jl +++ b/src/utility/interpolation.jl @@ -95,38 +95,6 @@ end Adapt.adapt_structure(to, fi::SphericalFieldInterpolator) = SphericalFieldInterpolator(Adapt.adapt(to, fi.itp)) -""" - SphericalVectorFieldInterpolator{Tr, Tth, Tph} - -A callable struct for spherical vector field interpolation where components are interpolated separately. -""" -struct SphericalVectorFieldInterpolator{Tr, Tth, Tph} <: AbstractFieldInterpolator - itpr::Tr - itpθ::Tth - itpϕ::Tph -end - -function (fi::SphericalVectorFieldInterpolator)(xu) - r_val, θ_val, ϕ_val = cart2sph(xu) - - Br = fi.itpr(r_val, θ_val, ϕ_val) - Bθ = fi.itpθ(r_val, θ_val, ϕ_val) - Bϕ = fi.itpϕ(r_val, θ_val, ϕ_val) - - return sph_to_cart_vector(Br, Bθ, Bϕ, θ_val, ϕ_val) -end - -function (fi::SphericalVectorFieldInterpolator)(xu, t) - return fi(xu) -end - -Adapt.adapt_structure(to, fi::SphericalVectorFieldInterpolator) = SphericalVectorFieldInterpolator( - Adapt.adapt(to, fi.itpr), - Adapt.adapt(to, fi.itpθ), - Adapt.adapt(to, fi.itpϕ) -) - - function getinterp_scalar(A, grid1, grid2, grid3, args...) return getinterp_scalar(CartesianGrid, A, grid1, grid2, grid3, args...) end From 734a76474a41f6550217466c538277aff126d5fa Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Feb 2026 11:57:45 -0500 Subject: [PATCH 13/25] Add the kernel solver to precompilation --- src/precompile.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/precompile.jl b/src/precompile.jl index 154d4f7fe..8d6fdabdd 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -36,6 +36,9 @@ sol = solve(prob; dt, savestepinterval = 100) sol = solve(prob, EnsembleThreads(); dt, savestepinterval = 100) + # Kernel Boris (CPU) + sol_kernel = solve(prob, CPU(); dt, savestepinterval = 100) + # Adaptive Boris alg_adaptive = AdaptiveBoris(dtmax = 1.0) sol_adaptive = solve(prob, alg_adaptive)[1] From f3b14b222a377d886f90f0afdba251f88cadc59d Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Feb 2026 12:27:14 -0500 Subject: [PATCH 14/25] fix synchronize import --- src/TestParticle.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TestParticle.jl b/src/TestParticle.jl index 4492b27fa..2311ba5f5 100644 --- a/src/TestParticle.jl +++ b/src/TestParticle.jl @@ -15,7 +15,7 @@ import ForwardDiff using ChunkSplitters: index_chunks using PrecompileTools: @setup_workload, @compile_workload using MuladdMacro: @muladd -using KernelAbstractions: @kernel, @index, @Const, @synchronize, Backend, CPU +using KernelAbstractions: @kernel, @index, @Const, synchronize, Backend, CPU import KernelAbstractions as KA import Adapt From c7e692d4a6aebed707ac52082d700f1d85993027 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Feb 2026 13:57:11 -0500 Subject: [PATCH 15/25] feat: support multithreading kernel Boris solver --- src/boris_kernel.jl | 44 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index a2ceb97ee..e56335863 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -69,7 +69,7 @@ end @inline function boris_step!( backend::Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, - n_particles, workgroup_size + n_particles, workgroup_size, ensemblealg::BasicEnsembleAlgorithm ) kernel! = boris_update_kernel!(backend, workgroup_size) kernel!(xv_in, xv_out, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) @@ -79,7 +79,7 @@ end @inline function boris_step!( ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, - workgroup_size + workgroup_size, ::EnsembleSerial ) @inbounds for i in 1:n_particles boris_update!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) @@ -87,9 +87,22 @@ end return end +@inline function boris_step!( + ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, + workgroup_size, ::EnsembleThreads + ) + nchunks = Threads.nthreads() + Threads.@threads for irange in index_chunks(1:n_particles; n = nchunks) + @inbounds for i in irange + boris_update!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) + end + end + return +end + @inline function boris_velocity_step!( backend::Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, - t, n_particles, workgroup_size + t, n_particles, workgroup_size, ensemblealg::BasicEnsembleAlgorithm ) kernel! = boris_velocity_kernel!(backend, workgroup_size) kernel!(xv_out, xv_in, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) @@ -99,7 +112,7 @@ end @inline function boris_velocity_step!( ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, - n_particles, workgroup_size + n_particles, workgroup_size, ::EnsembleSerial ) @inbounds for i in 1:n_particles v_new = get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) @@ -110,6 +123,22 @@ end return end +@inline function boris_velocity_step!( + ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, + n_particles, workgroup_size, ::EnsembleThreads + ) + nchunks = Threads.nthreads() + Threads.@threads for irange in index_chunks(1:n_particles; n = nchunks) + @inbounds for i in irange + v_new = get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) + xv_out[4, i] = v_new[1] + xv_out[5, i] = v_new[2] + xv_out[6, i] = v_new[3] + end + end + return +end + function _leapfrog_to_output(xv, Efunc, Bfunc, t, qdt_2m_half) T = eltype(xv) @@ -129,7 +158,8 @@ end @inbounds function solve( - prob::TraceProblem, backend::Backend; + prob::TraceProblem, backend::Backend, + ensemblealg::BasicEnsembleAlgorithm = EnsembleSerial(); dt::AbstractFloat, trajectories::Int = 1, savestepinterval::Int = 1, save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true, workgroup_size::Int = 256 @@ -216,7 +246,7 @@ end # Initial backward half-step boris_velocity_step!( backend, xv_current, xv_current, q2m, -0.5 * dt, - Efunc_gpu, Bfunc_gpu, tspan[1], n_particles, workgroup_size + Efunc_gpu, Bfunc_gpu, tspan[1], n_particles, workgroup_size, ensemblealg ) for it in 1:nt @@ -224,7 +254,7 @@ end boris_step!( backend, xv_current, xv_next, q2m, dt, - Efunc_gpu, Bfunc_gpu, t, n_particles, workgroup_size + Efunc_gpu, Bfunc_gpu, t, n_particles, workgroup_size, ensemblealg ) xv_current, xv_next = xv_next, xv_current From 58b9de0e354002ed42f1a23cd2f02df68d9016bc Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Feb 2026 14:39:27 -0500 Subject: [PATCH 16/25] Use threading as the default kernel solver option based on benchmarks --- src/boris_kernel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index e56335863..d7c6d3d59 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -159,7 +159,7 @@ end @inbounds function solve( prob::TraceProblem, backend::Backend, - ensemblealg::BasicEnsembleAlgorithm = EnsembleSerial(); + ensemblealg::BasicEnsembleAlgorithm = EnsembleThreads(); dt::AbstractFloat, trajectories::Int = 1, savestepinterval::Int = 1, save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true, workgroup_size::Int = 256 From 38b39659ffc42f2c4aa90d9edbf432d61dd711b0 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Feb 2026 20:41:22 -0500 Subject: [PATCH 17/25] Switch back to EnsembleSerial for default --- src/boris_kernel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index d7c6d3d59..e56335863 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -159,7 +159,7 @@ end @inbounds function solve( prob::TraceProblem, backend::Backend, - ensemblealg::BasicEnsembleAlgorithm = EnsembleThreads(); + ensemblealg::BasicEnsembleAlgorithm = EnsembleSerial(); dt::AbstractFloat, trajectories::Int = 1, savestepinterval::Int = 1, save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true, workgroup_size::Int = 256 From 5b719e72375a46d17552b03a9225f18ac239ec01 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Thu, 5 Feb 2026 21:55:16 -0500 Subject: [PATCH 18/25] redesign the multithreading kernel boris solver --- src/boris_kernel.jl | 298 ++++++++++++++++++++++++++++---------------- 1 file changed, 191 insertions(+), 107 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index e56335863..102903c78 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -34,7 +34,7 @@ adapt_field_to_gpu(field::ZeroField, backend::Backend) = field return v_new end -@inline @muladd function boris_update!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) +@inline @muladd function boris_update_xv!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) v_new = get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) # Scalar write for GPU compatibility xv_out[1, i] = xv_in[1, i] + v_new[1] * dt @@ -64,7 +64,7 @@ end Efunc, Bfunc, @Const(t) ) i = @index(Global) - boris_update!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) + boris_update_xv!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) end @inline function boris_step!( @@ -82,20 +82,7 @@ end workgroup_size, ::EnsembleSerial ) @inbounds for i in 1:n_particles - boris_update!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) - end - return -end - -@inline function boris_step!( - ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, - workgroup_size, ::EnsembleThreads - ) - nchunks = Threads.nthreads() - Threads.@threads for irange in index_chunks(1:n_particles; n = nchunks) - @inbounds for i in irange - boris_update!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) - end + boris_update_xv!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) end return end @@ -123,23 +110,6 @@ end return end -@inline function boris_velocity_step!( - ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, - n_particles, workgroup_size, ::EnsembleThreads - ) - nchunks = Threads.nthreads() - Threads.@threads for irange in index_chunks(1:n_particles; n = nchunks) - @inbounds for i in irange - v_new = get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) - xv_out[4, i] = v_new[1] - xv_out[5, i] = v_new[2] - xv_out[6, i] = v_new[3] - end - end - return -end - - function _leapfrog_to_output(xv, Efunc, Bfunc, t, qdt_2m_half) T = eltype(xv) # Extract position and velocity (v^{n-1/2}) @@ -157,9 +127,110 @@ function _leapfrog_to_output(xv, Efunc, Bfunc, t, qdt_2m_half) end +@inbounds function _solve_serial( + prob::TraceProblem, backend::Backend, irange; + dt::AbstractFloat, savestepinterval::Int, save_start::Bool, + save_end::Bool, save_everystep::Bool, workgroup_size::Int, + xv_current, xv_next, xv_cpu_buffer, is_cpu_accessible, + Efunc_gpu, Bfunc_gpu, Efunc, Bfunc, nout, nt + ) + (; tspan, p) = prob + q2m, _, _, _, _ = p + T = eltype(xv_current) + n_particles = length(irange) + + sols = Vector{ + typeof(build_solution(prob, :boris, [tspan[1]], [SVector{6, T}(prob.u0)])), + }(undef, n_particles) + + saved_data = [Vector{SVector{6, T}}(undef, nout) for _ in 1:n_particles] + saved_times = [Vector{typeof(tspan[1] + dt)}(undef, nout) for _ in 1:n_particles] + iout_counters = zeros(Int, n_particles) + + if save_start + if !is_cpu_accessible + copyto!(xv_cpu_buffer, xv_current) + end + for (local_i, i) in enumerate(irange) + iout_counters[local_i] += 1 + saved_data[local_i][iout_counters[local_i]] = + SVector{6, T}(xv_cpu_buffer[:, i]) + saved_times[local_i][iout_counters[local_i]] = tspan[1] + end + end + + boris_velocity_step!( + backend, xv_current, xv_current, q2m, -0.5 * dt, + Efunc_gpu, Bfunc_gpu, tspan[1], n_particles, workgroup_size, EnsembleSerial() + ) + + for it in 1:nt + t = tspan[1] + (it - 0.5) * dt + + boris_step!( + backend, xv_current, xv_next, q2m, dt, + Efunc_gpu, Bfunc_gpu, t, n_particles, workgroup_size, EnsembleSerial() + ) + + xv_current, xv_next = xv_next, xv_current + + if save_everystep && it % savestepinterval == 0 + if !is_cpu_accessible + copyto!(xv_cpu_buffer, xv_current) + end + + t_current = tspan[1] + it * dt + qdt_2m_half = q2m * 0.5 * (0.5 * dt) + + for (local_i, i) in enumerate(irange) + if iout_counters[local_i] < nout + iout_counters[local_i] += 1 + saved_data[local_i][iout_counters[local_i]] = _leapfrog_to_output( + @view(xv_cpu_buffer[:, i]), Efunc, Bfunc, t_current, qdt_2m_half + ) + saved_times[local_i][iout_counters[local_i]] = t_current + end + end + end + end + + if save_end + if !is_cpu_accessible + copyto!(xv_cpu_buffer, xv_current) + end + t_current = tspan[2] + qdt_2m_half = q2m * 0.5 * (0.5 * dt) + + for (local_i, i) in enumerate(irange) + if iout_counters[local_i] < nout + iout_counters[local_i] += 1 + saved_data[local_i][iout_counters[local_i]] = _leapfrog_to_output( + @view(xv_cpu_buffer[:, i]), Efunc, Bfunc, t_current, qdt_2m_half + ) + saved_times[local_i][iout_counters[local_i]] = t_current + end + end + end + + for local_i in 1:n_particles + actual_len = iout_counters[local_i] + if actual_len < nout + resize!(saved_data[local_i], actual_len) + resize!(saved_times[local_i], actual_len) + end + + interp = LinearInterpolation(saved_times[local_i], saved_data[local_i]) + sols[local_i] = build_solution( + prob, :boris, saved_times[local_i], saved_data[local_i]; + interp = interp, retcode = ReturnCode.Default, stats = nothing + ) + end + + return sols +end + @inbounds function solve( - prob::TraceProblem, backend::Backend, - ensemblealg::BasicEnsembleAlgorithm = EnsembleSerial(); + prob::TraceProblem, backend::Backend, ::EnsembleSerial; dt::AbstractFloat, trajectories::Int = 1, savestepinterval::Int = 1, save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true, workgroup_size::Int = 256 @@ -168,7 +239,6 @@ end q2m, _, Efunc, Bfunc, _ = p T = eltype(u0) - # Adapt interpolation fields to GPU memory Efunc_gpu = adapt_field_to_gpu(Efunc, backend) Bfunc_gpu = adapt_field_to_gpu(Bfunc, backend) @@ -197,7 +267,6 @@ end xv_current = KA.zeros(backend, T, 6, n_particles) xv_next = KA.zeros(backend, T, 6, n_particles) - # Optimization for CPU backend: alias buffers to avoid allocations is_cpu_accessible = xv_current isa Array if is_cpu_accessible @@ -216,100 +285,115 @@ end copyto!(xv_current, xv_init) end - # Buffer for particle positions on CPU (used for saving data) if is_cpu_accessible xv_cpu_buffer = xv_current else xv_cpu_buffer = zeros(T, 6, n_particles) end - sols = Vector{ - typeof(build_solution(prob, :boris, [tspan[1]], [SVector{6, T}(u0)])), - }(undef, trajectories) + return _solve_serial( + prob, backend, 1:trajectories; + dt, savestepinterval, save_start, save_end, save_everystep, workgroup_size, + xv_current, xv_next, xv_cpu_buffer, is_cpu_accessible, + Efunc_gpu, Bfunc_gpu, Efunc, Bfunc, nout, nt + ) +end - saved_data = [Vector{SVector{6, T}}(undef, nout) for _ in 1:trajectories] - saved_times = [Vector{typeof(tspan[1] + dt)}(undef, nout) for _ in 1:trajectories] - iout_counters = zeros(Int, trajectories) +@inbounds function solve( + prob::TraceProblem, backend::Backend, ::EnsembleThreads; + dt::AbstractFloat, trajectories::Int = 1, savestepinterval::Int = 1, + save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true, + workgroup_size::Int = 256 + ) + (; tspan, p, u0) = prob + q2m, _, Efunc, Bfunc, _ = p + T = eltype(u0) + # Adapt interpolation fields to GPU memory + Efunc_gpu = adapt_field_to_gpu(Efunc, backend) + Bfunc_gpu = adapt_field_to_gpu(Bfunc, backend) + + ttotal = tspan[2] - tspan[1] + nt = round(Int, abs(ttotal / dt)) + + nout = 0 if save_start - if !is_cpu_accessible - copyto!(xv_cpu_buffer, xv_current) + nout += 1 + end + if save_everystep + steps = nt ÷ savestepinterval + last_is_step = (nt > 0) && (nt % savestepinterval == 0) + nout += steps + if !save_end && last_is_step + nout -= 1 end - # If is_cpu_accessible, xv_cpu_buffer aliases xv_current, so it's already up to date - for i in 1:n_particles - iout_counters[i] += 1 - saved_data[i][iout_counters[i]] = SVector{6, T}(xv_cpu_buffer[:, i]) - saved_times[i][iout_counters[i]] = tspan[1] + if save_end && !last_is_step + nout += 1 end + elseif save_end + nout += 1 end - # Initial backward half-step - boris_velocity_step!( - backend, xv_current, xv_current, q2m, -0.5 * dt, - Efunc_gpu, Bfunc_gpu, tspan[1], n_particles, workgroup_size, ensemblealg - ) - - for it in 1:nt - t = tspan[1] + (it - 0.5) * dt - - boris_step!( - backend, xv_current, xv_next, q2m, dt, - Efunc_gpu, Bfunc_gpu, t, n_particles, workgroup_size, ensemblealg - ) - - xv_current, xv_next = xv_next, xv_current + n_particles = trajectories + xv_current = KA.zeros(backend, T, 6, n_particles) + xv_next = KA.zeros(backend, T, 6, n_particles) - if save_everystep && it % savestepinterval == 0 - if !is_cpu_accessible - copyto!(xv_cpu_buffer, xv_current) - end + # Optimization for CPU backend: alias buffers to avoid allocations + is_cpu_accessible = xv_current isa Array - t_current = tspan[1] + it * dt - qdt_2m_half = q2m * 0.5 * (0.5 * dt) + if is_cpu_accessible + xv_init = xv_current + else + xv_init = zeros(T, 6, n_particles) + end - for i in 1:n_particles - if iout_counters[i] < nout - iout_counters[i] += 1 - saved_data[i][iout_counters[i]] = _leapfrog_to_output( - @view(xv_cpu_buffer[:, i]), Efunc, Bfunc, t_current, qdt_2m_half - ) - saved_times[i][iout_counters[i]] = t_current - end - end - end + for i in 1:n_particles + new_prob = prob.prob_func(prob, i, false) + u0_i = new_prob.u0 + xv_init[:, i] .= u0_i end - if save_end - if !is_cpu_accessible - copyto!(xv_cpu_buffer, xv_current) - end - t_current = tspan[2] - qdt_2m_half = q2m * 0.5 * (0.5 * dt) + if !is_cpu_accessible + copyto!(xv_current, xv_init) + end - for i in 1:n_particles - if iout_counters[i] < nout - iout_counters[i] += 1 - saved_data[i][iout_counters[i]] = _leapfrog_to_output( - @view(xv_cpu_buffer[:, i]), Efunc, Bfunc, t_current, qdt_2m_half - ) - saved_times[i][iout_counters[i]] = t_current - end - end + # Buffer for particle positions on CPU (used for saving data) + if is_cpu_accessible + xv_cpu_buffer = xv_current + else + xv_cpu_buffer = zeros(T, 6, n_particles) end - for i in 1:trajectories - actual_len = iout_counters[i] - if actual_len < nout - resize!(saved_data[i], actual_len) - resize!(saved_times[i], actual_len) - end + sols = Vector{ + typeof(build_solution(prob, :boris, [tspan[1]], [SVector{6, T}(u0)])), + }(undef, trajectories) - interp = LinearInterpolation(saved_times[i], saved_data[i]) - sols[i] = build_solution( - prob, :boris, saved_times[i], saved_data[i]; - interp = interp, retcode = ReturnCode.Default, stats = nothing + nchunks = Threads.nthreads() + Threads.@threads for irange in index_chunks(1:trajectories; n = nchunks) + chunk_sols = _solve_serial( + prob, backend, irange; + dt, savestepinterval, save_start, save_end, save_everystep, workgroup_size, + xv_current, xv_next, xv_cpu_buffer, is_cpu_accessible, + Efunc_gpu, Bfunc_gpu, Efunc, Bfunc, nout, nt ) + for (local_i, i) in enumerate(irange) + sols[i] = chunk_sols[local_i] + end end return sols end + +@inbounds function solve( + prob::TraceProblem, backend::Backend, + ensemblealg::BasicEnsembleAlgorithm = EnsembleSerial(); + dt::AbstractFloat, trajectories::Int = 1, savestepinterval::Int = 1, + save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true, + workgroup_size::Int = 256 + ) + return solve( + prob, backend, ensemblealg; + dt, trajectories, savestepinterval, save_start, save_end, save_everystep, + workgroup_size + ) +end From 56d49023519afa2325b43c395783d04a7c16f768 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Feb 2026 00:19:59 -0500 Subject: [PATCH 19/25] Fix the race condition bug --- src/boris_kernel.jl | 47 ++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/boris_kernel.jl b/src/boris_kernel.jl index 102903c78..14064c5c0 100644 --- a/src/boris_kernel.jl +++ b/src/boris_kernel.jl @@ -48,10 +48,10 @@ end end @kernel function boris_velocity_kernel!( - xv_out, @Const(xv_in), @Const(q2m), @Const(dt), - Efunc, Bfunc, @Const(t) + xv_out, @Const(xv_in), @Const(q2m), @Const(dt), Efunc, Bfunc, + @Const(t), @Const(offset) ) - i = @index(Global) + i = @index(Global) + offset v_new = get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) # Scalar write for GPU compatibility xv_out[4, i] = v_new[1] @@ -60,48 +60,48 @@ end end @kernel function boris_update_kernel!( - @Const(xv_in), xv_out, @Const(q2m), @Const(dt), - Efunc, Bfunc, @Const(t) + @Const(xv_in), xv_out, @Const(q2m), @Const(dt), Efunc, Bfunc, + @Const(t), @Const(offset) ) - i = @index(Global) + i = @index(Global) + offset boris_update_xv!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) end @inline function boris_step!( - backend::Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, - n_particles, workgroup_size, ensemblealg::BasicEnsembleAlgorithm + backend::Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, irange, workgroup_size ) + offset = irange.start - 1 + n_particles = length(irange) kernel! = boris_update_kernel!(backend, workgroup_size) - kernel!(xv_in, xv_out, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) + kernel!(xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, offset; ndrange = n_particles) synchronize(backend) return end @inline function boris_step!( - ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, n_particles, - workgroup_size, ::EnsembleSerial + ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, irange, workgroup_size ) - @inbounds for i in 1:n_particles + @inbounds for i in irange boris_update_xv!(i, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t) end return end @inline function boris_velocity_step!( - backend::Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, - t, n_particles, workgroup_size, ensemblealg::BasicEnsembleAlgorithm + backend::Backend, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, irange, workgroup_size ) + offset = irange.start - 1 + n_particles = length(irange) kernel! = boris_velocity_kernel!(backend, workgroup_size) - kernel!(xv_out, xv_in, q2m, dt, Efunc, Bfunc, t; ndrange = n_particles) + kernel!(xv_out, xv_in, q2m, dt, Efunc, Bfunc, t, offset; ndrange = n_particles) synchronize(backend) return end @inline function boris_velocity_step!( - ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, - n_particles, workgroup_size, ::EnsembleSerial + ::CPU, xv_in, xv_out, q2m, dt, Efunc, Bfunc, t, irange, workgroup_size ) - @inbounds for i in 1:n_particles + @inbounds for i in irange v_new = get_boris_velocity(i, xv_in, q2m, dt, Efunc, Bfunc, t) xv_out[4, i] = v_new[1] xv_out[5, i] = v_new[2] @@ -110,7 +110,7 @@ end return end -function _leapfrog_to_output(xv, Efunc, Bfunc, t, qdt_2m_half) +@inline function _leapfrog_to_output(xv, Efunc, Bfunc, t, qdt_2m_half) T = eltype(xv) # Extract position and velocity (v^{n-1/2}) r_vec = SVector{3, T}(xv[1], xv[2], xv[3]) @@ -154,14 +154,17 @@ end for (local_i, i) in enumerate(irange) iout_counters[local_i] += 1 saved_data[local_i][iout_counters[local_i]] = - SVector{6, T}(xv_cpu_buffer[:, i]) + SVector{6, T}( + xv_cpu_buffer[1, i], xv_cpu_buffer[2, i], xv_cpu_buffer[3, i], + xv_cpu_buffer[4, i], xv_cpu_buffer[5, i], xv_cpu_buffer[6, i] + ) saved_times[local_i][iout_counters[local_i]] = tspan[1] end end boris_velocity_step!( backend, xv_current, xv_current, q2m, -0.5 * dt, - Efunc_gpu, Bfunc_gpu, tspan[1], n_particles, workgroup_size, EnsembleSerial() + Efunc_gpu, Bfunc_gpu, tspan[1], irange, workgroup_size ) for it in 1:nt @@ -169,7 +172,7 @@ end boris_step!( backend, xv_current, xv_next, q2m, dt, - Efunc_gpu, Bfunc_gpu, t, n_particles, workgroup_size, EnsembleSerial() + Efunc_gpu, Bfunc_gpu, t, irange, workgroup_size ) xv_current, xv_next = xv_next, xv_current From a9933670277340ea78b1ac25d8d182ee04bb30ed Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Feb 2026 13:59:16 -0500 Subject: [PATCH 20/25] refactor: inline Field call --- src/prepare.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/prepare.jl b/src/prepare.jl index d36c17cee..69826b44f 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -44,12 +44,12 @@ is_time_dependent(::AbstractField{itd}) where {itd} = itd #TODO: Have a proper treatment for the time dependency with LazyTimeInterpolator. is_time_dependent(::AbstractFieldInterpolator) = false # Always treat as static by default -(f::AbstractField{true})(xu, t) = f.field_function(xu, t) -function (f::AbstractField{true})(xu) +@inline (f::Field{true})(xu, t) = f.field_function(xu, t) +@inline function (f::Field{true})(xu) throw(ArgumentError("Time-dependent field function must have a time argument.")) end -(f::AbstractField{false})(xu, t) = SVector{3}(f.field_function(xu)) -(f::AbstractField{false})(xu) = SVector{3}(f.field_function(xu)) +@inline (f::Field{false})(xu, t) = SVector{3, eltype(xu)}(f.field_function(xu)) +@inline (f::Field{false})(xu) = SVector{3, eltype(xu)}(f.field_function(xu)) function Base.show(io::IO, f::Field) println(io, "Field with interpolation support") From bf6647b6a877e91ed2b9cfab917b2b0dbeb8518b Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Feb 2026 14:00:22 -0500 Subject: [PATCH 21/25] refactor: make isoutofdomain and velocity_updater type stable --- src/boris.jl | 114 ++++++++++++++++++++++++++------------------------- 1 file changed, 58 insertions(+), 56 deletions(-) diff --git a/src/boris.jl b/src/boris.jl index 51205f739..95594a7b9 100644 --- a/src/boris.jl +++ b/src/boris.jl @@ -100,7 +100,7 @@ end Update velocity using the Boris method, returning the new velocity as an SVector. """ -@muladd function update_velocity(v, r, param, dt, t) +@inline @muladd function update_velocity(v, r, param::P, dt, t) where {P} q2m, _, Efunc, Bfunc, _ = param E = Efunc(r, t) B = Bfunc(r, t) @@ -195,12 +195,12 @@ Trace particles using the Boris method with specified `prob`. - `save_work::Bool`: save the work done by the electric field. Default is `false`. """ @inline function solve( - prob::TraceProblem, ensemblealg::BasicEnsembleAlgorithm = EnsembleSerial(); + prob::TraceProblem, ensemblealg::EA = EnsembleSerial(); trajectories::Int = 1, savestepinterval::Int = 1, dt::AbstractFloat, - isoutofdomain::Function = ODE_DEFAULT_ISOUTOFDOMAIN, n::Int = 1, + isoutofdomain::F = ODE_DEFAULT_ISOUTOFDOMAIN, n::Int = 1, save_start::Bool = true, save_end::Bool = true, save_everystep::Bool = true, save_fields::Bool = false, save_work::Bool = false - ) + ) where {EA <: BasicEnsembleAlgorithm, F} return _solve( ensemblealg, prob, trajectories, dt, savestepinterval, isoutofdomain, n, save_start, save_end, save_everystep, Val(save_fields), Val(save_work) @@ -208,9 +208,9 @@ Trace particles using the Boris method with specified `prob`. end function _dispatch_boris!( - sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, n, + sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F, n, save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork} - ) where {SaveFields, SaveWork} + ) where {SaveFields, SaveWork, F} return if n == 1 _boris!( sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, @@ -224,10 +224,10 @@ function _dispatch_boris!( end end -function _solve( - ::EnsembleSerial, prob, trajectories, dt, savestepinterval, isoutofdomain, n, +@inline function _solve( + ::EnsembleSerial, prob::TraceProblem, trajectories, dt, savestepinterval, isoutofdomain::F, n, save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork} - ) where {SaveFields, SaveWork} + ) where {SaveFields, SaveWork, F} sols, nt, nout = _prepare( prob, trajectories, dt, savestepinterval, @@ -242,10 +242,10 @@ function _solve( return sols end -function _solve( - ::EnsembleThreads, prob, trajectories, dt, savestepinterval, isoutofdomain, n, +@inline function _solve( + ::EnsembleThreads, prob::TraceProblem, trajectories, dt, savestepinterval, isoutofdomain::F, n, save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork} - ) where {SaveFields, SaveWork} + ) where {SaveFields, SaveWork, F} sols, nt, nout = _prepare( prob, trajectories, dt, savestepinterval, @@ -362,13 +362,41 @@ end """ Apply Boris method for particles with index in `irange`. """ -@muladd function _generic_boris!( - sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, +@inline function _boris_loop!( + traj, tsave, iout, r, v, p, dt, nt, tspan, + savestepinterval, save_everystep, isoutofdomain::F1, velocity_updater::F2, + ::Val{SaveFields}, ::Val{SaveWork} + ) where {F1, F2, SaveFields, SaveWork} + it = 1 + while it <= nt + v_prev = v + t = (it - 0.5) * dt + v = velocity_updater(v, r, p, dt, t) + + if save_everystep && (it - 1) > 0 && (it - 1) % savestepinterval == 0 + iout += 1 + if iout <= length(traj) + t_current = tspan[1] + (it - 1) * dt + v_save = velocity_updater(v_prev, r, p, 0.5 * dt, t_current) + data = vcat(r, v_save) + traj[iout] = _prepare_saved_data(data, p, t_current, Val(SaveFields), Val(SaveWork)) + tsave[iout] = t_current + end + end + + r += v * dt + isoutofdomain(vcat(r, v), p, it * dt) && break + it += 1 + end + return it, iout, r, v +end + +@inline @muladd function _generic_boris!( + sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F1, save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork}, - velocity_updater, alg_name - ) where {SaveFields, SaveWork} + velocity_updater::F2, alg_name + ) where {SaveFields, SaveWork, F1, F2} (; tspan, p, u0) = prob - q2m, m, Efunc, Bfunc, _ = p T = eltype(u0) vars_dim = 6 @@ -379,57 +407,31 @@ Apply Boris method for particles with index in `irange`. vars_dim += 4 end - @fastmath @inbounds for i in irange + @inbounds for i in irange traj = Vector{SVector{vars_dim, T}}(undef, nout) tsave = Vector{typeof(tspan[1] + dt)}(undef, nout) # set initial conditions for each trajectory i iout = 0 new_prob = prob.prob_func(prob, i, false) - # Load independent r and v SVector from u0 u0_i = SVector{6, T}(new_prob.u0) r = u0_i[SVector(1, 2, 3)] v = u0_i[SVector(4, 5, 6)] if save_start iout += 1 - traj[iout] = _prepare_saved_data( - u0_i, p, tspan[1], Val(SaveFields), Val(SaveWork) - ) + traj[iout] = _prepare_saved_data(u0_i, p, tspan[1], Val(SaveFields), Val(SaveWork)) tsave[iout] = tspan[1] end # push velocity back in time by 1/2 dt v = velocity_updater(v, r, p, -0.5 * dt, tspan[1]) - it = 1 - while it <= nt - v_prev = v - t = (it - 0.5) * dt - v = velocity_updater(v, r, p, dt, t) - - if save_everystep && (it - 1) > 0 && (it - 1) % savestepinterval == 0 - iout += 1 - if iout <= nout - t_current = tspan[1] + (it - 1) * dt - # Approximate v_n from v_{n-1/2} (v_prev) - v_save = velocity_updater( - v_prev, r, p, 0.5 * dt, - t_current - ) - - data = vcat(r, v_save) - traj[iout] = _prepare_saved_data( - data, p, t_current, Val(SaveFields), Val(SaveWork) - ) - tsave[iout] = t_current - end - end - - r += v * dt - isoutofdomain(vcat(r, v), p, it * dt) && break - it += 1 - end + it, iout, r, v = _boris_loop!( + traj, tsave, iout, r, v, p, dt, nt, tspan, + savestepinterval, save_everystep, isoutofdomain, velocity_updater, + Val(SaveFields), Val(SaveWork) + ) final_step = min(it, nt) should_save_final = false @@ -475,10 +477,10 @@ end """ Apply Boris method for particles with index in `irange`. """ -@muladd function _boris!( - sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, +@inline @muladd function _boris!( + sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F, save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork} - ) where {SaveFields, SaveWork} + ) where {SaveFields, SaveWork, F} _generic_boris!( sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, @@ -552,10 +554,10 @@ Reference: [Zenitani & Kato 2025](https://arxiv.org/abs/2505.02270) return v_new end -@muladd function _multistep_boris!( - sols, prob, irange, savestepinterval, dt, nt, nout, isoutofdomain, n_steps::Int, +@inline @muladd function _multistep_boris!( + sols, prob::TraceProblem, irange, savestepinterval, dt, nt, nout, isoutofdomain::F, n_steps::Int, save_start, save_end, save_everystep, ::Val{SaveFields}, ::Val{SaveWork} - ) where {SaveFields, SaveWork} + ) where {SaveFields, SaveWork, F} velocity_updater = (v, r, p, dt, t) -> update_velocity_multistep(v, r, p, dt, t, n_steps) From 17496b2231b4f30f4ac340fa1a651ef1f408d19c Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Feb 2026 14:38:55 -0500 Subject: [PATCH 22/25] fix: less restrictive SVector type for Unitful integration --- src/prepare.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/prepare.jl b/src/prepare.jl index 69826b44f..53750dc08 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -48,8 +48,8 @@ is_time_dependent(::AbstractFieldInterpolator) = false # Always treat as static @inline function (f::Field{true})(xu) throw(ArgumentError("Time-dependent field function must have a time argument.")) end -@inline (f::Field{false})(xu, t) = SVector{3, eltype(xu)}(f.field_function(xu)) -@inline (f::Field{false})(xu) = SVector{3, eltype(xu)}(f.field_function(xu)) +@inline (f::Field{false})(xu, t) = SVector{3}(f.field_function(xu)) +@inline (f::Field{false})(xu) = SVector{3}(f.field_function(xu)) function Base.show(io::IO, f::Field) println(io, "Field with interpolation support") From 21c3b045e86a76c2b1532c9aaeb19a74cbc8d854 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Feb 2026 15:08:45 -0500 Subject: [PATCH 23/25] Remove redundant SVector wrapper --- src/prepare.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/prepare.jl b/src/prepare.jl index 53750dc08..302736389 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -48,8 +48,8 @@ is_time_dependent(::AbstractFieldInterpolator) = false # Always treat as static @inline function (f::Field{true})(xu) throw(ArgumentError("Time-dependent field function must have a time argument.")) end -@inline (f::Field{false})(xu, t) = SVector{3}(f.field_function(xu)) -@inline (f::Field{false})(xu) = SVector{3}(f.field_function(xu)) +@inline (f::Field{false})(xu, t) = f.field_function(xu) +@inline (f::Field{false})(xu) = f.field_function(xu) function Base.show(io::IO, f::Field) println(io, "Field with interpolation support") From 4c9dd2f1800ed34cb7baad46b3a3511c09d69fcb Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Feb 2026 16:58:55 -0500 Subject: [PATCH 24/25] test: add kernel boris threading test --- test/test_boris_kernel.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/test_boris_kernel.jl b/test/test_boris_kernel.jl index 6d7e963d2..44a7fea35 100644 --- a/test/test_boris_kernel.jl +++ b/test/test_boris_kernel.jl @@ -56,6 +56,28 @@ const KA = KernelAbstractions end end + @testset "EnsembleThreads" begin + backend = CPU() + + prob_func_gpu(prob, i, repeat) = remake( + prob; u0 = [prob.u0[1:3]..., i * 1.0e4, 0.0, 0.0] + ) + prob_multi = TraceProblem(stateinit, tspan, param; prob_func = prob_func_gpu) + + trajectories = 10 + sols_serial = TP.solve( + prob_multi, backend, EnsembleSerial(); + dt, trajectories, savestepinterval = 100 + ) + sols_threads = TP.solve( + prob_multi, backend, EnsembleThreads(); + dt, trajectories, savestepinterval = 100 + ) + + @test length(sols_threads) == trajectories + @test sum(sol.u[end] for sol in sols_threads) ≈ sum(sol.u[end] for sol in sols_serial) + end + @testset "Kernel vs Native Solver Equivalence" begin backend = CPU() From e01136a7f03dfb8a420197a0190a3b5ffce96eb4 Mon Sep 17 00:00:00 2001 From: Hongyang Zhou Date: Fri, 6 Feb 2026 18:22:54 -0500 Subject: [PATCH 25/25] Remove unnecessary SVector wrappers. --- src/equations.jl | 10 +++++----- src/gc_solver.jl | 15 +++++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/equations.jl b/src/equations.jl index f25c229e3..1fbe78d40 100644 --- a/src/equations.jl +++ b/src/equations.jl @@ -9,7 +9,7 @@ function get_dv(x, v, p, t) B = Bfunc(x, t) F = Ffunc(x, t) - return SVector{3}(q2m * (v × B + E) + F / m) + return q2m * (v × B + E) + F / m end function get_relativistic_v(γv; c = c) @@ -144,7 +144,7 @@ function trace_normalized!(dy, y, p, t) B = get_BField(p)(y, t) @inbounds dy[1:3] = v - @inbounds dy[4:6] = SVector{3}(v × B + E) + @inbounds dy[4:6] = v × B + E return end @@ -158,7 +158,7 @@ function trace_normalized(y, p, t) v = get_v(y) E = get_EField(p)(y, t) B = get_BField(p)(y, t) - dv = SVector{3}(v × B + E) + dv = v × B + E return vcat(v, dv) end @@ -175,7 +175,7 @@ function trace_relativistic_normalized!(dy, y, p, t) v = get_relativistic_v(γv; c = 1) @inbounds dy[1:3] = v - @inbounds dy[4:6] = SVector{3}(v × B + E) + @inbounds dy[4:6] = v × B + E return end @@ -191,7 +191,7 @@ function trace_relativistic_normalized(y, p, t) γv = get_v(y) v = get_relativistic_v(γv; c = 1) - dv = SVector{3}(v × B + E) + dv = v × B + E return vcat(v, dv) end diff --git a/src/gc_solver.jl b/src/gc_solver.jl index 0e061788c..a8ea23aca 100644 --- a/src/gc_solver.jl +++ b/src/gc_solver.jl @@ -303,8 +303,8 @@ end # p = (q, q2m, μ, Efunc, Bfunc) r = get_x(xv) T = eltype(xv) - E = SVector{3, T}(p[4](r, t)) - B = SVector{3, T}(p[5](r, t)) + E = p[4](r, t) + B = p[5](r, t) data = vcat(data, E, B) end if SaveWork @@ -342,11 +342,14 @@ function _rk4!( # set initial conditions for each trajectory i iout = 0 new_prob = prob.prob_func(prob, i, false) - xv = SVector{4, T}(new_prob.u0) + xv = new_prob.u0 if save_start iout += 1 - push!(traj, _prepare_saved_data_gc(xv, p, tspan[1], Val(SaveFields), Val(SaveWork))) + push!( + traj, + _prepare_saved_data_gc(xv, p, tspan[1], Val(SaveFields), Val(SaveWork)) + ) push!(tsave, tspan[1]) end @@ -436,7 +439,7 @@ function _rk45!( tsave = typeof(tspan[1] + one(T))[] new_prob = prob.prob_func(prob, i, false) - xv = SVector{4, T}(new_prob.u0) + xv = new_prob.u0 t = tspan[1] @@ -445,7 +448,7 @@ function _rk45!( # p = (q, q2m, μ, E, B) q2m = p[2] B_field = p[5] - R = SVector{3}(xv[1], xv[2], xv[3]) + R = get_x(xv) B_vec = B_field(R, t) Bmag = norm(B_vec)