Skip to content

Commit

Permalink
adapt TR
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Aug 1, 2024
1 parent ab0bb15 commit 505145a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
6 changes: 1 addition & 5 deletions src/solvers/difference_of_convex_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,7 @@ end
function initialize_solver!(::AbstractManoptProblem, dcs::DifferenceOfConvexState)
return dcs
end
function step_solver!(
amp::AbstractManoptProblem,
dcs::DifferenceOfConvexState{<:AbstractManoptProblem,<:AbstractManoptSolverState},
i,
)
function step_solver!(amp::AbstractManoptProblem, dcs::DifferenceOfConvexState, i)
M = get_manifold(amp)
get_subtrahend_gradient!(amp, dcs.X, dcs.p)
set_manopt_parameter!(dcs.sub_problem, :Objective, :Cost, :p, dcs.p)
Expand Down
43 changes: 28 additions & 15 deletions src/solvers/trust_regions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ A trust region state, where the sub problem is solved in closed form by a functi
[`trust_regions`](@ref), [`trust_regions!`](@ref)
"""
mutable struct TrustRegionsState{
P,T,Pr,St,SC<:StoppingCriterion,RTR<:AbstractRetractionMethod,R<:Real,Proj
P,
T,
Pr,
St<:AbstractManoptSolverState,
SC<:StoppingCriterion,
RTR<:AbstractRetractionMethod,
R<:Real,
Proj,
} <: AbstractSubProblemSolverState
p::P
X::T
Expand Down Expand Up @@ -111,7 +118,16 @@ mutable struct TrustRegionsState{
reduction_factor=0.25,
augmentation_factor=2.0,
σ::R=random ? 1e-6 : 0.0,
) where {P,T,Pr,St,SC<:StoppingCriterion,RTR<:AbstractRetractionMethod,R<:Real,Proj}
) where {
P,
T,
Pr,
St<:AbstractManoptSolverState,
SC<:StoppingCriterion,
RTR<:AbstractRetractionMethod,
R<:Real,
Proj,
}
trs = new{P,T,Pr,St,SC,RTR,R,Proj}()
trs.p = p
trs.X = X
Expand Down Expand Up @@ -140,23 +156,20 @@ function TrustRegionsState(
return TrustRegionsState(M, rand(M), sub_problem; kwargs...)
end
# No point but state -> add point
function TrustRegionsState(
M, sub_problem::Pr, sub_state::St; kwargs...
) where {
Pr<:Union{AbstractManoptProblem,T} where {T},
St<:Union{AbstractManoptSolverState,AbstractEvaluationType},
}
function TrustRegionsState(M, sub_problem, sub_state::AbstractManoptSolverState; kwargs...)
return TrustRegionsState(M, rand(M), sub_problem, sub_state; kwargs...)
end
# point, but no state for a function -> add evaluation as state
function TrustRegionsState(
M,
p,
sub_problem::Pr;
sub_problem::Function;
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
) where {Pr<:Function}
return TrustRegionsState(M, p, sub_problem, evaluation; kwargs...)
)
return TrustRegionsState(

Check warning on line 170 in src/solvers/trust_regions.jl

View check run for this annotation

Codecov / codecov/patch

src/solvers/trust_regions.jl#L170

Added line #L170 was not covered by tests
M, p, sub_problem, ClosedFormSubSolverState(evaluation); kwargs...
)
end
function TrustRegionsState(
M, p, mho::H; kwargs...
Expand All @@ -170,7 +183,7 @@ function TrustRegionsState(
M::TM,
p::P,
sub_problem::Pr,
sub_state::St=TruncatedConjugateGradientState(
sub_state::Union{AbstractEvaluationType,AbstractManoptSolverState}=TruncatedConjugateGradientState(
TangentSpace(M, copy(M, p)), zero_vector(M, p)
);
X::T=zero_vector(M, p),
Expand All @@ -191,15 +204,15 @@ function TrustRegionsState(
) where {
TM<:AbstractManifold,
Pr<:AbstractManoptProblem,
St,
P,
T,
R<:Real,
SC<:StoppingCriterion,
RTR<:AbstractRetractionMethod,
Proj,
}
return TrustRegionsState{P,T,Pr,St,SC,RTR,R,Proj}(
sub_state_storage = maybe_wrap_evaluation_type(sub_state)
return TrustRegionsState{P,T,Pr,typeof(sub_state_storage),SC,RTR,R,Proj}(
p,
X,
trust_region_radius,
Expand All @@ -212,7 +225,7 @@ function TrustRegionsState(
reduction_threshold,
augmentation_threshold,
sub_problem,
sub_state,
sub_state_storage,
project!,
reduction_factor,
augmentation_factor,
Expand Down

0 comments on commit 505145a

Please sign in to comment.