Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ using CUDA_Compiler_jll
public launch

# Compilation cache - stores CuFunction directly to avoid re-loading CuModule
const _compilation_cache = Dict{Any, Any}() # (f, argtypes, sm_arch, opt_level) => CuFunction
const _compilation_cache = Dict{Any, Any}() # (f, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction

"""
launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3)
launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing)

Compile and launch a kernel function with the given grid size and arguments.

Expand All @@ -26,6 +26,8 @@ are expanded to their constituent ptr, sizes, and strides parameters.
- `name`: Optional kernel name for debugging
- `sm_arch`: Target GPU architecture (default: current device's capability)
- `opt_level`: Optimization level 0-3 (default: 3)
- `num_ctas`: Number of CTAs in a CGA, 1-16, must be power of 2 (default: nothing)
- `occupancy`: Expected active CTAs per SM, 1-32 (default: nothing)

# Example
```julia
Expand All @@ -51,7 +53,9 @@ cuTile.launch(vadd_kernel, 64, a, b, c)
function cuTile.launch(@nospecialize(f), grid, args...;
name::Union{String, Nothing}=nothing,
sm_arch::String=default_sm_arch(),
opt_level::Int=3)
opt_level::Int=3,
num_ctas::Union{Int, Nothing}=nothing,
occupancy::Union{Int, Nothing}=nothing)
# Convert CuArray -> TileArray (and other conversions)
tile_args = map(to_tile_arg, args)

Expand All @@ -62,10 +66,10 @@ function cuTile.launch(@nospecialize(f), grid, args...;
kernel_name = name !== nothing ? name : string(nameof(f))

# Check compilation cache - returns CuFunction directly
cache_key = (f, argtypes, sm_arch, opt_level)
cache_key = (f, argtypes, sm_arch, opt_level, num_ctas, occupancy)
cufunc = get(_compilation_cache, cache_key, nothing)
if cufunc === nothing || cuTile.compile_hook[] !== nothing
cubin = compile(f, argtypes; name, sm_arch, opt_level)
cubin = compile(f, argtypes; name, sm_arch, opt_level, num_ctas, occupancy)
if cufunc === nothing
cumod = CuModule(cubin)
cufunc = CuFunction(cumod, kernel_name)
Expand Down Expand Up @@ -98,15 +102,18 @@ function cuTile.launch(@nospecialize(f), grid, args...;
end

"""
compile(f, argtypes; name=nothing, sm_arch=default_sm_arch(), opt_level=3) -> Vector{UInt8}
compile(f, argtypes; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing) -> Vector{UInt8}

Compile a Julia kernel function to a CUDA binary.
"""
function compile(@nospecialize(f), @nospecialize(argtypes);
name::Union{String, Nothing}=nothing,
sm_arch::String=default_sm_arch(),
opt_level::Int=3)
tile_bytecode = emit_tileir(f, argtypes; name)
opt_level::Int=3,
num_ctas::Union{Int, Nothing}=nothing,
occupancy::Union{Int, Nothing}=nothing)
tile_bytecode = emit_tileir(f, argtypes; name, sm_arch,
num_ctas, occupancy)

# Dump bytecode if JULIA_CUTILE_DUMP_BYTECODE is set
dump_dir = get(ENV, "JULIA_CUTILE_DUMP_BYTECODE", nothing)
Expand Down
68 changes: 68 additions & 0 deletions src/bytecode/writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,71 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder,
encode_varint!(func_buf, length(cb.buf))
append!(func_buf, cb.buf)
end

#=============================================================================
EntryHints: Kernel-level compilation hints
=============================================================================#

"""
Kernel-level compilation hints (num_ctas, occupancy).
Encoded as a dictionary attribute in bytecode.
"""
@kwdef struct EntryHints
num_ctas::Union{Int, Nothing} = nothing # 1, 2, 4, 8, 16
occupancy::Union{Int, Nothing} = nothing # 1-32
end

function validate_num_ctas(num_ctas::Union{Int, Nothing})
isnothing(num_ctas) && return
1 <= num_ctas <= 16 || throw(ArgumentError("num_ctas must be between 1 and 16, got $num_ctas"))
ispow2(num_ctas) || throw(ArgumentError("num_ctas must be a power of 2, got $num_ctas"))
end

function validate_occupancy(occupancy::Union{Int, Nothing})
isnothing(occupancy) && return
1 <= occupancy <= 32 || throw(ArgumentError("occupancy must be between 1 and 32, got $occupancy"))
end

"""
Encode EntryHints as OptimizationHints format.
Returns raw bytes for entry_hints parameter or nothing.
"""
function encode_entry_hints(writer::BytecodeWriter, sm_arch::Union{String, Nothing}, hints::EntryHints)
validate_num_ctas(hints.num_ctas)
validate_occupancy(hints.occupancy)

# Build items list (only non-nothing values)
items = Tuple{String, Int}[]
isnothing(hints.num_ctas) || push!(items, ("num_cta_in_cga", hints.num_ctas))
isnothing(hints.occupancy) || push!(items, ("occupancy", hints.occupancy))
isempty(items) && return nothing

# Use default architecture if not specified and hints are present
arch = @something sm_arch throw(ArgumentError("sm_arch must be specified when entry hints are present"))

buf = UInt8[]

# Start with OptimizationHints tag
push!(buf, AttributeTag.OptimizationHints)

# Encode as architecture-specific dictionary
# Format: num_archs, then for each arch: arch_id, dictionary
encode_varint!(buf, 1) # 1 architecture

# Architecture string ID
arch_id = writer.string_table[arch]
encode_varint!(buf, arch_id.id)

# Encode dictionary
push!(buf, AttributeTag.Dictionary)
encode_varint!(buf, length(items))
for (key, value) in items
key_id = writer.string_table[key]
encode_varint!(buf, key_id.id)
push!(buf, AttributeTag.Integer)
encode_typeid!(buf, I32(writer.type_table))
encode_varint!(buf, UInt32(value))
end

return buf
end
13 changes: 10 additions & 3 deletions src/compiler/codegen/kernel.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# kernel and argument handling

"""
emit_kernel!(writer, func_buf, target; name, is_entry=true)
emit_kernel!(writer, func_buf, target; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing)

Compile a TileTarget to Tile IR bytecode.
"""
function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
target::TileTarget;
name::String = string(target.mi.def.name),
is_entry::Bool = true)
sm_arch::Union{String, Nothing} = nothing,
is_entry::Bool = true,
num_ctas::Union{Int, Nothing} = nothing,
occupancy::Union{Int, Nothing} = nothing)
ctx = CGCtx(writer, target)
tt = ctx.tt

Expand Down Expand Up @@ -58,8 +61,12 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
push!(result_types, tile_type_for_julia!(ctx, target.rettype))
end

# Create entry hints if provided
entry_hints = encode_entry_hints(writer, sm_arch, EntryHints(; num_ctas, occupancy))

# Create function
cb = add_function!(writer, func_buf, name, param_types, result_types; is_entry)
cb = add_function!(writer, func_buf, name, param_types, result_types;
is_entry, entry_hints)
ctx.cb = cb

# Set up argument values
Expand Down
16 changes: 10 additions & 6 deletions src/compiler/reflection.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
export code_tiled, @code_tiled

"""
emit_tileir(f, argtypes; name=nothing) -> Vector{UInt8}
emit_tileir(f, argtypes; name, sm_arch, num_ctas, occupancy) -> Vector{UInt8}

Compile a Julia function to Tile IR bytecode.
"""
function emit_tileir(@nospecialize(f), @nospecialize(argtypes);
name::Union{String, Nothing} = nothing)
name::Union{String, Nothing} = nothing,
sm_arch::Union{String, Nothing} = nothing,
num_ctas::Union{Int, Nothing} = nothing,
occupancy::Union{Int, Nothing} = nothing)
target = TileTarget(f, argtypes)
kernel_name = name === nothing ? string(target.mi.def.name) : name

Expand All @@ -15,7 +18,8 @@ function emit_tileir(@nospecialize(f), @nospecialize(argtypes);
end

buf = write_bytecode!(1) do writer, func_buf
emit_kernel!(writer, func_buf, target; name=kernel_name)
emit_kernel!(writer, func_buf, target; name=kernel_name, sm_arch,
num_ctas, occupancy)
end

return buf
Expand All @@ -31,14 +35,14 @@ function disassemble_tileir(bytecode::Vector{UInt8})::String
end

"""
code_tiled(f, argtypes; name=nothing) -> String
code_tiled(f, argtypes; name, sm_arch, num_ctas, occupancy) -> String

Return the CUDA Tile IR for a Julia function as a textual MLIR representation.
Analogous to `code_typed` or `code_structured`.
"""
function code_tiled(@nospecialize(f), @nospecialize(argtypes);
name::Union{String, Nothing} = nothing)
bytecode = emit_tileir(f, argtypes; name)
kwargs...)
bytecode = emit_tileir(f, argtypes; kwargs...)
disassemble_tileir(bytecode)
end

Expand Down
115 changes: 115 additions & 0 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1846,3 +1846,118 @@ end
end
end
end

#=============================================================================
Entry Hints (optimization_hints attribute)
=============================================================================#

@testset "Entry Hints" begin
# Common ArraySpecs for tests
spec1d = ct.ArraySpec{1}(16, true)

@testset "num_ctas only" begin
@test @filecheck begin
@check "optimization_hints=<sm_100 = {num_cta_in_cga = 4}>"
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=4) do a
pid = ct.bid(1)
t = ct.load(a, pid, (16,))
ct.store(a, pid, t)
return nothing
end
end
end

@testset "occupancy only" begin
@test @filecheck begin
@check "optimization_hints=<sm_100 = {occupancy = 8}>"
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=8) do a
pid = ct.bid(1)
t = ct.load(a, pid, (16,))
ct.store(a, pid, t)
return nothing
end
end
end

@testset "both hints" begin
@test @filecheck begin
@check "optimization_hints=<sm_120 = {num_cta_in_cga = 2, occupancy = 4}"
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_120", num_ctas=2, occupancy=4) do a
pid = ct.bid(1)
t = ct.load(a, pid, (16,))
ct.store(a, pid, t)
return nothing
end
end
end

@testset "no hints" begin
@test @filecheck begin
@check_not "optimization_hints"
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100") do a
pid = ct.bid(1)
t = ct.load(a, pid, (16,))
ct.store(a, pid, t)
return nothing
end
end
end

@testset "architecture parameter" begin
@test @filecheck begin
@check "optimization_hints=<sm_120 = {num_cta_in_cga = 4}>"
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_120", num_ctas=4) do a
pid = ct.bid(1)
t = ct.load(a, pid, (16,))
ct.store(a, pid, t)
return nothing
end
end
end

@testset "num_ctas validation" begin
# Too small
@test_throws "num_ctas must be between 1 and 16" begin
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=0)
end

# Too large
@test_throws "num_ctas must be between 1 and 16" begin
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=17)
end

# Not power of 2
@test_throws "num_ctas must be a power of 2" begin
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=3)
end

@test_throws "num_ctas must be a power of 2" begin
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=5)
end

# Valid values should succeed
for num_ctas in [1, 2, 4, 8, 16]
bytecode = code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas)
@test !isempty(bytecode)
end
end

@testset "occupancy validation" begin
# Too small
@test_throws "occupancy must be between 1 and 32" begin
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=0)
end

# Too large
@test_throws "occupancy must be between 1 and 32" begin
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=33)
end

# Valid boundaries
bytecode1 = code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=1)
@test !isempty(bytecode1)

bytecode32 = code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=32)
@test !isempty(bytecode32)
end
end
Loading