Skip to content

Commit

Permalink
Accelerators
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Nov 3, 2023
1 parent 06c3edf commit ded75fb
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,25 @@ end
"""
Performs state update with modified Nesterov momentum approach.
The dependence of the momentum parameter for variable timestep can be found e.g. here "https://www.damtp.cam.ac.uk/user/hf323/M19-OPT/lecture5.pdf"
This update takes the perspective that the EKP time step Δt = C/h where C constant, and h is the timestepping of a typical gradient method
"""
function accelerate!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
## update "v" state:
#v = u .+ 2 / (get_N_iterations(ekp) + 2) * (u .- ekp.accelerator.u_prev)
Δt_prev = length(ekp.Δt) == 1 ? 1 : ekp.Δt[end - 1]
Δt_prev = length(ekp.Δt) == 1 ? ekp.Δt[end] : ekp.Δt[end - 1]
Δt = ekp.Δt[end]
θ_prev = ekp.accelerator.θ_prev

# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = sqrt(θ_prev^2 * Δt / Δt_prev)
θ = (-a + sqrt(a^2 + 4)) / 2

# condition θ_prev^2 * (1 - θ) * h \leq h_prev * θ^2
b = θ_prev^2 * Δt / Δt_prev

θ_lowbd = (-b + sqrt(b^2 + 4*b)) / 2
θ = min((θ_lowbd + 1)/2, θ_lowbd + 1.0/length(ekp.Δt)^2) # can be unstable close to the boundary, so add a quadratic penalization

v = u .+ θ * (1 / θ_prev - 1) * (u .- ekp.accelerator.u_prev)

## update "u" state:
ekp.accelerator.u_prev = u
ekp.accelerator.θ_prev = θ
Expand All @@ -109,6 +111,7 @@ end
"""
State update method for UKI with Nesterov Accelerator.
The dependence of the momentum parameter for variable timestep can be found e.g. here "https://www.damtp.cam.ac.uk/user/hf323/M19-OPT/lecture5.pdf"
This update takes the perspective that the EKP time step Δt = C/h where C constant, and h is the timestepping of a typical gradient method
Performs identical update as with other methods, but requires reconstruction of mean and covariance of the accelerated positions prior to saving.
"""
function accelerate!(
Expand All @@ -118,14 +121,15 @@ function accelerate!(

#identical update stage as before
## update "v" state:
Δt_prev = length(uki.Δt) == 1 ? 1 : uki.Δt[end - 1]
Δt_prev = length(uki.Δt) == 1 ? uki.Δt[end] : uki.Δt[end - 1]
Δt = uki.Δt[end]
θ_prev = uki.accelerator.θ_prev


# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = sqrt(θ_prev^2 * Δt / Δt_prev)
θ = (-a + sqrt(a^2 + 4)) / 2
# condition θ_prev^2 * (1 - θ) * h \leq h_prev * θ^2
b = θ_prev^2 * Δt / Δt_prev
θ_lowbd = (-b + sqrt(b^2 + 4*b)) / 2
θ = min((θ_lowbd + 1)/2, θ_lowbd + 1.0/length(uki.Δt)^2)# can be unstable at the boundary, so add a quadratic penalization

v = u .+ θ * (1 / θ_prev - 1) * (u .- uki.accelerator.u_prev)

Expand Down

0 comments on commit ded75fb

Please sign in to comment.