From ded75fb386015b8c577f7a99a67f590ba9a1c609 Mon Sep 17 00:00:00 2001 From: odunbar Date: Fri, 3 Nov 2023 13:01:17 -0700 Subject: [PATCH] Accelerators --- src/Accelerators.jl | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/Accelerators.jl b/src/Accelerators.jl index a55ab8613..40d58bd99 100644 --- a/src/Accelerators.jl +++ b/src/Accelerators.jl @@ -80,23 +80,25 @@ end """ Performs state update with modified Nesterov momentum approach. The dependence of the momentum parameter for variable timestep can be found e.g. here "https://www.damtp.cam.ac.uk/user/hf323/M19-OPT/lecture5.pdf" +This update takes the perspective that the EKP time step Δt = C/h where C constant, and h is the timestepping of a typical gradient method """ function accelerate!( ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}}, u::MA, ) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix} ## update "v" state: - #v = u .+ 2 / (get_N_iterations(ekp) + 2) * (u .- ekp.accelerator.u_prev) - Δt_prev = length(ekp.Δt) == 1 ? 1 : ekp.Δt[end - 1] + Δt_prev = length(ekp.Δt) == 1 ? ekp.Δt[end] : ekp.Δt[end - 1] Δt = ekp.Δt[end] θ_prev = ekp.accelerator.θ_prev - # condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2 - a = sqrt(θ_prev^2 * Δt / Δt_prev) - θ = (-a + sqrt(a^2 + 4)) / 2 - + # condition θ_prev^2 * (1 - θ) * h \leq h_prev * θ^2 + b = θ_prev^2 * Δt / Δt_prev + + θ_lowbd = (-b + sqrt(b^2 + 4*b)) / 2 + θ = min((θ_lowbd + 1)/2, θ_lowbd + 1.0/length(ekp.Δt)^2) # can be unstable close to the boundary, so add a quadratic penalization + v = u .+ θ * (1 / θ_prev - 1) * (u .- ekp.accelerator.u_prev) - + ## update "u" state: ekp.accelerator.u_prev = u ekp.accelerator.θ_prev = θ @@ -109,6 +111,7 @@ end """ State update method for UKI with Nesterov Accelerator. The dependence of the momentum parameter for variable timestep can be found e.g. here "https://www.damtp.cam.ac.uk/user/hf323/M19-OPT/lecture5.pdf" +This update takes the perspective that the EKP time step Δt = C/h where C constant, and h is the timestepping of a typical gradient method Performs identical update as with other methods, but requires reconstruction of mean and covariance of the accelerated positions prior to saving. """ function accelerate!( @@ -118,14 +121,15 @@ function accelerate!( #identical update stage as before ## update "v" state: - Δt_prev = length(uki.Δt) == 1 ? 1 : uki.Δt[end - 1] + Δt_prev = length(uki.Δt) == 1 ? uki.Δt[end] : uki.Δt[end - 1] Δt = uki.Δt[end] θ_prev = uki.accelerator.θ_prev - # condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2 - a = sqrt(θ_prev^2 * Δt / Δt_prev) - θ = (-a + sqrt(a^2 + 4)) / 2 + # condition θ_prev^2 * (1 - θ) * h \leq h_prev * θ^2 + b = θ_prev^2 * Δt / Δt_prev + θ_lowbd = (-b + sqrt(b^2 + 4*b)) / 2 + θ = min((θ_lowbd + 1)/2, θ_lowbd + 1.0/length(uki.Δt)^2)# can be unstable at the boundary, so add a quadratic penalization v = u .+ θ * (1 / θ_prev - 1) * (u .- uki.accelerator.u_prev)