Skip to content

Commit

Permalink
support iterate over nonzeros & rewrite IMatrix (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu authored Jun 25, 2022
1 parent 620f10e commit 401ba18
Show file tree
Hide file tree
Showing 28 changed files with 454 additions and 335 deletions.
1 change: 0 additions & 1 deletion .codecov.yml

This file was deleted.

2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxurySparse"
uuid = "d05aeea4-b7d4-55ac-b691-9e7fabb07ba2"
authors = ["GiggleLiu <cacate0129@gmail.com>", "Roger-luo <hiroger@qq.com>"]
version = "0.6.13"
version = "0.7.0"

[deps]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ BenchmarkTools.Trial:
## Identity Matrix
Identity matrix is static, which is defined as
```
struct IMatrix{N, Tv} <: AbstractMatrix{Tv} end
struct IMatrix{Tv} <: AbstractMatrix{Tv} end
```

With this type, the [Kronecker product](https://en.wikipedia.org/wiki/Kronecker_product) operation can be much faster. Now let's see a benchmark

```@example identity
using LuxurySparse: IMatrix
Id = IMatrix{1, Float64}()
Id = IMatrix{Float64}(1)
B = randn(7,7);
```

Expand Down
28 changes: 0 additions & 28 deletions src/Core.jl

This file was deleted.

43 changes: 23 additions & 20 deletions src/IMatrix.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
export IMatrix

"""
IMatrix{N, Tv}()
IMatrix{N}() where N = IMatrix{N, Int64}()
IMatrix{Tv}
IMatrix(n) -> IMatrix
IMatrix(A::AbstractMatrix{T}) where T -> IMatrix
IMatrix matrix, with size N as label, use `Int64` as its default type, both `*` and `kron` are optimized.
"""
struct IMatrix{N,Tv} <: AbstractMatrix{Tv} end
IMatrix{N}() where {N} = IMatrix{N,Bool}()
IMatrix(N::Int) = IMatrix{N}()
struct IMatrix{Tv} <: AbstractMatrix{Tv}
n::Int
end
IMatrix(n::Integer) = IMatrix{Bool}(n)

size(A::IMatrix{N}, i::Int) where {N} = N
size(A::IMatrix{N}) where {N} = (N, N)
getindex(A::IMatrix{N,T}, i::Integer, j::Integer) where {N,T} = T(i == j)
Base.size(A::IMatrix, i::Int) = (@assert i == 1 || i == 2; A.n)
Base.size(A::IMatrix) = (A.n, A.n)
Base.getindex(::IMatrix{T}, i::Integer, j::Integer) where {T} = T(i == j)

Base.:(==)(d1::IMatrix{Na}, d2::IMatrix{Nb}) where {Na,Nb} = Na == Nb
Base.isapprox(d1::IMatrix{Na}, d2::IMatrix{Nb}; kwargs...) where {Na,Nb} = Na == Nb
Base.:(==)(d1::IMatrix, d2::IMatrix) = d1.n == d2.n
Base.isapprox(d1::IMatrix, d2::IMatrix; kwargs...) = d1 == d2

####### sparse matrix ######
nnz(M::IMatrix{N}) where {N} = N
nonzeros(M::IMatrix{N,T}) where {N,T} = ones(T, N)
findnz(M::IMatrix{N,T}) where {N,T} = (collect(1:N), collect(1:N), ones(T, N))
ishermitian(D::IMatrix) = true
isdense(::IMatrix) = false
Base.similar(A::IMatrix{Tv}, ::Type{T}) where {Tv,T} = IMatrix{T}(A.n)
function Base.copyto!(A::IMatrix, B::IMatrix)
if A.n != B.n
throw(DimensionMismatch("matrix dimension mismatch, got $(A.n) and $(B.n)"))
end
A
end
LinearAlgebra.ishermitian(D::IMatrix) = true

similar(::IMatrix{N,Tv}, ::Type{T}) where {N,Tv,T} = IMatrix{N,T}()
copyto!(A::IMatrix{N}, B::IMatrix{N}) where {N} = A
####### sparse matrix ######
nnz(M::IMatrix) = M.n
findnz(M::IMatrix{T}) where {T} = (collect(1:M.n), collect(1:M.n), ones(T, M.n))
25 changes: 19 additions & 6 deletions src/LuxurySparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@ module LuxurySparse

using LinearAlgebra, SparseArrays, Random
using StaticArrays: SVector, SMatrix, SDiagonal, SArray
using SparseArrays: SparseMatrixCSC
using SparseArrays.HigherOrderFns
using Base: @propagate_inbounds
using LinearAlgebra
using LinearAlgebra: StructuredMatrixStyle
using Base.Broadcast:
BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!

import Base: copyto!, *, kron, -
import LinearAlgebra: ishermitian
import Base: getindex, size, similar, copy, show
# static types
export SDPermMatrix, SPermMatrix, PermMatrix, pmrand,
SDSparseMatrixCSC, SSparseMatrixCSC, SparseMatrixCSC, sprand,
SparseMatrixCOO,
SDMatrix, SDVector,
SDDiagonal, Diagonal,
IMatrix,
staticize, dynamicize,
fast_invperm,
IterNz

export I, fast_invperm, isdense, allocated_coo

include("Core.jl")
include("utils.jl")
include("IMatrix.jl")
include("PermMatrix.jl")
include("SparseMatrixCOO.jl")
Expand All @@ -24,4 +35,6 @@ include("linalg.jl")
include("kronecker.jl")
include("broadcast.jl")

include("iterate.jl")

end
20 changes: 7 additions & 13 deletions src/PermMatrix.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export PermMatrix, pmrand

"""
PermMatrix{Tv, Ti}(perm::AbstractVector{Ti}, vals::AbstractVector{Tv}) where {Tv, Ti<:Integer}
PermMatrix(perm::Vector{Ti}, vals::Vector{Tv}) where {Tv, Ti}
Expand Down Expand Up @@ -46,11 +44,12 @@ end

Base.:(==)(d1::PermMatrix, d2::PermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.isapprox(d1::PermMatrix, d2::PermMatrix; kwargs...) = isapprox(SparseMatrixCSC(d1), SparseMatrixCSC(d2); kwargs...)
Base.zero(pm::PermMatrix) = PermMatrix(pm.perm, zero(pm.vals))

################# Array Functions ##################

size(M::PermMatrix) = (length(M.perm), length(M.perm))
function size(A::PermMatrix, d::Integer)
Base.size(M::PermMatrix) = (length(M.perm), length(M.perm))
function Base.size(A::PermMatrix, d::Integer)
if d < 1
throw(ArgumentError("dimension must be ≥ 1, got $d"))
elseif d <= 2
Expand All @@ -59,7 +58,7 @@ function size(A::PermMatrix, d::Integer)
return 1
end
end
getindex(M::PermMatrix{Tv}, i::Integer, j::Integer) where {Tv} =
Base.getindex(M::PermMatrix{Tv}, i::Integer, j::Integer) where {Tv} =
M.perm[i] == j ? M.vals[i] : zero(Tv)
function Base.setindex!(M::PermMatrix, val, i::Integer, j::Integer)
if M.perm[i] == j
Expand All @@ -69,7 +68,7 @@ function Base.setindex!(M::PermMatrix, val, i::Integer, j::Integer)
end
end

copyto!(A::PermMatrix, B::PermMatrix) =
Base.copyto!(A::PermMatrix, B::PermMatrix) =
(copyto!(A.perm, B.perm); copyto!(A.vals, B.vals); A)

"""
Expand All @@ -82,9 +81,9 @@ function pmrand end
pmrand(::Type{T}, n::Int) where {T} = PermMatrix(randperm(n), randn(T, n))
pmrand(n::Int) = pmrand(Float64, n)

similar(x::PermMatrix{Tv,Ti}) where {Tv,Ti} =
Base.similar(x::PermMatrix{Tv,Ti}) where {Tv,Ti} =
PermMatrix{Tv,Ti}(copy(x.perm), similar(x.vals))
similar(x::PermMatrix{Tv,Ti}, ::Type{T}) where {Tv,Ti,T} =
Base.similar(x::PermMatrix{Tv,Ti}, ::Type{T}) where {Tv,Ti,T} =
PermMatrix{T,Ti}(copy(x.perm), similar(x.vals, T))

# TODO: rewrite this
Expand All @@ -98,9 +97,4 @@ similar(x::PermMatrix{Tv,Ti}, ::Type{T}) where {Tv,Ti,T} =

######### sparse array interfaces #########
nnz(M::PermMatrix) = length(M.vals)
nonzeros(M::PermMatrix) = M.vals
findnz(M::PermMatrix) = (collect(1:size(M, 1)), M.perm, M.vals)
dropzeros!(M::PermMatrix; trim::Bool = false) = M
isdense(::PermMatrix) = false

Base.zero(pm::PermMatrix) = PermMatrix(pm.perm, zero(pm.vals))
10 changes: 5 additions & 5 deletions src/SSparseMatrixCSC.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export SSparseMatrixCSC

@static if VERSION < v"1.4.0"

"""
Expand Down Expand Up @@ -72,12 +70,12 @@ function SSparseMatrixCSC(
SSparseMatrixCSC{Tv,Ti,length(nzval),n + 1}(m, n, colptr, rowval, nzval)
end

function size(spm::SSparseMatrixCSC{Tv,Ti,NNZ,NP}, i::Integer) where {Tv,Ti,NNZ,NP}
function Base.size(spm::SSparseMatrixCSC{Tv,Ti,NNZ,NP}, i::Integer) where {Tv,Ti,NNZ,NP}
i == 1 ? spm.m : (i == 2 ? NP - 1 : throw(ArgumentError("dimension out of bound!")))
end
size(spm::SSparseMatrixCSC{Tv,Ti,NNZ,NP}) where {Tv,Ti,NNZ,NP} = (spm.m, NP - 1)
Base.size(spm::SSparseMatrixCSC{Tv,Ti,NNZ,NP}) where {Tv,Ti,NNZ,NP} = (spm.m, NP - 1)

function getindex(ssp::SSparseMatrixCSC{Tv}, i::Integer, j::Integer) where {Tv}
function Base.getindex(ssp::SSparseMatrixCSC{Tv}, i::Integer, j::Integer) where {Tv}
S = ssp.colptr[j]
E = ssp.colptr[j+1] - 1
for ii = S:E
Expand Down Expand Up @@ -108,3 +106,5 @@ function SparseArrays.findnz(S::SSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
return (I, J, V)
end
SparseArrays.dropzeros!(M::SSparseMatrixCSC; trim::Bool = false) = M
SparseArrays.SparseMatrixCSC(sm::SSparseMatrixCSC) = dynamicize(sm)
Base.Matrix(sm::SSparseMatrixCSC) = Matrix(SparseMatrixCSC(sm))
22 changes: 9 additions & 13 deletions src/SparseMatrixCOO.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export SparseMatrixCOO

"""
SparseMatrixCOO(is::Vector, js::Vector, vs::Vector, m::Int, n::Int) -> SparseMatrixCOO
SparseMatrixCOO{Tv, Ti}(is::Vector{Ti}, js::Vector{Ti}, vs::Vector{Tv}, m::Int, n::Int) -> SparseMatrixCOO
Expand Down Expand Up @@ -50,9 +48,9 @@ end
SparseMatrixCOO(is::Vector{Ti}, js::Vector{Ti}, vs::Vector{Tv}, m, n) where {Ti,Tv} =
SparseMatrixCOO{Tv,Ti}(is, js, vs, m, n)

copy(coo::SparseMatrixCOO{Tv,Ti}) where {Tv,Ti} =
Base.copy(coo::SparseMatrixCOO{Tv,Ti}) where {Tv,Ti} =
SparseMatrixCOO{Tv,Ti}(copy(coo.is), copy(coo.js), copy(coo.vs), coo.m, coo.n)
function copyto!(A::SparseMatrixCOO{Tv,Ti}, B::SparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
function Base.copyto!(A::SparseMatrixCOO{Tv,Ti}, B::SparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
size(A) == size(B) && nnz(A) == nnz(B) ||
throw(MethodError("size/nnz of two coo matrices do not match!"))
copyto!(A.is, B.is)
Expand Down Expand Up @@ -86,7 +84,7 @@ function allocated_coo(::Type{T}, M::Int, N::Int, nnz::Int) where {T}
SparseMatrixCOO{T}(undef, M, N, nnz)
end

function getindex(coo::SparseMatrixCOO{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti}
function Base.getindex(coo::SparseMatrixCOO{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti}
res = zero(Tv)
for k = 1:nnz(coo)
if coo.is[k] == i && coo.js[k] == j
Expand All @@ -96,21 +94,19 @@ function getindex(coo::SparseMatrixCOO{Tv,Ti}, i::Integer, j::Integer) where {Tv
res
end

size(coo::SparseMatrixCOO) = (coo.m, coo.n)
size(coo::SparseMatrixCOO, axis::Int) =
Base.size(coo::SparseMatrixCOO) = (coo.m, coo.n)
Base.size(coo::SparseMatrixCOO, axis::Int) =
axis == 1 ? coo.m : (axis == 2 ? coo.n : throw(MethodError("invalid axis parameter")))

# SparseArrays: SparseMatrixCSC, nnz, nonzeros, dropzeros!, findnz
nnz(coo::SparseMatrixCOO) = coo.is |> length
nonzeros(coo::SparseMatrixCOO) = coo.vs
SparseArrays.nnz(coo::SparseMatrixCOO) = coo.is |> length
SparseArrays.nonzeros(coo::SparseMatrixCOO) = coo.vs

function dropzeros!(coo::SparseMatrixCOO{Tv,Ti}; trim::Bool = false) where {Tv,Ti}
function SparseArrays.dropzeros!(coo::SparseMatrixCOO{Tv,Ti}; trim::Bool = false) where {Tv,Ti}
mask = abs.(coo.vs) .> 1e-15
SparseMatrixCOO{Tv,Ti}(coo.is[mask], coo.js[mask], coo.vs[mask], coo.m, coo.n)
end

findnz(coo::SparseMatrixCOO) = (coo.is, coo.js, coo.vs)
isdense(::SparseMatrixCOO) = false
SparseArrays.findnz(coo::SparseMatrixCOO) = (coo.is, coo.js, coo.vs)

Base.@propagate_inbounds function Base.setindex!(
coo::SparseMatrixCOO{Tv,Ti},
Expand Down
60 changes: 28 additions & 32 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
@@ -1,67 +1,63 @@
import Base: conj, copy, real, imag
import LinearAlgebra: transpose, transpose!, adjoint!, adjoint

# IMatrix
for func in (:conj, :real, :transpose, :adjoint, :copy)
@eval ($func)(M::IMatrix{N,T}) where {N,T} = IMatrix{N,T}()
@eval (Base.$func)(M::IMatrix{T}) where {T} = IMatrix{T}(M.n)
end
for func in (:adjoint!, :transpose!)
@eval ($func)(M::IMatrix) = M
@eval (LinearAlgebra.$func)(M::IMatrix) = M
end
imag(M::IMatrix{N,T}) where {N,T} = Diagonal(zeros(T, N))
Base.imag(M::IMatrix{T}) where {N,T} = Diagonal(zeros(T, M.n))

# PermMatrix
for func in (:conj, :real, :imag)
@eval ($func)(M::PermMatrix) = PermMatrix(M.perm, ($func)(M.vals))
@eval (Base.$func)(M::PermMatrix) = PermMatrix(M.perm, ($func)(M.vals))
end
copy(M::PermMatrix) = PermMatrix(copy(M.perm), copy(M.vals))
Base.copy(M::PermMatrix) = PermMatrix(copy(M.perm), copy(M.vals))

function transpose(M::PermMatrix)
function Base.transpose(M::PermMatrix)
new_perm = fast_invperm(M.perm)
return PermMatrix(new_perm, M.vals[new_perm])
end

adjoint(S::PermMatrix{<:Real}) = transpose(S)
adjoint(S::PermMatrix{<:Complex}) = conj(transpose(S))
Base.adjoint(S::PermMatrix{<:Real}) = transpose(S)
Base.adjoint(S::PermMatrix{<:Complex}) = conj(transpose(S))

# scalar
import Base: *, /, ==, +, -,
*(A::IMatrix{N,T}, B::Number) where {N,T} = Diagonal(fill(promote_type(T, eltype(B))(B), N))
*(B::Number, A::IMatrix{N,T}) where {N,T} = Diagonal(fill(promote_type(T, eltype(B))(B), N))
/(A::IMatrix{N,T}, B::Number) where {N,T} =
Diagonal(fill(promote_type(T, eltype(B))(1 / B), N))
Base.:*(A::IMatrix{T}, B::Number) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
Base.:*(B::Number, A::IMatrix{T}) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
Base.:/(A::IMatrix{T}, B::Number) where {T} =
Diagonal(fill(promote_type(T, eltype(B))(1 / B), A.n))

*(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals * B)
*(B::Number, A::PermMatrix) = A * B
/(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals / B)
Base.:*(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals * B)
Base.:*(B::Number, A::PermMatrix) = A * B
Base.:/(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals / B)
#+(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv+B.dv, A.ev+B.ev)
#-(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv-B.dv, A.ev-B.ev)

for op in [:+, :-]
for MT in [:IMatrix, :PermMatrix]
@eval begin
# IMatrix, PermMatrix - SparseMatrixCSC
$op(A::$MT, B::SparseMatrixCSC) = $op(SparseMatrixCSC(A), B)
$op(B::SparseMatrixCSC, A::$MT) = $op(B, SparseMatrixCSC(A))
Base.$op(A::$MT, B::SparseMatrixCSC) = $op(SparseMatrixCSC(A), B)
Base.$op(B::SparseMatrixCSC, A::$MT) = $op(B, SparseMatrixCSC(A))
end
end
@eval begin
# IMatrix, PermMatrix - Diagonal
$op(d1::IMatrix, d2::Diagonal) = Diagonal($op(diag(d1), d2.diag))
$op(d1::Diagonal, d2::IMatrix) = Diagonal($op(d1.diag, diag(d2)))
$op(d1::PermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2)
$op(d1::Diagonal, d2::PermMatrix) = $op(d1, SparseMatrixCSC(d2))
Base.$op(d1::IMatrix, d2::Diagonal) = Diagonal($op(diag(d1), d2.diag))
Base.$op(d1::Diagonal, d2::IMatrix) = Diagonal($op(d1.diag, diag(d2)))
Base.$op(d1::PermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2)
Base.$op(d1::Diagonal, d2::PermMatrix) = $op(d1, SparseMatrixCSC(d2))
# PermMatrix - IMatrix
$op(A::PermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
$op(A::IMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
$op(A::PermMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::PermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::IMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::PermMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
end
end
# NOTE: promote to integer
+(d1::IMatrix{Na,Ta}, d2::IMatrix{Nb,Tb}) where {Na,Nb,Ta,Tb} =
d1 == d2 ? Diagonal(fill(promote_type(Ta, Tb, Int)(2), Na)) : throw(DimensionMismatch())
-(d1::IMatrix{Na,Ta}, d2::IMatrix{Nb,Tb}) where {Na,Ta,Nb,Tb} =
d1 == d2 ? spzeros(promote_type(Ta, Tb), Na, Na) : throw(DimensionMismatch())
Base.:+(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} =
d1 == d2 ? Diagonal(fill(promote_type(Ta, Tb, Int)(2), d1.n)) : throw(DimensionMismatch())
Base.:-(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} =
d1 == d2 ? spzeros(promote_type(Ta, Tb), d1.n, d1.n) : throw(DimensionMismatch())

for MT in [:IMatrix, :PermMatrix]
@eval Base.:(==)(A::$MT, B::SparseMatrixCSC) = SparseMatrixCSC(A) == B
Expand Down
Loading

2 comments on commit 401ba18

@Roger-luo
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Released via Ion

@JuliaRegistrator register branch=master

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/63056

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.0 -m "<description of version>" 401ba1867a85d057aa536633a18faa0f28b2f6ef
git push origin v0.7.0

Please sign in to comment.