Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a :Subsolver debug symbol and a DebugWhenActive #285

Merged
merged 10 commits into from
Aug 14, 2023
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manopt"
uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
authors = ["Ronny Bergmann <manopt@ronnybergmann.net>"]
version = "0.4.30"
version = "0.4.31"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ makedocs(
"Speedup using Inplace computations" => "tutorials/InplaceGradient.md",
"Use Automatic Differentiation" => "tutorials/AutomaticDifferentiation.md",
"Count and use a Cache" => "tutorials/CountAndCache.md",
"Perform Debug Output" => "tutorials/HowToDebug.md",
"Record values" => "tutorials/HowToRecord.md",
"Implement a Solver" => "tutorials/ImplementASolver.md",
"Do Contrained Optimization" => "tutorials/ConstrainedOptimization.md",
Expand Down
7 changes: 7 additions & 0 deletions docs/src/plans/state.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,10 @@ AbstractGradientSolverState
AbstractHessianSolverState
AbstractPrimalDualSolverState
```

For the sub problem state, there are two access functions

```@docs
get_sub_problem
get_sub_state
```
3 changes: 1 addition & 2 deletions src/Manopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,6 @@ export get_proximal_map,
set_gradient!,
set_iterate!,
set_manopt_parameter!,
set_manopt_parameter!,
set_manopt_parameter!,
linearized_forward_operator,
linearized_forward_operator!,
adjoint_linearized_operator,
Expand Down Expand Up @@ -482,6 +480,7 @@ export DebugDualBaseChange, DebugDualBaseIterate, DebugDualChange, DebugDualIter
export DebugDualResidual, DebugPrimalDualResidual, DebugPrimalResidual
export DebugProximalParameter, DebugWarnIfCostIncreases
export DebugGradient, DebugGradientNorm, DebugStepsize
export DebugWhenActive
export DebugWarnIfCostNotFinite, DebugWarnIfFieldNotFinite, DebugMessages
#
# Records - and access functions
Expand Down
91 changes: 84 additions & 7 deletions src/plans/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ function show(io::IO, dg::DebugGroup)
s = join(["$(di)" for di in dg.group], ", ")
return print(io, "DebugGroup([$s])")
end
function set_manopt_parameter!(dg::DebugGroup, v::Val, args...)
for di in dg.group
set_manopt_parameter!(di, v, args...)
end
return dg
end
function set_manopt_parameter!(dg::DebugGroup, e::Symbol, args...)
set_manopt_parameter!(dg, Val(e), args...)
return dg
end

@doc raw"""
DebugEvery <: DebugAction
Expand Down Expand Up @@ -154,6 +164,9 @@ function (d::DebugEvery)(p::AbstractManoptProblem, st::AbstractManoptSolverState
elseif d.always_update
d.debug(p, st, -1)
end
# set activity for the next iterate in subsolvers
set_manopt_parameter!(st, :SubState, :Debug, :active, !(i<1) && (rem(i + 1, d.every) == 0))
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
return nothing
end
function show(io::IO, de::DebugEvery)
return print(io, "DebugEvery($(de.debug), $(de.every), $(de.always_update))")
Expand All @@ -167,6 +180,15 @@ function status_summary(de::DebugEvery)
end
return "[$s, $(de.every)]"
end
function set_manopt_parameter!(de::DebugEvery, e::Symbol, args...)
set_manopt_parameter!(de, Val(e), args...)
return de
end
function set_manopt_parameter!(de::DebugEvery, args...)
set_manopt_parameter!(de.debug, args...)
return de
end

#
# Special single ones
#
Expand Down Expand Up @@ -256,13 +278,13 @@ print the current cost function value, see [`get_cost`](@ref).

* `format` - (`"$prefix %f"`) format to print the output using sprintf and a prefix (see `long`).
* `io` – (`stdout`) default steream to print the debug to.
* `long` - (`false`) short form to set the format to `F(x):` (default) or `current cost: ` and the cost
* `long` - (`false`) short form to set the format to `f(x):` (default) or `current cost: ` and the cost
"""
mutable struct DebugCost <: DebugAction
io::IO
format::String
function DebugCost(;
long::Bool=false, io::IO=stdout, format=long ? "current cost: %f" : "F(x): %f"
long::Bool=false, io::IO=stdout, format=long ? "current cost: %f" : "f(x): %f"
)
return new(io, format)
end
Expand Down Expand Up @@ -576,6 +598,57 @@ end
show(io::IO, ::DebugStoppingCriterion) = print(io, "DebugStoppingCriterion()")
status_summary(::DebugStoppingCriterion) = ":Stop"

@doc raw"""
DebugWhenActive <: DebugAction

evaluate and print debug only if the active boolean is set.
This can be set from outside and is for example triggered by [`DebugEvery`](@ref)
on debugs on the subsolver.

This method does not perform any print itself but relies on it's childrens print.

For now, the main interaction is with [`DebugEvery`](@ref) which might activate or
deactivate this debug

# Fields
* `always_update` – whether or not to call the order debugs with iteration `-1` in in active state
* `active` – a boolean that can (de-)activated from outside to enable/disable debug

# Constructor

DebugWhenActive(d::DebugAction, active=true, always_update=true)

Initialise the DebugSubsolver.
"""
mutable struct DebugWhenActive <: DebugAction
debug::DebugAction
active::Bool
always_update::Bool
function DebugWhenActive(d::DebugAction, active::Bool=true, always_update::Bool=true)
return new(d, active, always_update)
end
end
function (dwa::DebugWhenActive)(p::AbstractManoptProblem, st::AbstractManoptSolverState, i)
if dwa.active
dwa.debug(p, st, i)
elseif dwa.always_update
dwa.debug(p, st, -1)
end
end
function show(io::IO, dwa::DebugWhenActive)
return print(io, "DebugWhenActive($(dwa.debug), $(dwa.active), $(dwa.always_update))")
end
function status_summary(dwa::DebugWhenActive)
return repr(dwa)
end
function set_manopt_parameter!(dwa::DebugWhenActive, v::Val, args...)
set_manopt_parameter!(dwa.debug, v, args...)
return dwa
end
function set_manopt_parameter!(dwa::DebugWhenActive, ::Val{:active}, v)
return dwa.active = v
end

@doc raw"""
DebugTime()

Expand Down Expand Up @@ -808,6 +881,7 @@ given an array of `Symbol`s, `String`s [`DebugAction`](@ref)s and `Ints`

* The symbol `:Stop` creates an entry of to display the stopping criterion at the end
(`:Stop => DebugStoppingCriterion()`), for further symbols see [`DebugActionFactory`](@ref DebugActionFactory(::Symbol))
* The symbol `:Subsolver` wraps all `dictionary` entries with [`DebugWhenActive`](@ref) that can be set from outside.
* Tuples of a symbol and a string can be used to also specify a format, see [`DebugActionFactory`](@ref DebugActionFactory(::Tuple{Symbol,String}))
* any string creates a [`DebugDivider`](@ref)
* any [`DebugAction`](@ref) is directly included
Expand Down Expand Up @@ -836,7 +910,7 @@ It also adds the [`DebugStoppingCriterion`](@ref) to the `:Stop` entry of the di
function DebugFactory(a::Array{<:Any,1})
# filter out every
group = Array{DebugAction,1}()
for s in filter(x -> !isa(x, Int) && x != :Stop, a) # filter ints and stop
for s in filter(x -> !isa(x, Int) && (x ∉ [:Stop, :Subsolver]), a) # filter ints and stop
push!(group, DebugActionFactory(s))
end
dictionary = Dict{Symbol,DebugAction}()
Expand All @@ -849,8 +923,11 @@ function DebugFactory(a::Array{<:Any,1})
end
dictionary[:All] = debug
end
if :Stop in a
dictionary[:Stop] = DebugStoppingCriterion()
(:Stop in a) && (dictionary[:Stop] = DebugStoppingCriterion())
if (:Subsolver in a)
for k in keys(dictionary)
dictionary[k] = DebugWhenActive(dictionary[k])
end
end
return dictionary
end
Expand Down Expand Up @@ -887,8 +964,8 @@ Note that the Shortcut symbols should all start with a capital letter.
* `:Time` creates a [`DebugTime`](@ref)
* `:WarningMessages`creates a [`DebugMessages`](@ref)`(:Warning)`
* `:InfoMessages`creates a [`DebugMessages`](@ref)`(:Info)`
* `:ErrorMessages`creates a [`DebugMessages`](@ref)`(:Error)`
* `:Messages`creates a [`DebugMessages`](@ref)`()` (i.e. the same as `:InfoMessages`)
* `:ErrorMessages` creates a [`DebugMessages`](@ref)`(:Error)`
* `:Messages` creates a [`DebugMessages`](@ref)`()` (i.e. the same as `:InfoMessages`)

any other symbol creates a `DebugEntry(s)` to print the entry (o.:s) from the options.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/plans/difference_of_convex_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function set_manopt_parameter!(ldc::LinearizedDCCost, ::Val{:p}, p)
return ldc
end
function set_manopt_parameter!(ldc::LinearizedDCCost, ::Val{:X}, X)
ldc.Xk = X
ldc.Xk .= X
return ldc
end

Expand Down Expand Up @@ -203,7 +203,7 @@ function set_manopt_parameter!(ldcg::LinearizedDCGrad, ::Val{:p}, p)
return ldcg
end
function set_manopt_parameter!(ldcg::LinearizedDCGrad, ::Val{:X}, X)
ldcg.Xk = X
ldcg.Xk .= X
return ldcg
end

Expand Down
33 changes: 27 additions & 6 deletions src/plans/objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,26 @@ _get_objective(o::AbstractManifoldObjective, ::Val{false}, rec=true) = o
function _get_objective(o::AbstractManifoldObjective, ::Val{true}, rec=true)
return rec ? get_objective(o.objective) : o.objective
end
function status_summary(o::AbstractManifoldObjective{E}) where {E}
return ""#"$(nameof(typeof(o))){$E}"

"""
set_manopt_parameter!(amo::AbstractManifoldObjective, element::Symbol, args...)

Set a certain `args...` from the [`AbstractManifoldObjective`](@ref) `amo` to `value.
This function should dispatch on `Val(element)`.

Currently supported
* `:Cost` passes to the [`get_cost_function`](@ref)
* `:Gradient` passes to the [`get_gradient_function`](@ref)
"""
set_manopt_parameter!(amo::AbstractManifoldObjective, e::Symbol, args...)

function set_manopt_parameter!(amo::AbstractManifoldObjective, ::Val{:Cost}, args...)
set_manopt_parameter!(get_cost_function(amo), args...)
return amo
end
# Default undecorate for summary
function status_summary(co::AbstractDecoratedManifoldObjective)
return status_summary(get_objective(co, false))
function set_manopt_parameter!(amo::AbstractManifoldObjective, ::Val{:Gradient}, args...)
set_manopt_parameter!(get_gradient_function(amo), args...)
return amo
end

function show(io::IO, o::AbstractManifoldObjective{E}) where {E}
Expand All @@ -124,11 +138,18 @@ end
function show(io::IO, co::AbstractDecoratedManifoldObjective)
return show(io, get_objective(co, false))
end

function show(io::IO, t::Tuple{<:AbstractManifoldObjective,P}) where {P}
s = "$(status_summary(t[1]))"
length(s) > 0 && (s = "$(s)\n\n")
return print(
io, "$(s)To access the solver result, call `get_solver_result` on this variable."
)
end

function status_summary(o::AbstractManifoldObjective{E}) where {E}
return ""
end
# Default undecorate for summary
function status_summary(co::AbstractDecoratedManifoldObjective)
return status_summary(get_objective(co, false))
end
13 changes: 13 additions & 0 deletions src/plans/plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@ It might also be more verbose in explaining, or hide internal information.
"""
status_summary(e) = "$(e)"

"""
set_manopt_parameter!(f, element::Symbol , args...)

For any `f` and a `Symbol` `e` we dispatch on its value so by default, to
set some `args...` in `f` or one of uts sub elements.
"""
function set_manopt_parameter!(f, e::Symbol, args...)
return set_manopt_parameter!(f, Val(e), args...)
end
function set_manopt_parameter!(f, args...)
return f
end

include("objective.jl")
include("problem.jl")
include("solver_state.jl")
Expand Down
15 changes: 15 additions & 0 deletions src/plans/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,18 @@ end
evaluate the cost function `f` defined on `M` stored within the [`AbstractManifoldObjective`](@ref) at the point `p`.
"""
get_cost(::AbstractManifold, ::AbstractManifoldObjective, p)

"""
set_manopt_parameter!(ams::AbstractManoptProblem, element::Symbol, field::Symbol , value)

Set a certain field/element from the [`AbstractManoptProblem`](@ref) `ams` to `value.
This function should dispatch on `Val(element)`.

By default this passes on to the inner objective, see [`set_manopt_parameter!`](@ref)
"""
set_manopt_parameter!(amp::AbstractManoptProblem, e::Symbol, args...)

function set_manopt_parameter!(amp::AbstractManoptProblem, ::Val{:Objective}, args...)
set_manopt_parameter!(get_objective(amp), args...)
return amp
end
Loading
Loading