Skip to content

Commit

Permalink
Unscented accelerator and added/modified unit test
Browse files Browse the repository at this point in the history
format

remove verbose

make EKS run with stable timestepper

reduce long-term release as it causes breaking tests

runtest deepcopy schedulers bugfix

unwrap prediction step when saving u_mean, uu_cov

unscented

remove empty method

src/UnscentedKalmanInversion.jl

rm print statement

change func name to accelerate

removed FT for disambiguation

back to 1.6 tests
  • Loading branch information
odunbar committed Oct 12, 2023
1 parent 17c6c46 commit 9c5acb3
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 84 deletions.
41 changes: 6 additions & 35 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# included in EnsembleKalmanProcess.jl

export DefaultAccelerator, NesterovAccelerator
export update_state!, set_initial_acceleration!
export accelerate!, set_initial_acceleration!

"""
$(TYPEDEF)
Expand Down Expand Up @@ -31,28 +31,28 @@ end
"""
Sets u_prev to the initial parameter values
"""
function set_ICs!(accelerator::NesterovAccelerator{FT}, u::MA) where {FT <: AbstractFloat, MA <: AbstractMatrix{FT}}
function set_ICs!(accelerator::NesterovAccelerator{FT}, u::MA) where {FT <: AbstractFloat, MA <: AbstractMatrix}
accelerator.u_prev = u
end


"""
Performs traditional state update with no momentum.
"""
function update_state!(
function accelerate!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
push!(ekp.u, DataContainer(u, data_are_columns = true))
end

"""
Performs state update with modified Nesterov momentum approach.
"""
function update_state!(
function accelerate!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
## update "v" state:
k = get_N_iterations(ekp) + 2
v = u .+ (1 - ekp.accelerator.r / k) * (u .- ekp.accelerator.u_prev)
Expand All @@ -63,32 +63,3 @@ function update_state!(
## push "v" state to EKP object
push!(ekp.u, DataContainer(v, data_are_columns = true))
end


"""
State update method for UKI with no acceleration.
The Accelerator framework has not yet been integrated with UKI process;
UKI tracks its own states, so this method is empty.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}

end

"""
Placeholder state update method for UKI with Nesterov Accelerator.
The Accelerator framework has not yet been integrated with UKI process, so this
method throws an error.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
throw(
ArgumentError(
"option `accelerator = NesterovAccelerator` is not implemented for UKI, please use `DefaultAccelerator`",
),
)
end
10 changes: 6 additions & 4 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ function update_ensemble!(
terminate = calculate_timestep!(ekp, g, Δt_new)
if isnothing(terminate)
u = update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
update_state!(ekp, u)
accelerate!(ekp, u)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; use_prior_cov = use_prior_cov, s = s) : nothing
Expand Down Expand Up @@ -696,11 +696,13 @@ include("SparseEnsembleKalmanInversion.jl")
export Sampler
include("EnsembleKalmanSampler.jl")

# struct Accelerator
# [Must go before Unscented include]
include("Accelerators.jl")


# struct Unscented
export Unscented
export Gaussian_2d
export construct_initial_ensemble, construct_mean, construct_cov
include("UnscentedKalmanInversion.jl")

# struct Accelerator
include("Accelerators.jl")
64 changes: 53 additions & 11 deletions src/UnscentedKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ Inputs:
$(METHODLIST)
"""
mutable struct Unscented{FT <: AbstractFloat, IT <: Int} <: Process
"an interable of arrays of size `N_parameters` containing the mean of the parameters (in each `uki` iteration a new array of mean is added)"
"an interable of arrays of size `N_parameters` containing the mean of the parameters (in each `uki` iteration a new array of mean is added), note - this is not the same as the ensemble mean of the sigma ensemble as it is taken prior to prediction"
u_mean::Any # ::Iterable{AbtractVector{FT}}
"an iterable of arrays of size (`N_parameters x N_parameters`) containing the covariance of the parameters (in each `uki` iteration a new array of `cov` is added)"
"an iterable of arrays of size (`N_parameters x N_parameters`) containing the covariance of the parameters (in each `uki` iteration a new array of `cov` is added), note - this is not the same as the ensemble cov of the sigma ensemble as it is taken prior to prediction"
uu_cov::Any # ::Iterable{AbstractMatrix{FT}}
"an iterable of arrays of size `N_y` containing the predicted observation (in each `uki` iteration a new array of predicted observation is added)"
obs_pred::Any # ::Iterable{AbstractVector{FT}}
Expand Down Expand Up @@ -226,6 +226,7 @@ function EnsembleKalmanProcess(
process::Unscented{FT, IT};
kwargs...,
) where {FT <: AbstractFloat, IT <: Int}

# use the distribution stored in process to generate initial ensemble
init_params = update_ensemble_prediction!(process, 0.0)

Expand All @@ -237,8 +238,8 @@ function FailureHandler(process::Unscented, method::IgnoreFailures)
#perform analysis on the model runs
update_ensemble_analysis!(uki, u, g)
#perform new prediction output to model parameters u_p
u = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u
u_p = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u_p
end
return FailureHandler{Unscented, IgnoreFailures}(failsafe_update)
end
Expand Down Expand Up @@ -292,17 +293,17 @@ function FailureHandler(process::Unscented, method::SampleSuccGauss)
push!(uki.process.obs_pred, g_mean) # N_ens x N_data
push!(uki.process.u_mean, u_mean) # N_ens x N_params
push!(uki.process.uu_cov, uu_cov) # N_ens x N_data

push!(uki.g, DataContainer(g, data_are_columns = true))

compute_error!(uki)

end
function failsafe_update(uki, u, g, failed_ens)
#perform analysis on the model runs
succ_gauss_analysis!(uki, u, g, failed_ens)
#perform new prediction output to model parameters u_p
u = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u
u_p = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u_p
end
return FailureHandler{Unscented, SampleSuccGauss}(failsafe_update)
end
Expand Down Expand Up @@ -355,6 +356,45 @@ function construct_sigma_ensemble(
return x
end

"""
State update method for UKI with Nesterov Accelerator.
Performs identical update as with other methods, but requires reconstruction of mean and covariance of the accelerated positions prior to saving.
"""
function accelerate!(
uki::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::AM,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, AM <: AbstractMatrix}

#identical update stage as before
## update "v" state:
k = get_N_iterations(uki) + 2
v = u .+ (1 - uki.accelerator.r / k) * (u .- uki.accelerator.u_prev)

## update "u" state:
uki.accelerator.u_prev = u

## push "v" state to UKI object
push!(uki.u, DataContainer(v, data_are_columns = true))

# additional complication: the stored "u_mean" and "uu_cov" are not the mean/cov of this ensemble
# the ensemble comes from the prediction operation acted upon this. we invert the prediction of the mean/cov of the sigma ensemble from u_mean/uu_cov
# u_mean = 1/alpha*(mean(v) - r) + r
# uu_cov = 1/alpha^2*(cov(v) - Σ_ω)
α_reg = uki.process.α_reg
r = uki.process.r
Σ_ω = uki.process.Σ_ω
Δt = uki.Δt[end]

v_mean = construct_mean(uki, v)
vv_cov = construct_cov(uki, v, v_mean)
u_mean = 1 / α_reg * (v_mean - r) + r
uu_cov = (1 / α_reg)^2 * (vv_cov - Σ_ω * Δt)

# overwrite the saved u_mean/uu_cov
uki.process.u_mean[end] = u_mean # N_ens x N_params
uki.process.uu_cov[end] = uu_cov # N_ens x N_data

end

"""
construct_mean(
Expand Down Expand Up @@ -550,7 +590,10 @@ end
UKI prediction step : generate sigma points.
"""
function update_ensemble_prediction!(process::Unscented, Δt::FT) where {FT <: AbstractFloat}
function update_ensemble_prediction!(
process::Unscented,
Δt::FT,
) where {FT <: AbstractFloat, AV <: AbstractVector, AM <: AbstractMatrix}

process.iter += 1
# update evolution covariance matrix
Expand All @@ -565,7 +608,7 @@ function update_ensemble_prediction!(process::Unscented, Δt::FT) where {FT <: A
r = process.r
Σ_ω = process.Σ_ω

N_par = length(process.u_mean[1])
N_par = length(u_mean[1])
############# Prediction step:

u_p_mean = α_reg * u_mean + (1 - α_reg) * r
Expand Down Expand Up @@ -623,10 +666,10 @@ function update_ensemble_analysis!(
push!(uki.process.obs_pred, g_mean) # N_ens x N_data
push!(uki.process.u_mean, u_mean) # N_ens x N_params
push!(uki.process.uu_cov, uu_cov) # N_ens x N_data

push!(uki.g, DataContainer(g, data_are_columns = true))

compute_error!(uki)

end

"""
Expand Down Expand Up @@ -677,7 +720,6 @@ function update_ensemble!(

u_p = fh.failsafe_update(uki, u_p_old, g_in, failed_ens)

push!(uki.u, DataContainer(u_p, data_are_columns = true))

if uki.verbose
cov_new = get_u_cov_final(uki)
Expand Down
Loading

0 comments on commit 9c5acb3

Please sign in to comment.