Skip to content

Commit

Permalink
Ensure has_uniform_datalayouts for cuda copyto
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 24, 2024
1 parent 71f1a38 commit 32143e4
Showing 1 changed file with 48 additions and 20 deletions.
68 changes: 48 additions & 20 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,18 @@ function knl_copyto_linear!(dest, src, us)
return nothing
end

function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
if !(VERSION v"1.11.0-beta") && dest isa DataLayouts.EndsWithField
bc′ = Base.Broadcast.instantiate(
DataLayouts.to_non_extruded_broadcasted(bc),
)
args = (dest, bc′, us)
threads = threads_via_occupancy(knl_copyto_linear!, args)
n_max_threads = min(threads, get_N(us))
p = linear_partition(prod(size(dest)), n_max_threads)
auto_launch!(
knl_copyto_linear!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
else
if VERSION v"1.11.0-beta"
# https://github.com/JuliaLang/julia/issues/56295
# Julia 1.11's Base.Broadcast currently requires
# multiple integer indexing, wheras Julia 1.10 did not.
# This means that we cannot reserve linear indexing to
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
# (including the GPU-variant related issue resolution efforts:
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
args = (dest, bc, us)
threads = threads_via_occupancy(knl_copyto!, args)
n_max_threads = min(threads, get_N(us))
Expand All @@ -46,8 +39,43 @@ function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
blocks_s = p.blocks,
)
end
return dest
end
else
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
if DataLayouts.has_uniform_datalayouts(bc) &&
dest isa DataLayouts.EndsWithField
bc′ = Base.Broadcast.instantiate(
DataLayouts.to_non_extruded_broadcasted(bc),
)
args = (dest, bc′, us)
threads = threads_via_occupancy(knl_copyto_linear!, args)
n_max_threads = min(threads, get_N(us))
p = linear_partition(prod(size(dest)), n_max_threads)
auto_launch!(
knl_copyto_linear!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
else
args = (dest, bc, us)
threads = threads_via_occupancy(knl_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = partition(dest, n_max_threads)
auto_launch!(
knl_copyto!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
end
end
return dest
end
return dest
end

# broadcasting scalar assignment
Expand Down

0 comments on commit 32143e4

Please sign in to comment.