Skip to content

Commit

Permalink
use unified memory to check decomposition status
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Mar 11, 2024
1 parent b5c4140 commit 69fa85d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions lib/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ function LinearAlgebra.lu(A::MtlMatrix{T}; check::Bool = true) where {T<:MtlFloa
end

P = MtlMatrix{UInt32}(undef, 1, min(N, M))
status = MtlArray{MPSMatrixDecompositionStatus}(undef)
status = MtlArray{MPSMatrixDecompositionStatus,0,Shared}(undef)

cmdbuf_lu = MTLCommandBuffer(queue) do cmdbuf
mps_p = MPSMatrix(P)
Expand All @@ -196,7 +196,7 @@ function LinearAlgebra.lu(A::MtlMatrix{T}; check::Bool = true) where {T<:MtlFloa

wait_completed(cmdbuf_lu)

status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[])
status = convert(LinearAlgebra.BlasInt, status[])
check && checknonsingular(status)

return LinearAlgebra.LU(B, p, status)
Expand All @@ -219,7 +219,7 @@ function LinearAlgebra.lu!(A::MtlMatrix{T}; check::Bool = true) where {T<:MtlFlo
end

P = MtlMatrix{UInt32}(undef, 1, min(N, M))
status = MtlArray{MPSMatrixDecompositionStatus}(undef)
status = MtlArray{MPSMatrixDecompositionStatus,0,Shared}(undef)

cmdbuf_lu = MTLCommandBuffer(queue) do cmdbuf
mps_p = MPSMatrix(P)
Expand All @@ -237,7 +237,7 @@ function LinearAlgebra.lu!(A::MtlMatrix{T}; check::Bool = true) where {T<:MtlFlo

wait_completed(cmdbuf_lu)

status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[])
status = convert(LinearAlgebra.BlasInt, status[])
check && checknonsingular(status)

return LinearAlgebra.LU(A, p, status)
Expand Down

0 comments on commit 69fa85d

Please sign in to comment.