Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce an EmbeddedObjective #286

Merged
merged 26 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
544e5b5
Introduce a sketch for an EmbeddedObjective.
kellertuer Aug 15, 2023
0c821f9
Improve a menu entry.
kellertuer Aug 15, 2023
99a2af2
Apply suggestions from code review
kellertuer Aug 16, 2023
215f49c
Merge branch 'kellertuer/EmbeddedObjective' of github.com:JuliaManifo…
kellertuer Aug 16, 2023
f445f49
removes a spurious variable name.
kellertuer Aug 17, 2023
db2c200
Fix docs for the object decorator.
kellertuer Aug 17, 2023
9e103c8
Write the wrapper for the Euclidean Hessian.
kellertuer Aug 17, 2023
353b211
Change the keyword to `objective_type`.
kellertuer Aug 17, 2023
687f1b6
use embed from ManifoldsBase.
kellertuer Aug 18, 2023
74c80aa
Constraint gradient conversion.
kellertuer Aug 26, 2023
ae8304d
Merge branch 'master' into kellertuer/EmbeddedObjective
kellertuer Aug 26, 2023
ebf3cf8
Work on the docs.
kellertuer Aug 27, 2023
1dc369d
Fix two typos.
kellertuer Aug 27, 2023
2b689cd
fix another typo.
kellertuer Aug 27, 2023
29818c3
Sketch Tutorial and debug the code a bit.
kellertuer Aug 28, 2023
4a9178c
Finish the tutorial.
kellertuer Aug 28, 2023
45f2b07
Update Tutorial and fix Docs.
kellertuer Aug 30, 2023
8589d51
Gradient and Hessian conversion tests.
kellertuer Aug 31, 2023
1c0c461
Test Coverage I.
kellertuer Aug 31, 2023
534bd63
Code Coverage II.
kellertuer Aug 31, 2023
bedba7e
Fix a small bug in the deco initialisation.
kellertuer Aug 31, 2023
ca9e2a8
Fix interaction of embedded objectives and sub solvers in default cre…
kellertuer Aug 31, 2023
96abe9b
Test Coverage III, read the new tutorial again, runs formatter.
kellertuer Aug 31, 2023
215bd37
Bump version.
kellertuer Sep 1, 2023
40c41a2
remove output from a few function definitions.
kellertuer Sep 1, 2023
b0d2d7a
Run the new embedding tutorial on dev.
kellertuer Sep 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 139 additions & 7 deletions src/plans/objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,155 @@
"""
struct InplaceEvaluation <: AbstractEvaluationType end

struct ReturnManifoldObjective{E,P,O<:AbstractManifoldObjective{E}} <:
AbstractDecoratedManifoldObjective{E,P}
objective::O
@doc raw"""
ReturnManifoldObjective{E,O2,O1<:AbstractManifoldObjective{E}} <:
AbstractDecoratedManifoldObjective{E,O2}

A wrapper to indicate that `get_solver_result` should return the inner objetcive.

The types are such that one can still dispatch on the undecorated type `O2` of the
original objective as well.
"""
struct ReturnManifoldObjective{E,O2,O1<:AbstractManifoldObjective{E}} <:
AbstractDecoratedManifoldObjective{E,O2}
objective::O1
end
function ReturnManifoldObjective(
o::O
) where {E<:AbstractEvaluationType,O<:AbstractManifoldObjective{E}}
return ReturnManifoldObjective{E,O,O}(o)
end
function ReturnManifoldObjective(
o::O
o::O1
) where {
E<:AbstractEvaluationType,
O2<:AbstractManifoldObjective,
O1<:AbstractDecoratedManifoldObjective{E,O2},
}
return ReturnManifoldObjective{E,O2,O1}(o)
end

@doc raw"""
EmbeddedManifoldObjective{P, T, E, O2, O1<:AbstractManifoldObjective{E}} <:
AbstractDecoratedManifoldObjective{O2, O1}

Declare an objective to be defined in the embedding.
This also declares the gradient to be defined in the embedding,
and especially being the Riesz representer with respect to the metric in the embedding.
The types can be used to still dispatch on also the undecorated objective type `O2`.

# Fields
* `objective` – the objective that is defined in the embedding
* `p` - (`nothing`) a point in the embedding.
* `X` - (`nothing`) a tangent vector in the embedding

When a point in the embedding `p` is provided, `embed!` is used in place of this point to reduce
memory allocations. Similarly `X` is used when embedding tangent vectors

"""
struct EmbeddedManifoldObjective{P,T,E,O2,O1<:AbstractManifoldObjective{E}} <:
AbstractDecoratedManifoldObjective{E,O2}
objective::O1
p::P
X::T
end
function EmbeddedManifoldObjective(

Check warning on line 102 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L102

Added line #L102 was not covered by tests
o::O, p::P=nothing, X::T=nothing
) where {P,T,E<:AbstractEvaluationType,O<:AbstractManifoldObjective{E}}
return EmbeddedManifoldObjective{P,T,E,O,O}(o, p, X)

Check warning on line 105 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L105

Added line #L105 was not covered by tests
end
function EmbeddedManifoldObjective(

Check warning on line 107 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L107

Added line #L107 was not covered by tests
o::O1, p::P=nothing, X::T=nothing
) where {
P,
T,
E<:AbstractEvaluationType,
P<:AbstractManifoldObjective,
O<:AbstractDecoratedManifoldObjective{E,P},
O2<:AbstractManifoldObjective,
O1<:AbstractDecoratedManifoldObjective{E,O2},
}
return ReturnManifoldObjective{E,P,O}(o)
return EmbeddedManifoldObjective{P,T,E,P,O}(o, p, X)

Check warning on line 116 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L116

Added line #L116 was not covered by tests
end
function EmbeddedManifoldObjective(

Check warning on line 118 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L118

Added line #L118 was not covered by tests
M::AbstractManifold,
o::O;
q=rand(M),
p::P=embed(M, q),
X::T=embed(M, q, rand(M; vector_at=q)),
) where {P,T,O<:AbstractManifoldObjective}
return EmbeddedManifoldObjective(o, p, X)

Check warning on line 125 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L125

Added line #L125 was not covered by tests
end

@doc raw"""
get_cost(M, emo::EmbeddedManifoldObjective, p)

Evaluate the cost function of an objective defined in the embedding, that is
call [`embed`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions/#ManifoldsBase.embed-Tuple{AbstractManifold,%20Any})
on the point `p` and call the original cost on this point.

if `emo.p` is not nothing, the embedding is done in place of `emo.p`
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
"""
function get_cost(M, emo::EmbeddedManifoldObjective{Nothing}, p)
return get_cost(get_embedding(M), emo.objective, embed(M, p))

Check warning on line 138 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L137-L138

Added lines #L137 - L138 were not covered by tests
end
function get_cost(M, emo::EmbeddedManifoldObjective{P}, p) where {P}
embed!(M, emo.p, p)
return get_cost(get_embedding(M), emo.objective, emo.p)

Check warning on line 142 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L140-L142

Added lines #L140 - L142 were not covered by tests
end

function get_cost_function(emo::EmbeddedManifoldObjective)
return (M, p) -> get_cost(M, emo, p)

Check warning on line 146 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L145-L146

Added lines #L145 - L146 were not covered by tests
end

@doc raw"""
get_gradient(M, emo::EmbeddedManifoldObjective, p)
get_gradient(M, X, emo::EmbeddedManifoldObjective, p)
kellertuer marked this conversation as resolved.
Show resolved Hide resolved

Evaluate the gradient function of an objective defined in the embedding, that is
call [`embed`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions/#ManifoldsBase.embed-Tuple{AbstractManifold,%20Any})
on the point `p` and call the original cost on this point.
And convert the gradient using [`riemannian_gradient`]() on the result.

if `emo.p` is not nothing, the embedding is done in place of `emo.p`
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
"""
function get_gradient(M, emo::EmbeddedManifoldObjective{Nothing,Nothing}, p)
return riemannian_gradient(

Check warning on line 161 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L160-L161

Added lines #L160 - L161 were not covered by tests
M, p, get_gradient(get_embedding(M), emo.objective, embed(M, p))
)
end
function get_gradient(M, emo::EmbeddedManifoldObjective{P,Nothing}, p) where {P}
embed!(M, emo.p, p)
return riemannian_gradient(M, p, get_gradient(get_embedding(M), emo.objective, emo.p))

Check warning on line 167 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L165-L167

Added lines #L165 - L167 were not covered by tests
end
function get_gradient(M, emo::EmbeddedManifoldObjective{Nothing,T}, p) where {T}
get_gradient!(get_embedding(M), emo.X, emo.objective, embed(M, p))
return riemannian_gradient(M, p, emo.X)

Check warning on line 171 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L169-L171

Added lines #L169 - L171 were not covered by tests
end
function get_gradient(M, emo::EmbeddedManifoldObjective{P,T}, p) where {P,T}
embed!(M, emo.p, p)
get_gradient!(get_embedding(M), emo.X, emo.objective, emo.p)
return riemannian_gradient(M, p, emo.X)

Check warning on line 176 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L173-L176

Added lines #L173 - L176 were not covered by tests
end
function get_gradient!(M, X, emo::EmbeddedManifoldObjective{Nothing,Nothing}, p)
riemannian_gradient!(

Check warning on line 179 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L178-L179

Added lines #L178 - L179 were not covered by tests
M, X, p, get_gradient(get_embedding(M), emo.objective, embed(M, p))
)
return X

Check warning on line 182 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L182

Added line #L182 was not covered by tests
end
function get_gradient!(M, X, emo::EmbeddedManifoldObjective{P,Nothing}, p) where {P}
embed!(M, emo.p, p)
riemannian_gradient!(M, X, p, get_gradient(get_embedding(M), emo, emo.p))
return X

Check warning on line 187 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L184-L187

Added lines #L184 - L187 were not covered by tests
end
function get_gradient!(M, X, emo::EmbeddedManifoldObjective{Nothing,T}, p) where {T}
get_gradient!(get_embedding(M), emo.X, emo, embed(M, p))
riemannian_gradient!(M, X, p, emo.X)
return X

Check warning on line 192 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L189-L192

Added lines #L189 - L192 were not covered by tests
end
function get_gradient!(M, X, emo::EmbeddedManifoldObjective{P,T}, p) where {P,T}
embed!(M, emo.p, p)
get_gradient!(get_embedding(M), emo.X, emo, emo.p)
riemannian_gradient!(M, X, p, emo.X)
return X

Check warning on line 198 in src/plans/objective.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/objective.jl#L194-L198

Added lines #L194 - L198 were not covered by tests
end

"""
Expand Down
26 changes: 20 additions & 6 deletions src/solvers/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@
optional arguments provide necessary details on the decorators.
A specific one is used to activate certain decorators.

* `cache` – (`missing`) specify a cache. Currenlty `:Simple` is supported and `:LRU` if you
* `cache` – (`missing`) specify a cache. Currenlty `:Simple` is supported and `:LRU` if you
load `LRUCache.jl`. For this case a tuple specifying what to cache and how many can be provided,
i.e. `(:LRU, [:Cost, :Gradient], 10)`, where the number specifies the size of each cache.
and 10 is the default if one omits the last tuple entry
* `count` – (`missing`) specify calls to the objective to be called, see [`ManifoldCountObjective`](@ref) for the full list

* `count` – (`missing`) specify calls to the objective to be called, see [`ManifoldCountObjective`](@ref) for the full list
* `objective` - (`missing`) specify that an objective is `:Euclidean`, which is equivalent to specifying the default embedding
the equivalent statement here is `:Embedded`, which fits better in naming when `M` is an []`
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
other keywords are ignored.

# See also
Expand All @@ -83,12 +84,25 @@
Missing,Symbol,Tuple{Symbol,<:AbstractArray},Tuple{Symbol,<:AbstractArray,P}
}=missing,
count::Union{Missing,AbstractVector{<:Symbol}}=missing,
objective::Union{Missing,Symbol}=missing,
_p=ismissing(objective) ? missing : rand(M),
embedded_p=ismissing(objective) ? missing : embed(M, _p),
embedded_X=ismissing(objective) ? missing : embed(M, _p, rand(M; vector_at=p)),
return_objective=false,
kwargs...,
) where {O<:AbstractManifoldObjective,P}
# Order: First count _then_ cache, so that cache is the outer decorator and
# we only count _after_ cache misses
deco_o = ismissing(count) ? o : objective_count_factory(M, o, count)
# Order:
# 1) wrap embedding,
# 2) _then_ count
# 3) _then_ cache,
# count should not be affected by 1) but cache should be on manifold not embedding
# => we only count _after_ cache misses
# and always last wrapper: ReturnObjective.
deco_o = o
if !ismissing(objective) && objective ∈ [:Embedding, :Euclidan]
deco_o = EmbeddedManifoldObjective(o, embedded_p, embedded_X)

Check warning on line 103 in src/solvers/solver.jl

View check run for this annotation

Codecov / codecov/patch

src/solvers/solver.jl#L103

Added line #L103 was not covered by tests
end
deco_o = ismissing(count) ? o : objective_count_factory(M, deco_o, count)
deco_o = ismissing(cache) ? deco_o : objective_cache_factory(M, deco_o, cache)
deco_o = return_objective ? ReturnManifoldObjective(deco_o) : deco_o
return deco_o
Expand Down
Loading