From a2649ce527bc1e6d9d62a22dd5ca0913f4684af1 Mon Sep 17 00:00:00 2001 From: Maxim Vassiliev <76599693+max-vassili3v@users.noreply.github.com> Date: Tue, 23 Jul 2024 22:32:11 +0100 Subject: [PATCH] implement bandwidths for OneElement (#447) * implement bandwidths for OneElement * make improvements * fix sparse(::SparseMatrixCSC) * fix bandwidths for SparseMatrixCSC, add for SparseVector * add bandwidths(::Zeros) behaviour for empty sparse structures * add unit tests * cleanup bandwidths * Update interfaceimpl.jl --------- Co-authored-by: Sheehan Olver --- ext/BandedMatricesSparseArraysExt.jl | 40 +++++++++++++++++++--------- src/BandedMatrices.jl | 2 +- src/interfaceimpl.jl | 21 +++++++++++++++ test/test_interface.jl | 16 +++++++++-- test/test_miscs.jl | 9 +++++++ 5 files changed, 73 insertions(+), 15 deletions(-) diff --git a/ext/BandedMatricesSparseArraysExt.jl b/ext/BandedMatricesSparseArraysExt.jl index d5f371d0..0fe92d80 100644 --- a/ext/BandedMatricesSparseArraysExt.jl +++ b/ext/BandedMatricesSparseArraysExt.jl @@ -2,7 +2,7 @@ module BandedMatricesSparseArraysExt using BandedMatrices using BandedMatrices: _banded_rowval, _banded_colval, _banded_nzval -using SparseArrays +using SparseArrays, FillArrays import SparseArrays: sparse function sparse(B::BandedMatrix) @@ -10,29 +10,45 @@ function sparse(B::BandedMatrix) end function BandedMatrices.bandwidths(A::SparseMatrixCSC) - l,u = -size(A,1),-size(A,2) - - m,n = size(A) + l = u = -max(size(A,1),size(A,2)) + n = size(A)[2] rows = rowvals(A) vals = nonzeros(A) + + if isempty(vals) + return bandwidths(Zeros(1)) + end + for j = 1:n for ind in nzrange(A, j) i = rows[ind] # We skip non-structural zeros when computing the # bandwidths. iszero(vals[ind]) && continue - ij = abs(i-j) - if i ≥ j - l = max(l, ij) - u = max(u, -ij) - elseif i < j - l = max(l, -ij) - u = max(u, ij) - end + u = max(u, j-i) + l = max(l, i-j) end end l,u end +#Treat as n x 1 matrix +function BandedMatrices.bandwidths(A::SparseVector) + l = u = -size(A,1) + rows = rowvals(A) + + if isempty(rows) + return bandwidths(Zeros(1)) + end + + for i in rows + iszero(i) && continue + u = max(u, 1-i) + l = max(l, i-1) + end + + l,u +end + end diff --git a/src/BandedMatrices.jl b/src/BandedMatrices.jl index 5028f2ea..33900c67 100644 --- a/src/BandedMatrices.jl +++ b/src/BandedMatrices.jl @@ -34,7 +34,7 @@ import ArrayLayouts: AbstractTridiagonalLayout, BidiagonalLayout, BlasMatLdivVec symmetricuplo, transposelayout, triangulardata, triangularlayout, zero!, QRPackedQLayout, AdjQRPackedQLayout -import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal +import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal, OneElementMatrix, OneElementVector const libblas = LinearAlgebra.BLAS.libblas const liblapack = LinearAlgebra.BLAS.liblapack diff --git a/src/interfaceimpl.jl b/src/interfaceimpl.jl index ce22de65..789588a9 100644 --- a/src/interfaceimpl.jl +++ b/src/interfaceimpl.jl @@ -56,6 +56,27 @@ bandwidths(::Tridiagonal) = (1,1) sublayout(::AbstractTridiagonalLayout, ::Type{<:Tuple{AbstractUnitRange{Int},AbstractUnitRange{Int}}}) = BandedLayout() +#Implement bandwidths for OneElement structure +function bandwidths(o::OneElementVector) + k = FillArrays.nzind(o)[1] # index of non-zero + n = length(o) + if k > n || k < 1 + bandwidths(Zeros(o)) + else + (k-1, 1-k) + end +end + +function bandwidths(o::OneElementMatrix) + n,m = size(o) + k,j = Tuple(FillArrays.nzind(o)) # indices of non-zero entries + if k > n || j > m || k < 1 || j < 1 + bandwidths(Zeros(o)) + else + (k-j,j-k) + end +end + ### # rot180 ### diff --git a/test/test_interface.jl b/test/test_interface.jl index 6dbdb2f4..21d58538 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -1,10 +1,10 @@ module TestInterface -using BandedMatrices, LinearAlgebra, ArrayLayouts, FillArrays, Test +using BandedMatrices, LinearAlgebra, ArrayLayouts, FillArrays, Test, Random import BandedMatrices: isbanded, AbstractBandedLayout, BandedStyle, BandedColumns, bandeddata import ArrayLayouts: OnesLayout, UnknownLayout -using InfiniteArrays +using InfiniteArrays, SparseArrays struct PseudoBandedMatrix{T} <: AbstractMatrix{T} data::Array{T} @@ -310,6 +310,18 @@ end @test layout_getindex(T,1:10,1:10) isa BandedMatrix end +@testset "OneElement" begin + o = OneElement(1, 3, 5) + @test bandwidths(o) == (2,-2) + n,m = rand(1:10,2) + o = OneElement(1, (rand(1:n),rand(1:m)), (n, m)) + @test bandwidths(o) == bandwidths(sparse(o)) + o = OneElement(1, (n+1,m+1), (n, m)) + @test bandwidths(o) == bandwidths(Zeros(o)) + o = OneElement(1, 6, 5) + @test bandwidths(o) == bandwidths(Zeros(o)) +end + @testset "rot180" begin A = brand(5,5,1,2) R = rot180(A) diff --git a/test/test_miscs.jl b/test/test_miscs.jl index 23c06e1f..2e8250e4 100644 --- a/test/test_miscs.jl +++ b/test/test_miscs.jl @@ -50,8 +50,17 @@ import BandedMatrices: _BandedMatrix, DefaultBandedMatrix @test bA isa BandedMatrix @test bA == A @test bandwidths(bA) == min.((l,u),9) + v = sparsevec(brand(10, 1, l, u)) + @test bandwidths(v) == (l, min(0, u)) end + l, u = -1, 0 + A = brand(10, 10, l, u) + sA = sparse(A) + @test bandwidths(sA) == bandwidths(Zeros(1)) + v = sparsevec(brand(10, 1, l, u)) + @test bandwidths(v) == bandwidths(Zeros(1)) + for diags = [(-1 => ones(Int, 5),), (-2 => ones(Int, 5),), (2 => ones(Int, 5),),