Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ projects = ["test", "docs"]
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -32,6 +33,7 @@ TestParticleRecursiveArrayToolsExt = "RecursiveArrayTools"
[compat]
Adapt = "4.4"
ChunkSplitters = "3"
DiffResults = "1"
ForwardDiff = "0.10, 1"
Interpolations = "0.14, 0.15, 0.16"
KernelAbstractions = "0.9"
Expand Down
1 change: 1 addition & 0 deletions src/TestParticle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using SciMLBase: AbstractODEProblem, AbstractODEFunction, AbstractODESolution, R
using StaticArrays: SVector, @SMatrix, MVector, SA, StaticArray, SMatrix
using Meshes: coords, spacing, paramdim, CartesianGrid, RectilinearGrid, StructuredGrid
import ForwardDiff
import DiffResults
using ChunkSplitters: index_chunks
using PrecompileTools: @setup_workload, @compile_workload
using MuladdMacro: @muladd
Expand Down
14 changes: 9 additions & 5 deletions src/equations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,11 @@ function trace_relativistic_normalized(y, p, t)
end

@inline function get_B_parameters(x, t, Bfunc)
# Compute B and its Jacobian in a single pass using ForwardDiff
JB = ForwardDiff.jacobian(r -> Bfunc(r, t), x)
B = Bfunc(x, t)
# Compute B and its Jacobian in a single pass using DiffResults
result = DiffResults.JacobianResult(x)
result = ForwardDiff.jacobian!(result, r -> Bfunc(r, t), x)
JB = DiffResults.jacobian(result)
B = DiffResults.value(result)

Bmag = norm(B)
b̂ = B / Bmag
Expand All @@ -211,8 +213,10 @@ end
end

@inline function get_E_parameters(x, t, Efunc)
JE = ForwardDiff.jacobian(r -> Efunc(r, t), x)
E = Efunc(x, t)
result = DiffResults.JacobianResult(x)
result = ForwardDiff.jacobian!(result, r -> Efunc(r, t), x)
JE = DiffResults.jacobian(result)
E = DiffResults.value(result)

return E, JE
end
Expand Down
7 changes: 4 additions & 3 deletions src/utility/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,9 +567,10 @@ end
end

@inline function _get_B_jacobian(x, t, Bfunc)
#TODO: We may consider DiffResults to get JB and B in one pass.
JB = ForwardDiff.jacobian(r -> Bfunc(r, t), x)
B = Bfunc(x, t)
result = DiffResults.JacobianResult(x)
result = ForwardDiff.jacobian!(result, r -> Bfunc(r, t), x)
JB = DiffResults.jacobian(result)
B = DiffResults.value(result)
return B, JB
end

Expand Down
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,13 @@ end
sol_gc_analytic = solve(prob_gc_analytic, Vern9(); save_idxs = [1, 2, 3])
@test sol_gc[1, end] ≈ 0.9896155284173717
@test sol_gc_analytic[1, end] ≈ 0.9906923500002904 rtol = 1.0e-5

# Test get_E_parameters with constant E field (not used currently)
E_const(x, t) = SA[1.0e-9, 0.0, 0.0] # Time-dependent signature
x_test, t_test = SA[1.0, 0.0, 0.0], 0.0
E_expected = E_const(x_test, t_test)
E, JE = TP.get_E_parameters(x_test, t_test, E_const)
@test E == E_expected && JE == zeros(3, 3)
end
end

Expand Down
Loading