Skip to content

Commit

Permalink
final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Jan 27, 2025
1 parent d4bcc8c commit e476e82
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 59 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.4.0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"

[compat]
Aqua = "0.8"
Expand All @@ -15,6 +16,7 @@ Printf = "1"
Random = "1"
ScopedValues = "1"
Test = "1"
VectorInterface = "0.5"
julia = "1.8"

[extras]
Expand Down
29 changes: 24 additions & 5 deletions src/OptimKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module OptimKit
using LinearAlgebra: LinearAlgebra
using Printf
using ScopedValues
using VectorInterface
using Base: @kwdef

# Default values for the keyword arguments using ScopedValues
Expand All @@ -14,15 +15,34 @@ const GRADTOL = ScopedValue(1e-8)
const MAXITER = ScopedValue(1_000_000)
const VERBOSITY = ScopedValue(1)

_retract(x, d, α) = (x + α * d, d)
_inner(x, v1, v2) = v1 === v2 ? LinearAlgebra.norm(v1)^2 : LinearAlgebra.dot(v1, v2)
# Default values for the manifold structure
_retract(x, d, α) = (add(x, d, α), d)
_inner(x, v1, v2) = v1 === v2 ? norm(v1)^2 : real(inner(v1, v2))
_transport!(v, xold, d, α, xnew) = v
_scale!(v, α) = LinearAlgebra.rmul!(v, α)
_add!(vdst, vsrc, α) = LinearAlgebra.axpy!, vsrc, vdst)
_scale!(v, α) = scale!!(v, α)
_add!(vdst, vsrc, α) = add!!(vdst, vsrc, α)

_precondition(x, g) = deepcopy(g)
_finalize!(x, f, g, numiter) = x, f, g

# Default structs for new convergence and termination keywords
@kwdef struct DefaultHasConverged{T<:Real}
gradtol::T
end

function (d::DefaultHasConverged)(x, f, g, normgrad)
return normgrad <= d.gradtol
end

@kwdef struct DefaultShouldStop
maxiter::Int
end

function (d::DefaultShouldStop)(x, f, g, numfg, numiter, t)
return numiter >= d.maxiter
end

# Optimization
abstract type OptimizationAlgorithm end

const _xlast = Ref{Any}()
Expand Down Expand Up @@ -85,7 +105,6 @@ Also see [`GradientDescent`](@ref), [`ConjugateGradient`](@ref), [`LBFGS`](@ref)
function optimize end

include("linesearches.jl")
include("terminate.jl")
include("gd.jl")
include("cg.jl")
include("lbfgs.jl")
Expand Down
34 changes: 0 additions & 34 deletions src/tangentvector.jl

This file was deleted.

15 changes: 0 additions & 15 deletions src/terminate.jl

This file was deleted.

24 changes: 19 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,30 @@ function quadraticproblem(B, y)
return fg
end

function quadratictupleproblem(B, y)
function fg(x)
x1, x2 = x
y1, y2 = y
g1 = B * (x1 - y1)
g2 = x2 - y2
f = dot(x1 - y1, g1) / 2 + (x2 - y2)^2 / 2
return f, (g1, g2)
end
return fg
end

algorithms = (GradientDescent, ConjugateGradient, LBFGS)

@testset "Optimization Algorithm $algtype" for algtype in algorithms
n = 10
y = randn(n)
A = randn(n, n)
fg = quadraticproblem(A' * A, y)
A = A' * A
fg = quadraticproblem(A, y)
x₀ = randn(n)
alg = algtype(; verbosity=2, gradtol=1e-12, maxiter=10_000_000)
x, f, g, numfg, normgradhistory = optimize(fg, x₀, alg)
@test x y rtol = 10 * cond(A) * 1e-12
@test x y rtol = cond(A) * 1e-12
@test f < 1e-12

n = 1000
Expand All @@ -68,11 +81,12 @@ algorithms = (GradientDescent, ConjugateGradient, LBFGS)
smax = maximum(S)
A = U * Diagonal(1 .+ S ./ smax) * U'
# well conditioned, all eigenvalues between 1 and 2
fg = quadraticproblem(A' * A, y)
x₀ = randn(n)
fg = quadratictupleproblem(A' * A, (y, 1.0))
x₀ = (randn(n), 2.0)
alg = algtype(; verbosity=3, gradtol=1e-8)
x, f, g, numfg, normgradhistory = optimize(fg, x₀, alg)
@test x y rtol = 1e-7
@test x[1] y rtol = 1e-7
@test x[2] 1 rtol = 1e-7
@test f < 1e-12
end

Expand Down

2 comments on commit e476e82

@Jutho
Copy link
Owner Author

@Jutho Jutho commented on e476e82 Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Default functions for manipulating tangents and retractions now use VectorInterface.jl, so that tuples and nested structures are more easily supported
  • Various changes to improve robustness and configurability of both the optimization algorithms and the inner linesearch, in particular:
    • custom convergence and termination criteria
    • changes to info and warning output formatting and level settings
    • easier to control linesearch maximum number of iterations and maximum number of function evaluations

Breaking changes

  • The verbosity levels have changed meaning (slightly) so that printed output might not be the same as before
  • Requires Julia 1.8 because of use of ScopedVariables.jl

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/123678

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" e476e82d6c0dd0d0b70f5865991fc0359ba07946
git push origin v0.4.0

Please sign in to comment.