From 552bc6c21f8d93e7220c030bc0164f0a362e57c6 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 14 Jan 2026 01:34:03 +0100 Subject: [PATCH] Allow redefinition of kernel methods --- ext/CUDAExt.jl | 7 +++++-- test/execution.jl | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index b0cb87f..0012c3d 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -9,7 +9,7 @@ using CUDA_Compiler_jll public launch # Compilation cache - stores CuFunction directly to avoid re-loading CuModule -const _compilation_cache = Dict{Any, Any}() # (f, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction +const _compilation_cache = Dict{Any, Any}() # (method, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction """ launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing) @@ -65,8 +65,11 @@ function cuTile.launch(@nospecialize(f), grid, args...; # Determine kernel name kernel_name = name !== nothing ? name : string(nameof(f)) + # Use method instance in case of a redefinition + method = which(f, argtypes) + # Check compilation cache - returns CuFunction directly - cache_key = (f, argtypes, sm_arch, opt_level, num_ctas, occupancy) + cache_key = (method, argtypes, sm_arch, opt_level, num_ctas, occupancy) cufunc = get(_compilation_cache, cache_key, nothing) if cufunc === nothing || cuTile.compile_hook[] !== nothing cubin = compile(f, argtypes; name, sm_arch, opt_level, num_ctas, occupancy) diff --git a/test/execution.jl b/test/execution.jl index 44e5114..2996d63 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -1590,6 +1590,41 @@ end end +@testset "redefine kernel method" begin + mod = @eval module $(gensym()) + import cuTile as ct + function vadd_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1}) + pid = ct.bid(1) + ta = ct.load(a, (pid,), (16,)) + tb = ct.load(b, (pid,), (16,)) + ct.store(c, (pid,), ta + tb) + return + end + end + + a = CUDA.ones(Float32, 1024) + b = CUDA.ones(Float32, 1024) + c = CUDA.zeros(Float32, 1024) + + ct.launch(mod.vadd_kernel, 64, a, b, c) + + @test Array(c) ≈ Array(a) + Array(b) + + @eval mod begin + function vadd_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1}) + pid = ct.bid(1) + ta = ct.load(a, (pid,), (16,)) + tb = ct.load(b, (pid,), (16,)) + ct.store(c, (pid,), ta + tb * 2) + return + end + end + + ct.launch(mod.vadd_kernel, 64, a, b, c) + + @test Array(c) ≈ Array(a) + Array(b) * 2 +end + @testset "Entry Hints Integration" begin @testset "launch with num_ctas" begin