Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NICE sample-error correction #367

Merged
merged 16 commits into from
May 9, 2024
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
GaussianRandomFields = "e4b2fa32-6e09-5554-b718-106ed5adafe9"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Expand All @@ -24,8 +25,8 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Convex = "0.15"
Distributions = "~0.24.14, ^0.25"
DocStringExtensions = "^0.8, 0.9"
LinearAlgebra = "1"
GaussianRandomFields = "2"
LinearAlgebra = "1"
MathOptInterface = "1"
Optim = "1.7"
QuadGK = "^2.4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/API/Localizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ RBF
BernoulliDropout
SEC
SECFisher
SECNice
Delta
NoLocalization
```
Binary file added docs/src/assets/sec_comparison_lorenz96.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 14 additions & 2 deletions docs/src/localization.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,21 @@ y = 10.0 * rand(d)
Γ = 1.0 * I

# Construct EKP object with localization. Some examples of localization methods:
locs = [Delta(), RBF(1.0), RBF(0.1), BernoulliDropout(0.1), SEC(10.0), SECFisher(), SEC(1.0, 0.1)]
locs = [Delta(), RBF(1.0), RBF(0.1), BernoulliDropout(0.1), SEC(10.0), SECFisher(), SEC(1.0, 0.1), SECNice()]
for loc in locs
ekiobj = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); localization_method = loc)
end
```
```
!!! note
Currently Localization and SEC are implemented only for the `Inversion()` process, we are working on extensions to `TransformInversion()`

## The following example is found in `examples/Localization/localization_example_lorenz96.jl`
This example was originally taken from [Tong and Morzfeld (2022)](https://doi.org/10.48550/arXiv.2201.10821). Here, a single-scale lorenz 96 system state of dimension 200 is configured to be in a chaotic parameter regime, and integrated forward with timestep ``\Delta t`` until time ``T``. The goal is to perform ensemble inversion for the state at time ``T-20\Delta t``, given a noisy observation of the state at time ``T``.

To perform this state estimation we use ensemble inversion with ensembles of size 20. This problem is severely ill-posed, and we make up for this by imposing sample-error correction methods to our state. Note that the SEC methods do not assume any spatial structure in the state, (differing from traditional state localization) and so are well suited for other types of inversions over parameter space.
odunbar marked this conversation as resolved.
Show resolved Hide resolved

![SEC_compared](assets/sec_comparison_lorenz96.png)

!!! note "Our recommendation"
Based on these results, our recommendation is to use the `SECNice()` approach for the `Inversion()` process. Not only does it perform well, but additionally it requires no tuning parameters, unlike for example `SEC()`.

1 change: 1 addition & 0 deletions examples/Localization/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
60 changes: 48 additions & 12 deletions examples/Localization/localization_example_lorenz96.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,29 @@ for i in 1:N_iter
EKP.update_ensemble!(ekiobj_vanilla, g_ens_vanilla, deterministic_forward_map = true)
end
nonlocalized_error = get_error(ekiobj_vanilla)[end]

# Test Bernoulli
ekiobj_bernoulli = EKP.EnsembleKalmanProcess(
@info "EKI - complete"
# Test Inflated
ekiobj_inflated = EKP.EnsembleKalmanProcess(
initial_ensemble,
y,
Γ,
Inversion();
rng = rng,
localization_method = BernoulliDropout(0.98),
# localization_method = BernoulliDropout(0.98),
)

for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_bernoulli))
EKP.update_ensemble!(ekiobj_bernoulli, g_ens, deterministic_forward_map = true)
g_ens = G(get_ϕ_final(prior, ekiobj_inflated))
EKP.update_ensemble!(
ekiobj_inflated,
g_ens,
deterministic_forward_map = true,
additive_inflation = true,
additive_inflation_cov = 0.1 * I,
)
end
#@info "EKI (Benoulli) - complete"
@info "EKI (inflated) - complete"

# Test SEC
ekiobj_sec = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, localization_method = SEC(1.0))
Expand All @@ -106,6 +114,7 @@ for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec))
EKP.update_ensemble!(ekiobj_sec, g_ens, deterministic_forward_map = true)
end
@info "EKI (SEC) - complete"

# Test SEC with cutoff
ekiobj_sec_cutoff =
Expand All @@ -115,6 +124,7 @@ for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec_cutoff))
EKP.update_ensemble!(ekiobj_sec_cutoff, g_ens, deterministic_forward_map = true)
end
@info "EKI (SEC cut-off) - complete"

# Test SECFisher
ekiobj_sec_fisher =
Expand All @@ -124,18 +134,44 @@ for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec_fisher))
EKP.update_ensemble!(ekiobj_sec_fisher, g_ens, deterministic_forward_map = true)
end
@info "EKI (SEC Fisher) - complete"

# Test SECNice
ekiobj_sec_nice =
EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, localization_method = SECNice())

for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec_nice))
EKP.update_ensemble!(ekiobj_sec_nice, g_ens, deterministic_forward_map = true)
end
@info "EKI (SEC Nice) - complete"


u_final = get_u_final(ekiobj_sec)
g_final = get_g_final(ekiobj_sec)
cov_est = cov([u_final; g_final], [u_final; g_final], dims = 2, corrected = false)
cov_localized = ekiobj_sec.localizer.localize(cov_est)

fig = plot(get_error(ekiobj_vanilla), label = "No localization")
plot!(get_error(ekiobj_bernoulli), label = "Bernoulli")
plot!(get_error(ekiobj_sec), label = "SEC (Lee, 2021)")
plot!(get_error(ekiobj_sec_fisher), label = "SECFisher (Flowerdew, 2015)")
plot!(get_error(ekiobj_sec_cutoff), label = "SEC with cutoff")
fig = plot(
get_error(ekiobj_vanilla),
label = "No localization",
c = :Accent_6,
lw = 6,
size = (800 * 1.618, 800),
xtickfont = 16,
ytickfont = 16,
guidefont = 16,
legendfont = 16,
bottom_margin = 5Plots.mm,
left_margin = 5Plots.mm,
)
plot!(get_error(ekiobj_inflated), label = "Inflation only", lw = 6)
plot!(get_error(ekiobj_sec), label = "SEC (Lee, 2021)", lw = 6)
plot!(get_error(ekiobj_sec_fisher), label = "SECFisher (Flowerdew, 2015)", lw = 6)
plot!(get_error(ekiobj_sec_cutoff), label = "SEC with cutoff", lw = 6)
plot!(get_error(ekiobj_sec_nice), label = "SEC NICE", lw = 6)


xlabel!("Iterations")
ylabel!("Error")
savefig(fig, "result.png")
savefig(fig, "sec_comparison_lorenz96.png")
1 change: 1 addition & 0 deletions examples/Sinusoid/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
odunbar marked this conversation as resolved.
Show resolved Hide resolved
135 changes: 132 additions & 3 deletions src/Localizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
using Distributions
using LinearAlgebra
using DocStringExtensions
using Interpolations

export NoLocalization, Delta, RBF, BernoulliDropout, SEC, SECFisher
export NoLocalization, Delta, RBF, BernoulliDropout, SEC, SECFisher, SECNice
export LocalizationMethod, Localizer

export approximate_corr_std
abstract type LocalizationMethod end

"Idempotent localization method."
Expand Down Expand Up @@ -89,6 +90,34 @@
"""
struct SECFisher <: LocalizationMethod end


"""
SECNice{FT <: Real} <: LocalizationMethod

Sample error correction as of Morzfeld, Vishny et al. (2024).
odunbar marked this conversation as resolved.
Show resolved Hide resolved
Correlations are shrinked by a factor determined by correlation and ensemble size.
odunbar marked this conversation as resolved.
Show resolved Hide resolved
The factors are automatically determined by a discrepancy principle.
Thus no algorithm parameters are required, though some tuning of the discrepancy principle tolerances are made available.

# Fields

$(TYPEDFIELDS)

"""
struct SECNice{FT <: Real, AV <: AbstractVector} <: LocalizationMethod
"number of samples to approximate the std of correlation distribution (default 1000)"
n_samples::Int
"scaling for discrepancy principle for ug correlation (default 1.0)"
tol_ug::FT
odunbar marked this conversation as resolved.
Show resolved Hide resolved
"scaling for discrepancy principle for gg correlation (default 1.0)"
tol_gg::FT
"A vector that will house a Interpolation object on first call to the localizer"
std_of_corr::AV
end
SECNice() = SECNice(1000, 1.0, 1.0)
SECNice(n_samples) = SECNice(n_samples, 1.0, 1.0)

Check warning on line 118 in src/Localizers.jl

View check run for this annotation

Codecov / codecov/patch

src/Localizers.jl#L118

Added line #L118 was not covered by tests
odunbar marked this conversation as resolved.
Show resolved Hide resolved
SECNice(n_samples, tol_ug, tol_gg) = SECNice(n_samples, tol_ug, tol_gg, []) # always start with empty

"""
Localizer{LM <: LocalizationMethod, T}

Expand Down Expand Up @@ -212,7 +241,8 @@
V = Diagonal(v)
V_inv = inv(V)
R = V_inv * cov * V_inv

bd_tol = 1e8 * eps()
clamp!(R, -1 + bd_tol, 1 - bd_tol)
R_sec = zeros(size(R))
for i in 1:size(R)[1]
for j in 1:i
Expand All @@ -239,6 +269,99 @@
return Localizer{SECFisher, T}((cov) -> sec_fisher(cov, J))
end

function approximate_corr_std(r, N_ens, n_samples)
# ρ = arctanh(r) from Fisher
odunbar marked this conversation as resolved.
Show resolved Hide resolved
# assume r input is the mean value, i.e. assume arctanh(E(r)) = E(arctanh(r))

ρ = r # approx solution is the identity
#sample in ρ space
ρ_samples = rand(Normal(0.5 * log((1 + ρ) / (1 - ρ)), 1 / sqrt(N_ens - 3)), n_samples) # N_ens

# map back through Fisher to get std of r from samples tanh(ρ)
return std(tanh.(ρ_samples))
end


"""
Function that performs sampling error correction as per Morzfeld, Vishny (2024).
odunbar marked this conversation as resolved.
Show resolved Hide resolved
The input is assumed to be a covariance matrix, hence square.
"""
function sec_nice(cov, std_of_corr, δ_ug, δ_gg, N_ens, p, d)
if N_ens < 6
@warn "significant localization approximation error may occur for ensemble size below 6. Here, ensemble size = $N_ens"

Check warning on line 291 in src/Localizers.jl

View check run for this annotation

Codecov / codecov/patch

src/Localizers.jl#L291

Added line #L291 was not covered by tests
odunbar marked this conversation as resolved.
Show resolved Hide resolved
end
bd_tol = 1e8 * eps()

v = sqrt.(diag(cov))
V = Diagonal(v) #stds
V_inv = inv(V)
corr = clamp.(V_inv * cov * V_inv, -1 + bd_tol, 1 - bd_tol) # full corr

# parameter sweep over the exponents
max_exponent = 2 * 5 # must be even
interp_steps = 10

ug_idx = [1:p, (p + 1):(p + d)]
ugt_idx = [(p + 1):(p + d), 1:p] # don't loop over this one
gg_idx = [(p + 1):(p + d), (p + 1):(p + d)]

corr_updated = copy(corr)
for (idx_set, δ) in zip([ug_idx, gg_idx], [δ_ug, δ_gg])

corr_tmp = corr[idx_set...]
# use find the variability in the corr coeff matrix entries
# std_corrs = approximate_corr_std.(corr_tmp, N_ens, n_samples) # !! slowest part of code -> could speed up by precomputing/using an interpolation
std_corrs = std_of_corr.(corr_tmp)
std_tol = sqrt(sum(std_corrs .^ 2))
α_min_exceeded = [max_exponent]
odunbar marked this conversation as resolved.
Show resolved Hide resolved
for α in 2:2:max_exponent # even exponents give a PSD correction
corr_psd = corr_tmp .^ (α + 1) # abs not needed as α even
# find the first exponent that exceeds the noise tolerance in norm
if norm(corr_psd - corr_tmp) > δ * std_tol
α_min_exceeded[1] = α
break
end
end
corr_psd = corr_tmp .^ α_min_exceeded[1]
odunbar marked this conversation as resolved.
Show resolved Hide resolved
corr_psd_prev = corr_tmp .^ (α_min_exceeded[1] - 2) # previous PSD correction

for α in LinRange(1.0, 0.0, interp_steps)
corr_interp = ((1 - α) * (corr_psd_prev) + α * corr_psd) .* corr_tmp
if norm(corr_interp - corr_tmp) < δ * std_tol
corr_updated[idx_set...] = corr_interp #update the correlation matrix block
break
end
end

end

# finally correct the ug'
corr_updated[ugt_idx...] = corr_updated[ug_idx...]'

return V * corr_updated * V # rebuild the cov matrix

end


"Sampling error correction (Morzfeld, Vishny et al., 2024) constructor"
odunbar marked this conversation as resolved.
Show resolved Hide resolved
function Localizer(localization::SECNice, p::IT, d::IT, J::IT, T = Float64) where {IT <: Int}
if length(localization.std_of_corr) == 0 #i.e. if the user hasn't provided an interpolation
dr = 0.001
grid = LinRange(-1, 1, Int(1 / dr + 1))
std_grid = approximate_corr_std.(grid, J, localization.n_samples) # odd number to include 0
push!(localization.std_of_corr, linear_interpolation(grid, std_grid)) # pw-linear interpolation
end

return Localizer{SECNice, T}(
(cov) -> sec_nice(cov, localization.std_of_corr[1], localization.tol_ug, localization.tol_gg, J, p, d),
)
end





# utilities
"""
get_localizer(loc::Localizer)
Return localizer type.
Expand All @@ -247,4 +370,10 @@
return T1
end







end # module
2 changes: 1 addition & 1 deletion test/Localizers/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ const EKP = EnsembleKalmanProcesses
nonlocalized_error = get_error(ekiobj_vanilla)[end]

# Test different localizers
loc_methods = [Delta(), RBF(1.0), RBF(0.1), BernoulliDropout(0.1), SEC(10.0), SECFisher(), SEC(1.0, 0.1)]
loc_methods = [Delta(), RBF(1.0), RBF(0.1), BernoulliDropout(0.1), SEC(10.0), SECFisher(), SEC(1.0, 0.1), SECNice()]

for loc_method in loc_methods
ekiobj =
Expand Down
Loading