Skip to content

Commit

Permalink
upadte kron
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Dec 17, 2018
1 parent b272e33 commit b80b324
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ for MT in [:AbstractMatrix, :PermMatrix, :SparseMatrixCSC, :Diagonal]
end

####### diagonal kron ########
kron(A::Diagonal, B::Diagonal) = Diagonal(kron(A.diag, B.diag))
kron(A::StridedMatrix, B::Diagonal) = kron(A, PermMatrix(B))
kron(A::Diagonal, B::StridedMatrix) = kron(PermMatrix(A), B)
kron(A::Diagonal, B::SparseMatrixCSC) = kron(PermMatrix(A), B)
kron(A::SparseMatrixCSC, B::Diagonal) = kron(A, PermMatrix(B))
kron(A::Diagonal{<:Number}, B::Diagonal{<:Number}) = Diagonal(kron(A.diag, B.diag))
kron(A::StridedMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B))
kron(A::Diagonal{<:Number}, B::StridedMatrix{<:Number}) = kron(PermMatrix(A), B)
kron(A::Diagonal{<:Number}, B::SparseMatrixCSC{<:Number}) = kron(PermMatrix(A), B)
kron(A::SparseMatrixCSC{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B))


function kron(A::AbstractMatrix{Tv}, B::IMatrix{Nb}) where {Nb, Tv}
function kron(A::AbstractMatrix{Tv}, B::IMatrix{Nb}) where {Nb, Tv<:Number}
mA, nA = size(A)
nzval = Vector{Tv}(undef, Nb*mA*nA)
rowval = Vector{Int}(undef, Nb*mA*nA)
Expand All @@ -63,7 +63,7 @@ function kron(A::AbstractMatrix{Tv}, B::IMatrix{Nb}) where {Nb, Tv}
SparseMatrixCSC(mA*Nb, nA*Nb, colptr, rowval, nzval)
end

function kron(A::IMatrix{Na}, B::AbstractMatrix{Tv}) where {Na, Tv}
function kron(A::IMatrix{Na}, B::AbstractMatrix{Tv}) where {Na, Tv<:Number}
mB, nB = size(B)
rowval = Vector{Int}(undef, nB*mB*Na)
nzval = Vector{Tv}(undef, nB*mB*Na)
Expand All @@ -81,7 +81,7 @@ function kron(A::IMatrix{Na}, B::AbstractMatrix{Tv}) where {Na, Tv}
SparseMatrixCSC(mB*Na, Na*nB, colptr, rowval, nzval)
end

function kron(A::IMatrix{Na}, B::SparseMatrixCSC{T}) where {Na, T}
function kron(A::IMatrix{Na}, B::SparseMatrixCSC{T}) where {Na, T<:Number}
mB, nB = size(B)
nV = nnz(B)
nzval = Vector{T}(undef, Na*nV)
Expand All @@ -104,7 +104,7 @@ function kron(A::IMatrix{Na}, B::SparseMatrixCSC{T}) where {Na, T}
SparseMatrixCSC(mB*Na, nB*Na, colptr, rowval, nzval)
end

function kron(A::SparseMatrixCSC{T}, B::IMatrix{Nb}) where {T, Nb}
function kron(A::SparseMatrixCSC{T}, B::IMatrix{Nb}) where {T<:Number, Nb}
mA, nA = size(A)
nV = nnz(A)
rowval = Vector{Int}(undef, Nb*nV)
Expand All @@ -129,7 +129,7 @@ function kron(A::SparseMatrixCSC{T}, B::IMatrix{Nb}) where {T, Nb}
SparseMatrixCSC(mA*Nb, nA*Nb, colptr, rowval, nzval)
end

function kron(A::PermMatrix{T}, B::IMatrix) where T
function kron(A::PermMatrix{T}, B::IMatrix) where T<:Number
nA = size(A, 1)
nB = size(B, 1)
vals = Vector{T}(undef, nB*nA)
Expand All @@ -146,7 +146,7 @@ function kron(A::PermMatrix{T}, B::IMatrix) where T
PermMatrix(perm, vals)
end

function kron(A::IMatrix, B::PermMatrix{Tv, Ti}) where {Tv, Ti <: Integer}
function kron(A::IMatrix, B::PermMatrix{Tv, Ti}) where {Tv<:Number, Ti <: Integer}
nA = size(A, 1)
nB = size(B, 1)
perm = Vector{Int}(undef, nB*nA)
Expand All @@ -162,7 +162,7 @@ function kron(A::IMatrix, B::PermMatrix{Tv, Ti}) where {Tv, Ti <: Integer}
end


function kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv, Tb}
function kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv<:Number, Tb<:Number}
mA, nA = size(A)
nB = size(B, 1)
perm = fast_invperm(B.perm)
Expand All @@ -186,7 +186,7 @@ function kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv, Tb}
SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval)
end

function kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb, Ta}
function kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb<:Number, Ta<:Number}
mB, nB = size(B)
nA = size(A, 1)
perm = fast_invperm(A.perm)
Expand All @@ -210,7 +210,7 @@ function kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb, Ta}
SparseMatrixCSC(nA*mB, nA*nB, colptr, rowval, nzval)
end

function kron(A::PermMatrix, B::PermMatrix)
function kron(A::PermMatrix{<:Number}, B::PermMatrix{<:Number})
nA = size(A, 1)
nB = size(B, 1)
vals = kron(A.vals, B.vals)
Expand All @@ -225,10 +225,10 @@ function kron(A::PermMatrix, B::PermMatrix)
PermMatrix(perm, vals)
end

kron(A::PermMatrix, B::Diagonal) = kron(A, PermMatrix(B))
kron(A::Diagonal, B::PermMatrix) = kron(PermMatrix(A), B)
kron(A::PermMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B))
kron(A::Diagonal{<:Number}, B::PermMatrix{<:Number}) = kron(PermMatrix(A), B)

function kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta, Tb}
function kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta<:Number, Tb<:Number}
nA = size(A, 1)
mB, nB = size(B)
nV = nnz(B)
Expand All @@ -254,7 +254,7 @@ function kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta, Tb}
SparseMatrixCSC(mB*nA, nB*nA, colptr, rowval, nzval)
end

function kron(A::SparseMatrixCSC{T}, B::PermMatrix{Tb}) where {T, Tb}
function kron(A::SparseMatrixCSC{T}, B::PermMatrix{Tb}) where {T<:Number, Tb<:Number}
nB = size(B, 1)
mA, nA = size(A)
nV = nnz(A)
Expand Down

0 comments on commit b80b324

Please sign in to comment.