Skip to content

Commit

Permalink
adding interface for getting current state of integrator.
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 committed Jan 3, 2024
1 parent 4aba651 commit 412732d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DormandPrince"
uuid = "5e45e72d-22b8-4dd0-9c8b-f96714864bcd"
authors = ["John Long<jlong@quera.com>", "Phillip Weinberg<pweinberg@quera.com>"]
version = "0.2.0"
version = "0.3.0"

[deps]

Expand Down
7 changes: 4 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function Base.iterate(solver_iter::SolverIterator)
# integrate to first time
integrate(solver_iter.solver, first(solver_iter.times))
# return value and index which is the state
return (solver_iter.times[1], solver_iter.solver.y), 2
return (solver_iter.times[1], get_current_state(solver_iter.solver)), 2
end

# gets the next (t,y), return index+! which is the updated state
Expand All @@ -21,14 +21,15 @@ function Base.iterate(solver_iter::SolverIterator, index::Int)
# integrate to next time
integrate(solver_iter.solver, solver_iter.times[index])
# return time and state
return (solver_iter.times[index], solver_iter.solver.y), index+1
return (solver_iter.times[index], get_current_state(solver_iter.solver)), index+1
end

# 3 modes of operation for integrate
# 1. integrate(solver, time) -> state (modify solver object in place)
# 2. integrate(solver, times) -> iterator
# 3. integrate(callback, solver, times) -> vector of states with callback applied

get_current_state(::AbstractDPSolver) = error("not implemented")

Check warning on line 32 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L32

Added line #L32 was not covered by tests
integrate(solver::AbstractDPSolver{T}, times::AbstractVector{T}) where {T <: Real} = SolverIterator(solver, times)

function integrate(callback, solver::AbstractDPSolver{T}, times::AbstractVector{T}; sort_times::Bool = true) where {T <: Real}
Expand All @@ -37,7 +38,7 @@ function integrate(callback, solver::AbstractDPSolver{T}, times::AbstractVector{
result = []
for time in times
integrate(solver, time)
push!(result, callback(time, solver.y))
push!(result, callback(time, get_current_state(solver)))
end

return result
Expand Down
2 changes: 2 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,5 @@ function DP8Solver(
DP8Solver(f, x, y, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, y1;kw...)
end

get_current_state(solver::DP5Solver) = solver.y
get_current_state(solver::DP8Solver) = solver.y

0 comments on commit 412732d

Please sign in to comment.