Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unscented accelerator #332

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# included in EnsembleKalmanProcess.jl

export DefaultAccelerator, NesterovAccelerator
export update_state!, set_initial_acceleration!
export accelerate!, set_initial_acceleration!

"""
$(TYPEDEF)
Expand Down Expand Up @@ -31,28 +31,28 @@ end
"""
Sets u_prev to the initial parameter values
"""
function set_ICs!(accelerator::NesterovAccelerator{FT}, u::MA) where {FT <: AbstractFloat, MA <: AbstractMatrix{FT}}
function set_ICs!(accelerator::NesterovAccelerator{FT}, u::MA) where {FT <: AbstractFloat, MA <: AbstractMatrix}
accelerator.u_prev = u
end


"""
Performs traditional state update with no momentum.
"""
function update_state!(
function accelerate!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
push!(ekp.u, DataContainer(u, data_are_columns = true))
end

"""
Performs state update with modified Nesterov momentum approach.
"""
function update_state!(
function accelerate!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
## update "v" state:
k = get_N_iterations(ekp) + 2
v = u .+ (1 - ekp.accelerator.r / k) * (u .- ekp.accelerator.u_prev)
Expand All @@ -66,29 +66,41 @@ end


"""
State update method for UKI with no acceleration.
The Accelerator framework has not yet been integrated with UKI process;
UKI tracks its own states, so this method is empty.
State update method for UKI with Nesterov Accelerator.
Performs identical update as with other methods, but requires reconstruction of mean and covariance of the accelerated positions prior to saving.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
function accelerate!(
uki::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::AM,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, AM <: AbstractMatrix}

end
#identical update stage as before
## update "v" state:
k = get_N_iterations(uki) + 2
v = u .+ (1 - uki.accelerator.r / k) * (u .- uki.accelerator.u_prev)

## update "u" state:
uki.accelerator.u_prev = u

## push "v" state to UKI object
push!(uki.u, DataContainer(v, data_are_columns = true))

# additional complication: the stored "u_mean" and "uu_cov" are not the mean/cov of this ensemble
# the ensemble comes from the prediction operation acted upon this. we invert the prediction of the mean/cov of the sigma ensemble from u_mean/uu_cov
# u_mean = 1/alpha*(mean(v) - r) + r
# uu_cov = 1/alpha^2*(cov(v) - Σ_ω)
α_reg = uki.process.α_reg
r = uki.process.r
Σ_ω = uki.process.Σ_ω
Δt = uki.Δt[end]

v_mean = construct_mean(uki, v)
vv_cov = construct_cov(uki, v, v_mean)
u_mean = 1 / α_reg * (v_mean - r) + r
uu_cov = (1 / α_reg)^2 * (vv_cov - Σ_ω * Δt)

# overwrite the saved u_mean/uu_cov
uki.process.u_mean[end] = u_mean # N_ens x N_params
uki.process.uu_cov[end] = uu_cov # N_ens x N_data

"""
Placeholder state update method for UKI with Nesterov Accelerator.
The Accelerator framework has not yet been integrated with UKI process, so this
method throws an error.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
throw(
ArgumentError(
"option `accelerator = NesterovAccelerator` is not implemented for UKI, please use `DefaultAccelerator`",
),
)
end
4 changes: 3 additions & 1 deletion src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ function update_ensemble!(
terminate = calculate_timestep!(ekp, g, Δt_new)
if isnothing(terminate)
u = update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
update_state!(ekp, u)
accelerate!(ekp, u)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; use_prior_cov = use_prior_cov, s = s) : nothing
Expand Down Expand Up @@ -696,11 +696,13 @@ include("SparseEnsembleKalmanInversion.jl")
export Sampler
include("EnsembleKalmanSampler.jl")


# struct Unscented
export Unscented
export Gaussian_2d
export construct_initial_ensemble, construct_mean, construct_cov
include("UnscentedKalmanInversion.jl")


# struct Accelerator
include("Accelerators.jl")
25 changes: 14 additions & 11 deletions src/UnscentedKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ Inputs:
$(METHODLIST)
"""
mutable struct Unscented{FT <: AbstractFloat, IT <: Int} <: Process
"an interable of arrays of size `N_parameters` containing the mean of the parameters (in each `uki` iteration a new array of mean is added)"
"an interable of arrays of size `N_parameters` containing the mean of the parameters (in each `uki` iteration a new array of mean is added), note - this is not the same as the ensemble mean of the sigma ensemble as it is taken prior to prediction"
u_mean::Any # ::Iterable{AbtractVector{FT}}
"an iterable of arrays of size (`N_parameters x N_parameters`) containing the covariance of the parameters (in each `uki` iteration a new array of `cov` is added)"
"an iterable of arrays of size (`N_parameters x N_parameters`) containing the covariance of the parameters (in each `uki` iteration a new array of `cov` is added), note - this is not the same as the ensemble cov of the sigma ensemble as it is taken prior to prediction"
uu_cov::Any # ::Iterable{AbstractMatrix{FT}}
"an iterable of arrays of size `N_y` containing the predicted observation (in each `uki` iteration a new array of predicted observation is added)"
obs_pred::Any # ::Iterable{AbstractVector{FT}}
Expand Down Expand Up @@ -226,6 +226,7 @@ function EnsembleKalmanProcess(
process::Unscented{FT, IT};
kwargs...,
) where {FT <: AbstractFloat, IT <: Int}

# use the distribution stored in process to generate initial ensemble
init_params = update_ensemble_prediction!(process, 0.0)

Expand All @@ -237,8 +238,8 @@ function FailureHandler(process::Unscented, method::IgnoreFailures)
#perform analysis on the model runs
update_ensemble_analysis!(uki, u, g)
#perform new prediction output to model parameters u_p
u = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u
u_p = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u_p
end
return FailureHandler{Unscented, IgnoreFailures}(failsafe_update)
end
Expand Down Expand Up @@ -292,17 +293,17 @@ function FailureHandler(process::Unscented, method::SampleSuccGauss)
push!(uki.process.obs_pred, g_mean) # N_ens x N_data
push!(uki.process.u_mean, u_mean) # N_ens x N_params
push!(uki.process.uu_cov, uu_cov) # N_ens x N_data

push!(uki.g, DataContainer(g, data_are_columns = true))

compute_error!(uki)

end
function failsafe_update(uki, u, g, failed_ens)
#perform analysis on the model runs
succ_gauss_analysis!(uki, u, g, failed_ens)
#perform new prediction output to model parameters u_p
u = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u
u_p = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u_p
end
return FailureHandler{Unscented, SampleSuccGauss}(failsafe_update)
end
Expand Down Expand Up @@ -550,7 +551,10 @@ end

UKI prediction step : generate sigma points.
"""
function update_ensemble_prediction!(process::Unscented, Δt::FT) where {FT <: AbstractFloat}
function update_ensemble_prediction!(
process::Unscented,
Δt::FT,
) where {FT <: AbstractFloat, AV <: AbstractVector, AM <: AbstractMatrix}

process.iter += 1
# update evolution covariance matrix
Expand All @@ -565,7 +569,7 @@ function update_ensemble_prediction!(process::Unscented, Δt::FT) where {FT <: A
r = process.r
Σ_ω = process.Σ_ω

N_par = length(process.u_mean[1])
N_par = length(u_mean[1])
############# Prediction step:

u_p_mean = α_reg * u_mean + (1 - α_reg) * r
Expand Down Expand Up @@ -623,10 +627,10 @@ function update_ensemble_analysis!(
push!(uki.process.obs_pred, g_mean) # N_ens x N_data
push!(uki.process.u_mean, u_mean) # N_ens x N_params
push!(uki.process.uu_cov, uu_cov) # N_ens x N_data

push!(uki.g, DataContainer(g, data_are_columns = true))

compute_error!(uki)

end

"""
Expand Down Expand Up @@ -677,7 +681,6 @@ function update_ensemble!(

u_p = fh.failsafe_update(uki, u_p_old, g_in, failed_ens)

push!(uki.u, DataContainer(u_p, data_are_columns = true))

if uki.verbose
cov_new = get_u_cov_final(uki)
Expand Down
Loading