Skip to content

Commit

Permalink
Fix KernelAbstractions.copyto!
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Mar 12, 2024
1 parent c919da8 commit 09f92c4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
18 changes: 13 additions & 5 deletions src/MetalKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,28 @@ Adapt.adapt_storage(::KA.CPU, a::MtlArray) = convert(Array, a)

function KA.copyto!(::MetalBackend, A::MtlArray{T}, B::MtlArray{T}) where T
if device(dest) == device(src)
GC.@preserve A B unsafe_copyto!(device(A), pointer(A), pointer(B), length(A); async=true)
GC.@preserve A B unsafe_copyto!(device(A), pointer(A, storage=Private), pointer(B, storage=Private), length(A); async=true)
return A
else
error("Copy between different devices not implemented")
end
end

function KA.copyto!(::MetalBackend, A::Array{T}, B::MtlArray{T}) where T
GC.@preserve A B unsafe_copyto!(device(B), pointer(A), pointer(B), length(A); async=true)
function KA.copyto!(::MetalBackend, A::Array{T}, B::MtlArray{T,N,S}) where {T,N,S}
if (S == Metal.Private)
GC.@preserve A B unsafe_copyto!(device(B), pointer(A), pointer(B, storage=S), length(A); async=true)
else
GC.@preserve A B unsafe_copyto!(pointer(A), pointer(B, storage=S), length(A))
end
return A
end

function KA.copyto!(::MetalBackend, A::MtlArray{T}, B::Array{T}) where T
GC.@preserve A B unsafe_copyto!(device(A), pointer(A), pointer(B), length(A); async=true)
function KA.copyto!(::MetalBackend, A::MtlArray{T,N,S}, B::Array{T}) where {T,N,S}
if S == Private
GC.@preserve A B unsafe_copyto!(device(A), pointer(A, storage=S), pointer(B), length(A); async=true)
else
GC.@preserve A B unsafe_copyto!(pointer(A, storage=S), pointer(B), length(A))
end
return A
end

Expand Down
1 change: 0 additions & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ end

function Base.unsafe_convert(::Type{MtlPointer{T}}, x::MtlArray) where {T}
buf = x.data[]
synchronize()
MtlPointer{T}(buf, x.offset*Base.elsize(x))
end

Expand Down

0 comments on commit 09f92c4

Please sign in to comment.