Skip to content

Commit

Permalink
Stiefel, SPD, minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Sep 6, 2023
1 parent f8961e7 commit b399539
Show file tree
Hide file tree
Showing 18 changed files with 325 additions and 251 deletions.
31 changes: 28 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Sizes of all manifolds can now be either encoded in type or stored in a field to avoid over-specialization.
The default is set to store the size in a field. To obtain the old behavior, pass the `parameter=:type` keyword
argument to manifold constructor. Related changes:
- Statically sized `SpecialEuclidean{N}` is now `SpecialEuclidean{TypeParameter{Tuple{N}}}`, whereas the type of special Euclidean group with field-stored size is `SpecialEuclidean{Tuple{Int}}`. Similar change applies to `GeneralUnitaryMultiplicationGroup{n}`, `Orthogonal{n}`, `SpecialOrthogonal{n}`, `SpecialUnitary{n}`, `SpecialEuclideanManifold{n}`, `TranslationGroup`. For example
- Statically sized `SpecialEuclidean{N}` is now `SpecialEuclidean{TypeParameter{Tuple{N}}}`, whereas the type of special Euclidean group with field-stored size is `SpecialEuclidean{Tuple{Int}}`. Similar change applies to:
- `CholeskySpace{N}`,
- `Euclidean`,
- `GeneralUnitaryMultiplicationGroup{n}`,
- `Grassmann{n,k}`,
- `Orthogonal{n}`,
- `SpecialOrthogonal{n}`,
- `SpecialUnitary{n}`,
- `SpecialEuclideanManifold{n}`,
- `Stiefel{n,k}`,
- `SymmetricPositiveDefinite{n}`,
- `TranslationGroup`.

For example

```{julia}
function Base.show(io::IO, ::SpecialEuclidean{n}) where {n}
Expand All @@ -43,10 +56,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
end
```

for groups with size stored in field.
for groups with size stored in field. Alternatively, you can use a single generic method like this:

```{julia}
function Base.show(io::IO, G::SpecialEuclidean{T}) where {T}
n = get_n(G)
if T <: TypeParameter
return print(io, "SpecialEuclidean($(n); parameter=:type)")
else
return print(io, "SpecialEuclidean($(n))")
end
end
```

- Argument order for type alias `RotationActionOnVector`: most often dispatched on argument is now first.

### Removed

- `ProductRepr` is removed; please use `ArrayPartition` instead.
- Default methods throwing "not implemented" `ErrorException` for some group-related operations.
- Default methods throwing "not implemented" `ErrorException` for some group-related operations.
34 changes: 27 additions & 7 deletions src/manifolds/CholeskySpace.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
@doc raw"""
CholeskySpace{N} <: AbstractManifold{ℝ}
CholeskySpace{T} <: AbstractManifold{ℝ}
The manifold of lower triangular matrices with positive diagonal and
a metric based on the cholesky decomposition. The formulae for this manifold
are for example summarized in Table 1 of [Lin:2019](@cite).
# Constructor
CholeskySpace(n)
CholeskySpace(n; parameter::Symbol=:field)
Generate the manifold of $n× n$ lower triangular matrices with positive diagonal.
"""
struct CholeskySpace{N} <: AbstractManifold{ℝ} end
struct CholeskySpace{T} <: AbstractManifold{ℝ}
size::T
end

CholeskySpace(n::Int) = CholeskySpace{n}()
function CholeskySpace(n::Int; parameter::Symbol=:field)
size = wrap_type_parameter(parameter, (n,))
return CholeskySpace{typeof(size)}(size)
end

@doc raw"""
check_point(M::CholeskySpace, p; kwargs...)
Expand Down Expand Up @@ -105,6 +110,9 @@ function exp!(::CholeskySpace, q, p, X)
return q
end

get_n(::CholeskySpace{TypeParameter{N}}) where {N} = N
get_n(M::CholeskySpace{Tuple{Int}}) = get_parameter(M.size)[1]

@doc raw"""
inner(M::CholeskySpace, p, X, Y)
Expand Down Expand Up @@ -164,16 +172,28 @@ Return the manifold dimension for the [`CholeskySpace`](@ref) `M`, i.e.
\dim(\mathcal M) = \frac{N(N+1)}{2}.
````
"""
@generated manifold_dimension(::CholeskySpace{N}) where {N} = div(N * (N + 1), 2)
function manifold_dimension(M::CholeskySpace)
N = get_n(M)
return div(N * (N + 1), 2)
end

@doc raw"""
representation_size(M::CholeskySpace)
Return the representation size for the [`CholeskySpace`](@ref)`{N}` `M`, i.e. `(N,N)`.
"""
@generated representation_size(::CholeskySpace{N}) where {N} = (N, N)
function representation_size(M::CholeskySpace)
N = get_n(M)
return (N, N)
end

Base.show(io::IO, ::CholeskySpace{N}) where {N} = print(io, "CholeskySpace($(N))")
function Base.show(io::IO, ::CholeskySpace{TypeParameter{Tuple{n}}}) where {n}
return print(io, "CholeskySpace($(n); parameter=:type)")
end
function Base.show(io::IO, M::CholeskySpace{Tuple{Int}})
n = get_n(M)
return print(io, "CholeskySpace($(n))")
end

# two small helpers for strictly lower and upper triangulars
strictlyLowerTriangular(p) = LowerTriangular(p) - Diagonal(diag(p))
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/Grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ A good overview can be found in[BendokatZimmermannAbsil:2020](@cite).
# Constructor
Grassmann(n,k,field=ℝ)
Grassmann(n, k, field=ℝ, parameter::Symbol=:field)
Generate the Grassmann manifold $\operatorname{Gr}(n,k)$, where the real-valued
case `field = ℝ` is the default.
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/GrassmannStiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end
ManifoldsBase.@manifold_element_forwards StiefelPoint value
ManifoldsBase.@manifold_vector_forwards StiefelTVector value
ManifoldsBase.@default_manifold_fallbacks Stiefel StiefelPoint StiefelTVector value value
ManifoldsBase.@default_manifold_fallbacks (Stiefel{n,k,ℝ} where {n,k}) StiefelPoint StiefelTVector value value
ManifoldsBase.@default_manifold_fallbacks (Stiefel{<:Any,ℝ}) StiefelPoint StiefelTVector value value
ManifoldsBase.@default_manifold_fallbacks Grassmann StiefelPoint StiefelTVector value value

function default_vector_transport_method(::Grassmann, ::Type{<:AbstractArray})
Expand Down
3 changes: 2 additions & 1 deletion src/manifolds/Rotations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,8 @@ and that means the inverse has to be appliead to the (Euclidean) Hessian
to map it into the Lie algebra.
"""
riemannian_Hessian(M::Rotations, p, G, H, X)
function riemannian_Hessian!(::Rotations{N}, Y, p, G, H, X) where {N}
function riemannian_Hessian!(M::Rotations, Y, p, G, H, X)
N = get_n(M)
symmetrize!(Y, G' * p)
project!(SkewSymmetricMatrices(N), Y, p' * H - X * Y)
return Y
Expand Down
Loading

0 comments on commit b399539

Please sign in to comment.