Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensemble tranform Kalman inversion #329

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ struct EnsembleKalmanProcess{FT <: AbstractFloat, IT <: Int, P <: Process, LRS <
scheduler::LRS
"stored vector of timesteps used in each EK iteration"
Δt::Vector{FT}
"the particular EK process (`Inversion` or `Sampler` or `Unscented` or `SparseInversion`)"
"the particular EK process (`Inversion` or `Sampler` or `Unscented` or `TransformInversion` or `SparseInversion`)"
process::P
"Random number generator object (algorithm + seed) used for sampling and noise, for reproducibility. Defaults to `Random.GLOBAL_RNG`."
rng::AbstractRNG
"struct storing failsafe update directives, implemented for (`Inversion`, `SparseInversion`, `Unscented`)"
"struct storing failsafe update directives, implemented for (`Inversion`, `SparseInversion`, `Unscented`, `TransformInversion`)"
failure_handler::FailureHandler
"Localization kernel, implemented for (`Inversion`, `SparseInversion`, `Unscented`)"
localizer::Localizer
Expand Down Expand Up @@ -165,6 +165,10 @@ function EnsembleKalmanProcess(
# error store
err = FT[]

if (typeof(process) <: TransformInversion) & !(typeof(localization_method) == NoLocalization)
throw(ArgumentError("`TransformInversion` cannot currently be used with localization."))
end

# set the timestep methods (being cautious of EKS scheduler)
if isnothing(scheduler)
if !(isnothing(Δt))
Expand Down Expand Up @@ -643,6 +647,10 @@ end
export Inversion
include("EnsembleKalmanInversion.jl")

# struct TransformInversion
export TransformInversion
include("EnsembleTransformKalmanInversion.jl")

# struct SparseInversion
export SparseInversion
include("SparseEnsembleKalmanInversion.jl")
Expand Down
139 changes: 139 additions & 0 deletions src/EnsembleTransformKalmanInversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#Ensemble Transform Kalman Inversion: specific structures and function definitions

"""
TransformInversion <: Process

An ensemble transform Kalman inversion process.

# Fields

$(TYPEDFIELDS)
"""
struct TransformInversion{FT <: AbstractFloat} <: Process
"Inverse of the observation error covariance matrix"
Γ_inv::Union{AbstractMatrix{FT}, UniformScaling{FT}}
end

function FailureHandler(process::TransformInversion, method::IgnoreFailures)
failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens) = etki_update(ekp, u, g, y, obs_noise_cov)
return FailureHandler{TransformInversion, IgnoreFailures}(failsafe_update)
end

"""
FailureHandler(process::TransformInversion, method::SampleSuccGauss)

Provides a failsafe update that
- updates the successful ensemble according to the ETKI update,
- updates the failed ensemble by sampling from the updated successful ensemble.
"""
function FailureHandler(process::TransformInversion, method::SampleSuccGauss)
function failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens)
successful_ens = filter(x -> !(x in failed_ens), collect(1:size(g, 2)))
n_failed = length(failed_ens)
u[:, successful_ens] = etki_update(ekp, u[:, successful_ens], g[:, successful_ens], y, obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(u[:, successful_ens], n_failed)
end
return u
end
return FailureHandler{TransformInversion, SampleSuccGauss}(failsafe_update)
end

"""
etki_update(
ekp::EnsembleKalmanProcess{FT, IT, TransformInversion},
u::AbstractMatrix{FT},
g::AbstractMatrix{FT},
y::AbstractVector{FT},
obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}},
) where {FT <: Real, IT, CT <: Real}

Returns the updated parameter vectors given their current values and
the corresponding forward model evaluations.
"""
function etki_update(
ekp::EnsembleKalmanProcess{FT, IT, TransformInversion{FT}},
u::AbstractMatrix{FT},
g::AbstractMatrix{FT},
y::AbstractVector{FT},
obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}},
) where {FT <: Real, IT, CT <: Real}
m = size(u, 2)
Γ_inv = ekp.process.Γ_inv

X = FT.((u .- mean(u, dims = 2)) / sqrt(m - 1))
Y = FT.((g .- mean(g, dims = 2)) / sqrt(m - 1))
Ω = inv(I + Y' * Γ_inv * Y)
w = FT.(Ω * Y' * Γ_inv * (y .- mean(g, dims = 2)))

return mean(u, dims = 2) .+ X * (w .+ sqrt(m - 1) * real(sqrt(Ω))) # [N_par × N_ens]
end

"""
update_ensemble!(
ekp::EnsembleKalmanProcess{FT, IT, TransformInversion},
g::AbstractMatrix{FT},
process::TransformInversion;
failed_ens = nothing,
) where {FT, IT}

Updates the ensemble according to a TransformInversion process.

Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- process :: Type of the EKP.
- failed_ens :: Indices of failed particles. If nothing, failures are computed as columns of `g` with NaN entries.
"""
function update_ensemble!(
ekp::EnsembleKalmanProcess{FT, IT, TransformInversion{FT}},
g::AbstractMatrix{FT},
process::TransformInversion{FT};
failed_ens = nothing,
) where {FT, IT}

# u: N_par × N_ens
# g: N_obs × N_ens
u = get_u_final(ekp)
N_obs = size(g, 1)
cov_init = cov(u, dims = 2)

if ekp.verbose
if get_N_iterations(ekp) == 0
@info "Iteration 0 (prior)"
@info "Covariance trace: $(tr(cov_init))"
end

@info "Iteration $(get_N_iterations(ekp)+1) (T=$(sum(ekp.Δt)))"
end

fh = ekp.failure_handler

# Scale noise using Δt
scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end]

y = ekp.obs_mean

if isnothing(failed_ens)
_, failed_ens = split_indices_by_success(g)
end
if !isempty(failed_ens)
@info "$(length(failed_ens)) particle failure(s) detected. Handler used: $(nameof(typeof(fh).parameters[2]))."
end

u = fh.failsafe_update(ekp, u, g, y, scaled_obs_noise_cov, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))

# Store error
compute_error!(ekp)

# Diagnostics
cov_new = cov(get_u_final(ekp), dims = 2)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end
end
165 changes: 165 additions & 0 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,171 @@ end
end
end

@testset "EnsembleTransformKalmanInversion" begin

# Seed for pseudo-random number generator
rng = Random.MersenneTwister(rng_seed)

initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)

ekiobj = nothing
eki_final_result = nothing
iters_with_failure = [5, 8, 9, 15]

for (i_prob, inv_problem) in enumerate(inv_problems)

# Get inverse problem
y_obs, G, Γy, A = inv_problem
if i_prob == 1
scheduler = DataMisfitController(on_terminate = "continue")
else
scheduler = DefaultScheduler()
end

ekiobj = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
TransformInversion(inv(Γy));
rng = rng,
failure_handler_method = SampleSuccGauss(),
scheduler = scheduler,
)

ekiobj_unsafe = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
TransformInversion(inv(Γy));
rng = rng,
failure_handler_method = IgnoreFailures(),
scheduler = scheduler,
)


g_ens = G(get_ϕ_final(prior, ekiobj))
g_ens_t = permutedims(g_ens, (2, 1))

@test size(g_ens) == (n_obs, N_ens)

# ETKI iterations
u_i_vec = Array{Float64, 2}[]
g_ens_vec = Array{Float64, 2}[]
for i in 1:N_iter
params_i = get_ϕ_final(prior, ekiobj)
push!(u_i_vec, get_u_final(ekiobj))
g_ens = G(params_i)

# Add random failures
if i in iters_with_failure
g_ens[:, 1] .= NaN
end

EKP.update_ensemble!(ekiobj, g_ens)
push!(g_ens_vec, g_ens)
if i == 1
if !(size(g_ens, 1) == size(g_ens, 2))
g_ens_t = permutedims(g_ens, (2, 1))
@test_throws DimensionMismatch EKP.update_ensemble!(ekiobj, g_ens_t)
end
end

# Correct handling of failures
@test !any(isnan.(params_i))

# Check IgnoreFailures handler
if i <= iters_with_failure[1]
params_i_unsafe = get_ϕ_final(prior, ekiobj_unsafe)
g_ens_unsafe = G(params_i_unsafe)
if i < iters_with_failure[1]
EKP.update_ensemble!(ekiobj_unsafe, g_ens_unsafe)
elseif i == iters_with_failure[1]
g_ens_unsafe[:, 1] .= NaN
#inconsistent behaviour before/after v1.9 regarding NaNs in matrices
if (VERSION.major >= 1) && (VERSION.minor >= 9)
# new versions the NaNs break LinearAlgebra.jl
@test_throws ArgumentError EKP.update_ensemble!(ekiobj_unsafe, g_ens_unsafe)
end
end
end
end

push!(u_i_vec, get_u_final(ekiobj))

@test get_u_prior(ekiobj) == u_i_vec[1]
@test get_u(ekiobj) == u_i_vec
@test isequal(get_g(ekiobj), g_ens_vec)
@test isequal(get_g_final(ekiobj), g_ens_vec[end])
@test isequal(get_error(ekiobj), ekiobj.err)

# ETKI results: Test if ensemble has collapsed toward the true parameter
# values
eki_init_result = vec(mean(get_u_prior(ekiobj), dims = 2))
eki_final_result = get_u_mean_final(ekiobj)
eki_init_spread = tr(get_u_cov(ekiobj, 1))
eki_final_spread = tr(get_u_cov_final(ekiobj))

g_mean_init = get_g_mean(ekiobj, 1)
g_mean_final = get_g_mean_final(ekiobj)

@test eki_init_result == get_u_mean(ekiobj, 1)
@test eki_final_result == vec(mean(get_u_final(ekiobj), dims = 2))

@test eki_final_spread < 2 * eki_init_spread # we wouldn't expect the spread to increase much in any one dimension

ϕ_final_mean = get_ϕ_mean_final(prior, ekiobj)
ϕ_init_mean = get_ϕ_mean(prior, ekiobj, 1)

if nameof(typeof(ekiobj.localizer)) == EKP.Localizers.NoLocalization
@test norm(ϕ_star - ϕ_final_mean) < norm(ϕ_star - ϕ_init_mean)
@test norm(y_obs .- G(eki_final_result))^2 < norm(y_obs .- G(eki_init_result))^2
@test norm(y_obs .- g_mean_final)^2 < norm(y_obs .- g_mean_init)^2
end

if i_prob <= n_lin_inv_probs && nameof(typeof(ekiobj.localizer)) == EKP.Localizers.NoLocalization

posterior_cov_inv = (A' * (Γy \ A) + 1 * Matrix(I, n_par, n_par) / prior_cov)
ols_mean = (A' * (Γy \ A)) \ (A' * (Γy \ y_obs))
posterior_mean = posterior_cov_inv \ ((A' * (Γy \ A)) * ols_mean + (prior_cov \ prior_mean))

# ETKI provides a solution closer to the ordinary Least Squares estimate
@test norm(ols_mean - ϕ_final_mean) < norm(ols_mean - ϕ_init_mean)
end

# Plot evolution of the ETKI particles
if TEST_PLOT_OUTPUT
plot_inv_problem_ensemble(prior, ekiobj, joinpath(@__DIR__, "ETKI_test_$(i_prob).png"))
end
end

for (i, n_obs_test) in enumerate([10, 10, 100, 1000, 10000])
initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)

y_obs_test, G_test, Γ_test, A_test =
linear_inv_problem(ϕ_star, noise_level, n_obs_test, rng; return_matrix = true)

ekiobj = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs_test,
Γ_test,
TransformInversion(inv(Γ_test));
rng = rng,
failure_handler_method = SampleSuccGauss(),
)
T = 0.0
for i in 1:N_iter
params_i = get_ϕ_final(prior, ekiobj)
g_ens = G_test(params_i)

dt = @elapsed EKP.update_ensemble!(ekiobj, g_ens)
T += dt
end
# Skip timing of first due to precompilation
if i >= 2
@info "ETKI with $n_obs_test observations took $T seconds."
end
end
end

@testset "EnsembleKalmanProcess utils" begin
# Success/failure splitting
Expand Down