Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ function encode_LoadViewTkoOp!(cb::CodeBuilder,
token::Union{Value, Nothing}=nothing,
memory_ordering::MemoryOrderingSemantics=MemoryWeak,
memory_scope::Union{MemoryScope, Nothing}=nothing,
optimization_hints::Union{Vector{UInt8}, Nothing}=nothing)
optimization_hints::Union{OptimizationHints, Nothing}=nothing)
encode_varint!(cb.buf, Opcode.LoadViewTkoOp)
# Variadic result types
encode_typeid_seq!(cb.buf, [tile_type, token_type])
Expand All @@ -447,7 +447,7 @@ function encode_LoadViewTkoOp!(cb::CodeBuilder,
encode_enum!(cb.buf, memory_scope)
end
if optimization_hints !== nothing
append!(cb.buf, optimization_hints)
encode_opattr_optimization_hints!(cb, optimization_hints)
end

# Operands
Expand All @@ -472,7 +472,7 @@ function encode_StoreViewTkoOp!(cb::CodeBuilder,
token::Union{Value, Nothing}=nothing,
memory_ordering::MemoryOrderingSemantics=MemoryWeak,
memory_scope::Union{MemoryScope, Nothing}=nothing,
optimization_hints::Union{Vector{UInt8}, Nothing}=nothing)
optimization_hints::Union{OptimizationHints, Nothing}=nothing)
encode_varint!(cb.buf, Opcode.StoreViewTkoOp)
# Variadic result types (just token)
encode_typeid_seq!(cb.buf, [token_type])
Expand All @@ -496,7 +496,7 @@ function encode_StoreViewTkoOp!(cb::CodeBuilder,
encode_enum!(cb.buf, memory_scope)
end
if optimization_hints !== nothing
append!(cb.buf, optimization_hints)
encode_opattr_optimization_hints!(cb, optimization_hints)
end

# Operands
Expand Down Expand Up @@ -541,7 +541,7 @@ function encode_LoadPtrTkoOp!(cb::CodeBuilder,
token::Union{Value, Nothing}=nothing,
memory_ordering::MemoryOrderingSemantics=MemoryWeak,
memory_scope::Union{MemoryScope, Nothing}=nothing,
optimization_hints::Union{Vector{UInt8}, Nothing}=nothing)
optimization_hints::Union{OptimizationHints, Nothing}=nothing)
encode_varint!(cb.buf, Opcode.LoadPtrTkoOp)
# Result types
encode_typeid!(cb.buf, result_type)
Expand Down Expand Up @@ -572,7 +572,7 @@ function encode_LoadPtrTkoOp!(cb::CodeBuilder,
encode_enum!(cb.buf, memory_scope)
end
if optimization_hints !== nothing
append!(cb.buf, optimization_hints)
encode_opattr_optimization_hints!(cb, optimization_hints)
end

# Operands
Expand Down Expand Up @@ -600,7 +600,7 @@ function encode_StorePtrTkoOp!(cb::CodeBuilder,
token::Union{Value, Nothing}=nothing,
memory_ordering::MemoryOrderingSemantics=MemoryWeak,
memory_scope::Union{MemoryScope, Nothing}=nothing,
optimization_hints::Union{Vector{UInt8}, Nothing}=nothing)
optimization_hints::Union{OptimizationHints, Nothing}=nothing)
encode_varint!(cb.buf, Opcode.StorePtrTkoOp)
# Result type (token)
encode_typeid!(cb.buf, token_type)
Expand All @@ -627,7 +627,7 @@ function encode_StorePtrTkoOp!(cb::CodeBuilder,
encode_enum!(cb.buf, memory_scope)
end
if optimization_hints !== nothing
append!(cb.buf, optimization_hints)
encode_opattr_optimization_hints!(cb, optimization_hints)
end

# Operands
Expand Down
76 changes: 68 additions & 8 deletions src/bytecode/writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,75 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder,
end

#=============================================================================
EntryHints: Kernel-level compilation hints
Optimization Hints
=============================================================================#

"""
encode_tagged_value!(cb, value)

Encode a value with its type tag.
"""
function encode_tagged_value!(buf::Vector{UInt8}, type_table::TypeTable, value::Bool)
push!(buf, AttributeTag.Bool)
push!(buf, value)
end

function encode_tagged_value!(buf::Vector{UInt8}, type_table::TypeTable, value::Integer)
push!(buf, AttributeTag.Integer)
encode_typeid!(buf, I32(type_table))
encode_varint!(buf, UInt32(value))
end

"""
Optimization hints for load/store operations.
- `latency`: Optional latency hint (1-10), or nothing for default
- `allow_tma`: Whether TMA (Tensor Memory Accelerator) is allowed (default: true)
"""
@kwdef struct LoadStoreHints
latency::Union{Int, Nothing} = nothing
allow_tma::Bool = true
end

"""
Optimization hints for load/store operations.
- `hints_by_arch`: List of (SM architecture, load/store hints) pairs
"""
struct OptimizationHints
hints_by_arch::Vector{Tuple{String, LoadStoreHints}}
end

function make_load_store_hints(sm_arch::Union{String, Nothing}, hints::LoadStoreHints)
isnothing(sm_arch) && throw(ArgumentError("sm_arch must be explicitly passed when load/store hints are present"))
OptimizationHints([(sm_arch, hints)])
end

function encode_opattr_optimization_hints!(cb::CodeBuilder, hints::OptimizationHints)
# Outer dictionary: arch -> hints_dict
encode_varint!(cb.buf, length(hints.hints_by_arch))
for (arch, load_store_hints) in hints.hints_by_arch
arch_id = cb.string_table[arch]
encode_varint!(cb.buf, arch_id.id)
# Encode hints as inner dictionary (tagged)
encode_load_store_hints_dict!(cb, load_store_hints)
end
end

function encode_load_store_hints_dict!(cb::CodeBuilder, hints::LoadStoreHints)
# Build list of (key, value) pairs for non-default hints
items = Tuple{String, Any}[]
hints.allow_tma || push!(items, ("allow_tma", false))
isnothing(hints.latency) || push!(items, ("latency", hints.latency))

# Encode dictionary
push!(cb.buf, AttributeTag.Dictionary)
encode_varint!(cb.buf, length(items))
for (key, value) in items
key_id = cb.string_table[key]
encode_varint!(cb.buf, key_id.id)
encode_tagged_value!(cb.buf, cb.type_table, value)
end
end

"""
Kernel-level compilation hints (num_ctas, occupancy).
Encoded as a dictionary attribute in bytecode.
Expand All @@ -567,10 +633,6 @@ function validate_occupancy(occupancy::Union{Int, Nothing})
1 <= occupancy <= 32 || throw(ArgumentError("occupancy must be between 1 and 32, got $occupancy"))
end

"""
Encode EntryHints as OptimizationHints format.
Returns raw bytes for entry_hints parameter or nothing.
"""
function encode_entry_hints(writer::BytecodeWriter, sm_arch::Union{String, Nothing}, hints::EntryHints)
validate_num_ctas(hints.num_ctas)
validate_occupancy(hints.occupancy)
Expand Down Expand Up @@ -603,9 +665,7 @@ function encode_entry_hints(writer::BytecodeWriter, sm_arch::Union{String, Nothi
for (key, value) in items
key_id = writer.string_table[key]
encode_varint!(buf, key_id.id)
push!(buf, AttributeTag.Integer)
encode_typeid!(buf, I32(writer.type_table))
encode_varint!(buf, UInt32(value))
encode_tagged_value!(buf, writer.type_table, value)
end

return buf
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/codegen/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
is_entry::Bool = true,
num_ctas::Union{Int, Nothing} = nothing,
occupancy::Union{Int, Nothing} = nothing)
ctx = CGCtx(writer, target)
ctx = CGCtx(writer, target, sm_arch)
tt = ctx.tt

# Validate non-ghost argument types are concrete
Expand Down
8 changes: 8 additions & 0 deletions src/compiler/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ end

emit_intrinsic!(ctx::CGCtx, @nospecialize(func), args) = missing

# Shared helper for creating load/store optimization hints
function create_optimization_hints(ctx::CGCtx, latency::Union{Int, Nothing}, allow_tma::Bool=true)
isnothing(latency) && allow_tma && return nothing
isnothing(latency) || 1 <= latency <= 10 || error("latency must be between 1 and 10, got $latency")
hints = LoadStoreHints(; latency, allow_tma)
return make_load_store_hints(ctx.sm_arch, hints)
end

include("intrinsics/core.jl")
include("intrinsics/conversions.jl")
include("intrinsics/arithmetic.jl")
Expand Down
62 changes: 44 additions & 18 deletions src/compiler/intrinsics/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,28 @@
# cuda_tile.load_ptr_tko
@eval Intrinsics begin
"""
load_ptr_tko(ptrs, mask=nothing, padding=nothing)
load_ptr_tko(ptrs, latency, mask=nothing, padding=nothing)

Load values from a tile of pointers.
If mask is provided, masked-out positions return the padding value.
Compiled to cuda_tile.load_ptr_tko.

Note: TMA (allow_tma) is not applicable for pointer-based loads as they
support irregular access patterns incompatible with TMA requirements.
"""
@noinline function load_ptr_tko(ptrs::Tile{Ptr{T}, S},
latency::Union{Int, Nothing}=nothing,
mask::Union{Tile{Bool, S}, Nothing}=nothing,
padding::Union{Tile{T, S}, Nothing}=nothing) where {T, S}
donotdelete(ptrs, mask, padding)
donotdelete(ptrs, latency, mask, padding)
Tile{T, S}()
end
end
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args)
cb = ctx.cb
tt = ctx.tt

# args: (ptrs, latency, mask?, padding?)
# Get pointer tile (arg 1)
ptrs_tv = emit_value!(ctx, args[1])
ptrs_tv === nothing && error("load_ptr_tko: cannot resolve pointer tile")
Expand All @@ -36,29 +41,37 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args)
result_tile_type = tile_type!(tt, dtype, tile_shape)
token_type = Token(tt)

# Check if mask is provided (arg 2 is not nothing)
has_mask = length(args) >= 2 && get_constant(ctx, args[2]) !== nothing
# Extract latency hint (args[2])
latency = get_constant(ctx, args[2])

# Create optimization hints if provided
optimization_hints = create_optimization_hints(ctx, latency)

# Check if mask is provided (arg 3 is not nothing)
has_mask = length(args) >= 3 && get_constant(ctx, args[3]) !== nothing

if has_mask
# Get mask tile (arg 2)
mask_tv = emit_value!(ctx, args[2])
# Get mask tile (arg 3)
mask_tv = emit_value!(ctx, args[3])
mask_tv === nothing && error("load_ptr_tko: cannot resolve mask tile")
mask = mask_tv.v

# Get padding tile (arg 3)
padding_tv = emit_value!(ctx, args[3])
# Get padding tile (arg 4)
padding_tv = emit_value!(ctx, args[4])
padding_tv === nothing && error("load_ptr_tko: cannot resolve padding tile")
padding = padding_tv.v

# Load with mask and padding
tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers;
mask=mask,
padding_value=padding,
token=ctx.token)
token=ctx.token,
optimization_hints)
else
# Load without mask
tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers;
token=ctx.token)
token=ctx.token,
optimization_hints)
end
ctx.token = new_token

Expand All @@ -71,22 +84,27 @@ end
# cuda_tile.store_ptr_tko
@eval Intrinsics begin
"""
store_ptr_tko(ptrs, values, mask=nothing)
store_ptr_tko(ptrs, values, latency, mask=nothing)

Store values to a tile of pointers.
If mask is provided, masked-out positions are not written.
Compiled to cuda_tile.store_ptr_tko.

Note: TMA (allow_tma) is not applicable for pointer-based stores as they
support irregular access patterns incompatible with TMA requirements.
"""
@noinline function store_ptr_tko(ptrs::Tile{Ptr{T}, S}, values::Tile{T, S},
latency::Union{Int, Nothing},
mask::Union{Tile{Bool, S}, Nothing}=nothing) where {T, S}
donotdelete(ptrs, values, mask)
donotdelete(ptrs, values, latency, mask)
nothing
end
end
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args)
cb = ctx.cb
tt = ctx.tt

# args: (ptrs, values, latency, mask?)
# Get pointer tile (arg 1)
ptrs_tv = emit_value!(ctx, args[1])
ptrs_tv === nothing && error("store_ptr_tko: cannot resolve pointer tile")
Expand All @@ -99,23 +117,31 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args)

token_type = Token(tt)

# Check if mask is provided (arg 3 is not nothing)
has_mask = length(args) >= 3 && get_constant(ctx, args[3]) !== nothing
# Extract latency hint (args[3])
latency = get_constant(ctx, args[3])

# Create optimization hints if provided
optimization_hints = create_optimization_hints(ctx, latency)

# Check if mask is provided (arg 4 is not nothing)
has_mask = length(args) >= 4 && get_constant(ctx, args[4]) !== nothing

if has_mask
# Get mask tile (arg 3)
mask_tv = emit_value!(ctx, args[3])
# Get mask tile (arg 4)
mask_tv = emit_value!(ctx, args[4])
mask_tv === nothing && error("store_ptr_tko: cannot resolve mask tile")
mask = mask_tv.v

# Store with mask
new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values;
mask=mask,
token=ctx.token)
token=ctx.token,
optimization_hints)
else
# Store without mask
new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values;
token=ctx.token)
token=ctx.token,
optimization_hints)
end
ctx.token = new_token

Expand Down
Loading