diff --git a/src/EnsembleKalmanProcess.jl b/src/EnsembleKalmanProcess.jl index 1f2b355de..c901b37f3 100644 --- a/src/EnsembleKalmanProcess.jl +++ b/src/EnsembleKalmanProcess.jl @@ -121,11 +121,11 @@ struct EnsembleKalmanProcess{FT <: AbstractFloat, IT <: Int, P <: Process, LRS < scheduler::LRS "stored vector of timesteps used in each EK iteration" Δt::Vector{FT} - "the particular EK process (`Inversion` or `Sampler` or `Unscented` or `SparseInversion`)" + "the particular EK process (`Inversion` or `Sampler` or `Unscented` or `TransformInversion` or `SparseInversion`)" process::P "Random number generator object (algorithm + seed) used for sampling and noise, for reproducibility. Defaults to `Random.GLOBAL_RNG`." rng::AbstractRNG - "struct storing failsafe update directives, implemented for (`Inversion`, `SparseInversion`, `Unscented`)" + "struct storing failsafe update directives, implemented for (`Inversion`, `SparseInversion`, `Unscented`, `TransformInversion`)" failure_handler::FailureHandler "Localization kernel, implemented for (`Inversion`, `SparseInversion`, `Unscented`)" localizer::Localizer @@ -165,6 +165,10 @@ function EnsembleKalmanProcess( # error store err = FT[] + if (typeof(process) <: TransformInversion) & !(typeof(localization_method) == NoLocalization) + throw(ArgumentError("`TransformInversion` cannot currently be used with localization.")) + end + # set the timestep methods (being cautious of EKS scheduler) if isnothing(scheduler) if !(isnothing(Δt)) @@ -643,6 +647,10 @@ end export Inversion include("EnsembleKalmanInversion.jl") +# struct TransformInversion +export TransformInversion +include("EnsembleTransformKalmanInversion.jl") + # struct SparseInversion export SparseInversion include("SparseEnsembleKalmanInversion.jl") diff --git a/src/EnsembleTransformKalmanInversion.jl b/src/EnsembleTransformKalmanInversion.jl new file mode 100644 index 000000000..0640c09a7 --- /dev/null +++ b/src/EnsembleTransformKalmanInversion.jl @@ -0,0 +1,139 @@ +#Ensemble Transform Kalman Inversion: specific structures and function definitions + +""" + TransformInversion <: Process + +An ensemble transform Kalman inversion process. + +# Fields + +$(TYPEDFIELDS) +""" +struct TransformInversion{FT <: AbstractFloat} <: Process + "Inverse of the observation error covariance matrix" + Γ_inv::Union{AbstractMatrix{FT}, UniformScaling{FT}} +end + +function FailureHandler(process::TransformInversion, method::IgnoreFailures) + failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens) = etki_update(ekp, u, g, y, obs_noise_cov) + return FailureHandler{TransformInversion, IgnoreFailures}(failsafe_update) +end + +""" + FailureHandler(process::TransformInversion, method::SampleSuccGauss) + +Provides a failsafe update that + - updates the successful ensemble according to the ETKI update, + - updates the failed ensemble by sampling from the updated successful ensemble. +""" +function FailureHandler(process::TransformInversion, method::SampleSuccGauss) + function failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens) + successful_ens = filter(x -> !(x in failed_ens), collect(1:size(g, 2))) + n_failed = length(failed_ens) + u[:, successful_ens] = etki_update(ekp, u[:, successful_ens], g[:, successful_ens], y, obs_noise_cov) + if !isempty(failed_ens) + u[:, failed_ens] = sample_empirical_gaussian(u[:, successful_ens], n_failed) + end + return u + end + return FailureHandler{TransformInversion, SampleSuccGauss}(failsafe_update) +end + +""" + etki_update( + ekp::EnsembleKalmanProcess{FT, IT, TransformInversion}, + u::AbstractMatrix{FT}, + g::AbstractMatrix{FT}, + y::AbstractVector{FT}, + obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}, + ) where {FT <: Real, IT, CT <: Real} + +Returns the updated parameter vectors given their current values and +the corresponding forward model evaluations. +""" +function etki_update( + ekp::EnsembleKalmanProcess{FT, IT, TransformInversion{FT}}, + u::AbstractMatrix{FT}, + g::AbstractMatrix{FT}, + y::AbstractVector{FT}, + obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}, +) where {FT <: Real, IT, CT <: Real} + m = size(u, 2) + Γ_inv = ekp.process.Γ_inv + + X = FT.((u .- mean(u, dims = 2)) / sqrt(m - 1)) + Y = FT.((g .- mean(g, dims = 2)) / sqrt(m - 1)) + Ω = inv(I + Y' * Γ_inv * Y) + w = FT.(Ω * Y' * Γ_inv * (y .- mean(g, dims = 2))) + + return mean(u, dims = 2) .+ X * (w .+ sqrt(m - 1) * real(sqrt(Ω))) # [N_par × N_ens] +end + +""" + update_ensemble!( + ekp::EnsembleKalmanProcess{FT, IT, TransformInversion}, + g::AbstractMatrix{FT}, + process::TransformInversion; + failed_ens = nothing, + ) where {FT, IT} + +Updates the ensemble according to a TransformInversion process. + +Inputs: + - ekp :: The EnsembleKalmanProcess to update. + - g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms). + - process :: Type of the EKP. + - failed_ens :: Indices of failed particles. If nothing, failures are computed as columns of `g` with NaN entries. +""" +function update_ensemble!( + ekp::EnsembleKalmanProcess{FT, IT, TransformInversion{FT}}, + g::AbstractMatrix{FT}, + process::TransformInversion{FT}; + failed_ens = nothing, +) where {FT, IT} + + # u: N_par × N_ens + # g: N_obs × N_ens + u = get_u_final(ekp) + N_obs = size(g, 1) + cov_init = cov(u, dims = 2) + + if ekp.verbose + if get_N_iterations(ekp) == 0 + @info "Iteration 0 (prior)" + @info "Covariance trace: $(tr(cov_init))" + end + + @info "Iteration $(get_N_iterations(ekp)+1) (T=$(sum(ekp.Δt)))" + end + + fh = ekp.failure_handler + + # Scale noise using Δt + scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end] + + y = ekp.obs_mean + + if isnothing(failed_ens) + _, failed_ens = split_indices_by_success(g) + end + if !isempty(failed_ens) + @info "$(length(failed_ens)) particle failure(s) detected. Handler used: $(nameof(typeof(fh).parameters[2]))." + end + + u = fh.failsafe_update(ekp, u, g, y, scaled_obs_noise_cov, failed_ens) + + # store new parameters (and model outputs) + push!(ekp.u, DataContainer(u, data_are_columns = true)) + push!(ekp.g, DataContainer(g, data_are_columns = true)) + + # Store error + compute_error!(ekp) + + # Diagnostics + cov_new = cov(get_u_final(ekp), dims = 2) + + if ekp.verbose + @info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))" + end +end diff --git a/test/EnsembleKalmanProcess/runtests.jl b/test/EnsembleKalmanProcess/runtests.jl index 30f693255..9fc644dff 100644 --- a/test/EnsembleKalmanProcess/runtests.jl +++ b/test/EnsembleKalmanProcess/runtests.jl @@ -607,6 +607,177 @@ end end end +@testset "EnsembleTransformKalmanInversion" begin + + # Seed for pseudo-random number generator + rng = Random.MersenneTwister(rng_seed) + + initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens) + + ekiobj = nothing + eki_final_result = nothing + iters_with_failure = [5, 8, 9, 15] + + for (i_prob, inv_problem) in enumerate(inv_problems) + + # Get inverse problem + y_obs, G, Γy, A = inv_problem + if i_prob == 1 + scheduler = DataMisfitController(on_terminate = "continue") + else + scheduler = DefaultScheduler() + end + + ekiobj = EKP.EnsembleKalmanProcess( + initial_ensemble, + y_obs, + Γy, + TransformInversion(inv(Γy)); + rng = rng, + failure_handler_method = SampleSuccGauss(), + scheduler = scheduler, + ) + + ekiobj_unsafe = EKP.EnsembleKalmanProcess( + initial_ensemble, + y_obs, + Γy, + TransformInversion(inv(Γy)); + rng = rng, + failure_handler_method = IgnoreFailures(), + scheduler = scheduler, + ) + + + g_ens = G(get_ϕ_final(prior, ekiobj)) + g_ens_t = permutedims(g_ens, (2, 1)) + + @test size(g_ens) == (n_obs, N_ens) + + # ETKI iterations + u_i_vec = Array{Float64, 2}[] + g_ens_vec = Array{Float64, 2}[] + for i in 1:N_iter + params_i = get_ϕ_final(prior, ekiobj) + push!(u_i_vec, get_u_final(ekiobj)) + g_ens = G(params_i) + + # Add random failures + if i in iters_with_failure + g_ens[:, 1] .= NaN + end + + EKP.update_ensemble!(ekiobj, g_ens) + push!(g_ens_vec, g_ens) + if i == 1 + if !(size(g_ens, 1) == size(g_ens, 2)) + g_ens_t = permutedims(g_ens, (2, 1)) + @test_throws DimensionMismatch EKP.update_ensemble!(ekiobj, g_ens_t) + end + end + + # Correct handling of failures + @test !any(isnan.(params_i)) + + # Check IgnoreFailures handler + if i <= iters_with_failure[1] + params_i_unsafe = get_ϕ_final(prior, ekiobj_unsafe) + g_ens_unsafe = G(params_i_unsafe) + if i < iters_with_failure[1] + EKP.update_ensemble!(ekiobj_unsafe, g_ens_unsafe) + elseif i == iters_with_failure[1] + g_ens_unsafe[:, 1] .= NaN + #inconsistent behaviour before/after v1.9 regarding NaNs in matrices + if (VERSION.major >= 1) && (VERSION.minor >= 9) + # new versions the NaNs break LinearAlgebra.jl + @test_throws ArgumentError EKP.update_ensemble!(ekiobj_unsafe, g_ens_unsafe) + else + # old versions the NaNs pass through LinearAlgebra.jl + EKP.update_ensemble!(ekiobj_unsafe, g_ens_unsafe) + u_unsafe = get_u_final(ekiobj_unsafe) + # Propagation of unhandled failures + @test any(isnan.(u_unsafe)) + end + end + end + end + + push!(u_i_vec, get_u_final(ekiobj)) + + @test get_u_prior(ekiobj) == u_i_vec[1] + @test get_u(ekiobj) == u_i_vec + @test isequal(get_g(ekiobj), g_ens_vec) + @test isequal(get_g_final(ekiobj), g_ens_vec[end]) + @test isequal(get_error(ekiobj), ekiobj.err) + + # ETKI results: Test if ensemble has collapsed toward the true parameter + # values + eki_init_result = vec(mean(get_u_prior(ekiobj), dims = 2)) + eki_final_result = get_u_mean_final(ekiobj) + eki_init_spread = tr(get_u_cov(ekiobj, 1)) + eki_final_spread = tr(get_u_cov_final(ekiobj)) + + g_mean_init = get_g_mean(ekiobj, 1) + g_mean_final = get_g_mean_final(ekiobj) + + @test eki_init_result == get_u_mean(ekiobj, 1) + @test eki_final_result == vec(mean(get_u_final(ekiobj), dims = 2)) + + @test eki_final_spread < 2 * eki_init_spread # we wouldn't expect the spread to increase much in any one dimension + + ϕ_final_mean = get_ϕ_mean_final(prior, ekiobj) + ϕ_init_mean = get_ϕ_mean(prior, ekiobj, 1) + + if nameof(typeof(ekiobj.localizer)) == EKP.Localizers.NoLocalization + @test norm(ϕ_star - ϕ_final_mean) < norm(ϕ_star - ϕ_init_mean) + @test norm(y_obs .- G(eki_final_result))^2 < norm(y_obs .- G(eki_init_result))^2 + @test norm(y_obs .- g_mean_final)^2 < norm(y_obs .- g_mean_init)^2 + end + + if i_prob <= n_lin_inv_probs && nameof(typeof(ekiobj.localizer)) == EKP.Localizers.NoLocalization + + posterior_cov_inv = (A' * (Γy \ A) + 1 * Matrix(I, n_par, n_par) / prior_cov) + ols_mean = (A' * (Γy \ A)) \ (A' * (Γy \ y_obs)) + posterior_mean = posterior_cov_inv \ ((A' * (Γy \ A)) * ols_mean + (prior_cov \ prior_mean)) + + # ETKI provides a solution closer to the ordinary Least Squares estimate + @test norm(ols_mean - ϕ_final_mean) < norm(ols_mean - ϕ_init_mean) + end + + # Plot evolution of the ETKI particles + if TEST_PLOT_OUTPUT + plot_inv_problem_ensemble(prior, ekiobj, joinpath(@__DIR__, "ETKI_test_$(i_prob).png")) + end + end + + for (i, n_obs_test) in enumerate([10, 10, 100, 1000, 10000]) + initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens) + + y_obs_test, G_test, Γ_test, A_test = + linear_inv_problem(ϕ_star, noise_level, n_obs_test, rng; return_matrix = true) + + ekiobj = EKP.EnsembleKalmanProcess( + initial_ensemble, + y_obs_test, + Γ_test, + TransformInversion(inv(Γ_test)); + rng = rng, + failure_handler_method = SampleSuccGauss(), + ) + T = 0.0 + for i in 1:N_iter + params_i = get_ϕ_final(prior, ekiobj) + g_ens = G_test(params_i) + + dt = @elapsed EKP.update_ensemble!(ekiobj, g_ens) + T += dt + end + # Skip timing of first due to precompilation + if i >= 2 + @info "ETKI with $n_obs_test observations took $T seconds." + end + end +end @testset "EnsembleKalmanProcess utils" begin # Success/failure splitting