diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index 51810a63f..2d36e423d 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -156,7 +156,8 @@ end # Metal's pivoting sequence needs to be iterated sequentially... # TODO: figure out a GPU-compatible way to get the permutation matrix -LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T = LinearAlgebra.ipiv2perm(Array(v), maxi) +LinearAlgebra.ipiv2perm(v::MtlVector{T,S}, maxi::Integer) where {T,S} = LinearAlgebra.ipiv2perm(Array(v), maxi) +LinearAlgebra.ipiv2perm(v::MtlVector{T,S}, maxi::Integer) where {T,S<:CPUStorage} = LinearAlgebra.ipiv2perm(unsafe_wrap(Array, v), maxi) function LinearAlgebra.lu(A::MtlMatrix{T}; check::Bool = true) where {T<:MtlFloat} M,N = size(A)