Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/scankernel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Test
using CUDA
using cuTile
import cuTile as ct

function cumsum_1d_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
tile_size::ct.Constant{Int})
bid = ct.bid(1)
tile = ct.load(a, bid, (tile_size[],))
result = ct.cumsum(tile, Val(1)) # Val(1) means 1st (0th) dimension for 1D tile
ct.store(b, bid, result)
return nothing
end

sz = 32
N = 2^15
a = CUDA.rand(Float32, N)
b = CUDA.zeros(Float32, N)
CUDA.@sync ct.launch(cumsum_1d_kernel, cld(length(a), sz), a, b, ct.Constant(sz))

# This is supposed to be a single pass kernel but its simpler version than memory ordering version.
# The idea is to show how device scan operation can be done.

# CSDL phase 1: Intra-tile scan + store tile sums
function cumsum_csdl_phase1(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
tile_sums::ct.TileArray{Float32,1},
tile_size::ct.Constant{Int})
bid = ct.bid(1)
tile = ct.load(a, bid, (tile_size[],))
result = ct.cumsum(tile, Val(1))
ct.store(b, bid, result)
tile_sum = ct.extract(result, (tile_size[],), (1,)) # Extract last element (1 element shape)
ct.store(tile_sums, bid, tile_sum)
return
end

# CSDL phase 2: Decoupled lookback to accumulate previous tile sums
function cumsum_csdl_phase2(b::ct.TileArray{Float32,1},
tile_sums::ct.TileArray{Float32,1},
tile_size::ct.Constant{Int})
bid = ct.bid(1)
prev_sum = ct.zeros((tile_size[],), Float32)
k = Int32(bid)
while k > 1
tile_sum_k = ct.load(tile_sums, (k,), (1,))
prev_sum = prev_sum .+ tile_sum_k
k -= Int32(1)
end
tile = ct.load(b, bid, (tile_size[],))
result = tile .+ prev_sum
ct.store(b, bid, result)
return nothing
end

n = length(a)
num_tiles = cld(n, sz)
tile_sums = CUDA.zeros(Float32, num_tiles)
CUDA.@sync ct.launch(cumsum_csdl_phase1, num_tiles, a, b, tile_sums, ct.Constant(sz))
CUDA.@sync ct.launch(cumsum_csdl_phase2, num_tiles, b, tile_sums, ct.Constant(sz))

b_cpu = cumsum(a |> collect, dims=1)
@test isapprox(b |> collect, b_cpu) # This might fail occasionally
74 changes: 73 additions & 1 deletion src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
identities::Vector{<:ReduceIdentity},
identities::Vector{<:IdentityOp},
body_scalar_types::Vector{TypeId})
encode_varint!(cb.buf, Opcode.ReduceOp)

Expand Down Expand Up @@ -1331,6 +1331,78 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder,
end
end


#=============================================================================
Scan operations
=============================================================================#

"""
encode_ScanOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
reverse::Bool,
identities::Vector{<:IdentityOp},
body_scalar_types::Vector{TypeId})

Encode a ScanOp (parallel prefix sum) operation.

# Arguments
- body: Function that takes block args and yields result(s)
- cb: CodeBuilder for the bytecode
- result_types: Output tile types
- operands: Input tiles to scan
- dim: Dimension to scan along (0-indexed)
- reverse: Whether to scan in reverse order
- identities: Identity values for each operand (reuses IdentityOp from IntegerReduce)
- body_scalar_types: 0D tile types for body arguments
"""
function encode_ScanOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
reverse::Bool,
identities::Vector{<:IdentityOp},
body_scalar_types::Vector{TypeId})
encode_varint!(cb.buf, Opcode.ScanOp)

# Variadic result types
encode_typeid_seq!(cb.buf, result_types)

# Attributes: dim (int), reverse (bool), identities (array)
encode_opattr_int!(cb, dim)
encode_opattr_bool!(cb, reverse)
encode_identity_array!(cb, identities)

# Variadic operands
encode_varint!(cb.buf, length(operands))
encode_operands!(cb.buf, operands)

# Number of regions
push!(cb.debug_attrs, cb.cur_debug_attr)
cb.num_ops += 1
encode_varint!(cb.buf, 1) # 1 region: body

# Body region - block args are pairs of (acc, elem) for each operand
# The body operates on 0D tiles (scalars)
body_arg_types = TypeId[]
for scalar_type in body_scalar_types
push!(body_arg_types, scalar_type) # accumulator
push!(body_arg_types, scalar_type) # element
end
with_region(body, cb, body_arg_types)

# Create result values
num_results = length(result_types)
if num_results == 0
return Value[]
else
vals = [Value(cb.next_value_id + i) for i in 0:num_results-1]
cb.next_value_id += num_results
return vals
end
end

#=============================================================================
Comparison and selection operations
=============================================================================#
Expand Down
92 changes: 80 additions & 12 deletions src/bytecode/writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,30 +234,42 @@ end
=============================================================================#

"""
ReduceIdentity
IdentityOp

Abstract type for reduce identity attributes.
Abstract type for binary operation identity attributes (reduce, scan, etc.).
"""
abstract type ReduceIdentity end
abstract type IdentityOp end

"""
FloatIdentity(value, type_id, dtype)
FloatIdentityOp(value, type_id, dtype)

Float identity value for reduce operations.
Float identity value for binary operations.
"""
struct FloatIdentity <: ReduceIdentity
struct FloatIdentityOp <: IdentityOp
value::Float64
type_id::TypeId
dtype::Type # Float16, Float32, Float64, etc.
end

"""
encode_tagged_float!(cb, identity::FloatIdentity)
IntegerIdentityOp(value, type_id, dtype, signed)

Integer identity value for binary operations.
"""
struct IntegerIdentityOp <: IdentityOp
value::UInt128 # Store as UInt128 to handle all unsigned values up to 64 bits
type_id::TypeId
dtype::Type # Int8, Int16, Int32, Int64, UInt8, etc.
signed::Bool # true for signed, false for unsigned
end

"""
encode_tagged_float!(cb, identity::FloatIdentityOp)

Encode a tagged float attribute for reduce identity.
Format: tag(Float=0x02) + typeid + ap_int(value_bits)
"""
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentityOp)
# Tag for Float attribute
push!(cb.buf, 0x02)
# Type ID
Expand All @@ -267,6 +279,53 @@ function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
encode_signed_varint!(cb.buf, bits)
end

"""
encode_tagged_int!(cb, identity::IntegerIdentityOp)

Encode a tagged integer identity attribute.
Format: tag(Int=0x01) + typeid + ap_int(value)
"""
function encode_tagged_int!(cb::CodeBuilder, identity::IntegerIdentityOp)
# Tag for Int attribute
push!(cb.buf, 0x01)
# Type ID
encode_typeid!(cb.buf, identity.type_id)
# Value: signed uses zigzag varint, unsigned uses plain varint
# Mask value to correct bit width and apply zigzag if signed
masked_value = mask_to_width(identity.value, identity.dtype, identity.signed)
if identity.signed
encode_signed_varint!(cb.buf, masked_value)
else
encode_varint!(cb.buf, masked_value)
end
end

"""
mask_to_width(value, dtype, signed)

Mask a UInt128 value to the correct bit width for the given type and apply zigzag if signed.
"""
mask_to_width(value::UInt128, ::Type{Int64}, signed::Bool) =
let masked = UInt64(value & 0xFFFFFFFFFFFFFFFF)
UInt64((masked << 1) ⊻ (masked >>> 63))
end
mask_to_width(value::UInt128, ::Type{Int32}, signed::Bool) =
let masked = UInt32(value & 0xFFFFFFFF)
UInt32((masked << 1) ⊻ (masked >>> 31))
end
mask_to_width(value::UInt128, ::Type{Int16}, signed::Bool) =
let masked = UInt16(value & 0xFFFF)
UInt16((masked << 1) ⊻ (masked >>> 15))
end
mask_to_width(value::UInt128, ::Type{Int8}, signed::Bool) =
let masked = UInt8(value & 0xFF)
UInt8((masked << 1) ⊻ (masked >>> 7))
end
mask_to_width(value::UInt128, ::Type{UInt64}, signed::Bool) = UInt64(value & 0xFFFFFFFFFFFFFFFF)
mask_to_width(value::UInt128, ::Type{UInt32}, signed::Bool) = UInt32(value & 0xFFFFFFFF)
mask_to_width(value::UInt128, ::Type{UInt16}, signed::Bool) = UInt16(value & 0xFFFF)
mask_to_width(value::UInt128, ::Type{UInt8}, signed::Bool) = UInt8(value & 0xFF)

"""
float_to_bits(value, dtype)

Expand Down Expand Up @@ -304,15 +363,24 @@ end
"""
encode_identity_array!(cb, identities)

Encode an array of reduce identity attributes.
Encode an array of binary operation identity attributes.
Dispatches on identity type to encode correctly.
"""
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:ReduceIdentity})
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:IdentityOp})
encode_varint!(cb.buf, length(identities))
for identity in identities
encode_tagged_float!(cb, identity)
encode_identity!(cb, identity)
end
end

"""
encode_identity!(cb, identity)

Encode a single identity attribute, dispatching on type.
"""
encode_identity!(cb::CodeBuilder, identity::FloatIdentityOp) = encode_tagged_float!(cb, identity)
encode_identity!(cb::CodeBuilder, identity::IntegerIdentityOp) = encode_tagged_int!(cb, identity)

"""
BytecodeWriter

Expand Down Expand Up @@ -544,7 +612,7 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder,
end

#=============================================================================
Optimization Hints
Optimization Hints
=============================================================================#

"""
Expand Down
1 change: 1 addition & 0 deletions src/compiler/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Base: compilerbarrier, donotdelete
using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView
using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned
using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual
using ..cuTile: IdentityOp, FloatIdentityOp, IntegerIdentityOp

end

Expand Down
Loading