From 34f81f7be61690e0e0a8d6ce5a65b1ab498778df Mon Sep 17 00:00:00 2001 From: odunbar Date: Fri, 3 Nov 2023 13:01:17 -0700 Subject: [PATCH] Accelerators remove docstring format --- src/Accelerators.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/Accelerators.jl b/src/Accelerators.jl index a55ab8613..b11f29a59 100644 --- a/src/Accelerators.jl +++ b/src/Accelerators.jl @@ -86,14 +86,15 @@ function accelerate!( u::MA, ) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix} ## update "v" state: - #v = u .+ 2 / (get_N_iterations(ekp) + 2) * (u .- ekp.accelerator.u_prev) - Δt_prev = length(ekp.Δt) == 1 ? 1 : ekp.Δt[end - 1] + Δt_prev = length(ekp.Δt) == 1 ? ekp.Δt[end] : ekp.Δt[end - 1] Δt = ekp.Δt[end] θ_prev = ekp.accelerator.θ_prev - # condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2 - a = sqrt(θ_prev^2 * Δt / Δt_prev) - θ = (-a + sqrt(a^2 + 4)) / 2 + # condition θ_prev^2 * (1 - θ) * h \leq h_prev * θ^2 + b = θ_prev^2 * Δt / Δt_prev + + θ_lowbd = (-b + sqrt(b^2 + 4 * b)) / 2 + θ = min((θ_lowbd + 1) / 2, θ_lowbd + 1.0 / length(ekp.Δt)^2) # can be unstable close to the boundary, so add a quadratic penalization v = u .+ θ * (1 / θ_prev - 1) * (u .- ekp.accelerator.u_prev) @@ -118,14 +119,15 @@ function accelerate!( #identical update stage as before ## update "v" state: - Δt_prev = length(uki.Δt) == 1 ? 1 : uki.Δt[end - 1] + Δt_prev = length(uki.Δt) == 1 ? uki.Δt[end] : uki.Δt[end - 1] Δt = uki.Δt[end] θ_prev = uki.accelerator.θ_prev - # condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2 - a = sqrt(θ_prev^2 * Δt / Δt_prev) - θ = (-a + sqrt(a^2 + 4)) / 2 + # condition θ_prev^2 * (1 - θ) * h \leq h_prev * θ^2 + b = θ_prev^2 * Δt / Δt_prev + θ_lowbd = (-b + sqrt(b^2 + 4 * b)) / 2 + θ = min((θ_lowbd + 1) / 2, θ_lowbd + 1.0 / length(uki.Δt)^2) # can be unstable at the boundary, so add a quadratic penalization v = u .+ θ * (1 / θ_prev - 1) * (u .- uki.accelerator.u_prev)