From 711758d725da0a6783bf42d7908534c0f807cb7e Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 18 Oct 2024 08:08:13 +0200 Subject: [PATCH] Adapt to GPUArrays.jl transition to KernelAbstractions.jl. (#461) Co-authored-by: James Schloss --- Project.toml | 2 +- src/gpuarrays.jl | 54 ------------------------------------------------ test/random.jl | 3 +-- 3 files changed, 2 insertions(+), 57 deletions(-) diff --git a/Project.toml b/Project.toml index b32b0d84a..4e7f1bd88 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ BFloat16s = "0.5" CEnum = "0.4, 0.5" CodecBzip2 = "0.8" ExprTools = "0.1" -GPUArrays = "10.1" +GPUArrays = "11" GPUCompiler = "0.26, 0.27, 1" KernelAbstractions = "0.9.1" LLVM = "7.2, 8, 9" diff --git a/src/gpuarrays.jl b/src/gpuarrays.jl index d8aaae548..bdffc0873 100644 --- a/src/gpuarrays.jl +++ b/src/gpuarrays.jl @@ -1,59 +1,5 @@ ## GPUArrays interfaces -## execution - -struct mtlArrayBackend <: AbstractGPUBackend end - -struct mtlKernelContext <: AbstractKernelContext end - -@inline function GPUArrays.launch_heuristic(::mtlArrayBackend, f::F, args::Vararg{Any,N}; - elements::Int, elements_per_thread::Int) where {F,N} - kernel = @metal launch=false f(mtlKernelContext(), args...) - - # The pipeline state automatically computes occupancy stats - threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup) - blocks = cld(elements, threads) - - return (; threads=Int(threads), blocks=Int(blocks)) -end - -function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, groups::Int; - name::Union{String,Nothing}) - @metal threads groups name f(mtlKernelContext(), args...) -end - - -## on-device - -# indexing -GPUArrays.blockidx(ctx::mtlKernelContext) = threadgroup_position_in_grid_1d() -GPUArrays.blockdim(ctx::mtlKernelContext) = threads_per_threadgroup_1d() -GPUArrays.threadidx(ctx::mtlKernelContext) = thread_position_in_threadgroup_1d() -GPUArrays.griddim(ctx::mtlKernelContext) = threadgroups_per_grid_1d() -GPUArrays.global_index(ctx::mtlKernelContext) = thread_position_in_grid_1d() -GPUArrays.global_size(ctx::mtlKernelContext) = threads_per_grid_1d() - -# memory - -@inline function GPUArrays.LocalMemory(::mtlKernelContext, ::Type{T}, ::Val{dims}, ::Val{id} - ) where {T, dims, id} - ptr = emit_threadgroup_memory(T, Val(prod(dims))) - MtlDeviceArray(dims, ptr) -end - -# synchronization - -@inline GPUArrays.synchronize_threads(::mtlKernelContext) = - threadgroup_barrier(MemoryFlagThreadGroup) - - - -# -# Host abstractions -# - -GPUArrays.backend(::Type{<:MtlArray}) = mtlArrayBackend() - const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}() function GPUArrays.default_rng(::Type{<:MtlArray}) dev = device() diff --git a/test/random.jl b/test/random.jl index 608f03b08..3066acb50 100644 --- a/test/random.jl +++ b/test/random.jl @@ -246,8 +246,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES]; a = f(T, d) Metal.seed!(1) b = f(T, d) - # TODO: Remove broken parameter once https://github.com/JuliaGPU/GPUArrays.jl/issues/530 is fixed - @test Array(a) == Array(b) broken = (T == Float16 && d == (1000,1000)) + @test Array(a) == Array(b) end end end # testset