Skip to content

Commit

Permalink
Specialize HF for fused kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 24, 2024
1 parent 44c7339 commit a822d7e
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 28 deletions.
20 changes: 5 additions & 15 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,9 @@ function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
fmbc::FusedMultiBroadcast,
)
FusedMultiBroadcast(
map(fmbc.pairs) do pair
dest = pair.first
bc = pair.second
Pair(
Adapt.adapt(to, dest),
Base.Broadcast.Broadcasted(
bc.style,
adapt_f(to, bc.f),
Adapt.adapt(to, bc.args),
Adapt.adapt(to, bc.axes),
),
)
end,
)
FusedMultiBroadcast(map(fmbc.pairs) do pair
dest = pair.first
bc = pair.second
Pair(Adapt.adapt(to, dest), Adapt.adapt(to, bc))
end)
end
90 changes: 79 additions & 11 deletions ext/cuda/data_layouts_fused_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,92 @@ function knl_fused_copyto!(fmbc::FusedMultiBroadcast, dest1, us)
return nothing
end

Base.@propagate_inbounds function rcopyto_at_linear!(
pair::Pair{<:AbstractData, <:DataLayouts.NonExtrudedBroadcasted},
I,
)
(dest, bc) = pair.first, pair.second
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = bcI
return nothing
end
Base.@propagate_inbounds function rcopyto_at_linear!(
pair::Pair{<:DataF, <:DataLayouts.NonExtrudedBroadcasted},
I,
)
(dest, bc) = pair.first, pair.second
bcI = isascalar(bc) ? bc[] : bc[I]
dest[] = bcI
return nothing
end
Base.@propagate_inbounds function rcopyto_at_linear!(pairs::Tuple, I)
rcopyto_at_linear!(first(pairs), I)
rcopyto_at_linear!(Base.tail(pairs), I)
end
Base.@propagate_inbounds rcopyto_at_linear!(pairs::Tuple{<:Any}, I) =
rcopyto_at_linear!(first(pairs), I)
@inline rcopyto_at_linear!(pairs::Tuple{}, I) = nothing

function knl_fused_copyto_linear!(fmbc::FusedMultiBroadcast, us)
@inbounds begin
I = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
if linear_is_valid_index(I, us)
(; pairs) = fmbc
rcopyto_at_linear!(pairs, I)
end
end
return nothing
end

# https://github.com/JuliaLang/julia/issues/56295
# Julia 1.11's Base.Broadcast currently requires
# multiple integer indexing, wheras Julia 1.10 did not.
# This means that we cannot reserve linear indexing to
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
# (including the GPU-variant related issue resolution efforts:
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::DataLayouts.AbstractData,
::ToCUDA,
)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest1)
if Nv > 0 && Nh > 0
us = DataLayouts.UniversalSize(dest1)
args = (fmbc, dest1, us)
threads = threads_via_occupancy(knl_fused_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = partition(dest1, n_max_threads)
auto_launch!(
knl_fused_copyto!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
bcs = map(p -> p.second, fmbc.pairs)
destinations = map(p -> p.first, fmbc.pairs)
if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) &&
all(d -> d isa DataLayouts.EndsWithField, destinations) &&
!(VERSION v"1.11.0-beta")
pairs′ = map(fmbc.pairs) do p
bc′ = DataLayouts.to_non_extruded_broadcasted(p.second)
Pair(p.first, Base.Broadcast.instantiate(bc′))
end
us = DataLayouts.UniversalSize(dest1)
fmbc′ = FusedMultiBroadcast(pairs′)
args = (fmbc′, us)
threads = threads_via_occupancy(knl_fused_copyto_linear!, args)
n_max_threads = min(threads, get_N(us))
p = linear_partition(prod(size(dest1)), n_max_threads)
auto_launch!(
knl_fused_copyto_linear!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
always_inline = false,
)
else
us = DataLayouts.UniversalSize(dest1)
args = (fmbc, dest1, us)
threads = threads_via_occupancy(knl_fused_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = partition(dest1, n_max_threads)
auto_launch!(
knl_fused_copyto!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
end
end
return nothing
end
2 changes: 1 addition & 1 deletion src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,7 @@ Base.copy(data::AbstractData) =
union_all(singleton(data)){type_params(data)...}(copy(parent(data)))

# broadcast machinery
include("non_extruded_broadcasted.jl")
include("broadcast.jl")

Adapt.adapt_structure(to, data::AbstractData{S}) where {S} =
Expand Down Expand Up @@ -2191,7 +2192,6 @@ device_dispatch(x::MArray) = ToCPU()
@inline singleton(::Type{IV1JH2}) = IV1JH2Singleton()


include("non_extruded_broadcasted.jl")
include("has_uniform_datalayouts.jl")

include("copyto.jl")
Expand Down
6 changes: 6 additions & 0 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,10 @@ isascalar(
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
} = true
isascalar(
bc::NonExtrudedBroadcasted{Style},
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
} = true
isascalar(bc) = false
48 changes: 47 additions & 1 deletion src/DataLayouts/fused_copyto.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,30 @@

Base.@propagate_inbounds function rcopyto_at_linear!(
pair::Pair{<:AbstractData, <:Any},
I,
)
dest, bc = pair.first, pair.second
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = bcI
return nothing
end
Base.@propagate_inbounds function rcopyto_at_linear!(
pair::Pair{<:DataF, <:Any},
I,
)
dest, bc = pair.first, pair.second
bcI = isascalar(bc) ? bc[] : bc[I]
dest[] = bcI
return nothing
end
Base.@propagate_inbounds function rcopyto_at_linear!(pairs::Tuple, I)
rcopyto_at_linear!(first(pairs), I)
rcopyto_at_linear!(Base.tail(pairs), I)
end
Base.@propagate_inbounds rcopyto_at_linear!(pairs::Tuple{<:Any}, I) =
rcopyto_at_linear!(first(pairs), I)
@inline rcopyto_at_linear!(pairs::Tuple{}, I) = nothing

# Fused multi-broadcast entry point for DataLayouts
function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
Expand All @@ -18,7 +44,27 @@ function Base.copyto!(
end,
)
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
fused_copyto!(fmb_inst, dest1, device_dispatch(parent(dest1)))

bcs = map(p -> p.second, fmb_inst.pairs)
destinations = map(p -> p.first, fmb_inst.pairs)
dest1 = first(destinations)
us = DataLayouts.UniversalSize(dest1)
dev = device_dispatch(parent(dest1))
if dev isa ClimaComms.AbstractCPUDevice &&
all(bc -> has_uniform_datalayouts(bc), bcs) &&
all(d -> d isa EndsWithField, destinations) &&
!(VERSION v"1.11.0-beta")
pairs′ = map(fmb_inst.pairs) do p
bc′ = to_non_extruded_broadcasted(p.second)
Pair(p.first, bc′)
end
fmbc′ = FusedMultiBroadcast(pairs′)
@inbounds for I in 1:get_N(us)
rcopyto_at_linear!(fmbc′.pairs, I)
end
else
fused_copyto!(fmb_inst, dest1, dev)
end
end

function fused_copyto!(
Expand Down
7 changes: 7 additions & 0 deletions src/DataLayouts/non_extruded_broadcasted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ struct NonExtrudedBroadcasted{
end
end

@inline to_broadcasted(bc::NonExtrudedBroadcasted) =
Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, bc.axes)
@inline to_non_extruded_broadcasted(bc::Base.Broadcast.Broadcasted) =
NonExtrudedBroadcasted(bc.style, bc.f, to_non_extruded_broadcasted_args(bc.args), bc.axes)
@inline to_non_extruded_broadcasted(x) = x
Expand Down Expand Up @@ -77,7 +79,12 @@ end
Base.checkbounds_indices(Bool, (Base.OneTo(n_dofs(bc)),), (I,)) || Base.throw_boundserror(bc, (I,))
end

# To handle scalar cases, let's just switch back to
# Base.Broadcast.Broadcasted and allow cartesian indexing:
Base.@propagate_inbounds Base.getindex(bc::NonExtrudedBroadcasted) = to_broadcasted(bc)[CartesianIndex(())]

import StaticArrays
to_tuple(::Tuple{}) = ()
to_tuple(t::Tuple) = t
to_tuple(t::NTuple{N, <: Base.OneTo}) where {N} = map(x->x.stop, t)
to_tuple(t::NTuple{N, <: StaticArrays.SOneTo}) where {N} = map(x->x.stop, t)
Expand Down
28 changes: 28 additions & 0 deletions test/Fields/benchmark_field_multi_broadcast_fusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,34 @@ include("utils_field_multi_broadcast_fusion.jl")
nothing
end

@testset "FusedMultiBroadcast VIJHF" begin
FT = Float64
device = ClimaComms.device()
space = TU.CenterExtrudedFiniteDifferenceSpace(
FT;
zelem = 63,
helem = 30,
Nq = 4,
HorizontalLayout = DataLayouts.IJHF,
context = ClimaComms.context(device),
)
X = Fields.FieldVector(
x1 = rand_field(FT, space),
x2 = rand_field(FT, space),
x3 = rand_field(FT, space),
)
Y = Fields.FieldVector(
y1 = rand_field(FT, space),
y2 = rand_field(FT, space),
y3 = rand_field(FT, space),
)
test_kernel!(; fused!, unfused!, X, Y)

benchmark_kernel!(unfused!, X, Y, device)
benchmark_kernel!(fused!, X, Y, device)
nothing
end

@testset "FusedMultiBroadcast VIFH" begin
FT = Float64
device = ClimaComms.device()
Expand Down
26 changes: 26 additions & 0 deletions test/Fields/unit_field_multi_broadcast_fusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,32 @@ end
nothing
end

@testset "FusedMultiBroadcast VIJHF and VF" begin
FT = Float64
device = ClimaComms.device()
space = TU.CenterExtrudedFiniteDifferenceSpace(
FT;
zelem = 3,
helem = 4,
context = ClimaComms.context(device),
HorizontalLayout = DataLayouts.IJHF,
)
X = Fields.FieldVector(
x1 = rand_field(FT, space),
x2 = rand_field(FT, space),
x3 = rand_field(FT, space),
)
Y = Fields.FieldVector(
y1 = rand_field(FT, space),
y2 = rand_field(FT, space),
y3 = rand_field(FT, space),
)
test_kernel!(; fused!, unfused!, X, Y)
test_kernel!(; fused! = fused_bycolumn!, unfused! = unfused_bycolumn!, X, Y)

nothing
end

@testset "FusedMultiBroadcast VIFH" begin
FT = Float64
device = ClimaComms.device()
Expand Down
5 changes: 5 additions & 0 deletions test/Operators/finitedifference/benchmark_stencils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,29 @@ include("benchmark_stencils_utils.jl")
# column_benchmark_arrays(device, z_elems = 63, bm.float_type)
# sphere_benchmark_arrays(device, z_elems = 63, helem = 30, Nq = 4, bm.float_type)

@info "Column"
bm = Benchmark(;float_type = Float64, device_name)
# benchmark_operators_column(bm; z_elems = 63, helem = 30, Nq = 4, compile = true)
(;t_min) = benchmark_operators_column(bm; z_elems = 63, helem = 30, Nq = 4)
test_results_column(t_min)

@info "sphere, IJFH, Float64"
bm = Benchmark(;float_type = Float64, device_name)
# benchmark_operators_sphere(bm; z_elems = 63, helem = 30, Nq = 4, compile = true)
(;t_min) = benchmark_operators_sphere(bm; z_elems = 63, helem = 30, Nq = 4, HorizontalLayout = DataLayouts.IJFH)
test_results_sphere(t_min)

@info "sphere, IJHF, Float64"
bm = Benchmark(;float_type = Float64, device_name)
(;t_min) = benchmark_operators_sphere(bm; z_elems = 63, helem = 30, Nq = 4, HorizontalLayout = DataLayouts.IJHF)
test_results_sphere(t_min)

@info "sphere, IJFH, Float32"
bm = Benchmark(;float_type = Float32, device_name)
# benchmark_operators_sphere(bm; z_elems = 63, helem = 30, Nq = 4, compile = true)
(;t_min) = benchmark_operators_sphere(bm; z_elems = 63, helem = 30, Nq = 4, HorizontalLayout = DataLayouts.IJFH)

@info "sphere, IJHF, Float32"
bm = Benchmark(;float_type = Float32, device_name)
# benchmark_operators_sphere(bm; z_elems = 63, helem = 30, Nq = 4, compile = true)
(;t_min) = benchmark_operators_sphere(bm; z_elems = 63, helem = 30, Nq = 4, HorizontalLayout = DataLayouts.IJHF)
Expand Down

0 comments on commit a822d7e

Please sign in to comment.