Skip to content

Commit

Permalink
Merge #332
Browse files Browse the repository at this point in the history
332: Unscented accelerator  r=odunbar a=odunbar

<!--- THESE LINES ARE COMMENTED -->
## Purpose 
<!--- One sentence to describe the purpose of this PR, refer to any linked issues:
#14 -- this will link to issue 14
Closes #2 -- this will automatically close issue 2 on PR merge
-->
Closes #333


## Content
<!---  specific tasks that are currently complete 
- Solution implemented
-->
- Adds the accelerator method to `Unscented`. This implementation moves the sigma points (as it is the "ensemble") and then saves the reconstructed the mean and covariance from the new sigma point locations. Note one additional subtlety, the saved "u_mean" and "uu_cov" are not just the mean and cov of the sigma points, and so are back calculated
- Adds unit tests & removed some unnecessary ensemble members / iterations in the accelerator test
- renames the ambiguous `update_state!` to `accelerate!`

## Unit test output

```
Accelerator: DefaultAccelerator Process: Inversion
┌ Info: Convergence:
│   cost_initial = 48.24023862169107
└   cost_final = 4.489933318021739
Accelerator: NesterovAccelerator Process: Inversion
┌ Info: Convergence:
│   cost_initial = 48.24023862169107
└   cost_final = 5.6012259995349245
Accelerator: DefaultAccelerator Process: TransformInversion
┌ Info: Convergence:
│   cost_initial = 48.24023862169107
└   cost_final = 4.434232345531697
Accelerator: NesterovAccelerator Process: TransformInversion
┌ Info: Convergence:
│   cost_initial = 48.24023862169107
└   cost_final = 4.955048376140534
Accelerator: DefaultAccelerator Process: Unscented
┌ Info: Convergence:
│   cost_initial = 53.28933708166123
└   cost_final = 5.84078511368908
Accelerator: NesterovAccelerator Process: Unscented
┌ Info: Convergence:
│   cost_initial = 53.28933708166123
└   cost_final = 6.29489453592627
Accelerator: DefaultAccelerator Process: Sampler
┌ Info: Convergence:
│   cost_initial = 48.24023862169107
└   cost_final = 14.943442588378828
Accelerator: NesterovAccelerator Process: Sampler
┌ Warning: Acceleration is experimental for Sampler processes and may affect convergence.
└ @ EnsembleKalmanProcesses ~/work/EnsembleKalmanProcesses.jl/EnsembleKalmanProcesses.jl/src/EnsembleKalmanProcess.jl:220
┌ Info: Convergence:
│   cost_initial = 48.24023862169107
└   cost_final = 15.733653014250967
```

<!---
Review checklist

I have:
- followed the codebase contribution guide: https://clima.github.io/ClimateMachine.jl/latest/Contributing/
- followed the style guide: https://clima.github.io/ClimateMachine.jl/latest/DevDocs/CodeStyle/
- followed the documentation policy: https://github.com/CliMA/policies/wiki/Documentation-Policy
- checked that this PR does not duplicate an open PR.

In the Content, I have included 
- relevant unit tests, and integration tests, 
- appropriate docstrings on all functions, structs, and modules, and included relevant documentation.

-->

----
- [x] I have read and checked the items on the review checklist.


Co-authored-by: odunbar <odunbar@caltech.edu>
  • Loading branch information
bors[bot] and odunbar authored Oct 13, 2023
2 parents 17c6c46 + 8f1e3df commit 52d14d0
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 74 deletions.
68 changes: 40 additions & 28 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# included in EnsembleKalmanProcess.jl

export DefaultAccelerator, NesterovAccelerator
export update_state!, set_initial_acceleration!
export accelerate!, set_initial_acceleration!

"""
$(TYPEDEF)
Expand Down Expand Up @@ -31,28 +31,28 @@ end
"""
Sets u_prev to the initial parameter values
"""
function set_ICs!(accelerator::NesterovAccelerator{FT}, u::MA) where {FT <: AbstractFloat, MA <: AbstractMatrix{FT}}
function set_ICs!(accelerator::NesterovAccelerator{FT}, u::MA) where {FT <: AbstractFloat, MA <: AbstractMatrix}
accelerator.u_prev = u
end


"""
Performs traditional state update with no momentum.
"""
function update_state!(
function accelerate!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
push!(ekp.u, DataContainer(u, data_are_columns = true))
end

"""
Performs state update with modified Nesterov momentum approach.
"""
function update_state!(
function accelerate!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
## update "v" state:
k = get_N_iterations(ekp) + 2
v = u .+ (1 - ekp.accelerator.r / k) * (u .- ekp.accelerator.u_prev)
Expand All @@ -66,29 +66,41 @@ end


"""
State update method for UKI with no acceleration.
The Accelerator framework has not yet been integrated with UKI process;
UKI tracks its own states, so this method is empty.
State update method for UKI with Nesterov Accelerator.
Performs identical update as with other methods, but requires reconstruction of mean and covariance of the accelerated positions prior to saving.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
function accelerate!(
uki::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::AM,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, AM <: AbstractMatrix}

end
#identical update stage as before
## update "v" state:
k = get_N_iterations(uki) + 2
v = u .+ (1 - uki.accelerator.r / k) * (u .- uki.accelerator.u_prev)

## update "u" state:
uki.accelerator.u_prev = u

## push "v" state to UKI object
push!(uki.u, DataContainer(v, data_are_columns = true))

# additional complication: the stored "u_mean" and "uu_cov" are not the mean/cov of this ensemble
# the ensemble comes from the prediction operation acted upon this. we invert the prediction of the mean/cov of the sigma ensemble from u_mean/uu_cov
# u_mean = 1/alpha*(mean(v) - r) + r
# uu_cov = 1/alpha^2*(cov(v) - Σ_ω)
α_reg = uki.process.α_reg
r = uki.process.r
Σ_ω = uki.process.Σ_ω
Δt = uki.Δt[end]

v_mean = construct_mean(uki, v)
vv_cov = construct_cov(uki, v, v_mean)
u_mean = 1 / α_reg * (v_mean - r) + r
uu_cov = (1 / α_reg)^2 * (vv_cov - Σ_ω * Δt)

# overwrite the saved u_mean/uu_cov
uki.process.u_mean[end] = u_mean # N_ens x N_params
uki.process.uu_cov[end] = uu_cov # N_ens x N_data

"""
Placeholder state update method for UKI with Nesterov Accelerator.
The Accelerator framework has not yet been integrated with UKI process, so this
method throws an error.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
throw(
ArgumentError(
"option `accelerator = NesterovAccelerator` is not implemented for UKI, please use `DefaultAccelerator`",
),
)
end
4 changes: 3 additions & 1 deletion src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ function update_ensemble!(
terminate = calculate_timestep!(ekp, g, Δt_new)
if isnothing(terminate)
u = update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
update_state!(ekp, u)
accelerate!(ekp, u)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; use_prior_cov = use_prior_cov, s = s) : nothing
Expand Down Expand Up @@ -696,11 +696,13 @@ include("SparseEnsembleKalmanInversion.jl")
export Sampler
include("EnsembleKalmanSampler.jl")


# struct Unscented
export Unscented
export Gaussian_2d
export construct_initial_ensemble, construct_mean, construct_cov
include("UnscentedKalmanInversion.jl")


# struct Accelerator
include("Accelerators.jl")
25 changes: 14 additions & 11 deletions src/UnscentedKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ Inputs:
$(METHODLIST)
"""
mutable struct Unscented{FT <: AbstractFloat, IT <: Int} <: Process
"an interable of arrays of size `N_parameters` containing the mean of the parameters (in each `uki` iteration a new array of mean is added)"
"an interable of arrays of size `N_parameters` containing the mean of the parameters (in each `uki` iteration a new array of mean is added), note - this is not the same as the ensemble mean of the sigma ensemble as it is taken prior to prediction"
u_mean::Any # ::Iterable{AbtractVector{FT}}
"an iterable of arrays of size (`N_parameters x N_parameters`) containing the covariance of the parameters (in each `uki` iteration a new array of `cov` is added)"
"an iterable of arrays of size (`N_parameters x N_parameters`) containing the covariance of the parameters (in each `uki` iteration a new array of `cov` is added), note - this is not the same as the ensemble cov of the sigma ensemble as it is taken prior to prediction"
uu_cov::Any # ::Iterable{AbstractMatrix{FT}}
"an iterable of arrays of size `N_y` containing the predicted observation (in each `uki` iteration a new array of predicted observation is added)"
obs_pred::Any # ::Iterable{AbstractVector{FT}}
Expand Down Expand Up @@ -226,6 +226,7 @@ function EnsembleKalmanProcess(
process::Unscented{FT, IT};
kwargs...,
) where {FT <: AbstractFloat, IT <: Int}

# use the distribution stored in process to generate initial ensemble
init_params = update_ensemble_prediction!(process, 0.0)

Expand All @@ -237,8 +238,8 @@ function FailureHandler(process::Unscented, method::IgnoreFailures)
#perform analysis on the model runs
update_ensemble_analysis!(uki, u, g)
#perform new prediction output to model parameters u_p
u = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u
u_p = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u_p
end
return FailureHandler{Unscented, IgnoreFailures}(failsafe_update)
end
Expand Down Expand Up @@ -292,17 +293,17 @@ function FailureHandler(process::Unscented, method::SampleSuccGauss)
push!(uki.process.obs_pred, g_mean) # N_ens x N_data
push!(uki.process.u_mean, u_mean) # N_ens x N_params
push!(uki.process.uu_cov, uu_cov) # N_ens x N_data

push!(uki.g, DataContainer(g, data_are_columns = true))

compute_error!(uki)

end
function failsafe_update(uki, u, g, failed_ens)
#perform analysis on the model runs
succ_gauss_analysis!(uki, u, g, failed_ens)
#perform new prediction output to model parameters u_p
u = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u
u_p = update_ensemble_prediction!(uki.process, uki.Δt[end])
return u_p
end
return FailureHandler{Unscented, SampleSuccGauss}(failsafe_update)
end
Expand Down Expand Up @@ -550,7 +551,10 @@ end
UKI prediction step : generate sigma points.
"""
function update_ensemble_prediction!(process::Unscented, Δt::FT) where {FT <: AbstractFloat}
function update_ensemble_prediction!(
process::Unscented,
Δt::FT,
) where {FT <: AbstractFloat, AV <: AbstractVector, AM <: AbstractMatrix}

process.iter += 1
# update evolution covariance matrix
Expand All @@ -565,7 +569,7 @@ function update_ensemble_prediction!(process::Unscented, Δt::FT) where {FT <: A
r = process.r
Σ_ω = process.Σ_ω

N_par = length(process.u_mean[1])
N_par = length(u_mean[1])
############# Prediction step:

u_p_mean = α_reg * u_mean + (1 - α_reg) * r
Expand Down Expand Up @@ -623,10 +627,10 @@ function update_ensemble_analysis!(
push!(uki.process.obs_pred, g_mean) # N_ens x N_data
push!(uki.process.u_mean, u_mean) # N_ens x N_params
push!(uki.process.uu_cov, uu_cov) # N_ens x N_data

push!(uki.g, DataContainer(g, data_are_columns = true))

compute_error!(uki)

end

"""
Expand Down Expand Up @@ -677,7 +681,6 @@ function update_ensemble!(

u_p = fh.failsafe_update(uki, u_p_old, g_in, failed_ens)

push!(uki.u, DataContainer(u_p, data_are_columns = true))

if uki.verbose
cov_new = get_u_cov_final(uki)
Expand Down
Loading

0 comments on commit 52d14d0

Please sign in to comment.