Skip to content

Commit

Permalink
Fix decorated objectives and Cache Display (#279)
Browse files Browse the repository at this point in the history
* Fix a bug in decorated objectives being passed to cppa; improve display when using caches in the summary.
* Test Coverage I.
* Test Coverage II.
  • Loading branch information
kellertuer authored Jul 12, 2023
1 parent b07016c commit 6482858
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manopt"
uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
authors = ["Ronny Bergmann <manopt@ronnybergmann.net>"]
version = "0.4.28"
version = "0.4.29"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Expand Down
35 changes: 35 additions & 0 deletions src/plans/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,38 @@ function objective_cache_factory(M, o, cache::Tuple{Symbol,<:AbstractArray})
(cache[1] === :LRU) && return ManifoldCachedObjective(M, o, cache[2])
return o
end
function show(io::IO, smco::SimpleManifoldCachedObjective{E}) where {E}
return print(io, "SimpleManifoldCachedObjective{$E,$(smco.objective)}")
end
function show(
io::IO, t::Tuple{<:SimpleManifoldCachedObjective,S}
) where {S<:AbstractManoptSolverState}
return print(io, "$(t[2])\n\n$(status_summary(t[1]))")
end
function show(io::IO, mco::ManifoldCachedObjective)
return print(io, "$(status_summary(mco))")
end
function show(
io::IO, t::Tuple{<:ManifoldCachedObjective,S}
) where {S<:AbstractManoptSolverState}
return print(io, "$(t[2])\n\n$(status_summary(t[1]))")
end

function status_summary(smco::SimpleManifoldCachedObjective)
s = "## Cache\nA `SimpleManifoldCachedObjective` to cache one point and one tangent vector for the iterate and gradient, respectively"
s2 = status_summary(smco.objective)
length(s2) > 0 && (s2 = "\n$(s2)")
return "$(s)$(s2)"
end
function status_summary(mco::ManifoldCachedObjective)
s = "## Cache\n"
longest_key_length = max(length.(["$k" for k in keys(mco.cache)])...)
cache_strings = [
" * :" *
rpad("$k", longest_key_length, " ") *
" : $(v.currentsize)/$(v.maxsize) entries of type $(valtype(v)) used" for
(k, v) in zip(keys(mco.cache), values(mco.cache))
]
s2 = status_summary(mco.objective)
return "$(s)$(join(cache_strings,"\n"))\n\n$s2"
end
4 changes: 2 additions & 2 deletions src/plans/count.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ function status_summary(co::ManifoldCountObjective)
" * :$(rpad("$(c[1])",longest_key_length)) : $(c[2])" for c in co.counts
]
s2 = status_summary(co.objective)
!(co.objective isa AbstractDecoratedManifoldObjective) && (s2 = "on a $(s2)")
return "$(s)$(join(count_strings,"\n"))\n$s2"
(length(s2) > 0) && (s2 = "\n$(s2)")
return "$(s)$(join(count_strings,"\n"))$s2"
end

function show(io::IO, co::ManifoldCountObjective)
Expand Down
10 changes: 4 additions & 6 deletions src/plans/objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ function _get_objective(o::AbstractManifoldObjective, ::Val{true}, rec=true)
return rec ? get_objective(o.objective) : o.objective
end
function status_summary(o::AbstractManifoldObjective{E}) where {E}
return "$(nameof(typeof(o))){$E}"
return ""#"$(nameof(typeof(o))){$E}"
end
# Default undecorate for summary
function status_summary(co::AbstractDecoratedManifoldObjective)
Expand All @@ -126,11 +126,9 @@ function show(io::IO, co::AbstractDecoratedManifoldObjective)
end

function show(io::IO, t::Tuple{<:AbstractManifoldObjective,P}) where {P}
s = "$(status_summary(t[1]))"
length(s) > 0 && (s = "$(s)\n\n")
return print(
io,
"""
$(status_summary(t[1]))
To access the solver result, call `get_solver_result` on this variable.""",
io, "$(s)To access the solver result, call `get_solver_result` on this variable."
)
end
7 changes: 5 additions & 2 deletions src/plans/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@ get_manifold(::AbstractManoptProblem)
get_manifold(amp::DefaultManoptProblem) = amp.manifold

@doc raw"""
get_objective(mp::AbstractManoptProblem)
get_objective(mp::AbstractManoptProblem, recursive=false)
return the objective [`AbstractManifoldObjective`](@ref) stored within an [`AbstractManoptProblem`](@ref).
If `recursive is set to true, it additionally unwraps all decorators of the objective`
"""
get_objective(::AbstractManoptProblem)

get_objective(amp::DefaultManoptProblem) = amp.objective
function get_objective(amp::DefaultManoptProblem, recursive=false)
return recursive ? get_objective(amp.objective, true) : amp.objective
end

@doc raw"""
get_cost(amp::AbstractManoptProblem, p)
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/cyclic_proximal_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ function cyclic_proximal_point!(
return get_solver_return(get_objective(dmp), dcpps)
end
function initialize_solver!(amp::AbstractManoptProblem, cpps::CyclicProximalPointState)
c = length(get_objective(amp).proximal_maps!!)
c = length(get_objective(amp, true).proximal_maps!!)
cpps.order = collect(1:c)
(cpps.order_type == :FixedRandom) && shuffle!(cpps.order)
return cpps
Expand Down
10 changes: 10 additions & 0 deletions test/plans/test_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ end
mgoa = ManifoldGradientObjective(TestCostCount(0), TestGradCount(0))
mcgoa = ManifoldGradientObjective(TestCostCount(0), TestGradCount(0))
sco1 = Manopt.SimpleManifoldCachedObjective(M, mgoa; p=p)
@test repr(sco1) == "SimpleManifoldCachedObjective{AllocatingEvaluation,$(mgoa)}"
@test startswith(repr((sco1, 1.0)), "## Cache\nA `SimpleManifoldCachedObjective`")
@test startswith(
repr((sco1, DummyState())),
"DummyState(Float64[])\n\n## Cache\nA `SimpleManifoldCachedObjective`",
)
# We evaluated on init -> 1
@test sco1.objective.gradient!!.i == 1
@test sco1.objective.cost.i == 1
Expand Down Expand Up @@ -177,6 +183,10 @@ end
o = ManifoldGradientObjective(f, grad_f)
co = ManifoldCountObjective(M, o, [:Cost, :Gradient])
lco = objective_cache_factory(M, co, (:LRU, [:Cost, :Gradient]))
@test startswith(repr(lco), "## Cache\n * ")
@test startswith(
repr((lco, DummyState())), "DummyState(Float64[])\n\n## Cache\n * "
)
ro = DummyDecoratedObjective(o)
#indecorated works as well
lco2 = objective_cache_factory(M, o, (:LRU, [:Cost, :Gradient]))
Expand Down
10 changes: 4 additions & 6 deletions test/plans/test_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ include("../utils/dummy_types.jl")
r = Manopt.ReturnManifoldObjective(o)
@test repr(o) == "ManifoldCostObjective{AllocatingEvaluation}"
@test repr(r) == "ManifoldCostObjective{AllocatingEvaluation}"
@test Manopt.status_summary(o) == "ManifoldCostObjective{AllocatingEvaluation}"
@test Manopt.status_summary(r) == "ManifoldCostObjective{AllocatingEvaluation}"
@test repr((o, 1.0)) == """
ManifoldCostObjective{AllocatingEvaluation}
To access the solver result, call `get_solver_result` on this variable."""
@test Manopt.status_summary(o) == "" # both simplified to empty
@test Manopt.status_summary(r) == ""
@test repr((o, 1.0)) ==
"To access the solver result, call `get_solver_result` on this variable."
d = DummyDecoratedObjective(o)
r2 = Manopt.ReturnManifoldObjective(d)
@test repr(r) == "ManifoldCostObjective{AllocatingEvaluation}"
Expand Down
16 changes: 15 additions & 1 deletion test/solvers/test_cyclic_proximal_point.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Manifolds, Manopt, Test, Dates
using Manifolds, Manopt, Test, Dates, LRUCache

@testset "Cyclic Proximal Point" begin
@testset "Allocating" begin
Expand Down Expand Up @@ -70,6 +70,20 @@ using Manifolds, Manopt, Test, Dates
@test startswith(
repr(r), "# Solver state for `Manopt.jl`s Cyclic Proximal Point Algorithm"
)
@testset "Caching" begin
r2 = cyclic_proximal_point(
N,
f,
proxes!,
q;
λ=i -> π / (2 * i),
cache=(:LRU, [:Cost, :ProximalMap], 50),
stopping_criterion=StopAfterIteration(100),
evaluation=InplaceEvaluation(),
return_state=true,
return_objective=true,
)
end
end
@testset "Problem access functions" begin
n = 3
Expand Down

2 comments on commit 6482858

@kellertuer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/87325

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.29 -m "<description of version>" 6482858d7b540d6174dc0bea60122ec6f94f1b85
git push origin v0.4.29

Please sign in to comment.