Skip to content

Commit

Permalink
Adding support for Hermitian/Symmetric wrapped sparse matrices. (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
mipals authored Sep 19, 2024
1 parent 34212b7 commit 9ef1b9b
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 1 deletion.
99 changes: 98 additions & 1 deletion src/Metis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module Metis

using SparseArrays
using LinearAlgebra: ishermitian
using LinearAlgebra: ishermitian, Hermitian, Symmetric
using METIS_jll: libmetis

# Metis C API: Clang.jl auto-generated bindings and some manual methods
Expand Down Expand Up @@ -33,6 +33,7 @@ struct Graph
end
end


"""
Metis.graph(G::SparseMatrixCSC; weights=false, check_hermitian=true)
Expand Down Expand Up @@ -71,6 +72,102 @@ function graph(G::SparseMatrixCSC; weights::Bool=false, check_hermitian::Bool=tr
return Graph(idx_t(N), xadj, adjncy, vwgt, adjwgt)
end

const HermOrSymCSC{Tv, Ti} = Union{
Hermitian{Tv, SparseMatrixCSC{Tv, Ti}}, Symmetric{Tv, SparseMatrixCSC{Tv, Ti}},
}

if VERSION < v"1.10"
# From https://github.com/JuliaSparse/SparseArrays.jl/blob/313a04f4a78bbc534f89b6b4d9c598453e2af17c/src/linalg.jl#L1106-L1117
# MIT license: https://github.com/JuliaSparse/SparseArrays.jl/blob/main/LICENSE.md
function nzrangeup(A, i, excl=false)
r = nzrange(A, i); r1 = r.start; r2 = r.stop
rv = rowvals(A)
@inbounds r2 < r1 || rv[r2] <= i - excl ? r : r1:(searchsortedlast(view(rv, r1:r2), i - excl) + r1-1)
end
function nzrangelo(A, i, excl=false)
r = nzrange(A, i); r1 = r.start; r2 = r.stop
rv = rowvals(A)
@inbounds r2 < r1 || rv[r1] >= i + excl ? r : (searchsortedfirst(view(rv, r1:r2), i + excl) + r1-1):r2
end
else
using SparseArrays: nzrangeup, nzrangelo
end

"""
Metis.graph(G::Union{Hermitian, Symmetric}; weights::Bool=false)
Construct the 1-based CSR representation of the `Hermitian` or `Symmetric` wrapped sparse
matrix `G`.
Weights are not currently supported for this method so passing `weights=true` will throw an
error.
"""
function graph(H::HermOrSymCSC; weights::Bool=false)
# This method is derived from the method `SparseMatrixCSC(::HermOrSymCSC)` from
# SparseArrays.jl
# (https://github.com/JuliaSparse/SparseArrays.jl/blob/313a04f4a78bbc534f89b6b4d9c598453e2af17c/src/sparseconvert.jl#L124-L173)
# with MIT license
# (https://github.com/JuliaSparse/SparseArrays.jl/blob/main/LICENSE.md).
weights && throw(ArgumentError("weights not supported yet"))
# Extract data
A = H.data
upper = H.uplo == 'U'
rowval = rowvals(A)
m, n = size(A)
@assert m == n
# New colptr for the full matrix
newcolptr = Vector{idx_t}(undef, n + 1)
newcolptr[1] = 1
# SparseArrays.nzrange for the upper/lower part excluding the diagonal
nzrng = if upper
(A, col) -> nzrangeup(A, col, #=exclude diagonal=# true)
else
(A, col) -> nzrangelo(A, col, #=exclude diagonal=# true)
end
# If the upper part is stored we loop forward, otherwise backwards
colrange = upper ? (1:1:n) : (n:-1:1)
@inbounds for col in colrange
r = nzrng(A, col)
# Number of entries in the stored part of this column, excluding the diagonal entry
newcolptr[col + 1] = length(r)
# Increment columnptr corresponding to the stored rows
for k in r
row = rowval[k]
@assert upper ? row < col : row > col
@assert row != col # Diagonal entries should not be here
newcolptr[row + 1] += 1
end
end
# Accumulate the colptr and allocate new rowval
cumsum!(newcolptr, newcolptr)
nz = newcolptr[n + 1] - 1
newrowval = Vector{idx_t}(undef, nz)
# Populate the rowvals
@inbounds for col = 1:n
newk = newcolptr[col]
for k in nzrng(A, col)
row = rowval[k]
@assert col != row
newrowval[newk] = row
newk += 1
ni = newcolptr[row]
newrowval[ni] = col
newcolptr[row] = ni + 1
end
newcolptr[col] = newk
end
# Shuffle back the colptrs
@inbounds for j = n:-1:1
newcolptr[j+1] = newcolptr[j]
end
newcolptr[1] = 1
# Return Graph
N = n
xadj = newcolptr
adjncy = newrowval
vwgt = C_NULL
adjwgt = C_NULL
return Graph(idx_t(N), xadj, adjncy, vwgt, adjwgt)
end

"""
perm, iperm = Metis.permutation(G)
Expand Down
19 changes: 19 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Metis
using Random
using Test
using SparseArrays
using LinearAlgebra: Symmetric, Hermitian
import LightGraphs, Graphs

@testset "Metis.graph(::SparseMatrixCSC)" begin
Expand All @@ -20,6 +21,24 @@ import LightGraphs, Graphs
@test iszero(S - GW)
end

@testset "Metis.graph(::Union{Hermitian, Symmetric})" begin
rng = MersenneTwister(0)
for T in (Symmetric, Hermitian), uplo in (:U, :L)
S = sprand(rng, Int, 10, 10, 0.2); fill!(S.nzval, 1)
TS = T(S, uplo)
CSCS = SparseMatrixCSC(TS)
@test TS == CSCS
g1 = Metis.graph(TS)
g2 = Metis.graph(CSCS)
@test g1.nvtxs == g2.nvtxs
@test g1.xadj == g2.xadj
@test g1.adjncy == g2.adjncy
@test g1.vwgt == g2.vwgt == C_NULL
@test g1.adjwgt == g2.adjwgt == C_NULL
@test_throws ArgumentError Metis.graph(TS; weights = true)
end
end

@testset "Metis.permutation" begin
rng = MersenneTwister(0)
S = sprand(rng, 10, 10, 0.5); S = S + S'; fill!(S.nzval, 1)
Expand Down

0 comments on commit 9ef1b9b

Please sign in to comment.