Skip to content

Commit

Permalink
NICE sample-error correction (#367)
Browse files Browse the repository at this point in the history
* added Nice implementation to EKP

index fix

bugfix

formatting

plots comparing to inflation-only

docstrings

docstrings in API docs

improve docs example

code for docs example

basic SECNice test

localizers

format

link fix

* added accelerated version based on an interpolant
;

* add Interpolations

* typos

* docs/src/localization.md

* final typo

* resolved some review comments

* notation

* change constructors;

* [int] -> int

* estimate for low N_ens

* SECNice branch for N<6 and test

* format

* only run small test for SECNice()

* format

* changes for review
  • Loading branch information
odunbar authored May 9, 2024
1 parent 1b7489e commit fb45f75
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 28 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
GaussianRandomFields = "e4b2fa32-6e09-5554-b718-106ed5adafe9"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Expand All @@ -24,8 +25,8 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Convex = "0.15"
Distributions = "~0.24.14, ^0.25"
DocStringExtensions = "^0.8, 0.9"
LinearAlgebra = "1"
GaussianRandomFields = "2"
LinearAlgebra = "1"
MathOptInterface = "1"
Optim = "1.7"
QuadGK = "^2.4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/API/Localizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ RBF
BernoulliDropout
SEC
SECFisher
SECNice
Delta
NoLocalization
```
Binary file added docs/src/assets/sec_comparison_lorenz96.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 14 additions & 2 deletions docs/src/localization.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,21 @@ y = 10.0 * rand(d)
Γ = 1.0 * I

# Construct EKP object with localization. Some examples of localization methods:
locs = [Delta(), RBF(1.0), RBF(0.1), BernoulliDropout(0.1), SEC(10.0), SECFisher(), SEC(1.0, 0.1)]
locs = [Delta(), RBF(1.0), RBF(0.1), BernoulliDropout(0.1), SEC(10.0), SECFisher(), SEC(1.0, 0.1), SECNice()]
for loc in locs
ekiobj = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); localization_method = loc)
end
```
```
!!! note
Currently Localization and SEC are implemented only for the `Inversion()` process, we are working on extensions to `TransformInversion()`

## The following example is found in `examples/Localization/localization_example_lorenz96.jl`
This example was originally taken from [Tong and Morzfeld (2022)](https://doi.org/10.48550/arXiv.2201.10821). Here, a single-scale lorenz 96 system state of dimension 200 is configured to be in a chaotic parameter regime, and integrated forward with timestep ``\Delta t`` until time ``T``. The goal is to perform ensemble inversion for the state at time ``T-20\Delta t``, given a noisy observation of the state at time ``T``.

To perform this state estimation we use ensemble inversion with ensembles of size 20. This problem is severely rank-deficient, and we make up for this by imposing sampling-error correction methods to ensemble covariance matrices. Note that the SEC methods do not assume any spatial structure in the state, (differing from traditional state localization) and so are well suited for other types of inversions over parameter space.

![SEC_compared](assets/sec_comparison_lorenz96.png)

!!! note "Our recommendation"
Based on these results, our recommendation is to use the `SECNice()` approach for the `Inversion()` process. Not only does it perform well, but additionally it requires no tuning parameters, unlike for example `SEC()`.

1 change: 1 addition & 0 deletions examples/Localization/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
60 changes: 48 additions & 12 deletions examples/Localization/localization_example_lorenz96.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,29 @@ for i in 1:N_iter
EKP.update_ensemble!(ekiobj_vanilla, g_ens_vanilla, deterministic_forward_map = true)
end
nonlocalized_error = get_error(ekiobj_vanilla)[end]

# Test Bernoulli
ekiobj_bernoulli = EKP.EnsembleKalmanProcess(
@info "EKI - complete"
# Test Inflated
ekiobj_inflated = EKP.EnsembleKalmanProcess(
initial_ensemble,
y,
Γ,
Inversion();
rng = rng,
localization_method = BernoulliDropout(0.98),
# localization_method = BernoulliDropout(0.98),
)

for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_bernoulli))
EKP.update_ensemble!(ekiobj_bernoulli, g_ens, deterministic_forward_map = true)
g_ens = G(get_ϕ_final(prior, ekiobj_inflated))
EKP.update_ensemble!(
ekiobj_inflated,
g_ens,
deterministic_forward_map = true,
additive_inflation = true,
additive_inflation_cov = 0.1 * I,
)
end
#@info "EKI (Benoulli) - complete"
@info "EKI (inflated) - complete"

# Test SEC
ekiobj_sec = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, localization_method = SEC(1.0))
Expand All @@ -106,6 +114,7 @@ for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec))
EKP.update_ensemble!(ekiobj_sec, g_ens, deterministic_forward_map = true)
end
@info "EKI (SEC) - complete"

# Test SEC with cutoff
ekiobj_sec_cutoff =
Expand All @@ -115,6 +124,7 @@ for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec_cutoff))
EKP.update_ensemble!(ekiobj_sec_cutoff, g_ens, deterministic_forward_map = true)
end
@info "EKI (SEC cut-off) - complete"

# Test SECFisher
ekiobj_sec_fisher =
Expand All @@ -124,18 +134,44 @@ for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec_fisher))
EKP.update_ensemble!(ekiobj_sec_fisher, g_ens, deterministic_forward_map = true)
end
@info "EKI (SEC Fisher) - complete"

# Test SECNice
ekiobj_sec_nice =
EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, localization_method = SECNice())

for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec_nice))
EKP.update_ensemble!(ekiobj_sec_nice, g_ens, deterministic_forward_map = true)
end
@info "EKI (SEC Nice) - complete"


u_final = get_u_final(ekiobj_sec)
g_final = get_g_final(ekiobj_sec)
cov_est = cov([u_final; g_final], [u_final; g_final], dims = 2, corrected = false)
cov_localized = ekiobj_sec.localizer.localize(cov_est)

fig = plot(get_error(ekiobj_vanilla), label = "No localization")
plot!(get_error(ekiobj_bernoulli), label = "Bernoulli")
plot!(get_error(ekiobj_sec), label = "SEC (Lee, 2021)")
plot!(get_error(ekiobj_sec_fisher), label = "SECFisher (Flowerdew, 2015)")
plot!(get_error(ekiobj_sec_cutoff), label = "SEC with cutoff")
fig = plot(
get_error(ekiobj_vanilla),
label = "No localization",
c = :Accent_6,
lw = 6,
size = (800 * 1.618, 800),
xtickfont = 16,
ytickfont = 16,
guidefont = 16,
legendfont = 16,
bottom_margin = 5Plots.mm,
left_margin = 5Plots.mm,
)
plot!(get_error(ekiobj_inflated), label = "Inflation only", lw = 6)
plot!(get_error(ekiobj_sec), label = "SEC (Lee, 2021)", lw = 6)
plot!(get_error(ekiobj_sec_fisher), label = "SECFisher (Flowerdew, 2015)", lw = 6)
plot!(get_error(ekiobj_sec_cutoff), label = "SEC with cutoff", lw = 6)
plot!(get_error(ekiobj_sec_nice), label = "SEC NICE", lw = 6)


xlabel!("Iterations")
ylabel!("Error")
savefig(fig, "result.png")
savefig(fig, "sec_comparison_lorenz96.png")
153 changes: 149 additions & 4 deletions src/Localizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ module Localizers
using Distributions
using LinearAlgebra
using DocStringExtensions
using Interpolations

export NoLocalization, Delta, RBF, BernoulliDropout, SEC, SECFisher
export NoLocalization, Delta, RBF, BernoulliDropout, SEC, SECFisher, SECNice
export LocalizationMethod, Localizer

export approximate_corr_std
abstract type LocalizationMethod end

"Idempotent localization method."
Expand Down Expand Up @@ -76,7 +77,7 @@ SEC(α) = SEC{eltype(α)}(α, eltype(α)(0))
Sampling error correction for EKI, as per Lee (2021), but using
the method from Flowerdew (2015) based on the Fisher transformation.
Correlations are shrinked by a factor determined by the sample
Correlations are shrunk by a factor determined by the sample
correlation and the ensemble size.
Flowerdew, J. (2015). Towards a theory of optimal localisation.
Expand All @@ -89,6 +90,34 @@ http://arxiv.org/abs/2105.11341
"""
struct SECFisher <: LocalizationMethod end


"""
SECNice{FT <: Real} <: LocalizationMethod
Sampling error correction as of Vishny, Morzfeld, et al. (2024), [DOI](https://doi.org/10.22541/essoar.171501094.44068137/v1).
Correlations are shrunk by a factor determined by correlation and ensemble size.
The factors are automatically determined by a discrepancy principle.
Thus no algorithm parameters are required, though some tuning of the discrepancy principle tolerances are made available.
# Fields
$(TYPEDFIELDS)
"""
struct SECNice{FT <: Real, AV <: AbstractVector} <: LocalizationMethod
"number of samples to approximate the std of correlation distribution (default 1000)"
n_samples::Int
"scaling for discrepancy principle for ug correlation (default 1.0)"
δ_ug::FT
"scaling for discrepancy principle for gg correlation (default 1.0)"
δ_gg::FT
"A vector that will house a Interpolation object on first call to the localizer"
std_of_corr::AV
end
SECNice() = SECNice(1000, 1.0, 1.0)
SECNice(δ_ug, δ_gg) = SECNice(1000, δ_ug, δ_gg)
SECNice(n_samples, δ_ug, δ_gg) = SECNice(n_samples, δ_ug, δ_gg, []) # always start with empty

"""
Localizer{LM <: LocalizationMethod, T}
Expand Down Expand Up @@ -212,7 +241,8 @@ function sec_fisher(cov, N_ens)
V = Diagonal(v)
V_inv = inv(V)
R = V_inv * cov * V_inv

bd_tol = 1e8 * eps()
clamp!(R, -1 + bd_tol, 1 - bd_tol)
R_sec = zeros(size(R))
for i in 1:size(R)[1]
for j in 1:i
Expand All @@ -239,6 +269,115 @@ function Localizer(localization::SECFisher, p::IT, d::IT, J::IT, T = Float64) wh
return Localizer{SECFisher, T}((cov) -> sec_fisher(cov, J))
end

"""
For `N_ens >= 6`: The sampling distribution of a correlation coefficient for Gaussian random variables is, under the Fisher transformation, approximately Gaussian. To estimate the standard deviation in the sampling distribution of the correlation coefficient, we draw samples from a Gaussian, apply the inverse Fisher transformation to them, and estimate an empirical standard deviation from the transformed samples.
For `N_ens < 6`: Approximate the standard deviation of correlation coefficient empirically by sampling between two correlated Gaussians of known coefficient.
"""
function approximate_corr_std(r, N_ens, n_samples)

if N_ens >= 6 # apply Fisher Transform
# ρ = arctanh(r) from Fisher
# assume r input is the mean value, i.e. assume arctanh(E(r)) = E(arctanh(r))

ρ = r # approx solution is the identity
#sample in ρ space
ρ_samples = rand(Normal(0.5 * log((1 + ρ) / (1 - ρ)), 1 / sqrt(N_ens - 3)), n_samples)

# map back through Fisher to get std of r from samples tanh(ρ)
return std(tanh.(ρ_samples))
else # transformation not appropriate for N < 6
# Generate sample pairs with a correlation coefficient r
samples_1 = rand(Normal(0, 1), N_ens, n_samples)
samples_2 = rand(Normal(0, 1), N_ens, n_samples)
samples_corr_with_1 = r * samples_1 + sqrt(1 - r^2) * samples_2 # will have correlation r with samples_1

corrs = zeros(n_samples)
for i in 1:n_samples
corrs[i] = cor(samples_1[:, i], samples_corr_with_1[:, i])
end
return std(corrs)
end

end


"""
Function that performs sampling error correction as per Vishny, Morzfeld, et al. (2024).
The input is assumed to be a covariance matrix, hence square.
"""
function sec_nice(cov, std_of_corr, δ_ug, δ_gg, N_ens, p, d)
bd_tol = 1e8 * eps()

v = sqrt.(diag(cov))
V = Diagonal(v) #stds
V_inv = inv(V)
corr = clamp.(V_inv * cov * V_inv, -1 + bd_tol, 1 - bd_tol) # full corr

# parameter sweep over the exponents
max_exponent = 2 * 5 # must be even
interp_steps = 10

ug_idx = [1:p, (p + 1):(p + d)]
ugt_idx = [(p + 1):(p + d), 1:p] # don't loop over this one
gg_idx = [(p + 1):(p + d), (p + 1):(p + d)]

corr_updated = copy(corr)
for (idx_set, δ) in zip([ug_idx, gg_idx], [δ_ug, δ_gg])

corr_tmp = corr[idx_set...]
# use find the variability in the corr coeff matrix entries
# std_corrs = approximate_corr_std.(corr_tmp, N_ens, n_samples) # !! slowest part of code -> could speed up by precomputing/using an interpolation
std_corrs = std_of_corr.(corr_tmp)
std_tol = sqrt(sum(std_corrs .^ 2))
γ_min_exceeded = max_exponent
for γ in 2:2:max_exponent # even exponents give a PSD correction
corr_psd = corr_tmp .^+ 1) # abs not needed as γ even
# find the first exponent that exceeds the noise tolerance in norm
if norm(corr_psd - corr_tmp) > δ * std_tol
γ_min_exceeded = γ
break
end
end
corr_update = corr_tmp .^ γ_min_exceeded
corr_update_prev = corr_tmp .^ (γ_min_exceeded - 2) # previous PSD correction

for α in LinRange(1.0, 0.0, interp_steps)
corr_interp = ((1 - α) * (corr_update_prev) + α * corr_update) .* corr_tmp
if norm(corr_interp - corr_tmp) < δ * std_tol
corr_updated[idx_set...] = corr_interp #update the correlation matrix block
break
end
end

end

# finally correct the ug'
corr_updated[ugt_idx...] = corr_updated[ug_idx...]'

return V * corr_updated * V # rebuild the cov matrix

end


"Sampling error correction of Vishny, Morzfeld, et al. (2024) constructor"
function Localizer(localization::SECNice, p::IT, d::IT, J::IT, T = Float64) where {IT <: Int}
if length(localization.std_of_corr) == 0 #i.e. if the user hasn't provided an interpolation
dr = 0.001
grid = LinRange(-1, 1, Int(1 / dr + 1))
std_grid = approximate_corr_std.(grid, J, localization.n_samples) # odd number to include 0
push!(localization.std_of_corr, linear_interpolation(grid, std_grid)) # pw-linear interpolation
end

return Localizer{SECNice, T}(
(cov) -> sec_nice(cov, localization.std_of_corr[1], localization.δ_ug, localization.δ_gg, J, p, d),
)
end





# utilities
"""
get_localizer(loc::Localizer)
Return localizer type.
Expand All @@ -247,4 +386,10 @@ function get_localizer(loc::Localizer{T1, T2}) where {T1, T2}
return T1
end







end # module
Loading

0 comments on commit fb45f75

Please sign in to comment.