From 978e0173ec4d2edd6d4b4c12218f9838037eca8b Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Mon, 22 May 2023 11:42:01 +0200 Subject: [PATCH] rand for GL and Heisenberg group (#610) * rand for GL and Heisenberg group * Some docs * format * Update src/groups/general_linear.jl Co-authored-by: Ronny Bergmann * fix compose on semidirect product group * bump version * Upload coverage from all Julia versions (less chance of random upload error leading to incomplete coverage) --------- Co-authored-by: Ronny Bergmann --- .github/workflows/ci.yml | 2 +- Project.toml | 2 +- src/groups/general_linear.jl | 20 +++++++++++++ src/groups/heisenberg.jl | 37 +++++++++++++++++++++++++ src/groups/semidirect_product_group.jl | 4 ++- test/groups/general_linear.jl | 2 ++ test/groups/heisenberg.jl | 2 ++ test/groups/semidirect_product_group.jl | 5 ++++ 8 files changed, 71 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9adb54efe0..995f2b3632 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,4 +32,4 @@ jobs: - uses: codecov/codecov-action@v3 with: fail_ci_if_error: false - if: ${{ matrix.julia-version == '1.8' && matrix.os =='ubuntu-latest' }} + if: ${{ matrix.os =='ubuntu-latest' }} diff --git a/Project.toml b/Project.toml index 29353750a1..dfba41c6b9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.8.62" +version = "0.8.63" [deps] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" diff --git a/src/groups/general_linear.jl b/src/groups/general_linear.jl index 37d7c51212..c8f171bf65 100644 --- a/src/groups/general_linear.jl +++ b/src/groups/general_linear.jl @@ -246,6 +246,26 @@ project(::GeneralLinear, p, X) = X project!(::GeneralLinear, q, p) = copyto!(q, p) project!(::GeneralLinear, Y, p, X) = copyto!(Y, X) +@doc raw""" + Random.rand(G::GeneralLinear; vector_at=nothing, kwargs...) + +If `vector_at` is `nothing`, return a random point on the [`GeneralLinear`](@ref) group `G` +by using `rand` in the embedding. + +If `vector_at` is not `nothing`, return a random tangent vector from the tangent space of +the point `vector_at` on the [`GeneralLinear`](@ref) by using by using `rand` in the embedding. +""" +rand(G::GeneralLinear; kwargs...) + +function Random.rand!(G::GeneralLinear, pX; kwargs...) + rand!(get_embedding(G), pX; kwargs...) + return pX +end +function Random.rand!(rng::AbstractRNG, G::GeneralLinear, pX; kwargs...) + rand!(rng, get_embedding(G), pX; kwargs...) + return pX +end + Base.show(io::IO, ::GeneralLinear{n,𝔽}) where {n,𝔽} = print(io, "GeneralLinear($n, $𝔽)") translate_diff(::GeneralLinear, p, q, X, ::LeftAction) = X diff --git a/src/groups/heisenberg.jl b/src/groups/heisenberg.jl index 3910228025..e9c737d36c 100644 --- a/src/groups/heisenberg.jl +++ b/src/groups/heisenberg.jl @@ -353,6 +353,43 @@ function project!(M::HeisenbergGroup{n}, Y, p, X) where {n} return Y end +@doc raw""" + Random.rand(M::HeisenbergGroup; vector_at = nothing, σ::Real=1.0) + +If `vector_at` is `nothing`, return a random point on the [`HeisenbergGroup`](@ref) `M` +by sampling elements of the first row and the last column from the normal distribution with +mean 0 and standard deviation `σ`. + +If `vector_at` is not `nothing`, return a random tangent vector from the tangent space of +the point `vector_at` on the [`HeisenbergGroup`](@ref) by using a normal distribution with +mean 0 and standard deviation `σ`. +""" +rand(M::HeisenbergGroup; vector_at=nothing, σ::Real=1.0) + +function Random.rand!( + rng::AbstractRNG, + ::HeisenbergGroup{n}, + pX; + σ::Real=one(eltype(pX)), + vector_at=nothing, +) where {n} + if vector_at === nothing + copyto!(pX, I) + va = view(pX, 1, 2:(n + 2)) + randn!(rng, va) + va .*= σ + vb = view(pX, 2:(n + 1), n + 2) + randn!(rng, vb) + vb .*= σ + else + fill!(pX, 0) + randn!(rng, view(pX, 1, 2:(n + 2))) + randn!(rng, view(pX, 2:(n + 1), n + 2)) + pX .*= σ + end + return pX +end + Base.show(io::IO, ::HeisenbergGroup{n}) where {n} = print(io, "HeisenbergGroup($n)") translate_diff(::HeisenbergGroup, p, q, X, ::LeftAction) = X diff --git a/src/groups/semidirect_product_group.jl b/src/groups/semidirect_product_group.jl index a7b6e2c514..11743df2b8 100644 --- a/src/groups/semidirect_product_group.jl +++ b/src/groups/semidirect_product_group.jl @@ -123,13 +123,15 @@ function _compose!(G::SemidirectProductGroup, x, p, q) M = base_manifold(G) N, H = M.manifolds A = G.op.action + x_tmp = allocate(x) np, hp = submanifold_components(G, p) nq, hq = submanifold_components(G, q) - nx, hx = submanifold_components(G, x) + nx, hx = submanifold_components(G, x_tmp) compose!(H, hx, hp, hq) nxtmp = apply(A, hp, nq) compose!(N, nx, np, nxtmp) @inbounds _padpoint!(G, x) + copyto!(x, x_tmp) return x end diff --git a/test/groups/general_linear.jl b/test/groups/general_linear.jl index 04a0c07bc4..9462273c99 100644 --- a/test/groups/general_linear.jl +++ b/test/groups/general_linear.jl @@ -131,6 +131,8 @@ using NLsolve basis_types_to_from=basis_types, exp_log_atol_multiplier=1e7, retraction_atol_multiplier=1e7, + test_rand_point=true, + test_rand_tvector=true, ) end end diff --git a/test/groups/heisenberg.jl b/test/groups/heisenberg.jl index 7780e4447b..120bd43ab1 100644 --- a/test/groups/heisenberg.jl +++ b/test/groups/heisenberg.jl @@ -52,5 +52,7 @@ include("group_utils.jl") test_project_point=true, test_project_tangent=true, test_musical_isomorphisms=false, + test_rand_point=true, + test_rand_tvector=true, ) end diff --git a/test/groups/semidirect_product_group.jl b/test/groups/semidirect_product_group.jl index 679f9c9195..8d37078b67 100644 --- a/test/groups/semidirect_product_group.jl +++ b/test/groups/semidirect_product_group.jl @@ -43,6 +43,11 @@ include("group_utils.jl") @test compose(G, pts[1], e) == pts[1] @test compose(G, e, e) === e + # test in-place composition + o1 = copy(pts[1]) + compose!(G, o1, o1, pts[2]) + @test isapprox(G, o1, compose(G, pts[1], pts[2])) + eA = identity_element(G) @test isapprox(G, eA, e) @test isapprox(G, e, eA)