From 1110082453bdb25706e0d6a03d4d58faf61ae706 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Thu, 13 Jun 2024 12:06:25 +0200 Subject: [PATCH] =?UTF-8?q?refactor=20CMA-ES=20stopping=20criteria=20sligh?= =?UTF-8?q?tly=20and=20=F0=9F=93=88=20Part=208?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/solvers/cma_es.jl | 45 ++++++++++++++++++++++++------------- test/solvers/test_cma_es.jl | 14 +++++++++++- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/src/solvers/cma_es.jl b/src/solvers/cma_es.jl index e47d663e5a..bc34ab5c93 100644 --- a/src/solvers/cma_es.jl +++ b/src/solvers/cma_es.jl @@ -619,7 +619,10 @@ function status_summary(c::StopWhenBestCostInGenerationConstant) return "c.iterations_since_change > $(c.iteration_range):\t$s" end function get_reason(c::StopWhenBestCostInGenerationConstant) - return "For the last $(c.iterations_since_change) generation the best objective value in each generation was equal to $(c.best_objective_at_last_change).\n" + if c.at_iteration >= 0 + return "At iteration $(c.at_iteration): for the last $(c.iterations_since_change) generatiosn the best objective value in each generation was equal to $(c.best_objective_at_last_change).\n" + end + return "" end function show(io::IO, c::StopWhenBestCostInGenerationConstant) return print( @@ -714,37 +717,49 @@ function show(io::IO, c::StopWhenEvolutionStagnates) ) end -""" +@doc raw""" StopWhenPopulationStronglyConcentrated{TParam<:Real} <: StoppingCriterion Stop if the standard deviation in all coordinates is smaller than `tol` and norm of `σ * p_c` is smaller than `tol`. This corresponds to `TolX` condition from [Hansen:2023](@cite). + +# Fields + +* `tol` the tolerance to check against +* `at_iteration` an internal field to indicate at with iteration ``i \geq 0`` the tolerance was met. + +# Constructor + + StopWhenPopulationStronglyConcentrated(tol::Real) """ mutable struct StopWhenPopulationStronglyConcentrated{TParam<:Real} <: StoppingCriterion tol::TParam - is_active::Bool + at_iteration::Int end function StopWhenPopulationStronglyConcentrated(tol::Real) - return StopWhenPopulationStronglyConcentrated{typeof(tol)}(tol, false) + return StopWhenPopulationStronglyConcentrated{typeof(tol)}(tol, -1) end # It just indicates stagnation, not convergence to a minimizer indicates_convergence(c::StopWhenPopulationStronglyConcentrated) = true function is_active_stopping_criterion(c::StopWhenPopulationStronglyConcentrated) - return c.is_active + return c.at_iteration >= 0 end function (c::StopWhenPopulationStronglyConcentrated)( ::AbstractManoptProblem, s::CMAESState, i::Int ) if i == 0 # reset on init - c.is_active = false + c.at_iteration = -1 return false end norm_inf_dev = norm(s.deviations, Inf) norm_inf_p_c = norm(s.p_c, Inf) - c.is_active = norm_inf_dev < c.tol && s.σ * norm_inf_p_c < c.tol - return c.is_active + if norm_inf_dev < c.tol && s.σ * norm_inf_p_c < c.tol + c.at_iteration = i + return true + end + return false end function status_summary(c::StopWhenPopulationStronglyConcentrated) has_stopped = is_active_stopping_criterion(c) @@ -752,7 +767,7 @@ function status_summary(c::StopWhenPopulationStronglyConcentrated) return "norm(s.deviations, Inf) < $(c.tol) && norm(s.σ * s.p_c, Inf) < $(c.tol) :\t$s" end function get_reason(c::StopWhenPopulationStronglyConcentrated) - if c.is_active + if c.at_iteration >= 0 return "Standard deviation in all coordinates is smaller than $(c.tol) and `σ * p_c` has Inf norm lower than $(c.tol).\n" end return "" @@ -824,24 +839,24 @@ and all function values in the current generation is below `tol`. This correspon mutable struct StopWhenPopulationCostConcentrated{TParam<:Real} <: StoppingCriterion tol::TParam best_value_history::CircularBuffer{TParam} - is_active::Bool + at_iteration::Int end function StopWhenPopulationCostConcentrated(tol::TParam, max_size::Int) where {TParam<:Real} return StopWhenPopulationCostConcentrated{TParam}( - tol, CircularBuffer{TParam}(max_size), false + tol, CircularBuffer{TParam}(max_size), -1 ) end # It just indicates stagnation, not convergence to a minimizer indicates_convergence(c::StopWhenPopulationCostConcentrated) = true function is_active_stopping_criterion(c::StopWhenPopulationCostConcentrated) - return c.is_active + return c.at_iteration >= 0 end function (c::StopWhenPopulationCostConcentrated)( ::AbstractManoptProblem, s::CMAESState, i::Int ) if i == 0 # reset on init - c.is_active = false + c.at_iteration = -1 return false end push!(c.best_value_history, s.best_fitness_current_gen) @@ -849,7 +864,7 @@ function (c::StopWhenPopulationCostConcentrated)( min_hist, max_hist = extrema(c.best_value_history) if max_hist - min_hist < c.tol && s.best_fitness_current_gen - s.worst_fitness_current_gen < c.tol - c.is_active = true + c.at_iteration = i return true end end @@ -861,7 +876,7 @@ function status_summary(c::StopWhenPopulationCostConcentrated) return "range of best objective values in the last $(length(c.best_value_history)) generations and all objective values in the current one < $(c.tol) :\t$s" end function get_reason(c::StopWhenPopulationCostConcentrated) - if c.is_active + if c.at_iteration >= 0 return "Range of best objective function values in the last $(length(c.best_value_history)) gnerations and all values in the current generation is below $(c.tol)\n" end return "" diff --git a/test/solvers/test_cma_es.jl b/test/solvers/test_cma_es.jl index 1d79592284..a88757a8b7 100644 --- a/test/solvers/test_cma_es.jl +++ b/test/solvers/test_cma_es.jl @@ -138,5 +138,17 @@ flat_example(::AbstractManifold, p) = 0.0 p1 = cma_es(M, griewank, [0.0, 1.0, 0.0]; σ=1.0, rng=MersenneTwister(123)) @test griewank(M, p1) < 0.17 end - @testset "Special Stopping Criteria" begin end + @testset "Special Stopping Criteria" begin + sc1 = StopWhenBestCostInGenerationConstant{Float64}(10) + sc2 = StopWhenEvolutionStagnates(1, 2, 0.5) + sc3 = StopWhenPopulationStronglyConcentrated(0.1) + sc4 = StopWhenPopulationCostConcentrated(0.1, 5) + + for sc in [sc1, sc2, sc3, sc4] + @test get_reason(sc) == "" + # Manually set is active + sc.at_iteration = 10 + @test length(get_reason(sc)) > 0 + end + end end