Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Steady state callback #601

Merged
merged 21 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/fluid/pipe_flow_2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ ode = semidiscretize(semi, tspan)
info_callback = InfoCallback(interval=100)
saving_callback = SolutionSavingCallback(dt=0.02, prefix="")

callbacks = CallbackSet(info_callback, saving_callback, UpdateCallback())
extra_callback = nothing

callbacks = CallbackSet(info_callback, saving_callback, UpdateCallback(), extra_callback)

sol = solve(ode, RDPK3SpFSAL35(),
abstol=1e-5, # Default abstol is 1e-6 (may need to be tuned to prevent boundary penetration)
Expand Down
4 changes: 2 additions & 2 deletions src/TrixiParticles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using Printf: @printf, @sprintf
using RecipesBase: RecipesBase, @series
using Random: seed!
using SciMLBase: CallbackSet, DiscreteCallback, DynamicalODEProblem, u_modified!,
get_tmp_cache, set_proposed_dt!, ODESolution, ODEProblem
get_tmp_cache, set_proposed_dt!, ODESolution, ODEProblem, terminate!
@reexport using StaticArrays: SVector
using StaticArrays: @SMatrix, SMatrix, setindex
using StrideArrays: PtrArray, StaticInt
Expand Down Expand Up @@ -59,7 +59,7 @@ export WeaklyCompressibleSPHSystem, EntropicallyDampedSPHSystem, TotalLagrangian
BoundarySPHSystem, DEMSystem, BoundaryDEMSystem, OpenBoundarySPHSystem, InFlow,
OutFlow
export InfoCallback, SolutionSavingCallback, DensityReinitializationCallback,
PostprocessCallback, StepsizeCallback, UpdateCallback
PostprocessCallback, StepsizeCallback, UpdateCallback, SteadyStateReachedCallback
export ContinuityDensity, SummationDensity
export PenaltyForceGanzenmueller, TransportVelocityAdami
export SchoenbergCubicSplineKernel, SchoenbergQuarticSplineKernel,
Expand Down
1 change: 1 addition & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ include("density_reinit.jl")
include("post_process.jl")
include("stepsize.jl")
include("update.jl")
include("steady_state_reached.jl")
28 changes: 16 additions & 12 deletions src/callbacks/info.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,7 @@ end
# affect!
function (info_callback::InfoCallback)(integrator)
if isfinished(integrator)
println("─"^100)
println("Trixi simulation finished. Final time: ", integrator.t,
" Time steps: ", integrator.stats.naccept, " (accepted), ",
integrator.iter, " (total)")
println("─"^100)
println()

# Print timer
TimerOutputs.complement!(timer())
print_timer(timer(), title="TrixiParticles.jl",
allocations=true, linechars=:unicode, compact=false)
println()
print_summary(integrator)
else
t = integrator.t
t_initial = first(integrator.sol.prob.tspan)
Expand Down Expand Up @@ -266,3 +255,18 @@ function summary_footer(io; total_width=100, indentation_level=0)

print(io, s)
end

function print_summary(integrator)
println("─"^100)
println("Trixi simulation finished. Final time: ", integrator.t,
" Time steps: ", integrator.stats.naccept, " (accepted), ",
integrator.iter, " (total)")
println("─"^100)
println()

# Print timer
TimerOutputs.complement!(timer())
print_timer(timer(), title="TrixiParticles.jl",
allocations=true, linechars=:unicode, compact=false)
println()
end
161 changes: 161 additions & 0 deletions src/callbacks/steady_state_reached.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""
SteadyStateReachedCallback(; interval::Integer=0, dt=0.0,
interval_size::Integer=10, abstol=1.0e-8, reltol=1.0e-6)

Terminates the integration when the change of kinetic energy between time steps
falls below the threshold specified by `abstol + reltol * ekin`,
where `ekin` is the total kinetic energy of the simulation.

# Keywords
- `interval=0`: Check steady state condition every `interval` time steps.
- `dt=0.0`: Check steady state condition in regular intervals of `dt` in terms
of integration time by adding additional `tstops`
(note that this may change the solution).
- `interval_size`: The interval in which the change of the kinetic energy is considered.
`interval_size` is a (integer) multiple of `interval` or `dt`.
- `abstol`: Absolute tolerance.
- `reltol`: Relative tolerance.
"""
struct SteadyStateReachedCallback{I, ELTYPE <: Real}
interval :: I
abstol :: ELTYPE
reltol :: ELTYPE
previous_ekin :: Vector{ELTYPE}
interval_size :: Int
end

function SteadyStateReachedCallback(; interval::Integer=0, dt=0.0,
interval_size::Integer=10, abstol=1.0e-8, reltol=1.0e-6)
abstol, reltol = promote(abstol, reltol)

if dt > 0 && interval > 0
throw(ArgumentError("setting both `interval` and `dt` is not supported"))
end

if dt > 0
interval = Float64(dt)
end

steady_state_callback = SteadyStateReachedCallback(interval, abstol, reltol, [Inf64],
interval_size)

if dt > 0
return PeriodicCallback(steady_state_callback, dt, save_positions=(false, false),
final_affect=true)
else
return DiscreteCallback(steady_state_callback, steady_state_callback,
save_positions=(false, false))
end
end

function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SteadyStateReachedCallback})
@nospecialize cb # reduce precompilation time

cb_ = cb.affect!

print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", ", "reltol=", cb_.reltol,
")")
end

function Base.show(io::IO,
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:SteadyStateReachedCallback}})
@nospecialize cb # reduce precompilation time

cb_ = cb.affect!.affect!

print(io, "SteadyStateReachedCallback(abstol=", cb_.abstol, ", reltol=", cb_.reltol,
")")
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any, <:SteadyStateReachedCallback})
@nospecialize cb # reduce precompilation time

if get(io, :compact, false)
show(io, cb)
else
cb_ = cb.affect!

setup = ["absolute tolerance" => cb_.abstol,
"relative tolerance" => cb_.reltol,
"interval" => cb_.interval,
"interval size" => cb_.interval_size]
summary_box(io, "SteadyStateReachedCallback", setup)
end
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:SteadyStateReachedCallback}})
@nospecialize cb # reduce precompilation time

if get(io, :compact, false)
show(io, cb)
else
cb_ = cb.affect!.affect!

setup = ["absolute tolerance" => cb_.abstol,
"relative tolerance" => cb_.reltol,
"interval" => cb_.interval,
"interval_size" => cb_.interval_size]
summary_box(io, "SteadyStateReachedCallback", setup)
end
end

# `affect!` (`PeriodicCallback`)
function (cb::SteadyStateReachedCallback)(integrator)
steady_state_condition!(cb, integrator) || return nothing

print_summary(integrator)

terminate!(integrator)
end

# `affect!` (`DiscreteCallback`)
function (cb::SteadyStateReachedCallback{Int})(integrator)
print_summary(integrator)

terminate!(integrator)
end

# `condition` (`DiscreteCallback`)
function (steady_state_callback::SteadyStateReachedCallback)(vu_ode, t, integrator)
return steady_state_condition!(steady_state_callback, integrator)
end

@inline function steady_state_condition!(cb, integrator)
(; abstol, reltol, previous_ekin, interval_size) = cb

vu_ode = integrator.u
v_ode, u_ode = vu_ode.x
semi = integrator.p

# Calculate kinetic energy
ekin = sum(semi.systems) do system
v = wrap_v(v_ode, system, semi)
unused_arg = nothing

return kinetic_energy(v, unused_arg, unused_arg, system)
end

if length(previous_ekin) == interval_size

# Calculate MSE only over the `interval_size`
mse = sum(1:interval_size) do index
return (previous_ekin[index] - ekin)^2 / interval_size
end

if mse <= abstol + reltol * ekin
return true
end

# Pop old kinetic energy
popfirst!(previous_ekin)
end

# Add current kinetic energy
push!(previous_ekin, ekin)

return false
end
1 change: 1 addition & 0 deletions test/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
include("postprocess.jl")
include("update.jl")
include("solution_saving.jl")
include("steady_state_reached.jl")
end
51 changes: 51 additions & 0 deletions test/callbacks/steady_state_reached.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
@testset verbose=true "SteadyStateReachedCallback" begin
@testset verbose=true "show" begin
# Default
callback0 = SteadyStateReachedCallback()

show_compact = "SteadyStateReachedCallback(abstol=1.0e-8, reltol=1.0e-6)"
@test repr(callback0) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ SteadyStateReachedCallback │
│ ══════════════════════════ │
│ absolute tolerance: …………………………… 1.0e-8 │
│ relative tolerance: …………………………… 1.0e-6 │
│ interval: ……………………………………………………… 0.0 │
│ interval size: ………………………………………… 10.0 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback0) == show_box

callback1 = SteadyStateReachedCallback(interval=11)

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ SteadyStateReachedCallback │
│ ══════════════════════════ │
│ absolute tolerance: …………………………… 1.0e-8 │
│ relative tolerance: …………………………… 1.0e-6 │
│ interval: ……………………………………………………… 11.0 │
│ interval size: ………………………………………… 10.0 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback1) == show_box

callback2 = SteadyStateReachedCallback(dt=1.2)

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ SteadyStateReachedCallback │
│ ══════════════════════════ │
│ absolute tolerance: …………………………… 1.0e-8 │
│ relative tolerance: …………………………… 1.0e-6 │
│ interval: ……………………………………………………… 1.2 │
│ interval_size: ………………………………………… 10.0 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback2) == show_box
end

@testset "Illegal Input" begin
error_str = "setting both `interval` and `dt` is not supported"
@test_throws ArgumentError(error_str) SteadyStateReachedCallback(dt=0.1, interval=1)
end
end
27 changes: 27 additions & 0 deletions test/examples/examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,33 @@
@test count_rhs_allocations(sol, semi) == 0
end

@trixi_testset "fluid/pipe_flow_2d.jl - steady state reached (`dt`)" begin
steady_state_reached = SteadyStateReachedCallback(; dt=0.002, interval_size=10)

@test_nowarn_mod trixi_include(@__MODULE__,
joinpath(examples_dir(), "fluid",
"pipe_flow_2d.jl"),
extra_callback=steady_state_reached,
tspan=(0.0, 1.5))

@test sol.t[end] < 1.0
@test sol.retcode == ReturnCode.Terminated
end

@trixi_testset "fluid/pipe_flow_2d.jl - steady state reached (`interval`)" begin
steady_state_reached = SteadyStateReachedCallback(; interval=1,
interval_size=10,
abstol=1.0e-5, reltol=1.0e-4)
@test_nowarn_mod trixi_include(@__MODULE__,
joinpath(examples_dir(), "fluid",
"pipe_flow_2d.jl"),
extra_callback=steady_state_reached,
tspan=(0.0, 1.5))

@test sol.t[end] < 1.0
@test sol.retcode == ReturnCode.Terminated
end

@trixi_testset "fluid/pipe_flow_3d.jl" begin
@test_nowarn_mod trixi_include(@__MODULE__, tspan=(0.0, 0.5),
joinpath(examples_dir(), "fluid",
Expand Down
Loading