From b7e081532574aed2bbd17774216ac89abeef8f9e Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Mon, 10 Jul 2023 14:22:26 -0700 Subject: [PATCH] Use ClimaTimeSteppers --- examples/column/hydrostatic_implicit.jl | 14 +- examples/hybrid/driver.jl | 50 +++-- examples/hybrid/ode_config.jl | 102 ++++++++++ .../hybrid/plane/inertial_gravity_wave.jl | 2 +- examples/hybrid/schur_complement_W.jl | 178 ++++++++++-------- .../hybrid/sphere/balanced_flow_rhotheta.jl | 2 +- .../hybrid/sphere/baroclinic_wave_rhoe.jl | 2 +- .../hybrid/sphere/baroclinic_wave_rhotheta.jl | 2 +- examples/hybrid/sphere/held_suarez_rhoe.jl | 2 +- .../hybrid/sphere/held_suarez_rhoe_int.jl | 2 +- .../hybrid/sphere/held_suarez_rhotheta.jl | 2 +- .../sphere/held_suarez_rhotheta_scaling.jl | 2 +- .../sphere/held_suarez_rhotheta_tempest.jl | 2 +- 13 files changed, 238 insertions(+), 124 deletions(-) create mode 100644 examples/hybrid/ode_config.jl diff --git a/examples/column/hydrostatic_implicit.jl b/examples/column/hydrostatic_implicit.jl index de97568ff9..e052fef671 100644 --- a/examples/column/hydrostatic_implicit.jl +++ b/examples/column/hydrostatic_implicit.jl @@ -1,4 +1,5 @@ using LinearAlgebra +import ClimaTimeSteppers as CTS import ClimaCore: Fields, Domains, @@ -233,11 +234,12 @@ ndays = 1.0 # Solve the ODE operator prob = ODEProblem( - ODEFunction( - tendency!, - jac = jacobian!, - jac_prototype = zeros(length(Y), length(Y)), - tgrad = (dT, Y, p, t) -> fill!(dT, 0), + CTS.ClimaODEFunction(; + T_imp! = ODEFunction( + tendency!, + jac = jacobian!, + jac_prototype = zeros(length(Y), length(Y)), + ), ), Y, (0.0, 60 * 60 * 24 * ndays), @@ -246,7 +248,7 @@ prob = ODEProblem( sol = solve( prob, # ImplicitEuler(), - Rosenbrock23(linsolve = linsolve!), + CTS.IMEXAlgorithm(CTS.ARS343(), CTS.NewtonsMethod()), dt = Δt, saveat = 60 * 60, # save every hour progress = true, diff --git a/examples/hybrid/driver.jl b/examples/hybrid/driver.jl index 140b5fad43..fd2cbb97ca 100644 --- a/examples/hybrid/driver.jl +++ b/examples/hybrid/driver.jl @@ -24,6 +24,7 @@ postprocessing(sol, output_dir) = nothing ################################################################################ +import ClimaTimeSteppers as CTS using ClimaComms const comms_ctx = ClimaComms.context() is_distributed = comms_ctx isa ClimaComms.MPICommsContext @@ -99,28 +100,22 @@ else ) end p = get_cache(ᶜlocal_geometry, ᶠlocal_geometry, Y, dt, upwinding_mode) -if ode_algorithm <: Union{ - OrdinaryDiffEq.OrdinaryDiffEqImplicitAlgorithm, - OrdinaryDiffEq.OrdinaryDiffEqAdaptiveImplicitAlgorithm, -} - use_transform = !(ode_algorithm in (Rosenbrock23, Rosenbrock32)) - W = SchurComplementW(Y, use_transform, jacobian_flags, test_implicit_solver) - jac_kwargs = - use_transform ? (; jac_prototype = W, Wfact_t = Wfact!) : - (; jac_prototype = W, Wfact = Wfact!) - - alg_kwargs = (; linsolve = linsolve!) - if ode_algorithm <: Union{ - OrdinaryDiffEq.OrdinaryDiffEqNewtonAlgorithm, - OrdinaryDiffEq.OrdinaryDiffEqNewtonAdaptiveAlgorithm, - } - alg_kwargs = - (; alg_kwargs..., nlsolve = NLNewton(; max_iter = max_newton_iters)) - end -else - jac_kwargs = alg_kwargs = (;) + +include("ode_config.jl") + +import LinearAlgebra +# Function required by Krylov.jl (x and b can be AbstractVectors) +# See https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/605 for a +# related issue that requires the same workaround. +function LinearAlgebra.ldiv!(x, A::SchurComplementW, b) + A.temp1 .= b + LinearAlgebra.ldiv!(A.temp2, A, A.temp1) + x .= A.temp2 end +ode_algo = + ode_configuration(FT; ode_name = string(ode_algorithm), max_newton_iters) + if haskey(ENV, "OUTPUT_DIR") output_dir = ENV["OUTPUT_DIR"] else @@ -161,20 +156,21 @@ end callback = CallbackSet(dss_callback, save_to_disk_callback, additional_callbacks...) -problem = SplitODEProblem( - ODEFunction( - implicit_tendency!; - jac_kwargs..., - tgrad = (∂Y∂t, Y, p, t) -> (∂Y∂t .= FT(0)), +problem = ODE.ODEProblem( + CTS.ClimaODEFunction(; + T_imp! = ODEFunction( + implicit_tendency!; + jac_kwargs(ode_algo, Y, jacobian_flags)..., + ), + T_exp! = remaining_tendency!, ), - remaining_tendency!, Y, (t_start, t_end), p, ) integrator = OrdinaryDiffEq.init( problem, - ode_algorithm(; alg_kwargs...); + ode_algo; saveat = dt_save_to_sol == 0 ? [] : dt_save_to_sol, callback = callback, dt = dt, diff --git a/examples/hybrid/ode_config.jl b/examples/hybrid/ode_config.jl new file mode 100644 index 0000000000..f3ec67d218 --- /dev/null +++ b/examples/hybrid/ode_config.jl @@ -0,0 +1,102 @@ +import DiffEqBase +import ClimaTimeSteppers as CTS +import OrdinaryDiffEq as ODE + +is_explicit_CTS_algo_type(alg_or_tableau) = + alg_or_tableau <: CTS.ERKAlgorithmName + +is_imex_CTS_algo_type(alg_or_tableau) = + alg_or_tableau <: CTS.IMEXARKAlgorithmName + +is_implicit_type(::typeof(ODE.IMEXEuler)) = true +is_implicit_type(alg_or_tableau) = + alg_or_tableau <: Union{ + ODE.OrdinaryDiffEqImplicitAlgorithm, + ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm, + } || is_imex_CTS_algo_type(alg_or_tableau) + +is_ordinary_diffeq_newton(::typeof(ODE.IMEXEuler)) = true +is_ordinary_diffeq_newton(alg_or_tableau) = + alg_or_tableau <: Union{ + ODE.OrdinaryDiffEqNewtonAlgorithm, + ODE.OrdinaryDiffEqNewtonAdaptiveAlgorithm, + } + +is_imex_CTS_algo(::CTS.IMEXAlgorithm) = true +is_imex_CTS_algo(::DiffEqBase.AbstractODEAlgorithm) = false + +is_implicit(::ODE.OrdinaryDiffEqImplicitAlgorithm) = true +is_implicit(::ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm) = true +is_implicit(ode_algo) = is_imex_CTS_algo(ode_algo) + +is_rosenbrock(::ODE.Rosenbrock23) = true +is_rosenbrock(::ODE.Rosenbrock32) = true +is_rosenbrock(::DiffEqBase.AbstractODEAlgorithm) = false +use_transform(ode_algo) = + !(is_imex_CTS_algo(ode_algo) || is_rosenbrock(ode_algo)) + +additional_integrator_kwargs(::DiffEqBase.AbstractODEAlgorithm) = (; + adaptive = false, + progress = isinteractive(), + progress_steps = isinteractive() ? 1 : 1000, +) +additional_integrator_kwargs(::CTS.DistributedODEAlgorithm) = (; + kwargshandle = ODE.KeywordArgSilent, # allow custom kwargs + adjustfinal = true, + # TODO: enable progress bars in ClimaTimeSteppers +) + +is_cts_algo(::DiffEqBase.AbstractODEAlgorithm) = false +is_cts_algo(::CTS.DistributedODEAlgorithm) = true + +function jac_kwargs(ode_algo, Y, jacobi_flags) + if is_implicit(ode_algo) + W = SchurComplementW(Y, use_transform(ode_algo), jacobi_flags) + if use_transform(ode_algo) + return (; jac_prototype = W, Wfact_t = Wfact!) + else + return (; jac_prototype = W, Wfact = Wfact!) + end + else + return NamedTuple() + end +end + +function ode_configuration( + ::Type{FT}; + ode_name::Union{String, Nothing} = nothing, + max_newton_iters = nothing, +) where {FT} + if occursin(".", ode_name) + ode_name = split(ode_name, ".")[end] + end + ode_sym = Symbol(ode_name) + alg_or_tableau = if hasproperty(ODE, ode_sym) + @warn "apply_limiter flag is ignored for OrdinaryDiffEq algorithms" + getproperty(ODE, ode_sym) + else + getproperty(CTS, ode_sym) + end + @info "Using ODE config: `$alg_or_tableau`" + + if is_explicit_CTS_algo_type(alg_or_tableau) + return CTS.ExplicitAlgorithm(alg_or_tableau()) + elseif !is_implicit_type(alg_or_tableau) + return alg_or_tableau() + elseif is_ordinary_diffeq_newton(alg_or_tableau) + if max_newton_iters == 1 + error("OridinaryDiffEq requires at least 2 Newton iterations") + end + # κ like a relative tolerance; its default value in ODE is 0.01 + nlsolve = ODE.NLNewton(; + κ = max_newton_iters == 2 ? Inf : 0.01, + max_iter = max_newton_iters, + ) + return alg_or_tableau(; linsolve = linsolve!, nlsolve) + elseif is_imex_CTS_algo_type(alg_or_tableau) + newtons_method = CTS.NewtonsMethod(; max_iters = max_newton_iters) + return CTS.IMEXAlgorithm(alg_or_tableau(), newtons_method) + else + return alg_or_tableau(; linsolve = linsolve!) + end +end diff --git a/examples/hybrid/plane/inertial_gravity_wave.jl b/examples/hybrid/plane/inertial_gravity_wave.jl index 90d4a2f70f..4cfaee5002 100644 --- a/examples/hybrid/plane/inertial_gravity_wave.jl +++ b/examples/hybrid/plane/inertial_gravity_wave.jl @@ -63,7 +63,7 @@ horizontal_mesh = periodic_line_mesh(; x_max, x_elem = x_elem) dt = is_small_scale ? FT(1.5) : FT(20) t_end = is_small_scale ? FT(60 * 60 * 0.5) : FT(60 * 60 * 8) dt_save_to_sol = t_end / (animation_duration * fps) -ode_algorithm = OrdinaryDiffEq.Rosenbrock23 +ode_algorithm = CTS.ARS343 jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = ᶜ𝔼_name == :ρe ? :no_∂ᶜp∂ᶜK : :exact, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact, diff --git a/examples/hybrid/schur_complement_W.jl b/examples/hybrid/schur_complement_W.jl index 0a0948488a..a4c2cf94d0 100644 --- a/examples/hybrid/schur_complement_W.jl +++ b/examples/hybrid/schur_complement_W.jl @@ -6,7 +6,7 @@ using ClimaCore.Utilities: half const compose = Operators.ComposeStencils() const apply = Operators.ApplyStencil() -struct SchurComplementW{F, FT, J1, J2, J3, J4, S, A} +struct SchurComplementW{F, FT, J1, J2, J3, J4, S, A, T} # whether this struct is used to compute Wfact_t or Wfact transform::Bool @@ -29,6 +29,10 @@ struct SchurComplementW{F, FT, J1, J2, J3, J4, S, A} # whether to test the Jacobian and linear solver test::Bool + + # cache that is used to evaluate ldiv! + temp1::T + temp2::T end function SchurComplementW(Y, transform, flags, test = false) @@ -68,6 +72,7 @@ function SchurComplementW(Y, transform, flags, test = false) typeof(∂ᶠ𝕄ₜ∂ᶠ𝕄), typeof(S), typeof(S_column_array), + typeof(Y), }( transform, flags, @@ -80,6 +85,8 @@ function SchurComplementW(Y, transform, flags, test = false) S, S_column_array, test, + similar(Y), + similar(Y), ) end @@ -88,6 +95,7 @@ end # is a temporary workaround to avoid unnecessary allocations. Base.similar(w::SchurComplementW) = w + #= A = [-I 0 dtγ ∂ᶜρₜ∂ᶠ𝕄 ; 0 -I dtγ ∂ᶜ𝔼ₜ∂ᶠ𝕄 ; @@ -109,91 +117,97 @@ Finally, use (1) and (2) to get x1 and x2. Note: The matrix S = A31 A13 + A32 A23 + A33 - I is the "Schur complement" of [-I 0; 0 -I] (the top-left 4 blocks) in A. =# -function linsolve!(::Type{Val{:init}}, f, u0; kwargs...) - function _linsolve!(x, A, b, update_matrix = false; kwargs...) - (; dtγ_ref, ∂ᶜρₜ∂ᶠ𝕄, ∂ᶜ𝔼ₜ∂ᶠ𝕄, ∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶠ𝕄ₜ∂ᶠ𝕄) = A - (; S, S_column_array) = A - dtγ = dtγ_ref[] - - xᶜρ = x.c.ρ - bᶜρ = b.c.ρ - if :ρθ in propertynames(x.c) - xᶜ𝔼 = x.c.ρθ - bᶜ𝔼 = b.c.ρθ - elseif :ρe in propertynames(x.c) - xᶜ𝔼 = x.c.ρe - bᶜ𝔼 = b.c.ρe - elseif :ρe_int in propertynames(x.c) - xᶜ𝔼 = x.c.ρe_int - bᶜ𝔼 = b.c.ρe_int - end - if :ρw in propertynames(x.f) - xᶠ𝕄 = x.f.ρw.components.data.:1 - bᶠ𝕄 = b.f.ρw.components.data.:1 - elseif :w in propertynames(x.f) - xᶠ𝕄 = x.f.w.components.data.:1 - bᶠ𝕄 = b.f.w.components.data.:1 - end +linsolve!(::Type{Val{:init}}, f, u0; kwargs...) = _linsolve! +_linsolve!(x, A, b, update_matrix = false; kwargs...) = + LinearAlgebra.ldiv!(x, A, b) + +function LinearAlgebra.ldiv!( + x::Fields.FieldVector, + A::SchurComplementW, + b::Fields.FieldVector, +) + (; dtγ_ref, ∂ᶜρₜ∂ᶠ𝕄, ∂ᶜ𝔼ₜ∂ᶠ𝕄, ∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶠ𝕄ₜ∂ᶠ𝕄) = A + (; S, S_column_array) = A + dtγ = dtγ_ref[] + + xᶜρ = x.c.ρ + bᶜρ = b.c.ρ + if :ρθ in propertynames(x.c) + xᶜ𝔼 = x.c.ρθ + bᶜ𝔼 = b.c.ρθ + elseif :ρe in propertynames(x.c) + xᶜ𝔼 = x.c.ρe + bᶜ𝔼 = b.c.ρe + elseif :ρe_int in propertynames(x.c) + xᶜ𝔼 = x.c.ρe_int + bᶜ𝔼 = b.c.ρe_int + end + if :ρw in propertynames(x.f) + xᶠ𝕄 = x.f.ρw.components.data.:1 + bᶠ𝕄 = b.f.ρw.components.data.:1 + elseif :w in propertynames(x.f) + xᶠ𝕄 = x.f.w.components.data.:1 + bᶠ𝕄 = b.f.w.components.data.:1 + end - # TODO: Extend LinearAlgebra.I to work with stencil fields. - FT = eltype(eltype(S)) - I = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT)))) - if Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) != (-half, half) - str = "The linear solver cannot yet be run with the given ∂ᶜ𝔼ₜ/∂ᶠ𝕄 \ - block, since it has more than 2 diagonals. So, ∂ᶜ𝔼ₜ/∂ᶠ𝕄 will \ - be set to 0 for the Schur complement computation. Consider \ - changing the ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode or the energy variable." - @warn str maxlog = 1 - @. S = dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I - else - @. S = - dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + - dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶜ𝔼ₜ∂ᶠ𝕄) + - dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I - end + # TODO: Extend LinearAlgebra.I to work with stencil fields. + FT = eltype(eltype(S)) + I = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT)))) + if Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) != (-half, half) + str = "The linear solver cannot yet be run with the given ∂ᶜ𝔼ₜ/∂ᶠ𝕄 \ + block, since it has more than 2 diagonals. So, ∂ᶜ𝔼ₜ/∂ᶠ𝕄 will \ + be set to 0 for the Schur complement computation. Consider \ + changing the ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode or the energy variable." + @warn str maxlog = 1 + @. S = dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I + else + @. S = + dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + + dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶜ𝔼ₜ∂ᶠ𝕄) + + dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I + end - @. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼)) - - Operators.column_thomas_solve!(S, xᶠ𝕄) - - @. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄) - @. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄) - - if A.test && Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) == (-half, half) - Ni, Nj, _, Nv, Nh = size(Spaces.local_geometry_data(axes(xᶜρ))) - ∂Yₜ∂Y = Array{FT}(undef, 3 * Nv + 1, 3 * Nv + 1) - ΔY = Array{FT}(undef, 3 * Nv + 1) - ΔΔY = Array{FT}(undef, 3 * Nv + 1) - for h in 1:Nh, j in 1:Nj, i in 1:Ni - ∂Yₜ∂Y .= zero(FT) - ∂Yₜ∂Y[1:Nv, (2 * Nv + 1):(3 * Nv + 1)] .= - matrix_column(∂ᶜρₜ∂ᶠ𝕄, axes(x.f), i, j, h) - ∂Yₜ∂Y[(Nv + 1):(2 * Nv), (2 * Nv + 1):(3 * Nv + 1)] .= - matrix_column(∂ᶜ𝔼ₜ∂ᶠ𝕄, axes(x.f), i, j, h) - ∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), 1:Nv] .= - matrix_column(∂ᶠ𝕄ₜ∂ᶜρ, axes(x.c), i, j, h) - ∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (Nv + 1):(2 * Nv)] .= - matrix_column(∂ᶠ𝕄ₜ∂ᶜ𝔼, axes(x.c), i, j, h) - ∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (2 * Nv + 1):(3 * Nv + 1)] .= - matrix_column(∂ᶠ𝕄ₜ∂ᶠ𝕄, axes(x.f), i, j, h) - ΔY[1:Nv] .= vector_column(xᶜρ, i, j, h) - ΔY[(Nv + 1):(2 * Nv)] .= vector_column(xᶜ𝔼, i, j, h) - ΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(xᶠ𝕄, i, j, h) - ΔΔY[1:Nv] .= vector_column(bᶜρ, i, j, h) - ΔΔY[(Nv + 1):(2 * Nv)] .= vector_column(bᶜ𝔼, i, j, h) - ΔΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(bᶠ𝕄, i, j, h) - @assert (-LinearAlgebra.I + dtγ * ∂Yₜ∂Y) * ΔY ≈ ΔΔY - end + @. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼)) + + Operators.column_thomas_solve!(S, xᶠ𝕄) + + @. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄) + @. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄) + + if A.test && Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) == (-half, half) + Ni, Nj, _, Nv, Nh = size(Spaces.local_geometry_data(axes(xᶜρ))) + ∂Yₜ∂Y = Array{FT}(undef, 3 * Nv + 1, 3 * Nv + 1) + ΔY = Array{FT}(undef, 3 * Nv + 1) + ΔΔY = Array{FT}(undef, 3 * Nv + 1) + for h in 1:Nh, j in 1:Nj, i in 1:Ni + ∂Yₜ∂Y .= zero(FT) + ∂Yₜ∂Y[1:Nv, (2 * Nv + 1):(3 * Nv + 1)] .= + matrix_column(∂ᶜρₜ∂ᶠ𝕄, axes(x.f), i, j, h) + ∂Yₜ∂Y[(Nv + 1):(2 * Nv), (2 * Nv + 1):(3 * Nv + 1)] .= + matrix_column(∂ᶜ𝔼ₜ∂ᶠ𝕄, axes(x.f), i, j, h) + ∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), 1:Nv] .= + matrix_column(∂ᶠ𝕄ₜ∂ᶜρ, axes(x.c), i, j, h) + ∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (Nv + 1):(2 * Nv)] .= + matrix_column(∂ᶠ𝕄ₜ∂ᶜ𝔼, axes(x.c), i, j, h) + ∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (2 * Nv + 1):(3 * Nv + 1)] .= + matrix_column(∂ᶠ𝕄ₜ∂ᶠ𝕄, axes(x.f), i, j, h) + ΔY[1:Nv] .= vector_column(xᶜρ, i, j, h) + ΔY[(Nv + 1):(2 * Nv)] .= vector_column(xᶜ𝔼, i, j, h) + ΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(xᶠ𝕄, i, j, h) + ΔΔY[1:Nv] .= vector_column(bᶜρ, i, j, h) + ΔΔY[(Nv + 1):(2 * Nv)] .= vector_column(bᶜ𝔼, i, j, h) + ΔΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(bᶠ𝕄, i, j, h) + @assert (-LinearAlgebra.I + dtγ * ∂Yₜ∂Y) * ΔY ≈ ΔΔY end + end - if :ρuₕ in propertynames(x.c) - @. x.c.ρuₕ = -b.c.ρuₕ - elseif :uₕ in propertynames(x.c) - @. x.c.uₕ = -b.c.uₕ - end + if :ρuₕ in propertynames(x.c) + @. x.c.ρuₕ = -b.c.ρuₕ + elseif :uₕ in propertynames(x.c) + @. x.c.uₕ = -b.c.uₕ + end - if A.transform - x .*= dtγ - end + if A.transform + x .*= dtγ end end diff --git a/examples/hybrid/sphere/balanced_flow_rhotheta.jl b/examples/hybrid/sphere/balanced_flow_rhotheta.jl index 94f6049729..f59fe6f4c0 100644 --- a/examples/hybrid/sphere/balanced_flow_rhotheta.jl +++ b/examples/hybrid/sphere/balanced_flow_rhotheta.jl @@ -14,7 +14,7 @@ t_end = FT(60 * 60 * 24 * 10) dt = FT(400) dt_save_to_sol = FT(60 * 60 * 24) dt_save_to_disk = FT(0) # 0 means don't save to disk -ode_algorithm = OrdinaryDiffEq.Rosenbrock23 +ode_algorithm = CTS.ARS343 jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :exact, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact) additional_cache(ᶜlocal_geometry, ᶠlocal_geometry, dt) = merge( diff --git a/examples/hybrid/sphere/baroclinic_wave_rhoe.jl b/examples/hybrid/sphere/baroclinic_wave_rhoe.jl index 3540b77029..6c4f2f83f7 100644 --- a/examples/hybrid/sphere/baroclinic_wave_rhoe.jl +++ b/examples/hybrid/sphere/baroclinic_wave_rhoe.jl @@ -14,7 +14,7 @@ t_end = FT(60 * 60 * 24 * 10) dt = FT(400) dt_save_to_sol = FT(60 * 60 * 24) dt_save_to_disk = FT(0) # 0 means don't save to disk -ode_algorithm = OrdinaryDiffEq.Rosenbrock23 +ode_algorithm = CTS.ARS343 jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :no_∂ᶜp∂ᶜK, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact) additional_cache(ᶜlocal_geometry, ᶠlocal_geometry, dt) = merge( diff --git a/examples/hybrid/sphere/baroclinic_wave_rhotheta.jl b/examples/hybrid/sphere/baroclinic_wave_rhotheta.jl index 6d47ead9b2..ea680018c3 100644 --- a/examples/hybrid/sphere/baroclinic_wave_rhotheta.jl +++ b/examples/hybrid/sphere/baroclinic_wave_rhotheta.jl @@ -14,7 +14,7 @@ t_end = FT(60 * 60 * 24 * 10) dt = FT(400) dt_save_to_sol = FT(60 * 60 * 24) dt_save_to_disk = FT(0) # 0 means don't save to disk -ode_algorithm = OrdinaryDiffEq.Rosenbrock23 +ode_algorithm = CTS.ARS343 jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :exact, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact) additional_cache(ᶜlocal_geometry, ᶠlocal_geometry, dt) = merge( diff --git a/examples/hybrid/sphere/held_suarez_rhoe.jl b/examples/hybrid/sphere/held_suarez_rhoe.jl index 420e986bc5..fec7075932 100644 --- a/examples/hybrid/sphere/held_suarez_rhoe.jl +++ b/examples/hybrid/sphere/held_suarez_rhoe.jl @@ -14,7 +14,7 @@ t_end = FT(60 * 60 * 24 * 10) dt = FT(400) dt_save_to_sol = FT(60 * 60 * 24) dt_save_to_disk = FT(0) # 0 means don't save to disk -ode_algorithm = OrdinaryDiffEq.Rosenbrock23 +ode_algorithm = CTS.ARS343 jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :no_∂ᶜp∂ᶜK, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact) # Additional values required for driver diff --git a/examples/hybrid/sphere/held_suarez_rhoe_int.jl b/examples/hybrid/sphere/held_suarez_rhoe_int.jl index a59b678293..712f43cd2d 100644 --- a/examples/hybrid/sphere/held_suarez_rhoe_int.jl +++ b/examples/hybrid/sphere/held_suarez_rhoe_int.jl @@ -14,7 +14,7 @@ t_end = FT(60 * 60 * 24 * 10) dt = FT(400) dt_save_to_sol = FT(60 * 60 * 24) dt_save_to_disk = FT(0) # 0 means don't save to disk -ode_algorithm = OrdinaryDiffEq.Rosenbrock23 +ode_algorithm = CTS.ARS343 jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :exact, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact) additional_cache(ᶜlocal_geometry, ᶠlocal_geometry, dt) = merge( diff --git a/examples/hybrid/sphere/held_suarez_rhotheta.jl b/examples/hybrid/sphere/held_suarez_rhotheta.jl index fb775cfc16..9c22ab0d96 100644 --- a/examples/hybrid/sphere/held_suarez_rhotheta.jl +++ b/examples/hybrid/sphere/held_suarez_rhotheta.jl @@ -15,7 +15,7 @@ t_end = FT(60 * 60 * 24 * 10) dt = FT(400) dt_save_to_sol = FT(60 * 60 * 24) dt_save_to_disk = FT(0) # 0 means don't save to disk -ode_algorithm = OrdinaryDiffEq.Rosenbrock23 +ode_algorithm = CTS.ARS343 jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :exact, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact) additional_cache(ᶜlocal_geometry, ᶠlocal_geometry, dt) = merge( diff --git a/examples/hybrid/sphere/held_suarez_rhotheta_scaling.jl b/examples/hybrid/sphere/held_suarez_rhotheta_scaling.jl index c7834e42a6..391a1180c0 100644 --- a/examples/hybrid/sphere/held_suarez_rhotheta_scaling.jl +++ b/examples/hybrid/sphere/held_suarez_rhotheta_scaling.jl @@ -16,7 +16,7 @@ t_end = FT(60 * 60 * 1) dt = FT(100) dt_save_to_sol = FT(60 * 60 * 24) dt_save_to_disk = FT(0) # 0 means don't save to disk -ode_algorithm = OrdinaryDiffEq.Rosenbrock23 +ode_algorithm = CTS.ARS343 jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :exact, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact) additional_cache(ᶜlocal_geometry, ᶠlocal_geometry, dt) = merge( diff --git a/examples/hybrid/sphere/held_suarez_rhotheta_tempest.jl b/examples/hybrid/sphere/held_suarez_rhotheta_tempest.jl index 277dc88c58..2de222d8a5 100644 --- a/examples/hybrid/sphere/held_suarez_rhotheta_tempest.jl +++ b/examples/hybrid/sphere/held_suarez_rhotheta_tempest.jl @@ -14,7 +14,7 @@ t_end = FT(60 * 60 * 24 * 10) dt = FT(400) dt_save_to_sol = FT(60 * 60 * 24) dt_save_to_disk = FT(0) # 0 means don't save to disk -ode_algorithm = OrdinaryDiffEq.Rosenbrock23 +ode_algorithm = CTS.ARS343 jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :exact, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact) additional_cache(ᶜlocal_geometry, ᶠlocal_geometry, dt) = merge(