Skip to content

Commit

Permalink
Use ClimaTimeSteppers
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jul 10, 2023
1 parent b720858 commit bd04d8f
Show file tree
Hide file tree
Showing 13 changed files with 247 additions and 124 deletions.
14 changes: 8 additions & 6 deletions examples/column/hydrostatic_implicit.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LinearAlgebra
import ClimaTimeSteppers as CTS
import ClimaCore:
Fields,
Domains,
Expand Down Expand Up @@ -233,11 +234,12 @@ ndays = 1.0

# Solve the ODE operator
prob = ODEProblem(
ODEFunction(
tendency!,
jac = jacobian!,
jac_prototype = zeros(length(Y), length(Y)),
tgrad = (dT, Y, p, t) -> fill!(dT, 0),
CTS.ClimaODEFunction(;
T_imp! = ODEFunction(
tendency!,
jac = jacobian!,
jac_prototype = zeros(length(Y), length(Y)),
),
),
Y,
(0.0, 60 * 60 * 24 * ndays),
Expand All @@ -246,7 +248,7 @@ prob = ODEProblem(
sol = solve(
prob,
# ImplicitEuler(),
Rosenbrock23(linsolve = linsolve!),
CTS.IMEXAlgorithm(CTS.ARS343(), CTS.NewtonsMethod()),
dt = Δt,
saveat = 60 * 60, # save every hour
progress = true,
Expand Down
50 changes: 23 additions & 27 deletions examples/hybrid/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ postprocessing(sol, output_dir) = nothing

################################################################################

import ClimaTimeSteppers as CTS
using ClimaComms
const comms_ctx = ClimaComms.context()
is_distributed = comms_ctx isa ClimaComms.MPICommsContext
Expand Down Expand Up @@ -99,28 +100,22 @@ else
)
end
p = get_cache(ᶜlocal_geometry, ᶠlocal_geometry, Y, dt, upwinding_mode)
if ode_algorithm <: Union{
OrdinaryDiffEq.OrdinaryDiffEqImplicitAlgorithm,
OrdinaryDiffEq.OrdinaryDiffEqAdaptiveImplicitAlgorithm,
}
use_transform = !(ode_algorithm in (Rosenbrock23, Rosenbrock32))
W = SchurComplementW(Y, use_transform, jacobian_flags, test_implicit_solver)
jac_kwargs =
use_transform ? (; jac_prototype = W, Wfact_t = Wfact!) :
(; jac_prototype = W, Wfact = Wfact!)

alg_kwargs = (; linsolve = linsolve!)
if ode_algorithm <: Union{
OrdinaryDiffEq.OrdinaryDiffEqNewtonAlgorithm,
OrdinaryDiffEq.OrdinaryDiffEqNewtonAdaptiveAlgorithm,
}
alg_kwargs =
(; alg_kwargs..., nlsolve = NLNewton(; max_iter = max_newton_iters))
end
else
jac_kwargs = alg_kwargs = (;)

include("ode_config.jl")

import LinearAlgebra
# Function required by Krylov.jl (x and b can be AbstractVectors)
# See https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/605 for a
# related issue that requires the same workaround.
function LinearAlgebra.ldiv!(x, A::SchurComplementW, b)
A.temp1 .= b
LinearAlgebra.ldiv!(A.temp2, A, A.temp1)
x .= A.temp2
end

ode_algo =
ode_configuration(FT; ode_name = string(ode_algorithm), max_newton_iters)

if haskey(ENV, "OUTPUT_DIR")
output_dir = ENV["OUTPUT_DIR"]
else
Expand Down Expand Up @@ -161,20 +156,21 @@ end
callback =
CallbackSet(dss_callback, save_to_disk_callback, additional_callbacks...)

problem = SplitODEProblem(
ODEFunction(
implicit_tendency!;
jac_kwargs...,
tgrad = (∂Y∂t, Y, p, t) -> (∂Y∂t .= FT(0)),
problem = ODE.ODEProblem(
CTS.ClimaODEFunction(;
T_imp! = ODEFunction(
implicit_tendency!;
jac_kwargs(ode_algo, Y, jacobian_flags)...,
),
T_exp! = remaining_tendency!,
),
remaining_tendency!,
Y,
(t_start, t_end),
p,
)
integrator = OrdinaryDiffEq.init(
problem,
ode_algorithm(; alg_kwargs...);
ode_algo;
saveat = dt_save_to_sol == 0 ? [] : dt_save_to_sol,
callback = callback,
dt = dt,
Expand Down
102 changes: 102 additions & 0 deletions examples/hybrid/ode_config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import DiffEqBase
import ClimaTimeSteppers as CTS
import OrdinaryDiffEq as ODE

is_explicit_CTS_algo_type(alg_or_tableau) =
alg_or_tableau <: CTS.ERKAlgorithmName

is_imex_CTS_algo_type(alg_or_tableau) =
alg_or_tableau <: CTS.IMEXARKAlgorithmName

is_implicit_type(::typeof(ODE.IMEXEuler)) = true
is_implicit_type(alg_or_tableau) =
alg_or_tableau <: Union{
ODE.OrdinaryDiffEqImplicitAlgorithm,
ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm,
} || is_imex_CTS_algo_type(alg_or_tableau)

is_ordinary_diffeq_newton(::typeof(ODE.IMEXEuler)) = true
is_ordinary_diffeq_newton(alg_or_tableau) =
alg_or_tableau <: Union{
ODE.OrdinaryDiffEqNewtonAlgorithm,
ODE.OrdinaryDiffEqNewtonAdaptiveAlgorithm,
}

is_imex_CTS_algo(::CTS.IMEXAlgorithm) = true
is_imex_CTS_algo(::DiffEqBase.AbstractODEAlgorithm) = false

is_implicit(::ODE.OrdinaryDiffEqImplicitAlgorithm) = true
is_implicit(::ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm) = true
is_implicit(ode_algo) = is_imex_CTS_algo(ode_algo)

is_rosenbrock(::ODE.Rosenbrock23) = true
is_rosenbrock(::ODE.Rosenbrock32) = true
is_rosenbrock(::DiffEqBase.AbstractODEAlgorithm) = false
use_transform(ode_algo) =
!(is_imex_CTS_algo(ode_algo) || is_rosenbrock(ode_algo))

additional_integrator_kwargs(::DiffEqBase.AbstractODEAlgorithm) = (;
adaptive = false,
progress = isinteractive(),
progress_steps = isinteractive() ? 1 : 1000,
)
additional_integrator_kwargs(::CTS.DistributedODEAlgorithm) = (;
kwargshandle = ODE.KeywordArgSilent, # allow custom kwargs
adjustfinal = true,
# TODO: enable progress bars in ClimaTimeSteppers
)

is_cts_algo(::DiffEqBase.AbstractODEAlgorithm) = false
is_cts_algo(::CTS.DistributedODEAlgorithm) = true

function jac_kwargs(ode_algo, Y, jacobi_flags)
if is_implicit(ode_algo)
W = SchurComplementW(Y, use_transform(ode_algo), jacobi_flags)
if use_transform(ode_algo)
return (; jac_prototype = W, Wfact_t = Wfact!)
else
return (; jac_prototype = W, Wfact = Wfact!)
end
else
return NamedTuple()
end
end

function ode_configuration(
::Type{FT};
ode_name::Union{String, Nothing} = nothing,
max_newton_iters = nothing,
) where {FT}
if occursin(".", ode_name)
ode_name = split(ode_name, ".")[end]
end
ode_sym = Symbol(ode_name)
alg_or_tableau = if hasproperty(ODE, ode_sym)
@warn "apply_limiter flag is ignored for OrdinaryDiffEq algorithms"
getproperty(ODE, ode_sym)
else
getproperty(CTS, ode_sym)
end
@info "Using ODE config: `$alg_or_tableau`"

if is_explicit_CTS_algo_type(alg_or_tableau)
return CTS.ExplicitAlgorithm(alg_or_tableau())
elseif !is_implicit_type(alg_or_tableau)
return alg_or_tableau()
elseif is_ordinary_diffeq_newton(alg_or_tableau)
if max_newton_iters == 1
error("OridinaryDiffEq requires at least 2 Newton iterations")
end
# κ like a relative tolerance; its default value in ODE is 0.01
nlsolve = ODE.NLNewton(;
κ = max_newton_iters == 2 ? Inf : 0.01,
max_iter = max_newton_iters,
)
return alg_or_tableau(; linsolve = linsolve!, nlsolve)
elseif is_imex_CTS_algo_type(alg_or_tableau)
newtons_method = CTS.NewtonsMethod(; max_iters = max_newton_iters)
return CTS.IMEXAlgorithm(alg_or_tableau(), newtons_method)
else
return alg_or_tableau(; linsolve = linsolve!)
end
end
2 changes: 1 addition & 1 deletion examples/hybrid/plane/inertial_gravity_wave.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ horizontal_mesh = periodic_line_mesh(; x_max, x_elem = x_elem)
dt = is_small_scale ? FT(1.5) : FT(20)
t_end = is_small_scale ? FT(60 * 60 * 0.5) : FT(60 * 60 * 8)
dt_save_to_sol = t_end / (animation_duration * fps)
ode_algorithm = OrdinaryDiffEq.Rosenbrock23
ode_algorithm = CTS.ARS343
jacobian_flags = (;
∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = ᶜ𝔼_name == :ρe ? :no_∂ᶜp∂ᶜK : :exact,
∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact,
Expand Down
Loading

0 comments on commit bd04d8f

Please sign in to comment.