Skip to content
2 changes: 1 addition & 1 deletion src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
identities::Vector{<:ReduceIdentity},
identities::Vector{<:IdentityVal},
body_scalar_types::Vector{TypeId})
encode_varint!(cb.buf, Opcode.ReduceOp)

Expand Down
80 changes: 68 additions & 12 deletions src/bytecode/writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,30 +234,41 @@ end
=============================================================================#

"""
ReduceIdentity
IdentityVal

Abstract type for reduce identity attributes.
Abstract type for binary operation identity attributes (reduce, scan, etc.).
"""
abstract type ReduceIdentity end
abstract type IdentityVal end

"""
FloatIdentity(value, type_id, dtype)
FloatIdentityVal(value, type_id, dtype)

Float identity value for reduce operations.
Float identity value for binary operations.
"""
struct FloatIdentity <: ReduceIdentity
struct FloatIdentityVal <: IdentityVal
value::Float64
type_id::TypeId
dtype::Type # Float16, Float32, Float64, etc.
end

"""
encode_tagged_float!(cb, identity::FloatIdentity)
IntegerIdentityVal(value, type_id, dtype)

Integer identity value for binary operations.
"""
struct IntegerIdentityVal <: IdentityVal
value::UInt128 # Store as UInt128 to handle all unsigned values up to 64 bits
type_id::TypeId
dtype::Type # Int8, Int16, Int32, Int64, UInt8, etc. (signedness inferred from dtype)
end

"""
encode_tagged_float!(cb, identity::FloatIdentityVal)

Encode a tagged float attribute for reduce identity.
Format: tag(Float=0x02) + typeid + ap_int(value_bits)
"""
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentityVal)
# Tag for Float attribute
push!(cb.buf, 0x02)
# Type ID
Expand All @@ -267,6 +278,42 @@ function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
encode_signed_varint!(cb.buf, bits)
end

"""
encode_tagged_int!(cb, identity::IntegerIdentityVal)

Encode a tagged integer identity attribute.
Format: tag(Int=0x01) + typeid + ap_int(value)
"""
function encode_tagged_int!(cb::CodeBuilder, identity::IntegerIdentityVal)
# Tag for Int attribute
push!(cb.buf, 0x01)
# Type ID
encode_typeid!(cb.buf, identity.type_id)
# Mask value to correct bit width and apply zigzag encoding for signed types
masked_value = mask_to_width(identity.value, identity.dtype)
encode_varint!(cb.buf, masked_value) # masked_value are already zigzag encoded
end

"""
mask_to_width(value, dtype)

Mask a UInt128 value to the correct bit width for the given type.
Applies zigzag encoding for signed types.
"""
function mask_to_width(value::UInt128, ::Type{T}) where T <: Integer
bits = sizeof(T) * 8
mask = (UInt128(1) << bits) - 1
masked = value & mask
U = unsigned(T)
unsigned_masked = U(masked)
if T <: Signed # do zig-zag encoding
U((unsigned_masked << 1) ⊻ (unsigned_masked >>> (bits - 1)))
else
unsigned_masked
end
end


"""
float_to_bits(value, dtype)

Expand Down Expand Up @@ -304,15 +351,24 @@ end
"""
encode_identity_array!(cb, identities)

Encode an array of reduce identity attributes.
Encode an array of binary operation identity attributes.
Dispatches on identity type to encode correctly.
"""
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:ReduceIdentity})
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:IdentityVal})
encode_varint!(cb.buf, length(identities))
for identity in identities
encode_tagged_float!(cb, identity)
encode_identity!(cb, identity)
end
end

"""
encode_identity!(cb, identity)

Encode a single identity attribute, dispatching on type.
"""
encode_identity!(cb::CodeBuilder, identity::FloatIdentityVal) = encode_tagged_float!(cb, identity)
encode_identity!(cb::CodeBuilder, identity::IntegerIdentityVal) = encode_tagged_int!(cb, identity)

"""
BytecodeWriter

Expand Down Expand Up @@ -544,7 +600,7 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder,
end

#=============================================================================
Optimization Hints
Optimization Hints
=============================================================================#

"""
Expand Down
1 change: 1 addition & 0 deletions src/compiler/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Base: compilerbarrier, donotdelete
using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView
using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned
using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual
using ..cuTile: IdentityVal, FloatIdentityVal, IntegerIdentityVal

end

Expand Down
79 changes: 66 additions & 13 deletions src/compiler/intrinsics/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ end
Sum reduction along 0-indexed axis.
Compiled to cuda_tile.reduce with ADD.
"""
@noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis}
@noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis}
reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1)
Tile{T, reduced_shape}()
end
Expand All @@ -523,7 +523,7 @@ end
Maximum reduction along 0-indexed axis.
Compiled to cuda_tile.reduce with MAX.
"""
@noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis}
@noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis}
reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1)
Tile{T, reduced_shape}()
end
Expand Down Expand Up @@ -562,28 +562,81 @@ function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol)
# Scalar type for reduction body (0D tile)
scalar_tile_type = tile_type!(tt, dtype, Int[])

# Create identity value - use simple dtype (f32), not tile type
identity_val = reduce_fn == :add ? -0.0 : (reduce_fn == :max ? -Inf : 0.0)
identity = FloatIdentity(identity_val, dtype, elem_type)
# Create identity value via dispatch on reduction function and element type
identity = operation_identity(Val(reduce_fn), dtype, elem_type)

# Emit ReduceOp
results = encode_ReduceOp!(cb, [output_tile_type], [input_tv.v], axis, [identity], [scalar_tile_type]) do block_args
acc, elem = block_args[1], block_args[2]

if reduce_fn == :add
res = encode_AddFOp!(cb, scalar_tile_type, acc, elem)
elseif reduce_fn == :max
res = encode_MaxFOp!(cb, scalar_tile_type, acc, elem)
else
error("Unsupported reduction function: $reduce_fn")
end

res = encode_reduce_body(cb, scalar_tile_type, acc, elem, reduce_fn, elem_type)
encode_YieldOp!(cb, [res])
end

CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape)
end

#=============================================================================#
# Reduce Identity Values via Dispatch
#=============================================================================#

"""
operation_identity(fn, dtype, elem_type) -> IdentityVal
to_uint128(value)

Convert an integer value to UInt128 for storage in IntegerIdentityVal.
For signed types, this returns the two's complement bit representation.
"""
# Unsigned types: directly convert
to_uint128(value::UInt64) = UInt128(value)
to_uint128(value::UInt32) = UInt128(value)
to_uint128(value::UInt16) = UInt128(value)
to_uint128(value::UInt8) = UInt128(value)
# Signed types: reinterpret as unsigned first, then convert
to_uint128(value::Int64) = UInt128(reinterpret(UInt64, value))
to_uint128(value::Int32) = UInt128(reinterpret(UInt32, value))
to_uint128(value::Int16) = UInt128(reinterpret(UInt16, value))
to_uint128(value::Int8) = UInt128(reinterpret(UInt8, value))

"""
operation_identity(fn, dtype, elem_type) -> IdentityVal

Return the identity value for a binary operation (reduce, scan, etc.).
Identity must satisfy: identity ⊕ x = x for the operation.
"""

# Addition identity: 0 + x = x
operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: AbstractFloat =
FloatIdentityVal(zero(T), dtype, T)
operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: Integer =
IntegerIdentityVal(to_uint128(zero(T)), dtype, T)

# Maximum identity: max(typemin(T), x) = x
operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: AbstractFloat =
FloatIdentityVal(typemin(T), dtype, T)
operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: Integer =
IntegerIdentityVal(to_uint128(typemin(T)), dtype, T)

#=============================================================================#
# Reduce Body Operations
#=============================================================================#
function encode_reduce_body(cb, type, acc, elem, op::Symbol, ::Type{T}) where T
if T <: AbstractFloat
if op == :add
encode_AddFOp!(cb, type, acc, elem)
elseif op == :max
encode_MaxFOp!(cb, type, acc, elem)
end
else # Integer
signedness = T <: Signed ? SignednessSigned : SignednessUnsigned
if op == :add
encode_AddIOp!(cb, type, acc, elem)
elseif op == :max
encode_MaxIOp!(cb, type, acc, elem; signedness)
end
end
end


# cuda_tile.reshape
@eval Intrinsics begin
Expand Down
3 changes: 3 additions & 0 deletions src/cuTile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,7 @@ include("language/atomics.jl")
public launch
launch() = error("Please import CUDA.jl before using `cuTile.launch`.")

# Export identity types for reduction operations
public IdentityVal, FloatIdentityVal, IntegerIdentityVal

end # module cuTile
9 changes: 4 additions & 5 deletions src/language/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,10 @@ Returns a tile with the specified dimension removed.
sums = ct.reduce_sum(tile, 2) # Returns (128,) tile
```
"""
@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S}
@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: Number, S}
Intrinsics.reduce_sum(tile, Val(axis - 1))
end
@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis}
@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis}
Intrinsics.reduce_sum(tile, Val(axis - 1))
end

Expand All @@ -546,10 +546,10 @@ Maximum reduction along the specified axis (1-indexed).
maxes = ct.reduce_max(tile, 2) # Max along axis 2
```
"""
@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S}
@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: Number, S}
Intrinsics.reduce_max(tile, Val(axis - 1))
end
@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis}
@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis}
Intrinsics.reduce_max(tile, Val(axis - 1))
end

Expand Down Expand Up @@ -628,4 +628,3 @@ br = ct.extract(tile, (2, 2), (4, 4)) # Bottom-right (rows 5-8, cols 5-8)
Intrinsics.extract(tile, Val(map(i -> i - 1, index)), Val(shape))
@inline extract(tile::Tile{T}, ::Val{Index}, ::Val{Shape}) where {T, Index, Shape} =
Intrinsics.extract(tile, Val(map(i -> i - 1, Index)), Val(Shape))

56 changes: 56 additions & 0 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,62 @@
end
end

# Integer reduce_sum (Int32)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 16))
@check "reduce"
@check "addi"
sums = ct.reduce_sum(tile, 2)
ct.store(b, pid, sums)
return
end
end

# Integer reduce_max (Int32)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 16))
@check "reduce"
@check "maxi"
maxes = ct.reduce_max(tile, 2)
ct.store(b, pid, maxes)
return
end
end

# Unsigned reduce_sum (UInt32)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 16))
@check "reduce"
@check "addi"
sums = ct.reduce_sum(tile, 2)
ct.store(b, pid, sums)
return
end
end

# Unsigned reduce_max (UInt32)
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (4, 16))
@check "reduce"
@check "maxi"
maxes = ct.reduce_max(tile, 2)
ct.store(b, pid, maxes)
return
end
end

@testset "select" begin
@test @filecheck begin
@check_label "entry"
Expand Down
Loading