Skip to content

Commit

Permalink
Define Accelerator type
Browse files Browse the repository at this point in the history
Move EKP state update to Accelerator function

initial setup of Nesterov momentum

fixes for indexing, typos, constructors

undo accidental test changes

accelerator struct fixes

sanity-check comparison on simple problem

fixed index shift, added function to set accelerator ICs

add exp sin example

convergence plots

reproduced multi-trial convergence results on exp sin IP

fix bug with accelerator setup in default case

visualize momentum acceleration in EKI,EKS processes

darcy in progress

accelerator work and unit tests

formatting

undo formatting

Move EKP state update to Accelerator function

initial setup of Nesterov momentum

fixes for indexing, typos, constructors

undo accidental test changes

accelerator struct fixes

sanity-check comparison on simple problem

fixed index shift, added function to set accelerator ICs

add exp sin example

convergence plots

reproduced multi-trial convergence results on exp sin IP

fix bug with accelerator setup in default case

visualize momentum acceleration in EKI,EKS processes

darcy in progress

accelerator work and unit tests

ignore UKI, fix tests

Delete examples/LearningRateSchedulers/compare_schedulers_accelerated.jl

Delete examples/Sinusoid/exp_sin_multi_comparison.pdf

Delete examples/Sinusoid/exp_sin.pdf

Delete examples/Sinusoid/exp_sin_.pdf

Delete examples/Sinusoid/exp_sin_eki.pdf

Delete examples/Sinusoid/exp_sin_eks.pdf

Delete examples/Sinusoid/exp_sin_multi_comparison_a.pdf

Delete examples/Sinusoid/exp_sin_multi_comparison_b.pdf

Delete examples/Sinusoid/exp_sin_multi_comparison_c.pdf

Delete examples/Sinusoid/exp_sin_narrow.pdf

Delete examples/Sinusoid/exp_sin_targeted.pdf

Delete examples/Sinusoid/exp_sin_shifted.pdf

Delete examples/Sinusoid/exp_sin_wide.pdf

Delete examples/Sinusoid/exp_sinusoid_example_accelerated.jl

Delete examples/Sinusoid/exp_sinusoid_example_comparison.jl

Delete examples/Sinusoid/sinusoid_example_accelerated.jl

Delete test/Accelerators directory

Accelerator tests will be condensed with EKP tests

Delete output/ensembles_acc.pdf

Delete output/error_vs_spread_over_iteration_acc.pdf

Delete output/error_vs_spread_over_time_acc.pdf

Delete exp_sin_.pdf

restore original Project.toml

Delete examples/Darcy/darcy_accelerated.jl

changed test @info printing, formatting

Delete output/ensembles.pdf

Delete output/error_vs_spread_over_iteration.pdf

Delete output/error_vs_spread_over_time.pdf

Delete examples/Darcy/output/data_storage.jld2

Delete examples/Darcy/output/parameter_storage.jld2

fix test file

include EKTI

test code coverage fixes

code cleanup

Warning for accelerated EKS process
  • Loading branch information
sydneyvernon committed Oct 10, 2023
1 parent 15fad65 commit cd0df60
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["StableRNGs", "Test", "Plots"]
test = ["StableRNGs", "Test", "Plots"]
94 changes: 94 additions & 0 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# included in EnsembleKalmanProcess.jl

export DefaultAccelerator, NesterovAccelerator
export update_state!, set_initial_acceleration!

"""
$(TYPEDEF)
Default accelerator provides no acceleration, runs traditional EKI
"""
struct DefaultAccelerator <: Accelerator end

"""
$(TYPEDEF)
Accelerator that adapts Nesterov's momentum method for EKI.
Stores a previous state value u_prev for computational purposes (note this is distinct from state returned as "ensemble value")
$(TYPEDFIELDS)
"""
mutable struct NesterovAccelerator{FT <: AbstractFloat} <: Accelerator
r::FT
u_prev::Any
end

function NesterovAccelerator(r = 3.0, initial = Float64[])
return NesterovAccelerator(r, initial)
end


"""
Sets u_prev to the initial parameter values
"""
function set_ICs!(accelerator::NesterovAccelerator{FT}, u::MA) where {FT <: AbstractFloat, MA <: AbstractMatrix{FT}}
accelerator.u_prev = u
end


"""
Performs traditional state update with no momentum.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
push!(ekp.u, DataContainer(u, data_are_columns = true))
end

"""
Performs state update with modified Nesterov momentum approach.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
## update "v" state:
k = get_N_iterations(ekp) + 2
v = u .+ (1 - ekp.accelerator.r / k) * (u .- ekp.accelerator.u_prev)

## update "u" state:
ekp.accelerator.u_prev = u

## push "v" state to EKP object
push!(ekp.u, DataContainer(v, data_are_columns = true))
end


"""
State update method for UKI with no acceleration.
The Accelerator framework has not yet been integrated with UKI process;
UKI tracks its own states, so this method is empty.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}

end

"""
Placeholder state update method for UKI with Nesterov Accelerator.
The Accelerator framework has not yet been integrated with UKI process, so this
method throws an error.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
throw(
ArgumentError(
"option `accelerator = NesterovAccelerator` is not implemented for UKI, please use `DefaultAccelerator`",
),
)
end
6 changes: 3 additions & 3 deletions src/EnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,17 @@ function update_ensemble!(

u = fh.failsafe_update(ekp, u, g, y, scaled_obs_noise_cov, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))

# Store error
compute_error!(ekp)

# Diagnostics
cov_new = cov(get_u_final(ekp), dims = 2)
cov_new = cov(u, dims = 2)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end

return u
end
52 changes: 46 additions & 6 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export get_u, get_g, get_ϕ
export get_u_prior, get_u_final, get_g_final, get_ϕ_final
export get_N_iterations, get_error, get_cov_blocks
export get_u_mean, get_u_cov, get_g_mean, get_ϕ_mean
export get_u_mean_final, get_u_cov_prior, get_u_cov_final, get_g_mean_final, get_ϕ_mean_final
export get_u_mean_final, get_u_cov_prior, get_u_cov_final, get_g_mean_final, get_ϕ_mean_final, get_accelerator
export compute_error!
export update_ensemble!
export sample_empirical_gaussian, split_indices_by_success
Expand All @@ -29,6 +29,9 @@ abstract type LearningRateScheduler end
# Failure handlers
abstract type FailureHandlingMethod end

# Accelerators
abstract type Accelerator end



"Failure handling method that ignores forward model failures"
Expand Down Expand Up @@ -104,7 +107,13 @@ Inputs:
$(METHODLIST)
"""
struct EnsembleKalmanProcess{FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler}
struct EnsembleKalmanProcess{
FT <: AbstractFloat,
IT <: Int,
P <: Process,
LRS <: LearningRateScheduler,
ACC <: Accelerator,
}
"array of stores for parameters (`u`), each of size [`N_par × N_ens`]"
u::Array{DataContainer{FT}}
"vector of the observed vector size [`N_obs`]"
Expand All @@ -119,6 +128,8 @@ struct EnsembleKalmanProcess{FT <: AbstractFloat, IT <: Int, P <: Process, LRS <
err::Vector{FT}
"Scheduler to calculate the timestep size in each EK iteration"
scheduler::LRS
"accelerator object that informs EK update steps, stores additional state variables as needed"
accelerator::ACC
"stored vector of timesteps used in each EK iteration"
Δt::Vector{FT}
"the particular EK process (`Inversion` or `Sampler` or `Unscented` or `TransformInversion` or `SparseInversion`)"
Expand All @@ -139,6 +150,7 @@ function EnsembleKalmanProcess(
obs_noise_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}},
process::P;
scheduler::Union{Nothing, LRS} = nothing,
accelerator::Union{Nothing, ACC} = nothing,
Δt = nothing,
rng::AbstractRNG = Random.GLOBAL_RNG,
failure_handler_method::FM = IgnoreFailures(),
Expand All @@ -147,6 +159,7 @@ function EnsembleKalmanProcess(
) where {
FT <: AbstractFloat,
LRS <: LearningRateScheduler,
ACC <: Accelerator,
P <: Process,
FM <: FailureHandlingMethod,
LM <: LocalizationMethod,
Expand Down Expand Up @@ -193,23 +206,39 @@ function EnsembleKalmanProcess(
# timestep store
Δt = FT[]

# set up accelerator
if isnothing(accelerator)
acc = DefaultAccelerator()
else
acc = accelerator
end
AC = typeof(acc)

if AC <: NesterovAccelerator
set_ICs!(acc, params)
if P <: Sampler
@warn "Acceleration is experimental for Sampler processes and may affect convergence."
end
end

# failure handler
fh = FailureHandler(process, failure_handler_method)
# localizer
loc = Localizer(localization_method, N_par, N_obs, N_ens, FT)

if verbose
@info "Initializing ensemble Kalman process of type $(nameof(typeof(process)))\nNumber of ensemble members: $(N_ens)\nLocalization: $(nameof(typeof(localization_method)))\nFailure handler: $(nameof(typeof(failure_handler_method)))\nScheduler: $(nameof(typeof(lrs)))"
@info "Initializing ensemble Kalman process of type $(nameof(typeof(process)))\nNumber of ensemble members: $(N_ens)\nLocalization: $(nameof(typeof(localization_method)))\nFailure handler: $(nameof(typeof(failure_handler_method)))\nScheduler: $(nameof(typeof(lrs)))\nAccelerator: $(nameof(typeof(acc)))"

Check warning on line 230 in src/EnsembleKalmanProcess.jl

View check run for this annotation

Codecov / codecov/patch

src/EnsembleKalmanProcess.jl#L230

Added line #L230 was not covered by tests
end

EnsembleKalmanProcess{FT, IT, P, RS}(
EnsembleKalmanProcess{FT, IT, P, RS, AC}(
[init_params],
obs_mean,
obs_noise_cov,
N_ens,
g,
err,
lrs,
acc,
Δt,
process,
rng,
Expand All @@ -222,7 +251,6 @@ end

include("LearningRateSchedulers.jl")


"""
get_u(ekp::EnsembleKalmanProcess, iteration::IT; return_array=true) where {IT <: Integer}
Expand Down Expand Up @@ -423,6 +451,14 @@ function get_scheduler(ekp::EnsembleKalmanProcess)
return ekp.scheduler
end

"""
get_accelerator(ekp::EnsembleKalmanProcess)
Return accelerator type of EnsembleKalmanProcess.
"""
function get_accelerator(ekp::EnsembleKalmanProcess)
return ekp.accelerator
end


"""
construct_initial_ensemble(
Expand Down Expand Up @@ -628,7 +664,8 @@ function update_ensemble!(

terminate = calculate_timestep!(ekp, g, Δt_new)
if isnothing(terminate)
update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
u = update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
update_state!(ekp, u)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; use_prior_cov = use_prior_cov, s = s) : nothing
Expand Down Expand Up @@ -664,3 +701,6 @@ export Unscented
export Gaussian_2d
export construct_initial_ensemble, construct_mean, construct_cov
include("UnscentedKalmanInversion.jl")

# struct Accelerator
include("Accelerators.jl")
5 changes: 3 additions & 2 deletions src/EnsembleKalmanSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,18 @@ function update_ensemble!(
u = fh.failsafe_update(ekp, u_old, g, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))
# u_old is N_ens × N_par, g is N_ens × N_obs,
# but stored in data container with N_ens as the 2nd dim

compute_error!(ekp)

# Diagnostics
cov_new = get_u_cov_final(ekp)
cov_new = cov(u, dims = 2)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end

return u
end
5 changes: 3 additions & 2 deletions src/EnsembleTransformKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,17 @@ function update_ensemble!(
u = fh.failsafe_update(ekp, u, g, y, scaled_obs_noise_cov, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))

# Store error
compute_error!(ekp)

# Diagnostics
cov_new = cov(get_u_final(ekp), dims = 2)
cov_new = cov(u, dims = 2)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end

return u
end
3 changes: 2 additions & 1 deletion src/SparseEnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,13 @@ function update_ensemble!(
u = fh.failsafe_update(ekp, u, g, y, scaled_obs_noise_cov, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))

# Store error
compute_error!(ekp)

# Check convergence
cov_new = cov(get_u_final(ekp), dims = 2)

return u
end
Loading

0 comments on commit cd0df60

Please sign in to comment.