Skip to content

Commit

Permalink
Merge pull request #13 from QuEraComputing/dp8-impl
Browse files Browse the repository at this point in the history
Implementation of DP8.
  • Loading branch information
weinbe58 authored Jan 2, 2024
2 parents ccc9854 + 3e44275 commit 4aba651
Show file tree
Hide file tree
Showing 24 changed files with 956 additions and 292 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/.vscode
**/**/**.cov

/Manifest.toml
Manifest.toml
/docs/build/
/docs/Manifest.toml
/docs/src/assets/main.css
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DormandPrince"
uuid = "5e45e72d-22b8-4dd0-9c8b-f96714864bcd"
authors = ["John Long<jlong@quera.com>", "Phillip Weinberg<pweinberg@quera.com>"]
version = "0.1.0"
version = "0.2.0"

[deps]

Expand Down
2 changes: 2 additions & 0 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
DormandPrince = "5e45e72d-22b8-4dd0-9c8b-f96714864bcd"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4 changes: 2 additions & 2 deletions benchmarks/rabi-diffeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ u0 = ComplexF64[1.0, 0.0]
tspan = (0.0, 2π)
prob = ODEProblem(f, u0, tspan)
# get precompilation out of the way
sol = solve(prob, DP5(), reltol=1e-10, abstol=1e-10)
sol = solve(prob, DP8(), reltol=1e-10, abstol=1e-10)

# terminate benchmark after maximum of 5 minutes
@benchmark solve(prob, DP5(), reltol=1e-10, abstol=1e-10) samples=10000 evals=5 seconds=300
@benchmark solve(prob, DP8(), reltol=1e-10, abstol=1e-10) samples=10000 evals=5 seconds=300
8 changes: 4 additions & 4 deletions benchmarks/rabi-dormand-prince.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using BenchmarkTools
using DormandPrince:DP5Solver, core_integrator
using DormandPrince:DP5Solver, DP8Solver, integrate

function fcn(x, y, f)
g(x) = 2.2*2π*sin(2π*x)
Expand All @@ -8,12 +8,12 @@ function fcn(x, y, f)
f[2] = -1im * g(x)/2 * y[1]
end

solver = DP5Solver(
solver = DP8Solver(
fcn,
0.0,
ComplexF64[1.0, 0.0]
)

core_integrator(solver, 2π)
integrate(solver, 2π)

@benchmark core_integrator(clean_solver, 2π) setup=(clean_solver = DP5Solver(fcn, 0.0, ComplexF64[1.0, 0.0])) samples=10000 evals=5 seconds=500
@benchmark integrate(clean_solver, 2π) setup=(clean_solver = DP8Solver(fcn, 0.0, ComplexF64[1.0, 0.0])) samples=10000 evals=5 seconds=500
11 changes: 9 additions & 2 deletions benchmarks/type_stab.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DormandPrince: DP5Solver, integrate
using DormandPrince: DP5Solver, DP8Solver, integrate
using DormandPrince.DP8: dop853, error_estimation
using JET: @report_opt


Expand All @@ -9,12 +10,18 @@ function fcn(x, y, f)
f[2] = -1im * g(x)/2 * y[1]
end

solver = DP5Solver(
solver = DP8Solver(
fcn,
0.0,
ComplexF64[1.0, 0.0]
)

h = 1e-6
# @report_opt dop853(solver, 1.0, 1.0, h)
# @code_warntype error_estimation(solver, 1e-6)
# @report_opt error_estimation(solver, 1e-6)

@code_warntype integrate(solver, 2π)
@report_opt integrate(solver, 2π)


8 changes: 4 additions & 4 deletions src/DormandPrince.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ using Base.Iterators:repeated, Repeated

# internal imports
include("types.jl")
include("hinit.jl")
include("checks.jl")
include("interface.jl")
include("dp5/mod.jl")


using DormandPrince. DP5: core_integrator
include("dp8/mod.jl")

# export Interface
export DP5Solver, integrate
export DP5Solver, DP8Solver, integrate


end # DormandPrince
8 changes: 4 additions & 4 deletions src/dp5/checks.jl → src/checks.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
# include("types.jl")

# make enums for every error type and return that instead of printing errors
function check_max_allowed_steps(options::Options{T}) where T
function check_max_allowed_steps(options::Options)
if options.maximum_allowed_steps < 0
return false
else
return true
end
end

function check_uround(options::Options{T}) where T
function check_uround(options::Options)
if (options.uround <= 1e-35) || (options.uround >= 1.0)
return false
else
return true
end
end

function check_beta(options::Options{T}) where T
function check_beta(options::Options)
if options.beta > 0.2
return false
else
return true
end
end

function check_safety_factor(options::Options{T}) where T
function check_safety_factor(options::Options)
if (options.safety_factor >= 1.0) || (options.safety_factor <= 1e-4)
return false
else
Expand Down
72 changes: 4 additions & 68 deletions src/dp5/helpers.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,4 @@
function do_step!(solver, h)

# define constants
c2=0.2
c3=0.3
c4=0.8
c5=8.0/9.0
a21=0.2
a31=3.0/40.0
a32=9.0/40.0
a41=44.0/45.0
a42=-56.0/15.0
a43=32.0/9.0
a51=19372.0/6561.0
a52=-25360.0/2187.0
a53=64448.0/6561.0
a54=-212.0/729.0
a61=9017.0/3168.0
a62=-355.0/33.0
a63=46732.0/5247.0
a64=49.0/176.0
a65=-5103.0/18656.0
a71=35.0/384.0
a73=500.0/1113.0
a74=125.0/192.0
a75=-2187.0/6784.0
a76=11.0/84.0
e1=71.0/57600.0
e3=-71.0/16695.0
e4=71.0/1920.0
e5=-17253.0/339200.0
e6=22.0/525.0
e7=-1.0/40.0
function do_step!(solver::DP5Solver{T}, h::T) where T

####### First 6 stages

Expand Down Expand Up @@ -61,28 +29,15 @@ function error_estimation(solver)

err = mapreduce(+, solver.consts.atol_iter, solver.consts.rtol_iter, solver.k4, solver.y, solver.ysti) do atoli, rtoli, k4i, yi, ystii
sk = atoli + rtoli*max(abs(yi), abs(ystii))
abs(k4i/sk)^2
(abs(k4i)/sk)^2
end

err = sqrt(err/length(solver.y))

return err
end

function estimate_second_derivative(solver, h)

der2 = mapreduce(+, solver.consts.atol_iter, solver.consts.rtol_iter, solver.k2, solver.k1, solver.y) do atoli, rtoli, f1i, f0i, yi
sk = atoli + rtoli*abs(yi)
((f1i-f0i)/sk)^2
end

der2 = sqrt(der2)/h

return der2

end

function stiffness_detection!(solver, naccpt, h)
function stiffness_detection!(solver::DP5Solver{T}, naccpt::Int, h::T) where T
if (mod(naccpt, solver.options.stiffness_test_activation_step) == 0) || (solver.vars.iasti > 0)
#stnum = 0.0
#stden = 0.0
Expand All @@ -95,7 +50,7 @@ function stiffness_detection!(solver, naccpt, h)
end

if stden > 0.0
solver.vars.hlamb = h*sqrt(stnum/stden)
solver.vars.hlamb = abs(h)*sqrt(stnum/stden)
else
solver.vars.hlamb = Inf
end
Expand All @@ -114,22 +69,3 @@ function stiffness_detection!(solver, naccpt, h)
end
end
end

function euler_first_guess(solver, hmax, posneg)

dnf, dny = mapreduce(.+, solver.consts.atol_iter, solver.consts.rtol_iter, solver.k1, solver.y) do atoli, rtoli, f0i, yi
sk = atoli + rtoli*abs(yi)
abs(f0i/sk)^2, abs(yi/sk)^2 # dnf, dny
end


if (dnf <= 1.0e-10) || (dny <= 1.0e-10)
h = 1.0e-6
else
h = 0.01*sqrt(dny/dnf)
end
h = min(h, hmax)
h = h * Base.sign(posneg)

return h, dnf
end
14 changes: 12 additions & 2 deletions src/dp5/mod.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
module DP5

using ..DormandPrince: DormandPrince, DP5Solver, Vars, Consts, Options, Report
using ..DormandPrince: DormandPrince,
DP5Solver,
Vars,
Consts,
Options,
Report,
hinit,
check_beta,
check_max_allowed_steps,
check_safety_factor,
check_uround

include("params.jl")
include("helpers.jl")
include("checks.jl")
include("solver.jl")

end
31 changes: 31 additions & 0 deletions src/dp5/params.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# define constants
const c2=0.2
const c3=0.3
const c4=0.8
const c5=8.0/9.0
const a21=0.2
const a31=3.0/40.0
const a32=9.0/40.0
const a41=44.0/45.0
const a42=-56.0/15.0
const a43=32.0/9.0
const a51=19372.0/6561.0
const a52=-25360.0/2187.0
const a53=64448.0/6561.0
const a54=-212.0/729.0
const a61=9017.0/3168.0
const a62=-355.0/33.0
const a63=46732.0/5247.0
const a64=49.0/176.0
const a65=-5103.0/18656.0
const a71=35.0/384.0
const a73=500.0/1113.0
const a74=125.0/192.0
const a75=-2187.0/6784.0
const a76=11.0/84.0
const e1=71.0/57600.0
const e3=-71.0/16695.0
const e4=71.0/1920.0
const e5=-17253.0/339200.0
const e6=22.0/525.0
const e7=-1.0/40.0
Loading

2 comments on commit 4aba651

@weinbe58
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

release via ion

@JuliaRegistrator register branch=main

Release notes:

Release Note

  • Added implementation of Dorman Prince 8th order solver
  • Adding Abstract interface for future work to follow

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/98071

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 4aba651dd73c036bf1d55e47baefaaf022997785
git push origin v0.2.0

Please sign in to comment.