diff --git a/docs/examples/features/demo_hybrid.jl b/docs/examples/features/demo_hybrid.jl index 9b5be385f..c0f58ed39 100644 --- a/docs/examples/features/demo_hybrid.jl +++ b/docs/examples/features/demo_hybrid.jl @@ -97,6 +97,7 @@ alg = AdaptiveHybrid(; dtmax = T_gyro, dtmin = 1.0e-4 * T_gyro, maxiters = 500_000, + check_interval = 100, ) Random.seed!(1234) diff --git a/src/hybrid.jl b/src/hybrid.jl index 90a497dc4..f6e4bc7d5 100644 --- a/src/hybrid.jl +++ b/src/hybrid.jl @@ -15,18 +15,21 @@ struct AdaptiveHybrid{T} abstol::T reltol::T maxiters::Int + check_interval::Int end function AdaptiveHybrid(; threshold = 0.1, dtmax, dtmin = 1.0e-2 * dtmax, safety_fo = 0.1, - abstol = 1.0e-6, reltol = 1.0e-6, maxiters = 10000 + abstol = 1.0e-6, reltol = 1.0e-6, maxiters = 10000, check_interval = 10 ) + check_interval > 0 || throw(ArgumentError("check_interval must be positive.")) T = promote_type( typeof(threshold), typeof(dtmin), typeof(dtmax), typeof(safety_fo), typeof(abstol), typeof(reltol) ) return AdaptiveHybrid{T}( - T(threshold), T(dtmin), T(dtmax), T(safety_fo), T(abstol), T(reltol), maxiters + T(threshold), T(dtmin), T(dtmax), T(safety_fo), T(abstol), T(reltol), maxiters, + check_interval ) end @@ -254,24 +257,26 @@ end while t < tspan[2] && steps < alg.maxiters if mode == :GC # Adiabaticity check - ϵ = get_adiabaticity(xv_gc[SVector(1, 2, 3)], Bfunc, q, m, μ, t) - if ϵ >= alg.threshold - # Switch to FO (GC -> FO) - mode = :FO - verbose && @info "Switch GC → FO" ϵ t r = xv_gc[SVector(1, 2, 3)] - xv_fo_vec = _gc_to_full_at_t( - xv_gc, Efunc, Bfunc, q, m, μ, t, 2π * rand() - ) - xv_fo = xv_fo_vec - r = xv_fo[SVector(1, 2, 3)] - v = xv_fo[SVector(4, 5, 6)] - - B_mag = norm(Bfunc(r, t)) - omega = abs(q2m * B_mag) - dt = alg.safety_fo / omega - dt = clamp(dt, alg.dtmin, alg.dtmax) - v = update_velocity(v, r, p, -0.5 * dt, t) - continue + if it % alg.check_interval == 0 + ϵ = get_adiabaticity(xv_gc[SVector(1, 2, 3)], Bfunc, q, m, μ, t) + if ϵ >= alg.threshold + # Switch to FO (GC -> FO) + mode = :FO + verbose && @info "Switch GC → FO" ϵ t r = xv_gc[SVector(1, 2, 3)] + xv_fo_vec = _gc_to_full_at_t( + xv_gc, Efunc, Bfunc, q, m, μ, t, 2π * rand() + ) + xv_fo = xv_fo_vec + r = xv_fo[SVector(1, 2, 3)] + v = xv_fo[SVector(4, 5, 6)] + + B_mag = norm(Bfunc(r, t)) + omega = abs(q2m * B_mag) + dt = alg.safety_fo / omega + dt = clamp(dt, alg.dtmin, alg.dtmax) + v = update_velocity(v, r, p, -0.5 * dt, t) + continue + end end if t + dt > tspan[2] @@ -312,24 +317,26 @@ end t_sync = is_td ? t : zero(T) v_check = update_velocity(v, r, p, 0.5 * dt, t_sync) - xv_check = SVector{6, T}( - r[1], r[2], r[3], v_check[1], v_check[2], v_check[3] - ) - X_gc, vpar, μ = _get_gc_parameters_at_t( - xv_check, Efunc, Bfunc, q, m, t - ) - ϵ = get_adiabaticity(X_gc, Bfunc, q, m, μ, t) - if ϵ < alg.threshold - # Switch to GC - mode = :GC - verbose && @info "Switch FO → GC" ϵ t r = X_gc - xv_gc = SVector{4, T}(X_gc[1], X_gc[2], X_gc[3], vpar) - p_gc = (q, q2m, μ, Efunc, Bfunc) - - Bmag = norm(Bfunc(X_gc, t)) - omega = abs(q2m * Bmag) - dt = 0.5 * 2π / omega - continue + if it % alg.check_interval == 0 + xv_check = SVector{6, T}( + r[1], r[2], r[3], v_check[1], v_check[2], v_check[3] + ) + X_gc, vpar, μ = _get_gc_parameters_at_t( + xv_check, Efunc, Bfunc, q, m, t + ) + ϵ = get_adiabaticity(X_gc, Bfunc, q, m, μ, t) + if ϵ < alg.threshold + # Switch to GC + mode = :GC + verbose && @info "Switch FO → GC" ϵ t r = X_gc + xv_gc = SVector{4, T}(X_gc[1], X_gc[2], X_gc[3], vpar) + p_gc = (q, q2m, μ, Efunc, Bfunc) + + Bmag = norm(Bfunc(X_gc, t)) + omega = abs(q2m * Bmag) + dt = 0.5 * 2π / omega + continue + end end if t + dt > tspan[2] diff --git a/test/test_hybrid.jl b/test/test_hybrid.jl index 4dbbdf513..7a7ac4dce 100644 --- a/test/test_hybrid.jl +++ b/test/test_hybrid.jl @@ -52,21 +52,12 @@ using Test @test sols[1].retcode == TestParticle.ReturnCode.Success end - # 3. Dynamic GC ↔ FO Switching (Magnetic Bottle) - # - # B_z(z) = B0*(1 + α*z²), with B_r from ∇·B = 0: - # B_x ≈ -B0*α*x*z, B_y ≈ -B0*α*y*z - # - # Near the midplane (z ≈ 0) the field is weak and - # curvature is high → FO. Away from midplane the - # field strengthens and becomes more uniform → GC. - # A bouncing particle triggers repeated GC → FO → GC - # transitions. - # - # Also tests EnsembleThreads against EnsembleSerial. + # 3. Demo Hybrid Case (Explicit Switching) + # Replicates the setup from docs/examples/features/demo_hybrid.jl + # Stronger curvature (α = 1.0e-2) forces GC <-> FO switching. let - B0 = 1.0e-4 # [T] background field - α = 1.0e-4 # [m⁻²] mirror ratio parameter + B0 = 1.0e-4 # [T] + α = 1.0e-2 # [m⁻²] Stronger curvature function bottle_B(x, t) Bz = B0 * (1 + α * x[3]^2) @@ -77,49 +68,34 @@ using Test B_bottle = TestParticle.Field(bottle_B) x0 = SA[0.0, 0.0, 0.0] - v_perp = 5.0e4 # [m/s] - v_par = 2.0e5 # [m/s] + v_perp = 5.0e4 # [m/s] + v_par = 1.0e5 # [m/s] v0 = SA[v_perp, 0.0, v_par] u0 = vcat(x0, v0) Ω = abs(q2m) * B0 T_gyro = 2π / Ω - tspan = (0.0, 50 * T_gyro) + tspan = (0.0, 30 * T_gyro) p = (q2m, m, E_field, B_bottle, TestParticle.ZeroField()) + + # Consistent with demo_hybrid.jl alg = AdaptiveHybrid(; - threshold = 0.05, + threshold = 0.1, dtmax = T_gyro, dtmin = 1.0e-4 * T_gyro, + check_interval = 100, ) - ntraj = 4 - KE_init = 0.5 * m * sum(abs2, v0) - - # Serial - sols_serial = TestParticle.solve( - TraceHybridProblem(u0, tspan, p), alg, EnsembleSerial(); - trajectories = ntraj, - ) - - @test all(s -> s.retcode == TestParticle.ReturnCode.Success, sols_serial) - @test all(s -> length(s.t) > 10, sols_serial) - @test all(sols_serial) do s - isapprox(0.5 * m * sum(abs2, s.u[end][SA[4, 5, 6]]), KE_init; rtol = 0.1) - end - @test all(s -> maximum(u -> abs(u[3]), s.u) < 1.0e6, sols_serial) + sols = TestParticle.solve(TraceHybridProblem(u0, tspan, p), alg) - # Threaded — should match serial - sols_threads = TestParticle.solve( - TraceHybridProblem(u0, tspan, p), alg, EnsembleThreads(); - trajectories = ntraj, - ) + @test sols[1].retcode == TestParticle.ReturnCode.Success + # Check that we actually have a reasonable number of steps (hybrid should adapt) + @test length(sols[1].t) > 100 - @test all(s -> s.retcode == TestParticle.ReturnCode.Success, sols_threads) - @test all(i -> length(sols_threads[i].t) == length(sols_serial[i].t), 1:ntraj) - @test all(sols_threads) do s - isapprox(0.5 * m * sum(abs2, s.u[end][SA[4, 5, 6]]), KE_init; rtol = 0.1) - end - @test all(s -> maximum(u -> abs(u[3]), s.u) < 1.0e6, sols_threads) + # Energy conservation check (approximate for hybrid) + KE_init = 0.5 * m * sum(abs2, v0) + KE_final = 0.5 * m * sum(abs2, sols[1].u[end][SA[4, 5, 6]]) + @test isapprox(KE_final, KE_init; rtol = 0.1) end end