Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,24 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"

[extensions]
StridedAMDGPUExt = "AMDGPU"
StridedJLArraysExt = "JLArrays"
StridedGPUArraysExt = "GPUArrays"
StridedCUDAExt = "CUDA"
StridedGPUArraysExt = "GPUArrays"
StridedJLArraysExt = "JLArrays"

[compat]
AMDGPU = "2"
Aqua = "0.8"
CUDA = "5"
JLArrays = "0.3.1"
GPUArrays = "11.4.1"
JLArrays = "0.3.1"
LinearAlgebra = "1.6"
Metal = "1.9"
Random = "1.6"
StridedViews = "0.4.6"
Test = "1.6"
Expand All @@ -37,10 +38,11 @@ julia = "1.6"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Random", "Aqua", "AMDGPU", "CUDA", "GPUArrays", "JLArrays"]
test = ["Test", "Random", "Aqua", "AMDGPU", "CUDA", "GPUArrays", "JLArrays", "Metal"]
164 changes: 137 additions & 27 deletions ext/StridedGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,154 @@ module StridedGPUArraysExt
using Strided, GPUArrays
using GPUArrays: Adapt, KernelAbstractions
using GPUArrays.KernelAbstractions: @kernel, @index
using StridedViews: ParentIndex

ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}

KernelAbstractions.get_backend(sv::StridedView{T, N, TA}) where {T, N, TA <: AnyGPUArray{T}} = KernelAbstractions.get_backend(parent(sv))
# StridedView backed by any GPU array type, with element type linked to the parent.
const GPUStridedView{T, N} = StridedView{T, N, <:AnyGPUArray{T}}

function Base.Broadcast.BroadcastStyle(gpu_sv::StridedView{T, N, TA}) where {T, N, TA <: AnyGPUArray{T}}
raw_style = Base.Broadcast.BroadcastStyle(TA)
return typeof(raw_style)(Val(N)) # sets the dimensionality correctly
end
KernelAbstractions.get_backend(sv::GPUStridedView) = KernelAbstractions.get_backend(parent(sv))

function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TS <: Number, NS, TAS <: AbstractGPUArray{TS}, FS <: ALL_FS}
bc_style = Base.Broadcast.BroadcastStyle(TAS)
bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst))
GPUArrays._copyto!(dst, bc)
return dst
end

# lifted from GPUArrays.jl
function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractGPUArray{T}, F <: ALL_FS}
isempty(A) && return A
@kernel function fill_kernel!(a, val)
idx = @index(Global, Cartesian)
@inbounds a[idx] = val
end
# ndims check for 0D support
kernel = fill_kernel!(KernelAbstractions.get_backend(A))
f_x = F <: Union{typeof(conj), typeof(adjoint)} ? conj(x) : x
kernel(A, f_x; ndrange = size(A))
return A
# Conversion to CPU Array: materialise into a contiguous GPU array first (so the
# GPU-to-GPU copy! path is used), then let the GPU array type handle the transfer.
function Base.Array(a::GPUStridedView)
b = similar(parent(a), eltype(a), size(a))
copy!(StridedView(b), a)
return Array(b)
end

function Strided.__mul!(
C::StridedView{TC, 2, <:AnyGPUArray{TC}},
A::StridedView{TA, 2, <:AnyGPUArray{TA}},
B::StridedView{TB, 2, <:AnyGPUArray{TB}},
C::GPUStridedView{TC, 2},
A::GPUStridedView{TA, 2},
B::GPUStridedView{TB, 2},
α::Number, β::Number
) where {TC, TA, TB}
return GPUArrays.generic_matmatmul!(C, A, B, α, β)
end

# ---------- GPU mapreduce support ----------

@inline _gpu_init_acc(::Nothing, current_val) = current_val
@inline _gpu_init_acc(initop, current_val) = initop(current_val)

@inline _gpu_accum(::Nothing, acc, val) = val
@inline _gpu_accum(op, acc, val) = op(acc, val)

@inline function _strides_dot(strides::NTuple{N, Int}, cidx::CartesianIndex{N}) where {N}
s = 0
for d in Base.OneTo(N)
@inbounds s += strides[d] * (cidx[d] - 1)
end
return s
end

@kernel function _mapreduce_gpu_kernel!(
f, op, initop,
dims::NTuple{N, Int},
out::OT,
inputs::IT
) where {N, OT <: StridedView, IT <: Tuple}

out_linear = @index(Global, Linear)

# Non-reduction subspace sizes (1 for reduction dims)
nred_sizes = ntuple(Val(N)) do d
@inbounds iszero(out.strides[d]) ? 1 : dims[d]
end
# Reduction subspace sizes (1 for non-reduction dims)
red_sizes = ntuple(Val(N)) do d
@inbounds iszero(out.strides[d]) ? dims[d] : 1
end

# Map out_linear → cartesian in non-reduction subspace
nred_cidx = CartesianIndices(nred_sizes)[out_linear]
out_parent = out.offset + 1 + _strides_dot(out.strides, nred_cidx)

# Initialize accumulator from current output value (or apply initop)
@inbounds acc = _gpu_init_acc(initop, out[ParentIndex(out_parent)])

# Sequential reduction loop over reduction subspace
@inbounds for red_linear in Base.OneTo(prod(red_sizes))
red_cidx = CartesianIndices(red_sizes)[red_linear]
complete_cidx = CartesianIndex(
ntuple(Val(N)) do d
@inbounds nred_cidx[d] + red_cidx[d] - 1
end
)

val = f(
ntuple(Val(length(inputs))) do m
@inbounds begin
a = inputs[m]
ip = a.offset + 1 + _strides_dot(a.strides, complete_cidx)
a[ParentIndex(ip)]
end
end...
)

acc = _gpu_accum(op, acc, val)
end

@inbounds out[ParentIndex(out_parent)] = acc
end

# GPU-compatible _mapreduce: avoids scalar indexing (first(A), out[ParentIndex(1)])
# that JLArrays/real GPUs prohibit. Mirrors GPUArrays' neutral_element approach:
# infer output type via Broadcast machinery, look up the neutral element (errors on
# unknown ops), fill the output buffer, then read back a single scalar via Array().
function Strided._mapreduce(
f, op, A::GPUStridedView{T, N}, nt = nothing
) where {T, N}
if length(A) == 0
b = Base.mapreduce_empty(f, op, T)
return nt === nothing ? b : op(b, nt.init)
end

dims = size(A)

if nt === nothing
ET = Base.Broadcast.combine_eltypes(f, (A,))
ET = Base.promote_op(op, ET, ET)
(ET === Union{} || ET === Any) &&
error("cannot infer output element type for mapreduce; pass an explicit `init`")
init = GPUArrays.neutral_element(op, ET)
else
ET = typeof(nt.init)
init = nt.init
end

out = similar(parent(A), ET, (1,))
fill!(out, init)

Strided._mapreducedim!(f, op, nothing, dims, (sreshape(StridedView(out), one.(dims)), A))

return Array(out)[1]
end

function Strided._mapreduce_fuse!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the generic _mapreduce_fuse! step is still valid for GPUStridedView objects, so maybe _mapreduce_order! is where the lowering could be intercepted for GPUStridedView?

f, op, initop,
dims::Dims{N},
arrays::Tuple{GPUStridedView{TO, N}, Vararg{GPUStridedView{<:Any, N}}}
) where {TO, N}

out = arrays[1]
inputs_raw = Base.tail(arrays)
M = length(inputs_raw)
inputs = ntuple(i -> inputs_raw[i], Val(M))

# Number of output elements = product of non-reduction dims
out_total = prod(
ntuple(Val(N)) do d
@inbounds iszero(out.strides[d]) ? 1 : dims[d]
end
)
Comment on lines +143 to +147
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think size(out) is still valid. Also, why do you want to go via linear indexing if KernelAbstractions supports Cartesian or Tuple indices?


backend = KernelAbstractions.get_backend(parent(out))
kernel! = _mapreduce_gpu_kernel!(backend)
kernel!(f, op, initop, dims, out, inputs; ndrange = out_total)

return nothing
end

end
15 changes: 14 additions & 1 deletion src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,21 @@ function Broadcast.BroadcastStyle(
end

function Base.similar(bc::Broadcasted{<:StridedArrayStyle{N}}, ::Type{T}) where {N, T}
return StridedView(similar(convert(Broadcasted{DefaultArrayStyle{N}}, bc), T))
sv = _find_strided_view(bc)
if sv !== nothing
return StridedView(similar(parent(sv), T, size(bc)))
end
return StridedView(similar(Array{T}, axes(bc)))
end

@inline _find_strided_view(bc::Broadcasted) = _find_strided_view(bc.args...)
@inline _find_strided_view(sv::StridedView, rest...) = sv
@inline function _find_strided_view(nested::Broadcasted, rest...)
sv = _find_strided_view(nested)
return sv === nothing ? _find_strided_view(rest...) : sv
end
@inline _find_strided_view(x, rest...) = _find_strided_view(rest...)
@inline _find_strided_view() = nothing

Base.dotview(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...)

Expand Down
20 changes: 5 additions & 15 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
# Methods based on map!
function Base.copy!(dst::StridedView{<:Any, N}, src::StridedView{<:Any, N}) where {N}
return map!(identity, dst, src)
end
Base.copy!(dst::StridedView{<:Any, N}, src::StridedView{<:Any, N}) where {N} = map!(identity, dst, src)
Base.conj!(a::StridedView{<:Real}) = a
Base.conj!(a::StridedView) = map!(conj, a, a)
function LinearAlgebra.adjoint!(
dst::StridedView{<:Any, N},
src::StridedView{<:Any, N}
) where {N}
return copy!(dst, adjoint(src))
end
function Base.permutedims!(
dst::StridedView{<:Any, N}, src::StridedView{<:Any, N},
p
) where {N}
return copy!(dst, permutedims(src, p))
end
LinearAlgebra.adjoint!(dst::StridedView, src::StridedView) = copy!(dst, adjoint(src))
LinearAlgebra.transpose!(C::StridedView, A::StridedView) = copy!(C, transpose(A))
Base.permutedims!(dst::StridedView, src::StridedView, p) = copy!(dst, permutedims(src, p))
Base.fill!(A::StridedView, val) = map!(Returns(val), A)

function Base.mapreduce(f, op, A::StridedView; dims = :, kw...)
return Base._mapreduce_dim(f, op, values(kw), A, dims)
Expand Down
19 changes: 0 additions & 19 deletions test/jlarrays.jl

This file was deleted.

Loading
Loading