Skip to content

remove positive definiteness constraints, allow user defined additive inflation #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions docs/src/inflation.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,25 @@ Multiplicative inflation can be used by flagging the `update_ensemble!` method a
```

## Additive Inflation
Additive inflation is implemented by systematically adding stochastic perturbations to the parameter ensemble in the form of Gaussian noise. Additive inflation breaks the linear subspace property, meaning the parameter ensemble can evolve outside of the span of the initial ensemble. In additive inflation, the ensemble is perturbed in the following manner after the standard Kalman update:
Additive inflation is implemented by systematically adding stochastic perturbations to the parameter ensemble in the form of Gaussian noise. Additive inflation is capable of breaking the linear subspace property, meaning the parameter ensemble can evolve outside of the span of the current ensemble. In additive inflation, the ensemble is perturbed in the following manner after the standard Kalman update:

```math
u_{n+1} = u_n + \zeta_{n} \qquad (3) \\
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_n) \qquad (4)
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} \Sigma) \qquad (4)
```
This can be seen as a stochastic modification of the ensemble covariance, while the mean remains fixed
```math
C_{n + 1} = C_{n} + \frac{s \Delta{t} }{1 - s \Delta{t}} \Sigma \qquad (5)
```
This inflates the parameter covariance by a factor of ``\frac{1}{1 - s \Delta{t}}`` as in eqn. 2 , while the ensemble mean remains fixed.

Additive inflation can be used by flagging the `update_ensemble!` method as follows:
For example, if ``\Sigma = C_{n}`` we see inflation that is statistically equivalent to scaling the parameter covariance by a factor of ``\frac{1}{1 - s \Delta{t}}`` as in eqn. 2.

Additive inflation, by default takes ``\Sigma = C_0`` (the prior covariance), and can be used by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 1.0)
```
Alternatively, the prior covariance matrix may be used to generate additive noise, following:
```math
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_{0}) \qquad (5)
```
This results in an additive increase in the parameter covariance by `` \frac{s \Delta{t} }{1 - s \Delta{t}} * C_{0}`` , while the mean remains fixed.
```math
C_{n + 1} = C_{n} + \frac{s \Delta{t} }{1 - s \Delta{t}} C_{0} \qquad (6)
```

Additive inflation using the scaled prior covariance (parameter covariance of initial ensemble) can be used by flagging the `update_ensemble!` method as follows:
Any positive semi-definite matrix (or uniform scaling) ``\Sigma`` may be provided to generate additive noise to the ensemble by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, use_prior_cov = true, s = 1.0)
Σ = 0.01*I # user defined inflation
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, additive_inflation_cov = Σ, s = 1.0)
```
4 changes: 2 additions & 2 deletions src/EnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function FailureHandler(process::Inversion, method::SampleSuccGauss)
u[:, successful_ens] =
eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(u[:, successful_ens], n_failed)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed)
end
return u
end
Expand Down Expand Up @@ -123,7 +123,7 @@ function update_ensemble!(

# Scale noise using Δt
scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end]
noise = rand(ekp.rng, MvNormal(zeros(N_obs), scaled_obs_noise_cov), ekp.N_ens)
noise = sqrt(scaled_obs_noise_cov) * rand(ekp.rng, MvNormal(zeros(N_obs), I), ekp.N_ens)

# Add obs_mean (N_obs) to each column of noise (N_obs × N_ens) if
# G is deterministic
Expand Down
51 changes: 32 additions & 19 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,16 @@ get_error(ekp::EnsembleKalmanProcess) = ekp.err

"""
sample_empirical_gaussian(
rng::AbstractRNG,
u::AbstractMatrix{FT},
n::IT;
inflation::Union{FT, Nothing} = nothing,
) where {FT <: Real, IT <: Int}

Returns `n` samples from an empirical Gaussian based on point estimates `u`, adding inflation
if the covariance is singular.
Returns `n` samples from an empirical Gaussian based on point estimates `u`, adding inflation if the covariance is singular.
"""
function sample_empirical_gaussian(
rng::AbstractRNG,
u::AbstractMatrix{FT},
n::IT;
inflation::Union{FT, Nothing} = nothing,
Expand All @@ -525,9 +526,18 @@ function sample_empirical_gaussian(
cov_u_new = cov_u_new + inflation * I
end
mean_u_new = mean(u, dims = 2)
return rand(MvNormal(mean_u_new[:], cov_u_new), n)
return mean_u_new .+ sqrt(cov_u_new) * rand(rng, MvNormal(zeros(length(mean_u_new[:])), I), n)
end

function sample_empirical_gaussian(
u::AbstractMatrix{FT},
n::IT;
inflation::Union{FT, Nothing} = nothing,
) where {FT <: Real, IT <: Int}
return sample_empirical_gaussian(Random.GLOBAL_RNG, u, n, inflation = inflation)
end


"""
split_indices_by_success(g::AbstractMatrix{FT}) where {FT <: Real}

Expand Down Expand Up @@ -588,33 +598,35 @@ end

"""
additive_inflation!(
ekp::EnsembleKalmanProcess;
use_prior_cov::Bool = false,
ekp::EnsembleKalmanProcess
inflation_cov::AM;
s::FT = 1.0,
) where {FT <: Real}
Applies additive Gaussian noise to particles. Noise is drawn from normal distribution with 0 mean
and scaled parameter covariance. If use_prior_cov=false (default), scales parameter covariance matrix from
current ekp iteration. Otherwise, scales parameter covariance of initial ensemble.
and scaled parameter covariance. The original parameter covariance is a provided matrix, assumed positive semi-definite.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for time step in additive perturbation.
- use_prior_cov :: Bool specifying whether to use prior covariance estimate for additive inflation.
If false (default), parameter covariance from the current iteration is used.
- inflation_cov :: AbstractMatrix provide a N_par x N_par matrix to use.
"""
function additive_inflation!(ekp::EnsembleKalmanProcess; use_prior_cov::Bool = false, s::FT = 1.0) where {FT <: Real}
function additive_inflation!(
ekp::EnsembleKalmanProcess,
inflation_cov::MorUS;
s::FT = 1.0,
) where {FT <: Real, MorUS <: Union{AbstractMatrix, UniformScaling}}

scaled_Δt = s * ekp.Δt[end]

if scaled_Δt >= 1.0
error(string("Scaled time step: ", scaled_Δt, " is >= 1.0", "\nChange s or EK time step."))
end

Σ = use_prior_cov ? get_u_cov_prior(ekp) : get_u_cov_final(ekp)

u = get_u_final(ekp)

Σ_sqrt = sqrt(scaled_Δt / (1 - scaled_Δt) .* inflation_cov)

# add multivariate noise with 0 mean and scaled covariance
noise_multivariate = MvNormal((scaled_Δt / (1 - scaled_Δt)) .* Σ)
u_updated = u + rand(noise_multivariate, size(u, 2))
u_updated = u .+ Σ_sqrt * rand(ekp.rng, MvNormal(zeros(size(u, 1)), I), size(u, 2))
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)
end

Expand All @@ -627,7 +639,7 @@ end
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
use_prior_cov::Bool = false,
additive_inflation_cov::MorUS = get_u_cov_prior(ekp),
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}
Expand All @@ -637,7 +649,7 @@ Inputs:
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- multiplicative_inflation :: Flag indicating whether to use multiplicative inflation.
- additive_inflation :: Flag indicating whether to use additive inflation.
- use_prior_cov :: Bool specifying whether to use prior covariance estimate for additive inflation.
- additive_inflation_cov :: specifying an additive inflation matrix (default is the prior covariance) assumed positive semi-definite
If false (default), parameter covariance from the current iteration is used.
- s :: Scaling factor for time step in inflation step.
- ekp_kwargs :: Keyword arguments to pass to standard ekp update_ensemble!.
Expand All @@ -647,11 +659,11 @@ function update_ensemble!(
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
use_prior_cov::Bool = false,
additive_inflation_cov::MorUS = get_u_cov_prior(ekp),
s::FT = 0.0,
Δt_new::NFT = nothing,
ekp_kwargs...,
) where {FT, NFT <: Union{Nothing, AbstractFloat}}
) where {FT, NFT <: Union{Nothing, AbstractFloat}, MorUS <: Union{AbstractMatrix, UniformScaling}}

#catch works when g non-square
if !(size(g)[2] == ekp.N_ens)
Expand All @@ -668,8 +680,9 @@ function update_ensemble!(
accelerate!(ekp, u)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; use_prior_cov = use_prior_cov, s = s) : nothing
additive_inflation ? additive_inflation!(ekp, additive_inflation_cov, s = s) : nothing
end

else
return terminate # true if scheduler has not stepped
end
Expand Down
4 changes: 2 additions & 2 deletions src/EnsembleKalmanSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ function eks_update(
# Default: Δt = 1 / (norm(D) + eps(FT))
Δt = ekp.Δt[end]

noise = MvNormal(u_cov)
noise = MvNormal(zeros(size(u_cov, 1)), I)

implicit =
(1 * Matrix(I, size(u)[2], size(u)[2]) + Δt * (ekp.process.prior_cov' \ u_cov')') \
(u' .- Δt * (u' .- u_mean) * D .+ Δt * u_cov * (ekp.process.prior_cov \ ekp.process.prior_mean))

u = implicit' + sqrt(2 * Δt) * rand(ekp.rng, noise, ekp.N_ens)'
u = implicit' + sqrt(2 * Δt) * (sqrt(u_cov) * rand(ekp.rng, noise, ekp.N_ens))'

return u
end
Expand Down
2 changes: 1 addition & 1 deletion src/EnsembleTransformKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function FailureHandler(process::TransformInversion, method::SampleSuccGauss)
n_failed = length(failed_ens)
u[:, successful_ens] = etki_update(ekp, u[:, successful_ens], g[:, successful_ens], y, obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(u[:, successful_ens], n_failed)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed)
end
return u
end
Expand Down
8 changes: 5 additions & 3 deletions src/LearningRateSchedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ function posdef_correct(mat::AbstractMatrix; tol::Real = 1e8 * eps())
out = mat
end

nugget = abs(minimum(eigvals(out)))
for i in 1:size(out, 1)
out[i, i] += nugget + tol #add to diag
if !isposdef(out)
nugget = abs(minimum(eigvals(out)))
for i in 1:size(out, 1)
out[i, i] += nugget + tol # add to diag
end
end
return out
end
Expand Down
4 changes: 2 additions & 2 deletions src/SparseEnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function FailureHandler(process::SparseInversion, method::SampleSuccGauss)
u[:, successful_ens] =
sparse_eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(u[:, successful_ens], n_failed)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed)
end
return u
end
Expand Down Expand Up @@ -206,7 +206,7 @@ function update_ensemble!(

# Scale noise using Δt
scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end]
noise = rand(ekp.rng, MvNormal(zeros(N_obs), scaled_obs_noise_cov), ekp.N_ens)
noise = sqrt(scaled_obs_noise_cov) * rand(ekp.rng, MvNormal(zeros(N_obs), I), ekp.N_ens)

# Add obs_mean (N_obs) to each column of noise (N_obs × N_ens) if
# G is deterministic
Expand Down
26 changes: 23 additions & 3 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ end


@testset "LearningRateSchedulers" begin

# Utility
X = [2 1; 1.1 2] # correct with symmetrisation
@test isposdef(posdef_correct(X))
@test posdef_correct(X) ≈ 0.5 * (X + permutedims(X, (2, 1))) atol = 1e-8
Y = [0 1; -1 0]
tol = 1e-8
@test isposdef(posdef_correct(Y, tol = tol)) # symmetrize and add to diagonal
@test posdef_correct(Y, tol = tol) ≈ tol * I(2) atol = 1e-8



# Default
Δt = 3
dlrs1 = EKP.DefaultScheduler()
Expand Down Expand Up @@ -944,15 +956,23 @@ end
end
@test_logs (:warn, r"More than 50% of runs produced NaNs") match_mode = :any split_indices_by_success(g)


rng = Random.MersenneTwister(rng_seed)

u = rand(10, 4)
@test_logs (:warn, r"Sample covariance matrix over ensemble is singular.") match_mode = :any sample_empirical_gaussian(
u,
2,
)
@test_throws PosDefException sample_empirical_gaussian(u, 2, inflation = 0.0)

# Initial ensemble construction
rng = Random.MersenneTwister(rng_seed)
u2 = rand(rng, 5, 20)
@test all(
isapprox.(
sample_empirical_gaussian(copy(rng), u2, 2),
sample_empirical_gaussian(copy(rng), u2, 2, inflation = 0.0);
atol = 1e-8,
),
)

### sanity check on rng:
d = Parameterized(Normal(0, 1))
Expand Down
9 changes: 7 additions & 2 deletions test/Inflation/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,25 @@ initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)
eki_mult_inflation = deepcopy(ekiobj)
eki_add_inflation = deepcopy(ekiobj)
eki_add_inflation_prior = deepcopy(ekiobj)
eki_add_inflation_I = deepcopy(ekiobj)

# multiplicative inflation after standard update
EKP.multiplicative_inflation!(eki_mult_inflation)
# additive inflation after standard update
EKP.additive_inflation!(eki_add_inflation)
EKP.additive_inflation!(eki_add_inflation, get_u_cov_final(eki_add_inflation))
# additive inflation (scaling prior cov) after standard update
EKP.additive_inflation!(eki_add_inflation_prior; use_prior_cov = true)
EKP.additive_inflation!(eki_add_inflation_prior, get_u_cov_prior(eki_add_inflation_prior))
# additive inflation (scaling prior cov) after standard update
EKP.additive_inflation!(eki_add_inflation_I, I)

# ensure multiplicative inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) ≈ get_u_mean_final(eki_mult_inflation) atol = 0.2
# ensure additive inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) ≈ get_u_mean_final(eki_add_inflation) atol = 0.2
# ensure additive inflation (scaling prior cov) approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) ≈ get_u_mean_final(eki_add_inflation_prior) atol = 0.2
# ensure additive inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) ≈ get_u_mean_final(eki_add_inflation_I) atol = 0.2

# ensure inflation expands ensemble variance as expected
expected_var_gain = 1 / (1 - Δt)
Expand Down
7 changes: 3 additions & 4 deletions test/SparseInversion/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,14 @@ include("../EnsembleKalmanProcess/inverse_problem.jl")

## Repeat first test with several schedulers
y_obs, G, Γy = nl_inv_problems[1]

T_end = 3
schedulers = [
DefaultScheduler(0.1),
MutableScheduler(0.1),
DataMisfitController(terminate_at = T_end),
DataMisfitController(on_terminate = "continue"),
DataMisfitController(on_terminate = "continue_fixed"),
# DataMisfitController(terminate_at = T_end), # This test can be unstable
]
N_iters = [10, 10, 50, 50, 50]
N_iters = [10, 10]# ..., 20]

final_ensembles = []
init_means = []
Expand Down