diff --git a/Project.toml b/Project.toml index 67a0d96363..a73abff701 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manopt" uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" authors = ["Ronny Bergmann "] -version = "0.4.28" +version = "0.4.29" [deps] ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" diff --git a/src/plans/cache.jl b/src/plans/cache.jl index 8f83774c0a..3926825b22 100644 --- a/src/plans/cache.jl +++ b/src/plans/cache.jl @@ -785,3 +785,38 @@ function objective_cache_factory(M, o, cache::Tuple{Symbol,<:AbstractArray}) (cache[1] === :LRU) && return ManifoldCachedObjective(M, o, cache[2]) return o end +function show(io::IO, smco::SimpleManifoldCachedObjective{E}) where {E} + return print(io, "SimpleManifoldCachedObjective{$E,$(smco.objective)}") +end +function show( + io::IO, t::Tuple{<:SimpleManifoldCachedObjective,S} +) where {S<:AbstractManoptSolverState} + return print(io, "$(t[2])\n\n$(status_summary(t[1]))") +end +function show(io::IO, mco::ManifoldCachedObjective) + return print(io, "$(status_summary(mco))") +end +function show( + io::IO, t::Tuple{<:ManifoldCachedObjective,S} +) where {S<:AbstractManoptSolverState} + return print(io, "$(t[2])\n\n$(status_summary(t[1]))") +end + +function status_summary(smco::SimpleManifoldCachedObjective) + s = "## Cache\nA `SimpleManifoldCachedObjective` to cache one point and one tangent vector for the iterate and gradient, respectively" + s2 = status_summary(smco.objective) + length(s2) > 0 && (s2 = "\n$(s2)") + return "$(s)$(s2)" +end +function status_summary(mco::ManifoldCachedObjective) + s = "## Cache\n" + longest_key_length = max(length.(["$k" for k in keys(mco.cache)])...) + cache_strings = [ + " * :" * + rpad("$k", longest_key_length, " ") * + " : $(v.currentsize)/$(v.maxsize) entries of type $(valtype(v)) used" for + (k, v) in zip(keys(mco.cache), values(mco.cache)) + ] + s2 = status_summary(mco.objective) + return "$(s)$(join(cache_strings,"\n"))\n\n$s2" +end diff --git a/src/plans/count.jl b/src/plans/count.jl index 676c9c5923..08198c2d4e 100644 --- a/src/plans/count.jl +++ b/src/plans/count.jl @@ -414,8 +414,8 @@ function status_summary(co::ManifoldCountObjective) " * :$(rpad("$(c[1])",longest_key_length)) : $(c[2])" for c in co.counts ] s2 = status_summary(co.objective) - !(co.objective isa AbstractDecoratedManifoldObjective) && (s2 = "on a $(s2)") - return "$(s)$(join(count_strings,"\n"))\n$s2" + (length(s2) > 0) && (s2 = "\n$(s2)") + return "$(s)$(join(count_strings,"\n"))$s2" end function show(io::IO, co::ManifoldCountObjective) diff --git a/src/plans/objective.jl b/src/plans/objective.jl index 96b3487712..6642b1861f 100644 --- a/src/plans/objective.jl +++ b/src/plans/objective.jl @@ -110,7 +110,7 @@ function _get_objective(o::AbstractManifoldObjective, ::Val{true}, rec=true) return rec ? get_objective(o.objective) : o.objective end function status_summary(o::AbstractManifoldObjective{E}) where {E} - return "$(nameof(typeof(o))){$E}" + return ""#"$(nameof(typeof(o))){$E}" end # Default undecorate for summary function status_summary(co::AbstractDecoratedManifoldObjective) @@ -126,11 +126,9 @@ function show(io::IO, co::AbstractDecoratedManifoldObjective) end function show(io::IO, t::Tuple{<:AbstractManifoldObjective,P}) where {P} + s = "$(status_summary(t[1]))" + length(s) > 0 && (s = "$(s)\n\n") return print( - io, - """ -$(status_summary(t[1])) - -To access the solver result, call `get_solver_result` on this variable.""", + io, "$(s)To access the solver result, call `get_solver_result` on this variable." ) end diff --git a/src/plans/problem.jl b/src/plans/problem.jl index 8d1df6dbfa..1b49334ebd 100644 --- a/src/plans/problem.jl +++ b/src/plans/problem.jl @@ -53,13 +53,16 @@ get_manifold(::AbstractManoptProblem) get_manifold(amp::DefaultManoptProblem) = amp.manifold @doc raw""" - get_objective(mp::AbstractManoptProblem) + get_objective(mp::AbstractManoptProblem, recursive=false) return the objective [`AbstractManifoldObjective`](@ref) stored within an [`AbstractManoptProblem`](@ref). +If `recursive is set to true, it additionally unwraps all decorators of the objective` """ get_objective(::AbstractManoptProblem) -get_objective(amp::DefaultManoptProblem) = amp.objective +function get_objective(amp::DefaultManoptProblem, recursive=false) + return recursive ? get_objective(amp.objective, true) : amp.objective +end @doc raw""" get_cost(amp::AbstractManoptProblem, p) diff --git a/src/solvers/cyclic_proximal_point.jl b/src/solvers/cyclic_proximal_point.jl index c75c28bb3f..d60e5e6f68 100644 --- a/src/solvers/cyclic_proximal_point.jl +++ b/src/solvers/cyclic_proximal_point.jl @@ -137,7 +137,7 @@ function cyclic_proximal_point!( return get_solver_return(get_objective(dmp), dcpps) end function initialize_solver!(amp::AbstractManoptProblem, cpps::CyclicProximalPointState) - c = length(get_objective(amp).proximal_maps!!) + c = length(get_objective(amp, true).proximal_maps!!) cpps.order = collect(1:c) (cpps.order_type == :FixedRandom) && shuffle!(cpps.order) return cpps diff --git a/test/plans/test_cache.jl b/test/plans/test_cache.jl index be859c0d6b..5fdfa15464 100644 --- a/test/plans/test_cache.jl +++ b/test/plans/test_cache.jl @@ -69,6 +69,12 @@ end mgoa = ManifoldGradientObjective(TestCostCount(0), TestGradCount(0)) mcgoa = ManifoldGradientObjective(TestCostCount(0), TestGradCount(0)) sco1 = Manopt.SimpleManifoldCachedObjective(M, mgoa; p=p) + @test repr(sco1) == "SimpleManifoldCachedObjective{AllocatingEvaluation,$(mgoa)}" + @test startswith(repr((sco1, 1.0)), "## Cache\nA `SimpleManifoldCachedObjective`") + @test startswith( + repr((sco1, DummyState())), + "DummyState(Float64[])\n\n## Cache\nA `SimpleManifoldCachedObjective`", + ) # We evaluated on init -> 1 @test sco1.objective.gradient!!.i == 1 @test sco1.objective.cost.i == 1 @@ -177,6 +183,10 @@ end o = ManifoldGradientObjective(f, grad_f) co = ManifoldCountObjective(M, o, [:Cost, :Gradient]) lco = objective_cache_factory(M, co, (:LRU, [:Cost, :Gradient])) + @test startswith(repr(lco), "## Cache\n * ") + @test startswith( + repr((lco, DummyState())), "DummyState(Float64[])\n\n## Cache\n * " + ) ro = DummyDecoratedObjective(o) #indecorated works as well lco2 = objective_cache_factory(M, o, (:LRU, [:Cost, :Gradient])) diff --git a/test/plans/test_objective.jl b/test/plans/test_objective.jl index 0c33812ad9..9f3ad175cb 100644 --- a/test/plans/test_objective.jl +++ b/test/plans/test_objective.jl @@ -15,12 +15,10 @@ include("../utils/dummy_types.jl") r = Manopt.ReturnManifoldObjective(o) @test repr(o) == "ManifoldCostObjective{AllocatingEvaluation}" @test repr(r) == "ManifoldCostObjective{AllocatingEvaluation}" - @test Manopt.status_summary(o) == "ManifoldCostObjective{AllocatingEvaluation}" - @test Manopt.status_summary(r) == "ManifoldCostObjective{AllocatingEvaluation}" - @test repr((o, 1.0)) == """ - ManifoldCostObjective{AllocatingEvaluation} - - To access the solver result, call `get_solver_result` on this variable.""" + @test Manopt.status_summary(o) == "" # both simplified to empty + @test Manopt.status_summary(r) == "" + @test repr((o, 1.0)) == + "To access the solver result, call `get_solver_result` on this variable." d = DummyDecoratedObjective(o) r2 = Manopt.ReturnManifoldObjective(d) @test repr(r) == "ManifoldCostObjective{AllocatingEvaluation}" diff --git a/test/solvers/test_cyclic_proximal_point.jl b/test/solvers/test_cyclic_proximal_point.jl index d6440390a8..ae03a60364 100644 --- a/test/solvers/test_cyclic_proximal_point.jl +++ b/test/solvers/test_cyclic_proximal_point.jl @@ -1,4 +1,4 @@ -using Manifolds, Manopt, Test, Dates +using Manifolds, Manopt, Test, Dates, LRUCache @testset "Cyclic Proximal Point" begin @testset "Allocating" begin @@ -70,6 +70,20 @@ using Manifolds, Manopt, Test, Dates @test startswith( repr(r), "# Solver state for `Manopt.jl`s Cyclic Proximal Point Algorithm" ) + @testset "Caching" begin + r2 = cyclic_proximal_point( + N, + f, + proxes!, + q; + λ=i -> π / (2 * i), + cache=(:LRU, [:Cost, :ProximalMap], 50), + stopping_criterion=StopAfterIteration(100), + evaluation=InplaceEvaluation(), + return_state=true, + return_objective=true, + ) + end end @testset "Problem access functions" begin n = 3