Skip to content

Commit

Permalink
Fix adapt call
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 23, 2024
1 parent 72cc110 commit 71f1a38
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ include("data_layouts_fused_copyto.jl")
include("data_layouts_mapreduce.jl")
include("data_layouts_threadblock.jl")

adapt_f(to, f::F) where {F} = Adapt.adapt(to, f)
adapt_f(to, ::Type{F}) where {F} = (x...) -> F(x...)

function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
bc::DataLayouts.NonExtrudedBroadcasted{Style},
) where {Style}
DataLayouts.NonExtrudedBroadcasted{Style}(
Adapt.adapt_f(to, bc.f),
adapt_f(to, bc.f),
Adapt.adapt(to, bc.args),
Adapt.adapt(to, bc.axes),
)
end

adapt_f(to, f::F) where {F} = Adapt.adapt(to, f)
adapt_f(to, ::Type{F}) where {F} = (x...) -> F(x...)

function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
fmbc::FusedMultiBroadcast,
Expand Down

0 comments on commit 71f1a38

Please sign in to comment.