Skip to content

Commit

Permalink
refactor CMA-ES stopping criteria slightly and 📈 Part 8
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Jun 13, 2024
1 parent ade064a commit 1110082
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
45 changes: 30 additions & 15 deletions src/solvers/cma_es.jl
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,10 @@ function status_summary(c::StopWhenBestCostInGenerationConstant)
return "c.iterations_since_change > $(c.iteration_range):\t$s"
end
function get_reason(c::StopWhenBestCostInGenerationConstant)
return "For the last $(c.iterations_since_change) generation the best objective value in each generation was equal to $(c.best_objective_at_last_change).\n"
if c.at_iteration >= 0
return "At iteration $(c.at_iteration): for the last $(c.iterations_since_change) generatiosn the best objective value in each generation was equal to $(c.best_objective_at_last_change).\n"
end
return ""
end
function show(io::IO, c::StopWhenBestCostInGenerationConstant)
return print(
Expand Down Expand Up @@ -714,45 +717,57 @@ function show(io::IO, c::StopWhenEvolutionStagnates)
)
end

"""
@doc raw"""
StopWhenPopulationStronglyConcentrated{TParam<:Real} <: StoppingCriterion
Stop if the standard deviation in all coordinates is smaller than `tol` and
norm of `σ * p_c` is smaller than `tol`. This corresponds to `TolX` condition from
[Hansen:2023](@cite).
# Fields
* `tol` the tolerance to check against
* `at_iteration` an internal field to indicate at with iteration ``i \geq 0`` the tolerance was met.
# Constructor
StopWhenPopulationStronglyConcentrated(tol::Real)
"""
mutable struct StopWhenPopulationStronglyConcentrated{TParam<:Real} <: StoppingCriterion
tol::TParam
is_active::Bool
at_iteration::Int
end
function StopWhenPopulationStronglyConcentrated(tol::Real)
return StopWhenPopulationStronglyConcentrated{typeof(tol)}(tol, false)
return StopWhenPopulationStronglyConcentrated{typeof(tol)}(tol, -1)
end

# It just indicates stagnation, not convergence to a minimizer
indicates_convergence(c::StopWhenPopulationStronglyConcentrated) = true
function is_active_stopping_criterion(c::StopWhenPopulationStronglyConcentrated)
return c.is_active
return c.at_iteration >= 0
end
function (c::StopWhenPopulationStronglyConcentrated)(
::AbstractManoptProblem, s::CMAESState, i::Int
)
if i == 0 # reset on init
c.is_active = false
c.at_iteration = -1
return false
end
norm_inf_dev = norm(s.deviations, Inf)
norm_inf_p_c = norm(s.p_c, Inf)
c.is_active = norm_inf_dev < c.tol && s.σ * norm_inf_p_c < c.tol
return c.is_active
if norm_inf_dev < c.tol && s.σ * norm_inf_p_c < c.tol
c.at_iteration = i
return true
end
return false
end
function status_summary(c::StopWhenPopulationStronglyConcentrated)
has_stopped = is_active_stopping_criterion(c)
s = has_stopped ? "reached" : "not reached"
return "norm(s.deviations, Inf) < $(c.tol) && norm(s.σ * s.p_c, Inf) < $(c.tol) :\t$s"
end
function get_reason(c::StopWhenPopulationStronglyConcentrated)
if c.is_active
if c.at_iteration >= 0
return "Standard deviation in all coordinates is smaller than $(c.tol) and `σ * p_c` has Inf norm lower than $(c.tol).\n"
end
return ""
Expand Down Expand Up @@ -824,32 +839,32 @@ and all function values in the current generation is below `tol`. This correspon
mutable struct StopWhenPopulationCostConcentrated{TParam<:Real} <: StoppingCriterion
tol::TParam
best_value_history::CircularBuffer{TParam}
is_active::Bool
at_iteration::Int
end
function StopWhenPopulationCostConcentrated(tol::TParam, max_size::Int) where {TParam<:Real}
return StopWhenPopulationCostConcentrated{TParam}(
tol, CircularBuffer{TParam}(max_size), false
tol, CircularBuffer{TParam}(max_size), -1
)
end

# It just indicates stagnation, not convergence to a minimizer
indicates_convergence(c::StopWhenPopulationCostConcentrated) = true
function is_active_stopping_criterion(c::StopWhenPopulationCostConcentrated)
return c.is_active
return c.at_iteration >= 0
end
function (c::StopWhenPopulationCostConcentrated)(
::AbstractManoptProblem, s::CMAESState, i::Int
)
if i == 0 # reset on init
c.is_active = false
c.at_iteration = -1
return false
end
push!(c.best_value_history, s.best_fitness_current_gen)
if isfull(c.best_value_history)
min_hist, max_hist = extrema(c.best_value_history)
if max_hist - min_hist < c.tol &&
s.best_fitness_current_gen - s.worst_fitness_current_gen < c.tol
c.is_active = true
c.at_iteration = i
return true
end
end
Expand All @@ -861,7 +876,7 @@ function status_summary(c::StopWhenPopulationCostConcentrated)
return "range of best objective values in the last $(length(c.best_value_history)) generations and all objective values in the current one < $(c.tol) :\t$s"
end
function get_reason(c::StopWhenPopulationCostConcentrated)
if c.is_active
if c.at_iteration >= 0
return "Range of best objective function values in the last $(length(c.best_value_history)) gnerations and all values in the current generation is below $(c.tol)\n"
end
return ""
Expand Down
14 changes: 13 additions & 1 deletion test/solvers/test_cma_es.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,17 @@ flat_example(::AbstractManifold, p) = 0.0
p1 = cma_es(M, griewank, [0.0, 1.0, 0.0]; σ=1.0, rng=MersenneTwister(123))
@test griewank(M, p1) < 0.17
end
@testset "Special Stopping Criteria" begin end
@testset "Special Stopping Criteria" begin
sc1 = StopWhenBestCostInGenerationConstant{Float64}(10)
sc2 = StopWhenEvolutionStagnates(1, 2, 0.5)
sc3 = StopWhenPopulationStronglyConcentrated(0.1)
sc4 = StopWhenPopulationCostConcentrated(0.1, 5)

for sc in [sc1, sc2, sc3, sc4]
@test get_reason(sc) == ""
# Manually set is active
sc.at_iteration = 10
@test length(get_reason(sc)) > 0
end
end
end

0 comments on commit 1110082

Please sign in to comment.