diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index 608d8b5..5443e74 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -1,6 +1,6 @@ module StridedGPUArraysExt -using Strided, GPUArrays +using Strided, GPUArrays, LinearAlgebra using GPUArrays: Adapt, KernelAbstractions using GPUArrays.KernelAbstractions: @kernel, @index @@ -20,6 +20,14 @@ function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS return dst end +function Base.copyto!(dest::StridedView{T, N, <:AnyGPUArray{T}}, bc::Base.Broadcast.Broadcasted{Strided.StridedArrayStyle{N}}) where {T <: Number, N} + dims = size(dest) + any(isequal(0), dims) && return dest + + GPUArrays._copyto!(dest, bc) + return dest +end + # lifted from GPUArrays.jl function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractGPUArray{T}, F <: ALL_FS} isempty(A) && return A @@ -34,7 +42,7 @@ function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractG return A end -function Strided.__mul!( +function LinearAlgebra.mul!( C::StridedView{TC, 2, <:AnyGPUArray{TC}}, A::StridedView{TA, 2, <:AnyGPUArray{TA}}, B::StridedView{TB, 2, <:AnyGPUArray{TB}},