Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ext/StridedGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module StridedGPUArraysExt

using Strided, GPUArrays
using GPUArrays: Adapt, KernelAbstractions
using GPUArrays.KernelAbstractions: @kernel, @index

ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}

Expand All @@ -19,4 +20,27 @@ function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS
return dst
end

# lifted from GPUArrays.jl
function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractGPUArray{T}, F <: ALL_FS}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm slightly hesitant about these kinds of overloads, but I think it is just ignorance:

From what I know, the point of Strided is to have some logic that makes it such that even though A might have strides that are not just the usual multidimensional ones, it tries to figure out a way to access the parent array "as linearly as possible".
If I understand this overload correctly, this is simply doing a mapping and making use of the index computations, so in that sense it is somewhat similar to the implementation for just a regular SubArray, I think?
However, I don't actually know if the same kinds of performance ideas about linear access are even valid on the GPU, so this might be completely reasonable?

Maybe we could schedule a meeting somewhere next week to discuss this, and the possible implications of deciding what type of objects to use for our strided views?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think that makes sense. I think coming up with a good access pattern is probably even more important on the GPU but here I've done the naive thing just to get it started, and then we should definitely do some benchmarking/profiling to see if this is really slowing us down or not.

isempty(A) && return A
@kernel function fill_kernel!(a, val)
idx = @index(Global, Linear)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would using @index(Global, Cartesian) have been an option? Linear indexing is not particularly efficient for a StridedView?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can do that instead, but it also affects the launch configuration for the kernel of course. I'll try that in a fix PR that hopefully will also fix TagBot...

@inbounds a[idx] = val
end
# ndims check for 0D support
kernel = fill_kernel!(KernelAbstractions.get_backend(A))
f_x = F <: Union{typeof(conj), typeof(adjoint)} ? conj(x) : x
kernel(A, f_x; ndrange = length(A))
return A
end

function Strided.__mul!(
C::StridedView{TC, 2, <:AnyGPUArray{TC}},
A::StridedView{TA, 2, <:AnyGPUArray{TA}},
B::StridedView{TB, 2, <:AnyGPUArray{TB}},
α::Number, β::Number
) where {TC, TA, TB}
return GPUArrays.generic_matmatmul!(C, A, B, α, β)
end

end
3 changes: 1 addition & 2 deletions ext/StridedJLArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ module StridedJLArraysExt

using Strided, StridedViews, JLArrays
using JLArrays: Adapt
using JLArrays: GPUArrays

const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}

function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TAD <: JLArray{TD}, FD <: ALL_FS, TS <: Number, NS, TAS <: JLArray{TS}, FS <: ALL_FS}
bc_style = Base.Broadcast.BroadcastStyle(TAS)
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
GPUArrays._copyto!(dst, bc)
JLArrays.GPUArrays._copyto!(dst, bc)
return dst
end

Expand Down
2 changes: 2 additions & 0 deletions test/amd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
axes(f1(A1)) == axes(f2(A2)) || continue
@test collect(ROCMatrix(copy!(f2(A2), f1(A1)))) == AMDGPU.Adapt.adapt(Vector{T}, copy!(B2, B1))
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
x = rand(T)
@test f1(StridedView(AMDGPU.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == AMDGPU.Adapt.adapt(Vector{T}, fill!(B1, x))
end
end
end
2 changes: 2 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
axes(f1(A1)) == axes(f2(A2)) || continue
@test collect(CuMatrix(copy!(f2(A2), f1(A1)))) == CUDA.Adapt.adapt(Vector{T}, copy!(B2, B1))
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
x = rand(T)
@test f1(StridedView(CUDA.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == CUDA.Adapt.adapt(Vector{T}, fill!(B1, x))
end
end
end
4 changes: 3 additions & 1 deletion test/jlarrays.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "Copy with JLArrayStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint)
for m1 in (0, 16, 32), m2 in (0, 16, 32)
A1 = JLArray(randn(T, (m1, m2)))
Expand All @@ -12,6 +12,8 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
axes(f1(A1)) == axes(f2(A2)) || continue
@test collect(Matrix(copy!(f2(A2), f1(A1)))) == JLArrays.Adapt.adapt(Vector{T}, copy!(B2, B1))
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
x = rand(T)
@test f1(StridedView(JLArrays.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == JLArrays.Adapt.adapt(Vector{T}, fill!(B1, x))
end
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Random.seed!(1234)
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

if !is_buildkite
include("jlarrays.jl")
println("Base.Threads.nthreads() = $(Base.Threads.nthreads())")

println("Running tests single-threaded:")
Expand All @@ -28,7 +29,6 @@ if !is_buildkite
include("blasmultests.jl")
Strided.disable_threaded_mul()

include("jlarrays.jl")
Aqua.test_all(Strided; piracies = false)
end

Expand Down
Loading