Skip to content

Commit

Permalink
added terminate cancel for gnki
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Nov 12, 2024
1 parent d9f8d85 commit 24a5f79
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ end
# Get inverse problem
y_obs, G, Γy, A = inv_problem
if i_prob == 1
scheduler = DataMisfitController()
scheduler = DataMisfitController() # if terminated too early can miss out later tests
else
scheduler = DefaultScheduler(0.001)
end
Expand Down Expand Up @@ -1111,6 +1111,7 @@ end
# GNKI iterations
u_i_vec = Array{Float64, 2}[]
g_ens_vec = Array{Float64, 2}[]
terminated = nothing
for i in 1:N_iter
# Check SampleSuccGauss handler
params_i = get_ϕ_final(prior, gnkiobj)
Expand All @@ -1120,7 +1121,11 @@ end
if i in iters_with_failure
g_ens[:, 1] .= NaN
end
EKP.update_ensemble!(gnkiobj, g_ens)
terminated = EKP.update_ensemble!(gnkiobj, g_ens)
if !isnothing(terminated)
break
end

push!(g_ens_vec, g_ens)
if i == 1
if !(size(g_ens, 1) == size(g_ens, 2))
Expand All @@ -1141,7 +1146,7 @@ end
params_i_unsafe = get_ϕ_final(prior, gnkiobj_unsafe)
g_ens_unsafe = G(params_i_unsafe)
if i < iters_with_failure[1]
EKP.update_ensemble!(gnkiobj_unsafe, g_ens_unsafe)
terminated = EKP.update_ensemble!(gnkiobj_unsafe, g_ens_unsafe)
elseif i == iters_with_failure[1]
g_ens_unsafe[:, 1] .= NaN
#inconsistent behaviour before/after v1.9 regarding NaNs in matrices
Expand All @@ -1157,12 +1162,15 @@ end
end
end
end

end
push!(u_i_vec, get_u_final(gnkiobj))

@test get_u_prior(gnkiobj) == u_i_vec[1]
if isnothing(terminated)
push!(u_i_vec, get_u_final(gnkiobj))
end # if cancelled early then don't need "final iteration"

@test get_u_prior(gnkiobj) == u_i_vec[1]
@test get_u(gnkiobj) == u_i_vec
@test isequal(get_g(gnkiobj), g_ens_vec)
@test isequal(get_g(gnkiobj), g_ens_vec) # can deal with NaNs
@test isequal(get_g_final(gnkiobj), g_ens_vec[end])
@test isequal(get_error(gnkiobj), gnkiobj.error)

Expand Down

0 comments on commit 24a5f79

Please sign in to comment.