Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/examples/features/demo_hybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ alg = AdaptiveHybrid(;
dtmax = T_gyro,
dtmin = 1.0e-4 * T_gyro,
maxiters = 500_000,
check_interval = 100,
)

Random.seed!(1234)
Expand Down
83 changes: 45 additions & 38 deletions src/hybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@ struct AdaptiveHybrid{T}
abstol::T
reltol::T
maxiters::Int
check_interval::Int
end

function AdaptiveHybrid(;
threshold = 0.1, dtmax, dtmin = 1.0e-2 * dtmax, safety_fo = 0.1,
abstol = 1.0e-6, reltol = 1.0e-6, maxiters = 10000
abstol = 1.0e-6, reltol = 1.0e-6, maxiters = 10000, check_interval = 10
)
check_interval > 0 || throw(ArgumentError("check_interval must be positive."))
T = promote_type(
typeof(threshold), typeof(dtmin), typeof(dtmax), typeof(safety_fo), typeof(abstol),
typeof(reltol)
)
return AdaptiveHybrid{T}(
T(threshold), T(dtmin), T(dtmax), T(safety_fo), T(abstol), T(reltol), maxiters
T(threshold), T(dtmin), T(dtmax), T(safety_fo), T(abstol), T(reltol), maxiters,
check_interval
)
end

Expand Down Expand Up @@ -254,24 +257,26 @@ end
while t < tspan[2] && steps < alg.maxiters
if mode == :GC
# Adiabaticity check
ϵ = get_adiabaticity(xv_gc[SVector(1, 2, 3)], Bfunc, q, m, μ, t)
if ϵ >= alg.threshold
# Switch to FO (GC -> FO)
mode = :FO
verbose && @info "Switch GC → FO" ϵ t r = xv_gc[SVector(1, 2, 3)]
xv_fo_vec = _gc_to_full_at_t(
xv_gc, Efunc, Bfunc, q, m, μ, t, 2π * rand()
)
xv_fo = xv_fo_vec
r = xv_fo[SVector(1, 2, 3)]
v = xv_fo[SVector(4, 5, 6)]

B_mag = norm(Bfunc(r, t))
omega = abs(q2m * B_mag)
dt = alg.safety_fo / omega
dt = clamp(dt, alg.dtmin, alg.dtmax)
v = update_velocity(v, r, p, -0.5 * dt, t)
continue
if it % alg.check_interval == 0
ϵ = get_adiabaticity(xv_gc[SVector(1, 2, 3)], Bfunc, q, m, μ, t)
if ϵ >= alg.threshold
# Switch to FO (GC -> FO)
mode = :FO
verbose && @info "Switch GC → FO" ϵ t r = xv_gc[SVector(1, 2, 3)]
xv_fo_vec = _gc_to_full_at_t(
xv_gc, Efunc, Bfunc, q, m, μ, t, 2π * rand()
)
xv_fo = xv_fo_vec
r = xv_fo[SVector(1, 2, 3)]
v = xv_fo[SVector(4, 5, 6)]

B_mag = norm(Bfunc(r, t))
omega = abs(q2m * B_mag)
dt = alg.safety_fo / omega
dt = clamp(dt, alg.dtmin, alg.dtmax)
v = update_velocity(v, r, p, -0.5 * dt, t)
continue
end
end

if t + dt > tspan[2]
Expand Down Expand Up @@ -312,24 +317,26 @@ end
t_sync = is_td ? t : zero(T)
v_check = update_velocity(v, r, p, 0.5 * dt, t_sync)

xv_check = SVector{6, T}(
r[1], r[2], r[3], v_check[1], v_check[2], v_check[3]
)
X_gc, vpar, μ = _get_gc_parameters_at_t(
xv_check, Efunc, Bfunc, q, m, t
)
ϵ = get_adiabaticity(X_gc, Bfunc, q, m, μ, t)
if ϵ < alg.threshold
# Switch to GC
mode = :GC
verbose && @info "Switch FO → GC" ϵ t r = X_gc
xv_gc = SVector{4, T}(X_gc[1], X_gc[2], X_gc[3], vpar)
p_gc = (q, q2m, μ, Efunc, Bfunc)

Bmag = norm(Bfunc(X_gc, t))
omega = abs(q2m * Bmag)
dt = 0.5 * 2π / omega
continue
if it % alg.check_interval == 0
xv_check = SVector{6, T}(
r[1], r[2], r[3], v_check[1], v_check[2], v_check[3]
)
X_gc, vpar, μ = _get_gc_parameters_at_t(
xv_check, Efunc, Bfunc, q, m, t
)
ϵ = get_adiabaticity(X_gc, Bfunc, q, m, μ, t)
if ϵ < alg.threshold
# Switch to GC
mode = :GC
verbose && @info "Switch FO → GC" ϵ t r = X_gc
xv_gc = SVector{4, T}(X_gc[1], X_gc[2], X_gc[3], vpar)
p_gc = (q, q2m, μ, Efunc, Bfunc)

Bmag = norm(Bfunc(X_gc, t))
omega = abs(q2m * Bmag)
dt = 0.5 * 2π / omega
continue
end
end

if t + dt > tspan[2]
Expand Down
64 changes: 20 additions & 44 deletions test/test_hybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,12 @@ using Test
@test sols[1].retcode == TestParticle.ReturnCode.Success
end

# 3. Dynamic GC ↔ FO Switching (Magnetic Bottle)
#
# B_z(z) = B0*(1 + α*z²), with B_r from ∇·B = 0:
# B_x ≈ -B0*α*x*z, B_y ≈ -B0*α*y*z
#
# Near the midplane (z ≈ 0) the field is weak and
# curvature is high → FO. Away from midplane the
# field strengthens and becomes more uniform → GC.
# A bouncing particle triggers repeated GC → FO → GC
# transitions.
#
# Also tests EnsembleThreads against EnsembleSerial.
# 3. Demo Hybrid Case (Explicit Switching)
# Replicates the setup from docs/examples/features/demo_hybrid.jl
# Stronger curvature (α = 1.0e-2) forces GC <-> FO switching.
let
B0 = 1.0e-4 # [T] background field
α = 1.0e-4 # [m⁻²] mirror ratio parameter
B0 = 1.0e-4 # [T]
α = 1.0e-2 # [m⁻²] Stronger curvature

function bottle_B(x, t)
Bz = B0 * (1 + α * x[3]^2)
Expand All @@ -77,49 +68,34 @@ using Test
B_bottle = TestParticle.Field(bottle_B)

x0 = SA[0.0, 0.0, 0.0]
v_perp = 5.0e4 # [m/s]
v_par = 2.0e5 # [m/s]
v_perp = 5.0e4 # [m/s]
v_par = 1.0e5 # [m/s]
v0 = SA[v_perp, 0.0, v_par]
u0 = vcat(x0, v0)

Ω = abs(q2m) * B0
T_gyro = 2π / Ω
tspan = (0.0, 50 * T_gyro)
tspan = (0.0, 30 * T_gyro)

p = (q2m, m, E_field, B_bottle, TestParticle.ZeroField())

# Consistent with demo_hybrid.jl
alg = AdaptiveHybrid(;
threshold = 0.05,
threshold = 0.1,
dtmax = T_gyro,
dtmin = 1.0e-4 * T_gyro,
check_interval = 100,
)

ntraj = 4
KE_init = 0.5 * m * sum(abs2, v0)

# Serial
sols_serial = TestParticle.solve(
TraceHybridProblem(u0, tspan, p), alg, EnsembleSerial();
trajectories = ntraj,
)

@test all(s -> s.retcode == TestParticle.ReturnCode.Success, sols_serial)
@test all(s -> length(s.t) > 10, sols_serial)
@test all(sols_serial) do s
isapprox(0.5 * m * sum(abs2, s.u[end][SA[4, 5, 6]]), KE_init; rtol = 0.1)
end
@test all(s -> maximum(u -> abs(u[3]), s.u) < 1.0e6, sols_serial)
sols = TestParticle.solve(TraceHybridProblem(u0, tspan, p), alg)

# Threaded — should match serial
sols_threads = TestParticle.solve(
TraceHybridProblem(u0, tspan, p), alg, EnsembleThreads();
trajectories = ntraj,
)
@test sols[1].retcode == TestParticle.ReturnCode.Success
# Check that we actually have a reasonable number of steps (hybrid should adapt)
@test length(sols[1].t) > 100

@test all(s -> s.retcode == TestParticle.ReturnCode.Success, sols_threads)
@test all(i -> length(sols_threads[i].t) == length(sols_serial[i].t), 1:ntraj)
@test all(sols_threads) do s
isapprox(0.5 * m * sum(abs2, s.u[end][SA[4, 5, 6]]), KE_init; rtol = 0.1)
end
@test all(s -> maximum(u -> abs(u[3]), s.u) < 1.0e6, sols_threads)
# Energy conservation check (approximate for hybrid)
KE_init = 0.5 * m * sum(abs2, v0)
KE_final = 0.5 * m * sum(abs2, sols[1].u[end][SA[4, 5, 6]])
@test isapprox(KE_final, KE_init; rtol = 0.1)
end
end
Loading