Skip to content

Commit

Permalink
move Unscented update into accelerators
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Oct 12, 2023
1 parent 9c5acb3 commit 0d7dfe4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 43 deletions.
41 changes: 41 additions & 0 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,44 @@ function accelerate!(
## push "v" state to EKP object
push!(ekp.u, DataContainer(v, data_are_columns = true))
end


"""
State update method for UKI with Nesterov Accelerator.
Performs identical update as with other methods, but requires reconstruction of mean and covariance of the accelerated positions prior to saving.
"""
function accelerate!(
uki::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::AM,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, AM <: AbstractMatrix}

#identical update stage as before
## update "v" state:
k = get_N_iterations(uki) + 2
v = u .+ (1 - uki.accelerator.r / k) * (u .- uki.accelerator.u_prev)

## update "u" state:
uki.accelerator.u_prev = u

## push "v" state to UKI object
push!(uki.u, DataContainer(v, data_are_columns = true))

# additional complication: the stored "u_mean" and "uu_cov" are not the mean/cov of this ensemble
# the ensemble comes from the prediction operation acted upon this. we invert the prediction of the mean/cov of the sigma ensemble from u_mean/uu_cov
# u_mean = 1/alpha*(mean(v) - r) + r
# uu_cov = 1/alpha^2*(cov(v) - Σ_ω)
α_reg = uki.process.α_reg
r = uki.process.r
Σ_ω = uki.process.Σ_ω
Δt = uki.Δt[end]

v_mean = construct_mean(uki, v)
vv_cov = construct_cov(uki, v, v_mean)
u_mean = 1 / α_reg * (v_mean - r) + r
uu_cov = (1 / α_reg)^2 * (vv_cov - Σ_ω * Δt)

# overwrite the saved u_mean/uu_cov
uki.process.u_mean[end] = u_mean # N_ens x N_params
uki.process.uu_cov[end] = uu_cov # N_ens x N_data

end
8 changes: 4 additions & 4 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -696,13 +696,13 @@ include("SparseEnsembleKalmanInversion.jl")
export Sampler
include("EnsembleKalmanSampler.jl")

# struct Accelerator
# [Must go before Unscented include]
include("Accelerators.jl")


# struct Unscented
export Unscented
export Gaussian_2d
export construct_initial_ensemble, construct_mean, construct_cov
include("UnscentedKalmanInversion.jl")


# struct Accelerator
include("Accelerators.jl")
39 changes: 0 additions & 39 deletions src/UnscentedKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,45 +356,6 @@ function construct_sigma_ensemble(
return x
end

"""
State update method for UKI with Nesterov Accelerator.
Performs identical update as with other methods, but requires reconstruction of mean and covariance of the accelerated positions prior to saving.
"""
function accelerate!(
uki::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::AM,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, AM <: AbstractMatrix}

#identical update stage as before
## update "v" state:
k = get_N_iterations(uki) + 2
v = u .+ (1 - uki.accelerator.r / k) * (u .- uki.accelerator.u_prev)

## update "u" state:
uki.accelerator.u_prev = u

## push "v" state to UKI object
push!(uki.u, DataContainer(v, data_are_columns = true))

# additional complication: the stored "u_mean" and "uu_cov" are not the mean/cov of this ensemble
# the ensemble comes from the prediction operation acted upon this. we invert the prediction of the mean/cov of the sigma ensemble from u_mean/uu_cov
# u_mean = 1/alpha*(mean(v) - r) + r
# uu_cov = 1/alpha^2*(cov(v) - Σ_ω)
α_reg = uki.process.α_reg
r = uki.process.r
Σ_ω = uki.process.Σ_ω
Δt = uki.Δt[end]

v_mean = construct_mean(uki, v)
vv_cov = construct_cov(uki, v, v_mean)
u_mean = 1 / α_reg * (v_mean - r) + r
uu_cov = (1 / α_reg)^2 * (vv_cov - Σ_ω * Δt)

# overwrite the saved u_mean/uu_cov
uki.process.u_mean[end] = u_mean # N_ens x N_params
uki.process.uu_cov[end] = uu_cov # N_ens x N_data

end

"""
construct_mean(
Expand Down

0 comments on commit 0d7dfe4

Please sign in to comment.