From 7f6d6f3396ab7529ba7fb93ed924e488c3801c91 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Sep 2019 16:46:30 +0300 Subject: [PATCH] cholesky --- src/Kronecker.jl | 3 +- src/base.jl | 11 ++-- src/cholesky.jl | 65 +++++++++++++++++++ src/factorization.jl | 2 +- src/indexedkroncker.jl | 6 +- src/kroneckerpowers.jl | 18 +++-- src/kroneckersum.jl | 4 +- test/{testbase.jl => base.jl} | 0 test/cholesky.jl | 32 +++++++++ test/{testeigen.jl => eigen.jl} | 0 ...{testfactorization.jl => factorization.jl} | 0 test/{testindexed.jl => indexed.jl} | 0 ...tkroneckergraphs.jl => kroneckergraphs.jl} | 0 ...tkroneckerpowers.jl => kroneckerpowers.jl} | 0 test/{testkroneckersum.jl => kroneckersum.jl} | 0 test/runtests.jl | 17 ++--- test/{testvectrick.jl => vectrick.jl} | 0 17 files changed, 129 insertions(+), 29 deletions(-) create mode 100644 src/cholesky.jl rename test/{testbase.jl => base.jl} (100%) create mode 100644 test/cholesky.jl rename test/{testeigen.jl => eigen.jl} (100%) rename test/{testfactorization.jl => factorization.jl} (100%) rename test/{testindexed.jl => indexed.jl} (100%) rename test/{testkroneckergraphs.jl => kroneckergraphs.jl} (100%) rename test/{testkroneckerpowers.jl => kroneckerpowers.jl} (100%) rename test/{testkroneckersum.jl => kroneckersum.jl} (100%) rename test/{testvectrick.jl => vectrick.jl} (100%) diff --git a/src/Kronecker.jl b/src/Kronecker.jl index 8daf871..80c55e5 100644 --- a/src/Kronecker.jl +++ b/src/Kronecker.jl @@ -20,7 +20,8 @@ include("kroneckerpowers.jl") include("vectrick.jl") include("indexedkroncker.jl") include("eigen.jl") -include("factorization.jl") +include("cholesky.jl") +# include("factorization.jl") include("kroneckersum.jl") include("kroneckergraphs.jl") diff --git a/src/base.jl b/src/base.jl index f24f4e6..3d2ae5b 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,11 +1,11 @@ -abstract type GeneralizedKroneckerProduct <: AbstractMatrix{Number} end +abstract type GeneralizedKroneckerProduct{T} <: AbstractMatrix{T} end -abstract type AbstractKroneckerProduct <: GeneralizedKroneckerProduct end +abstract type AbstractKroneckerProduct{T} <: GeneralizedKroneckerProduct{T} end Base.IndexStyle(::Type{<:GeneralizedKroneckerProduct}) = IndexCartesian() # general Kronecker product between two matrices -struct KroneckerProduct{T,TA<:AbstractMatrix, TB<:AbstractMatrix} <: AbstractKroneckerProduct +struct KroneckerProduct{T, TA<:AbstractMatrix, TB<:AbstractMatrix} <: AbstractKroneckerProduct{T} A::TA B::TB function KroneckerProduct(A::AbstractMatrix{T}, B::AbstractMatrix{V}) where {T, V} @@ -63,10 +63,7 @@ Returns a matrix itself. Needed for recursion. """ getmatrices(A::AbstractArray) = (A,) -function Base.eltype(K::AbstractKroneckerProduct) - A, B = getmatrices(K) - return promote_type(eltype(A), eltype(B)) -end +Base.eltype(K::AbstractKroneckerProduct{T}) where {T} = T function Base.size(K::AbstractKroneckerProduct) A, B = getmatrices(K) diff --git a/src/cholesky.jl b/src/cholesky.jl new file mode 100644 index 0000000..01beb3e --- /dev/null +++ b/src/cholesky.jl @@ -0,0 +1,65 @@ +import Base: getproperty +import LinearAlgebra: cholesky, Cholesky, char_uplo, UpperTriangular, LowerTriangular, \, + istril, istriu + +const KroneckerCholesky{T} = Cholesky{T, <:AbstractKroneckerProduct{T}} where {T} + +function cholesky(A::AbstractKroneckerProduct; check=true) + P, Q = getmatrices(A) + chol_P, chol_Q = cholesky(P; check=true), cholesky(Q; check=true) + return Cholesky(chol_P.factors ⊗ chol_Q.factors, 'U', 0) +end + +function Cholesky(factors::KroneckerProduct{T}, uplo::AbstractChar, info::Integer) where {T} + return Cholesky{T, typeof(factors)}(factors, uplo, info) +end + +function UpperTriangular(C::AbstractKroneckerProduct) + A, B = getmatrices(C) + return UpperTriangular(A) ⊗ UpperTriangular(B) +end + +function LowerTriangular(C::AbstractKroneckerProduct) + A, B = getmatrices(C) + return LowerTriangular(A) ⊗ LowerTriangular(B) +end + +function istril(C::AbstractKroneckerProduct) + A, B = getmatrices(C) + return istril(A) && istril(B) +end + +function istriu(C::AbstractKroneckerProduct) + A, B = getmatrices(C) + return istriu(A) && istriu(B) +end + +function getproperty(C::KroneckerCholesky, d::Symbol) + Cfactors = getfield(C, :factors) + Cuplo = getfield(C, :uplo) + info = getfield(C, :info) + + Cuplo != 'U' && throw(NotImplementedError("")) + + if d == :U + return UpperTriangular(Cuplo === char_uplo(d) ? Cfactors : copy(Cfactors')) + elseif d == :L + return LowerTriangular(Cuplo === char_uplo(d) ? Cfactors : copy(Cfactors')) + elseif d == :UL + return (Cuplo === 'U' ? UpperTriangular(Cfactors) : LowerTriangular(Cfactors)) + else + return getfield(C, d) + end +end + +function logdet(C::KroneckerCholesky) + A, B = getmatrices(C.factors) + logdet_A = logdet(Cholesky(A, C.uplo, 0)) + logdet_B = logdet(Cholesky(B, C.uplo, 0)) + return size(B, 1) * logdet_A + size(A, 1) * logdet_B +end + +function \(C::KroneckerCholesky, x::AbstractVecOrMat) + C_upper = C.U + return C_upper \ (C_upper' \ x) +end diff --git a/src/factorization.jl b/src/factorization.jl index 00c7879..1a3b426 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -9,7 +9,7 @@ Standard matrix factorization algorithms applied on Kronecker systems. =# -abstract type FactorizedKronecker <: AbstractKroneckerProduct end +abstract type FactorizedKronecker{T} <: AbstractKroneckerProduct{T} end # CHOLESKY DECOMPOSITION # ---------------------- diff --git a/src/indexedkroncker.jl b/src/indexedkroncker.jl index 6a3471d..109dc13 100644 --- a/src/indexedkroncker.jl +++ b/src/indexedkroncker.jl @@ -1,7 +1,7 @@ Index = Union{UnitRange{I}, AbstractVector{I}} where I <: Int #TODO: make for general indices -struct IndexedKroneckerProduct <: GeneralizedKroneckerProduct - K::AbstractKroneckerProduct +struct IndexedKroneckerProduct{T} <: GeneralizedKroneckerProduct{T} + K::AbstractKroneckerProduct{T} p::Index q::Index r::Index @@ -25,7 +25,7 @@ struct IndexedKroneckerProduct <: GeneralizedKroneckerProduct if !(maximum(r) ≤ n && maximum(t) ≤ l) throw(BoundsError("Indices exeed matrix bounds")) end - return new(K, p, q, r, t) + return new{eltype(K)}(K, p, q, r, t) end end diff --git a/src/kroneckerpowers.jl b/src/kroneckerpowers.jl index 4e88f7a..7066cd0 100644 --- a/src/kroneckerpowers.jl +++ b/src/kroneckerpowers.jl @@ -13,12 +13,12 @@ Efficient way of storing Kronecker powers, e.g. K = A ⊗ A ⊗ ... ⊗ A. """ -struct KroneckerPower{TA<:AbstractMatrix, N} <: AbstractKroneckerProduct +struct KroneckerPower{T, TA<:AbstractMatrix{T}, N} <: AbstractKroneckerProduct{T} A::TA pow::Integer function KroneckerPower(A::AbstractMatrix{T}, pow::Integer) where {T} @assert pow ≥ 2 "KroneckerPower only makes sense for powers greater than 1" - return new{typeof(A), pow}(A, pow) + return new{T, typeof(A), pow}(A, pow) end end @@ -32,9 +32,11 @@ kronecker(A::AbstractMatrix, pow::Int) = KroneckerPower(A, pow) ⊗(A::AbstractMatrix, pow::Int) = kronecker(A, pow) -getmatrices(K::KroneckerPower{T, N}) where {T, N} = (K.A, KroneckerPower(K.A, K.pow-1)) -getmatrices(K::KroneckerPower{T, 2}) where {T} = (K.A, K.A) -getmatrices(K::KroneckerPower{T, 1}) where {T} = (K.A, ) +function getmatrices(K::KroneckerPower{T, TA, N}) where {T, TA, N} + return (K.A, KroneckerPower(K.A, K.pow-1)) +end +getmatrices(K::KroneckerPower{T, TA, 2}) where {T, TA} = (K.A, K.A) +getmatrices(K::KroneckerPower{T, TA, 1}) where {T, TA} = (K.A, ) order(K::KroneckerPower) = K.pow Base.size(K::KroneckerPower) = size(K.A).^K.pow @@ -118,8 +120,10 @@ function Base.conj(K::KroneckerPower) end # mixed-product property -function Base.:*(K1::KroneckerPower{T1, N}, - K2::KroneckerPower{T2, N}) where {T1, T2, N} +function Base.:*( + K1::KroneckerPower{T1, TA1, N}, + K2::KroneckerPower{T1, TA2, N}, +) where {T1, TA1, T2, TA2, N} if size(K1, 2) != size(K2, 1) throw(DimensionMismatch("Mismatch between K1 and K2")) end diff --git a/src/kroneckersum.jl b/src/kroneckersum.jl index 66022cd..2050b31 100644 --- a/src/kroneckersum.jl +++ b/src/kroneckersum.jl @@ -1,6 +1,6 @@ -abstract type AbstractKroneckerSum <: GeneralizedKroneckerProduct end +abstract type AbstractKroneckerSum{T} <: GeneralizedKroneckerProduct{T} end -struct KroneckerSum{T, TA<:AbstractMatrix, TB<:AbstractMatrix} <: AbstractKroneckerSum +struct KroneckerSum{T, TA<:AbstractMatrix, TB<:AbstractMatrix} <: AbstractKroneckerSum{T} A::TA B::TB function KroneckerSum(A::AbstractMatrix{T}, diff --git a/test/testbase.jl b/test/base.jl similarity index 100% rename from test/testbase.jl rename to test/base.jl diff --git a/test/cholesky.jl b/test/cholesky.jl new file mode 100644 index 0000000..2355d7d --- /dev/null +++ b/test/cholesky.jl @@ -0,0 +1,32 @@ +to_psd(A) = A * A' + I + +@testset "cholesky" begin + rng = MersenneTwister(123456) + M, N = 7, 3 + A, B = to_psd(randn(rng, M, M)), to_psd(randn(rng, N, N)) + + # Construct Kronecker-factored Cholesky + A_kron_B = A ⊗ B + chol_A_kron_B = cholesky(A_kron_B) + + # Construct equivalent dense Cholesky. + A_kron_B_dense = kron(A, B) + chol_A_kron_B_dense = cholesky(A_kron_B_dense) + + # Check for agreement in user-facing properties of Kronecker-factored and dense. + @test chol_A_kron_B.U ≈ chol_A_kron_B_dense.U + @test chol_A_kron_B.L ≈ chol_A_kron_B_dense.L + @test det(chol_A_kron_B) ≈ det(chol_A_kron_B_dense) + @test logdet(chol_A_kron_B) ≈ logdet(chol_A_kron_B_dense) + + @show typeof(chol_A_kron_B.U) + @show typeof(chol_A_kron_B.U') + + # Test backsolve vs dense vector from the left. + x = randn(rng, M * N) + @test chol_A_kron_B \ x ≈ chol_A_kron_B_dense \ x + + # Test backsolve vs dense matrix from the left. + X = randn(rng, M * N, 11) + @test chol_A_kron_B \ X ≈ chol_A_kron_B_dense \ X +end diff --git a/test/testeigen.jl b/test/eigen.jl similarity index 100% rename from test/testeigen.jl rename to test/eigen.jl diff --git a/test/testfactorization.jl b/test/factorization.jl similarity index 100% rename from test/testfactorization.jl rename to test/factorization.jl diff --git a/test/testindexed.jl b/test/indexed.jl similarity index 100% rename from test/testindexed.jl rename to test/indexed.jl diff --git a/test/testkroneckergraphs.jl b/test/kroneckergraphs.jl similarity index 100% rename from test/testkroneckergraphs.jl rename to test/kroneckergraphs.jl diff --git a/test/testkroneckerpowers.jl b/test/kroneckerpowers.jl similarity index 100% rename from test/testkroneckerpowers.jl rename to test/kroneckerpowers.jl diff --git a/test/testkroneckersum.jl b/test/kroneckersum.jl similarity index 100% rename from test/testkroneckersum.jl rename to test/kroneckersum.jl diff --git a/test/runtests.jl b/test/runtests.jl index 1dfd9d5..5e36ad2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,12 +2,13 @@ using Kronecker, Test, LinearAlgebra, Random, FillArrays using SparseArrays: SparseMatrixCSC, sprand, AbstractSparseMatrix @testset "Kronecker" begin - include("testbase.jl") - include("testkroneckerpowers.jl") - include("testvectrick.jl") - include("testindexed.jl") - include("testeigen.jl") - include("testkroneckersum.jl") - include("testfactorization.jl") - include("testkroneckergraphs.jl") + # include("base.jl") + # include("kroneckerpowers.jl") + # include("vectrick.jl") + # include("indexed.jl") + # include("eigen.jl") + include("cholesky.jl") + # include("kroneckersum.jl") + # include("factorization.jl") + # include("kroneckergraphs.jl") end diff --git a/test/testvectrick.jl b/test/vectrick.jl similarity index 100% rename from test/testvectrick.jl rename to test/vectrick.jl