diff --git a/Project.toml b/Project.toml index e42c704..f82dce3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NamedDims" uuid = "356022a1-0364-5f58-8944-0da4b18d706f" authors = ["Invenia Technical Computing Corporation"] -version = "0.2.40" +version = "0.2.41" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/functions_linearalgebra.jl b/src/functions_linearalgebra.jl index caa1c93..310b29d 100644 --- a/src/functions_linearalgebra.jl +++ b/src/functions_linearalgebra.jl @@ -118,3 +118,39 @@ function Base.getproperty( return inner end end + +function LinearAlgebra.:\( + fact::NamedFactorization{L,T,F}, nda::NamedDimsArray{W} +) where {L,T,F<:Factorization{T},W} + n1, n2 = L + n1 != W[1] && throw( + DimensionMismatch( + "Mismatched dimensions with factorization: $L and NamedDimsArray: $W" + ), + ) + return NamedDimsArray{(n2,)}(LinearAlgebra.:\(parent(fact), parent(nda))) +end + +function LinearAlgebra.:\( + fact::NamedFactorization{L,T,F}, nda::AbstractVector +) where {L,T,F<:Factorization{T}} + n1, n2 = L + return NamedDimsArray{(n2,)}(LinearAlgebra.:\(parent(fact), nda)) +end + +# Specialised routines for \ often do in-place ops that result in the nameddim populated from B +# Leading to an incorrect named-dim +# We also unname here because that handles wrapper types +for S in (UpperTriangular, LowerTriangular) + @eval begin + function LinearAlgebra.:\( + A::$S{T,<:NamedDimsArray{L}}, B::AbstractVector + ) where {L,T} + n1, n2 = L + return NamedDimsArray{(n2,)}(LinearAlgebra.:\($S(unname(A)), parent(B))) + end + end +end + +# Diagonal on a nameddim presently loses its nameddimsness. So just pass through for now. +LinearAlgebra.:\(A::Diagonal, B::NamedDimsArray) = LinearAlgebra.:\(A, parent(B)) diff --git a/src/name_operations.jl b/src/name_operations.jl index a5283ee..eb8e96a 100644 --- a/src/name_operations.jl +++ b/src/name_operations.jl @@ -41,9 +41,17 @@ Return the input array `A` without any dimension names. For `NamedDimsArray`s this returns the parent array, equivalent to calling `parent`, but for anything else it simply returns the input. + +Supports some LinearAlgebra wrappers LowerTriangular, UpperTriangular such that +`unname(x::UpperTriangular{T, <:NamedDimsArray}) == unname(parent(x)) + """ unname(x::NamedDimsArray) = parent(x) unname(x) = x +# Unwrap LinearAlgebra wrappers +for W in (LowerTriangular, UpperTriangular) + @eval unname(x::$W{T,<:NamedDimsArray}) where {T} = unname(parent(x)) +end """ dimnames(A) -> Tuple diff --git a/test/functions_linearalgebra.jl b/test/functions_linearalgebra.jl index 454fece..13ebd29 100644 --- a/test/functions_linearalgebra.jl +++ b/test/functions_linearalgebra.jl @@ -1,3 +1,4 @@ +using Test: approx_full using LinearAlgebra using NamedDims using NamedDims: dimnames @@ -139,5 +140,31 @@ end @testset "#164 factorization eltype not same as input eltype" begin # https://github.com/invenia/NamedDims.jl/issues/164 nda = NamedDimsArray{(:foo, :bar)}([1 2 3; 4 5 6; 7 8 9]) # Int eltype - @test qr(nda) isa NamedDims.NamedFactorization{(:foo, :bar), Float64} + @test qr(nda) isa NamedDims.NamedFactorization{(:foo, :bar),Float64} +end + +@testset "LinearAlgebra.:ldiv " begin + r1 = [2 3 5; 7 11 13; 17 19 23] + r2 = r1[:, 1:2] + b = [29, 31, 37] + b_nda = NamedDimsArray{(:foo,)}(b) + + for A in (r1, r2) + (m, n) = size(A) + issquare = m == n + fn = issquare ? (identity, triu, tril, Diagonal) : (identity,) + for f in fn + for B in (b_nda, b) + nda = NamedDimsArray{(:foo, :bar)}(f(A)) + x = nda \ B + @test parent(x) ≈ f(A) \ parent(B) + # NOTE: Diagonal loses NamedDimness so specialcase + f != Diagonal && @test dimnames(x) == (:bar,) + end + end + end + + @test_throws DimensionMismatch (\)( + NamedDimsArray{(:A, :B)}(r1), NamedDimsArray{(:NotA,)}(b) + ) end diff --git a/test/name_operations.jl b/test/name_operations.jl index 870901a..e56819a 100644 --- a/test/name_operations.jl +++ b/test/name_operations.jl @@ -1,11 +1,17 @@ +using LinearAlgebra: LowerTriangular, UpperTriangular + @testset "unname" begin for orig in ([1 2; 3 4], spzeros(2, 2)) @test unname(NamedDimsArray(orig, (:x, :y))) === orig @test unname(orig) === orig end - @test unname((1,2,3)) === (1,2,3) -end + @test unname((1, 2, 3)) === (1, 2, 3) + for wrapper in (LowerTriangular, UpperTriangular) + orig = [1 2; 3 4] + @test unname(wrapper(NamedDimsArray(orig, (:x, :y)))) === orig + end +end @testset "dimnames" begin nda = NamedDimsArray([10 20; 30 40], (:x, :y))