From 69fa85d71e8ceec2cce4bb782f0875baaf746ce4 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Wed, 6 Mar 2024 20:17:49 +0100 Subject: [PATCH] use unified memory to check decomposition status --- lib/mps/linalg.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index cf2c8f360..51810a63f 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -174,7 +174,7 @@ function LinearAlgebra.lu(A::MtlMatrix{T}; check::Bool = true) where {T<:MtlFloa end P = MtlMatrix{UInt32}(undef, 1, min(N, M)) - status = MtlArray{MPSMatrixDecompositionStatus}(undef) + status = MtlArray{MPSMatrixDecompositionStatus,0,Shared}(undef) cmdbuf_lu = MTLCommandBuffer(queue) do cmdbuf mps_p = MPSMatrix(P) @@ -196,7 +196,7 @@ function LinearAlgebra.lu(A::MtlMatrix{T}; check::Bool = true) where {T<:MtlFloa wait_completed(cmdbuf_lu) - status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[]) + status = convert(LinearAlgebra.BlasInt, status[]) check && checknonsingular(status) return LinearAlgebra.LU(B, p, status) @@ -219,7 +219,7 @@ function LinearAlgebra.lu!(A::MtlMatrix{T}; check::Bool = true) where {T<:MtlFlo end P = MtlMatrix{UInt32}(undef, 1, min(N, M)) - status = MtlArray{MPSMatrixDecompositionStatus}(undef) + status = MtlArray{MPSMatrixDecompositionStatus,0,Shared}(undef) cmdbuf_lu = MTLCommandBuffer(queue) do cmdbuf mps_p = MPSMatrix(P) @@ -237,7 +237,7 @@ function LinearAlgebra.lu!(A::MtlMatrix{T}; check::Bool = true) where {T<:MtlFlo wait_completed(cmdbuf_lu) - status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[]) + status = convert(LinearAlgebra.BlasInt, status[]) check && checknonsingular(status) return LinearAlgebra.LU(A, p, status)