From 961c04c0e0e11d59cf9e95ce9d09bbe36113c1bd Mon Sep 17 00:00:00 2001 From: Jinguo Liu Date: Sun, 29 May 2022 01:58:37 -0400 Subject: [PATCH] Fix complex dense-perm matmul (#69) * fix complex dense-perm matmul * bump version * fix test * fix nightly test --- Project.toml | 3 ++- src/linalg.jl | 2 +- test/PermMatrix.jl | 6 ++++++ test/kronecker.jl | 46 ++++++++++++++++++++++++---------------------- test/linalg.jl | 18 +++++++++--------- 5 files changed, 42 insertions(+), 33 deletions(-) diff --git a/Project.toml b/Project.toml index 45ff7dc..3e4b86b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "LuxurySparse" uuid = "d05aeea4-b7d4-55ac-b691-9e7fabb07ba2" authors = ["GiggleLiu ", "Roger-luo "] -version = "0.6.11" +version = "0.6.12" [deps] +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/src/linalg.jl b/src/linalg.jl index dc9bd7b..76b10ae 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -101,7 +101,7 @@ end function *(X::AbstractMatrix, A::PermMatrix) mX, nX = size(X) nX == size(A, 1) || throw(DimensionMismatch()) - return @views (A.vals'.*X)[:, fast_invperm(A.perm)] + return @views (transpose(A.vals) .* X)[:, fast_invperm(A.perm)] end # NOTE: this is just a temperory fix for v0.7. We should overload mul! in diff --git a/test/PermMatrix.jl b/test/PermMatrix.jl index a57d445..e535b8e 100644 --- a/test/PermMatrix.jl +++ b/test/PermMatrix.jl @@ -96,4 +96,10 @@ end pm = PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3]) res = pm .* 3im @test res == PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3] .* 3im) && res isa PermMatrix +end + +@testset "fix dense-perm multiplication" begin + A = randn(ComplexF64, 4, 4) + pm = PermMatrix([3, 2, 4, 1], [0.2im, 0.6im, 0.1, 0.3]) + @test A * pm ≈ A * Matrix(pm) end \ No newline at end of file diff --git a/test/kronecker.jl b/test/kronecker.jl index 738b19d..08acbfc 100644 --- a/test/kronecker.jl +++ b/test/kronecker.jl @@ -1,28 +1,30 @@ using Test, Random, SparseArrays, LinearAlgebra import LuxurySparse: IMatrix, PermMatrix -Random.seed!(2) +@testset "kron" begin + Random.seed!(2) -p1 = IMatrix{4}() -sp = sprand(ComplexF64, 4, 4, 0.5) -ds = rand(ComplexF64, 4, 4) -pm = PermMatrix([2, 3, 4, 1], randn(4)) -v = [0.5, 0.3im, 0.2, 1.0] -dv = Diagonal(v) + p1 = IMatrix{4}() + sp = sprand(ComplexF64, 4, 4, 0.5) + ds = rand(ComplexF64, 4, 4) + pm = PermMatrix([2, 3, 4, 1], randn(4)) + pm = PermMatrix([2, 3, 4, 1], randn(4)) + v = [0.5, 0.3im, 0.2, 1.0] + dv = Diagonal(v) - -@testset "kron(::$(typeof(source)), ::$(typeof(target)))" for source in [p1, sp, ds, dv, pm], - target in [p1, sp, ds, dv, pm] - lres = kron(source, target) - rres = kron(target, source) - flres = kron(Matrix(source), Matrix(target)) - frres = kron(Matrix(target), Matrix(source)) - @test lres == flres - @test rres == frres - @test eltype(lres) == eltype(flres) - @test eltype(rres) == eltype(frres) - if !(target === ds && source === ds) - @test !(typeof(lres) <: StridedMatrix) - @test !(typeof(rres) <: StridedMatrix) + for source in Any[p1, sp, ds, dv, pm], + target in Any[p1, sp, ds, dv, pm] + lres = kron(source, target) + rres = kron(target, source) + flres = kron(Matrix(source), Matrix(target)) + frres = kron(Matrix(target), Matrix(source)) + @test lres == flres + @test rres == frres + @test eltype(lres) == eltype(flres) + @test eltype(rres) == eltype(frres) + if !(target === ds && source === ds) + @test !(typeof(lres) <: StridedMatrix) + @test !(typeof(rres) <: StridedMatrix) + end end -end +end \ No newline at end of file diff --git a/test/linalg.jl b/test/linalg.jl index 7b81b0b..9f52550 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -19,12 +19,12 @@ dv = Diagonal(v) @test logdet(p1) == 0 @test inv(pm) == inv(Matrix(pm)) - for m in [pm, sp, p1, dv] + for m in Any[pm, sp, p1, dv] @test !(m |> isdense) @test !(m' |> isdense) @test !(transpose(m) |> isdense) end - for m in [ds, v] + for m in Any[ds, v] @test m |> isdense @test m' |> isdense @test transpose(m) |> isdense @@ -32,9 +32,9 @@ dv = Diagonal(v) end @testset "multiply" begin - for source_ in [p1, sp, ds, dv, pm] - for target in [p1, sp, ds, dv, pm] - for source in [source_, source_', transpose(source_)] + for source_ in Any[p1, sp, ds, dv, pm] + for target in Any[p1, sp, ds, dv, pm] + for source in Any[source_, source_', transpose(source_)] lres = source * target rres = target * source flres = Matrix(source) * Matrix(target) @@ -82,7 +82,7 @@ end @testset "randn" begin Random.seed!(2) T = ComplexF64 - for m in [sprand(T, 5, 5, 0.5)] + for m in Any[sprand(T, 5, 5, 0.5)] zm = zero(m) @test zm ≈ zeros(T, 5, 5) if VERSION < v"1.4.0" @@ -93,7 +93,7 @@ end @test !(zm ≈ zeros(T, 5, 5)) end end - for m in [pmrand(T, 5), Diagonal(randn(T, 5))] + for m in Any[pmrand(T, 5), Diagonal(randn(T, 5))] zm = zero(m) @test zm ≈ zeros(T, 5, 5) rand!(zm) @@ -112,8 +112,8 @@ end end @testset "findnz" begin - for m in [p1, sp, ds, dv, pm] - for _m in [m, staticize(m)] + for m in Any[p1, sp, ds, dv, pm] + for _m in Any[m, staticize(m)] out = zeros(eltype(m), size(m)...) for (i, j, v) in zip(LuxurySparse.findnz(_m)...) out[i, j] = v