Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] cholesky #44

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ include("kroneckerpowers.jl")
include("vectrick.jl")
include("indexedkroncker.jl")
include("eigen.jl")
include("factorization.jl")
include("cholesky.jl")
# include("factorization.jl")
include("kroneckersum.jl")
include("kroneckergraphs.jl")

Expand Down
11 changes: 4 additions & 7 deletions src/base.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
abstract type GeneralizedKroneckerProduct <: AbstractMatrix{Number} end
abstract type GeneralizedKroneckerProduct{T} <: AbstractMatrix{T} end

abstract type AbstractKroneckerProduct <: GeneralizedKroneckerProduct end
abstract type AbstractKroneckerProduct{T} <: GeneralizedKroneckerProduct{T} end

Base.IndexStyle(::Type{<:GeneralizedKroneckerProduct}) = IndexCartesian()

# general Kronecker product between two matrices
struct KroneckerProduct{T,TA<:AbstractMatrix, TB<:AbstractMatrix} <: AbstractKroneckerProduct
struct KroneckerProduct{T, TA<:AbstractMatrix, TB<:AbstractMatrix} <: AbstractKroneckerProduct{T}
A::TA
B::TB
function KroneckerProduct(A::AbstractMatrix{T}, B::AbstractMatrix{V}) where {T, V}
Expand Down Expand Up @@ -63,10 +63,7 @@ Returns a matrix itself. Needed for recursion.
"""
getmatrices(A::AbstractArray) = (A,)

function Base.eltype(K::AbstractKroneckerProduct)
A, B = getmatrices(K)
return promote_type(eltype(A), eltype(B))
end
Base.eltype(K::AbstractKroneckerProduct{T}) where {T} = T

function Base.size(K::AbstractKroneckerProduct)
A, B = getmatrices(K)
Expand Down
65 changes: 65 additions & 0 deletions src/cholesky.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import Base: getproperty
import LinearAlgebra: cholesky, Cholesky, char_uplo, UpperTriangular, LowerTriangular, \,
istril, istriu

const KroneckerCholesky{T} = Cholesky{T, <:AbstractKroneckerProduct{T}} where {T}

function cholesky(A::AbstractKroneckerProduct; check=true)
P, Q = getmatrices(A)
chol_P, chol_Q = cholesky(P; check=true), cholesky(Q; check=true)
MichielStock marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be

chol_P, chol_Q = cholesky(P; check=check), cholesky(Q; check=check)

? Otherwise, the kwarg doesn't enter anywhere.

return Cholesky(chol_P.factors ⊗ chol_Q.factors, 'U', 0)
end

function Cholesky(factors::KroneckerProduct{T}, uplo::AbstractChar, info::Integer) where {T}
return Cholesky{T, typeof(factors)}(factors, uplo, info)
end

function UpperTriangular(C::AbstractKroneckerProduct)
A, B = getmatrices(C)
return UpperTriangular(A) ⊗ UpperTriangular(B)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UpperTriangular is the constructor of the UpperTriangular type. So I guess this should return an object of that type, and you may wish to consider returning

UpperTriangular(UpperTriangular(A)  UpperTriangular(B))

? Otherwise, subsequent code will not benefit from upper-triangularizing an AbstractKroneckerProduct, it will be "invisible" from the type point of view.

end

function LowerTriangular(C::AbstractKroneckerProduct)
A, B = getmatrices(C)
return LowerTriangular(A) ⊗ LowerTriangular(B)
end

function istril(C::AbstractKroneckerProduct)
A, B = getmatrices(C)
return istril(A) && istril(B)
end

function istriu(C::AbstractKroneckerProduct)
A, B = getmatrices(C)
return istriu(A) && istriu(B)
end

function getproperty(C::KroneckerCholesky, d::Symbol)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed an error here by changing it into

function getproperty(C::KroneckerCholesky, d::Symbol)
    Cfactors = getfield(C, :factors)
    Cuplo    = getfield(C, :uplo)
    info     = getfield(C, :info)

    Cuplo != 'U' && throw(NotImplementedError(""))

    if d == :U
        return UpperTriangular(Cfactors)
    elseif d == :L
        return LowerTriangular(copy(Cfactors'))
        #return LowerTriangular(Cuplo === char_uplo(d) ? Cfactors : copy(Cfactors'))
    elseif d == :UL
        return (Cuplo === 'U' ? UpperTriangular(Cfactors) : LowerTriangular(Cfactors))
    else
        return getfield(C, d)
    end
end

I also needed my own version of copy which provides shallow copies of Kronecker products. This gives the right answer.

Cfactors = getfield(C, :factors)
Cuplo = getfield(C, :uplo)
info = getfield(C, :info)

Cuplo != 'U' && throw(NotImplementedError(""))

if d == :U
return UpperTriangular(Cuplo === char_uplo(d) ? Cfactors : copy(Cfactors'))
elseif d == :L
return LowerTriangular(Cuplo === char_uplo(d) ? Cfactors : copy(Cfactors'))
elseif d == :UL
return (Cuplo === 'U' ? UpperTriangular(Cfactors) : LowerTriangular(Cfactors))
else
return getfield(C, d)
end
end

function logdet(C::KroneckerCholesky)
A, B = getmatrices(C.factors)
logdet_A = logdet(Cholesky(A, C.uplo, 0))
logdet_B = logdet(Cholesky(B, C.uplo, 0))
return size(B, 1) * logdet_A + size(A, 1) * logdet_B
end

function \(C::KroneckerCholesky, x::AbstractVecOrMat)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this is correct. If I understand it correctly, the main performance of Cholesky decomposition originates from the fact that it is easy to solve a triangular system. The vec trick kind of destroys this structure...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I admit that I hadn't really thought very hard about this when writing it, but I'm pretty sure that you can use the vec trick to compute backsolves. Note that

vec(inv(L2) X inv(L1)) = (inv(L1)  inv(L2)) vec(X) = inv(L1  L2) vec(X)

so we should be able to implement backsolving with kronecker matrices efficiently provided that the individual backsolves are efficient.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you would leave it as is for the moment?

C_upper = C.U
return C_upper \ (C_upper' \ x)
end
2 changes: 1 addition & 1 deletion src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Standard matrix factorization algorithms applied on Kronecker systems.
=#


abstract type FactorizedKronecker <: AbstractKroneckerProduct end
abstract type FactorizedKronecker{T} <: AbstractKroneckerProduct{T} end

# CHOLESKY DECOMPOSITION
# ----------------------
Expand Down
6 changes: 3 additions & 3 deletions src/indexedkroncker.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Index = Union{UnitRange{I}, AbstractVector{I}} where I <: Int #TODO: make for general indices

struct IndexedKroneckerProduct <: GeneralizedKroneckerProduct
K::AbstractKroneckerProduct
struct IndexedKroneckerProduct{T} <: GeneralizedKroneckerProduct{T}
K::AbstractKroneckerProduct{T}
p::Index
q::Index
r::Index
Expand All @@ -25,7 +25,7 @@ struct IndexedKroneckerProduct <: GeneralizedKroneckerProduct
if !(maximum(r) ≤ n && maximum(t) ≤ l)
throw(BoundsError("Indices exeed matrix bounds"))
end
return new(K, p, q, r, t)
return new{eltype(K)}(K, p, q, r, t)
end
end

Expand Down
18 changes: 11 additions & 7 deletions src/kroneckerpowers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ Efficient way of storing Kronecker powers, e.g.

K = A ⊗ A ⊗ ... ⊗ A.
"""
struct KroneckerPower{TA<:AbstractMatrix, N} <: AbstractKroneckerProduct
struct KroneckerPower{T, TA<:AbstractMatrix{T}, N} <: AbstractKroneckerProduct{T}
A::TA
pow::Integer
function KroneckerPower(A::AbstractMatrix{T}, pow::Integer) where {T}
@assert pow ≥ 2 "KroneckerPower only makes sense for powers greater than 1"
return new{typeof(A), pow}(A, pow)
return new{T, typeof(A), pow}(A, pow)
end
end

Expand All @@ -32,9 +32,11 @@ kronecker(A::AbstractMatrix, pow::Int) = KroneckerPower(A, pow)

⊗(A::AbstractMatrix, pow::Int) = kronecker(A, pow)

getmatrices(K::KroneckerPower{T, N}) where {T, N} = (K.A, KroneckerPower(K.A, K.pow-1))
getmatrices(K::KroneckerPower{T, 2}) where {T} = (K.A, K.A)
getmatrices(K::KroneckerPower{T, 1}) where {T} = (K.A, )
function getmatrices(K::KroneckerPower{T, TA, N}) where {T, TA, N}
return (K.A, KroneckerPower(K.A, K.pow-1))
end
getmatrices(K::KroneckerPower{T, TA, 2}) where {T, TA} = (K.A, K.A)
getmatrices(K::KroneckerPower{T, TA, 1}) where {T, TA} = (K.A, )

order(K::KroneckerPower) = K.pow
Base.size(K::KroneckerPower) = size(K.A).^K.pow
Expand Down Expand Up @@ -118,8 +120,10 @@ function Base.conj(K::KroneckerPower)
end

# mixed-product property
function Base.:*(K1::KroneckerPower{T1, N},
K2::KroneckerPower{T2, N}) where {T1, T2, N}
function Base.:*(
K1::KroneckerPower{T1, TA1, N},
K2::KroneckerPower{T1, TA2, N},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you require that both Kronecker factors have the same eltype T1? I don't think that's necessary.

) where {T1, TA1, T2, TA2, N}
if size(K1, 2) != size(K2, 1)
throw(DimensionMismatch("Mismatch between K1 and K2"))
end
Expand Down
4 changes: 2 additions & 2 deletions src/kroneckersum.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
abstract type AbstractKroneckerSum <: GeneralizedKroneckerProduct end
abstract type AbstractKroneckerSum{T} <: GeneralizedKroneckerProduct{T} end

struct KroneckerSum{T, TA<:AbstractMatrix, TB<:AbstractMatrix} <: AbstractKroneckerSum
struct KroneckerSum{T, TA<:AbstractMatrix, TB<:AbstractMatrix} <: AbstractKroneckerSum{T}
A::TA
B::TB
function KroneckerSum(A::AbstractMatrix{T},
Expand Down
File renamed without changes.
32 changes: 32 additions & 0 deletions test/cholesky.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
to_psd(A) = A * A' + I

@testset "cholesky" begin
rng = MersenneTwister(123456)
M, N = 7, 3
A, B = to_psd(randn(rng, M, M)), to_psd(randn(rng, N, N))

# Construct Kronecker-factored Cholesky
A_kron_B = A ⊗ B
chol_A_kron_B = cholesky(A_kron_B)

# Construct equivalent dense Cholesky.
A_kron_B_dense = kron(A, B)
chol_A_kron_B_dense = cholesky(A_kron_B_dense)

# Check for agreement in user-facing properties of Kronecker-factored and dense.
@test chol_A_kron_B.U ≈ chol_A_kron_B_dense.U
@test chol_A_kron_B.L ≈ chol_A_kron_B_dense.L
@test det(chol_A_kron_B) ≈ det(chol_A_kron_B_dense)
@test logdet(chol_A_kron_B) ≈ logdet(chol_A_kron_B_dense)

@show typeof(chol_A_kron_B.U)
@show typeof(chol_A_kron_B.U')

# Test backsolve vs dense vector from the left.
x = randn(rng, M * N)
@test chol_A_kron_B \ x ≈ chol_A_kron_B_dense \ x

# Test backsolve vs dense matrix from the left.
X = randn(rng, M * N, 11)
@test chol_A_kron_B \ X ≈ chol_A_kron_B_dense \ X
end
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
17 changes: 9 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ using Kronecker, Test, LinearAlgebra, Random, FillArrays
using SparseArrays: SparseMatrixCSC, sprand, AbstractSparseMatrix

@testset "Kronecker" begin
include("testbase.jl")
include("testkroneckerpowers.jl")
include("testvectrick.jl")
include("testindexed.jl")
include("testeigen.jl")
include("testkroneckersum.jl")
include("testfactorization.jl")
include("testkroneckergraphs.jl")
# include("base.jl")
# include("kroneckerpowers.jl")
# include("vectrick.jl")
# include("indexed.jl")
# include("eigen.jl")
include("cholesky.jl")
# include("kroneckersum.jl")
# include("factorization.jl")
# include("kroneckergraphs.jl")
end
File renamed without changes.