diff --git a/src/bytecode/encodings.jl b/src/bytecode/encodings.jl index 9f06415..1e1672a 100644 --- a/src/bytecode/encodings.jl +++ b/src/bytecode/encodings.jl @@ -1291,7 +1291,7 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder, result_types::Vector{TypeId}, operands::Vector{Value}, dim::Int, - identities::Vector{<:ReduceIdentity}, + identities::Vector{<:IdentityVal}, body_scalar_types::Vector{TypeId}) encode_varint!(cb.buf, Opcode.ReduceOp) diff --git a/src/bytecode/writer.jl b/src/bytecode/writer.jl index eb87585..19beb8e 100644 --- a/src/bytecode/writer.jl +++ b/src/bytecode/writer.jl @@ -234,30 +234,41 @@ end =============================================================================# """ - ReduceIdentity + IdentityVal -Abstract type for reduce identity attributes. +Abstract type for binary operation identity attributes (reduce, scan, etc.). """ -abstract type ReduceIdentity end +abstract type IdentityVal end """ - FloatIdentity(value, type_id, dtype) + FloatIdentityVal(value, type_id, dtype) -Float identity value for reduce operations. +Float identity value for binary operations. """ -struct FloatIdentity <: ReduceIdentity +struct FloatIdentityVal <: IdentityVal value::Float64 type_id::TypeId dtype::Type # Float16, Float32, Float64, etc. end """ - encode_tagged_float!(cb, identity::FloatIdentity) + IntegerIdentityVal(value, type_id, dtype) + +Integer identity value for binary operations. +""" +struct IntegerIdentityVal <: IdentityVal + value::UInt128 # Store as UInt128 to handle all unsigned values up to 64 bits + type_id::TypeId + dtype::Type # Int8, Int16, Int32, Int64, UInt8, etc. (signedness inferred from dtype) +end + +""" + encode_tagged_float!(cb, identity::FloatIdentityVal) Encode a tagged float attribute for reduce identity. Format: tag(Float=0x02) + typeid + ap_int(value_bits) """ -function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity) +function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentityVal) # Tag for Float attribute push!(cb.buf, 0x02) # Type ID @@ -267,6 +278,42 @@ function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity) encode_signed_varint!(cb.buf, bits) end +""" + encode_tagged_int!(cb, identity::IntegerIdentityVal) + +Encode a tagged integer identity attribute. +Format: tag(Int=0x01) + typeid + ap_int(value) +""" +function encode_tagged_int!(cb::CodeBuilder, identity::IntegerIdentityVal) + # Tag for Int attribute + push!(cb.buf, 0x01) + # Type ID + encode_typeid!(cb.buf, identity.type_id) + # Mask value to correct bit width and apply zigzag encoding for signed types + masked_value = mask_to_width(identity.value, identity.dtype) + encode_varint!(cb.buf, masked_value) # masked_value are already zigzag encoded +end + +""" + mask_to_width(value, dtype) + +Mask a UInt128 value to the correct bit width for the given type. +Applies zigzag encoding for signed types. +""" +function mask_to_width(value::UInt128, ::Type{T}) where T <: Integer + bits = sizeof(T) * 8 + mask = (UInt128(1) << bits) - 1 + masked = value & mask + U = unsigned(T) + unsigned_masked = U(masked) + if T <: Signed # do zig-zag encoding + U((unsigned_masked << 1) ⊻ (unsigned_masked >>> (bits - 1))) + else + unsigned_masked + end +end + + """ float_to_bits(value, dtype) @@ -304,15 +351,24 @@ end """ encode_identity_array!(cb, identities) -Encode an array of reduce identity attributes. +Encode an array of binary operation identity attributes. +Dispatches on identity type to encode correctly. """ -function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:ReduceIdentity}) +function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:IdentityVal}) encode_varint!(cb.buf, length(identities)) for identity in identities - encode_tagged_float!(cb, identity) + encode_identity!(cb, identity) end end +""" + encode_identity!(cb, identity) + +Encode a single identity attribute, dispatching on type. +""" +encode_identity!(cb::CodeBuilder, identity::FloatIdentityVal) = encode_tagged_float!(cb, identity) +encode_identity!(cb::CodeBuilder, identity::IntegerIdentityVal) = encode_tagged_int!(cb, identity) + """ BytecodeWriter @@ -544,7 +600,7 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder, end #============================================================================= - Optimization Hints + Optimization Hints =============================================================================# """ diff --git a/src/compiler/intrinsics.jl b/src/compiler/intrinsics.jl index 16c55da..45dc34f 100644 --- a/src/compiler/intrinsics.jl +++ b/src/compiler/intrinsics.jl @@ -8,6 +8,7 @@ using Base: compilerbarrier, donotdelete using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual +using ..cuTile: IdentityVal, FloatIdentityVal, IntegerIdentityVal end diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index 7fb2530..91e1dcf 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -512,7 +512,7 @@ end Sum reduction along 0-indexed axis. Compiled to cuda_tile.reduce with ADD. """ - @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} + @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) Tile{T, reduced_shape}() end @@ -523,7 +523,7 @@ end Maximum reduction along 0-indexed axis. Compiled to cuda_tile.reduce with MAX. """ - @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} + @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) Tile{T, reduced_shape}() end @@ -562,28 +562,81 @@ function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol) # Scalar type for reduction body (0D tile) scalar_tile_type = tile_type!(tt, dtype, Int[]) - # Create identity value - use simple dtype (f32), not tile type - identity_val = reduce_fn == :add ? -0.0 : (reduce_fn == :max ? -Inf : 0.0) - identity = FloatIdentity(identity_val, dtype, elem_type) + # Create identity value via dispatch on reduction function and element type + identity = operation_identity(Val(reduce_fn), dtype, elem_type) # Emit ReduceOp results = encode_ReduceOp!(cb, [output_tile_type], [input_tv.v], axis, [identity], [scalar_tile_type]) do block_args acc, elem = block_args[1], block_args[2] - if reduce_fn == :add - res = encode_AddFOp!(cb, scalar_tile_type, acc, elem) - elseif reduce_fn == :max - res = encode_MaxFOp!(cb, scalar_tile_type, acc, elem) - else - error("Unsupported reduction function: $reduce_fn") - end - + res = encode_reduce_body(cb, scalar_tile_type, acc, elem, reduce_fn, elem_type) encode_YieldOp!(cb, [res]) end CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape) end +#=============================================================================# +# Reduce Identity Values via Dispatch +#=============================================================================# + +""" + operation_identity(fn, dtype, elem_type) -> IdentityVal + to_uint128(value) + +Convert an integer value to UInt128 for storage in IntegerIdentityVal. +For signed types, this returns the two's complement bit representation. +""" +# Unsigned types: directly convert +to_uint128(value::UInt64) = UInt128(value) +to_uint128(value::UInt32) = UInt128(value) +to_uint128(value::UInt16) = UInt128(value) +to_uint128(value::UInt8) = UInt128(value) +# Signed types: reinterpret as unsigned first, then convert +to_uint128(value::Int64) = UInt128(reinterpret(UInt64, value)) +to_uint128(value::Int32) = UInt128(reinterpret(UInt32, value)) +to_uint128(value::Int16) = UInt128(reinterpret(UInt16, value)) +to_uint128(value::Int8) = UInt128(reinterpret(UInt8, value)) + +""" + operation_identity(fn, dtype, elem_type) -> IdentityVal + +Return the identity value for a binary operation (reduce, scan, etc.). +Identity must satisfy: identity ⊕ x = x for the operation. +""" + +# Addition identity: 0 + x = x +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityVal(zero(T), dtype, T) +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityVal(to_uint128(zero(T)), dtype, T) + +# Maximum identity: max(typemin(T), x) = x +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityVal(typemin(T), dtype, T) +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityVal(to_uint128(typemin(T)), dtype, T) + +#=============================================================================# +# Reduce Body Operations +#=============================================================================# +function encode_reduce_body(cb, type, acc, elem, op::Symbol, ::Type{T}) where T + if T <: AbstractFloat + if op == :add + encode_AddFOp!(cb, type, acc, elem) + elseif op == :max + encode_MaxFOp!(cb, type, acc, elem) + end + else # Integer + signedness = T <: Signed ? SignednessSigned : SignednessUnsigned + if op == :add + encode_AddIOp!(cb, type, acc, elem) + elseif op == :max + encode_MaxIOp!(cb, type, acc, elem; signedness) + end + end +end + # cuda_tile.reshape @eval Intrinsics begin diff --git a/src/cuTile.jl b/src/cuTile.jl index 375aaa2..7cb13b4 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -40,4 +40,7 @@ include("language/atomics.jl") public launch launch() = error("Please import CUDA.jl before using `cuTile.launch`.") +# Export identity types for reduction operations +public IdentityVal, FloatIdentityVal, IntegerIdentityVal + end # module cuTile diff --git a/src/language/operations.jl b/src/language/operations.jl index 463a358..2dcb24f 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -529,10 +529,10 @@ Returns a tile with the specified dimension removed. sums = ct.reduce_sum(tile, 2) # Returns (128,) tile ``` """ -@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S} +@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} Intrinsics.reduce_sum(tile, Val(axis - 1)) end -@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} +@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} Intrinsics.reduce_sum(tile, Val(axis - 1)) end @@ -546,10 +546,10 @@ Maximum reduction along the specified axis (1-indexed). maxes = ct.reduce_max(tile, 2) # Max along axis 2 ``` """ -@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S} +@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} Intrinsics.reduce_max(tile, Val(axis - 1)) end -@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} +@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} Intrinsics.reduce_max(tile, Val(axis - 1)) end @@ -628,4 +628,3 @@ br = ct.extract(tile, (2, 2), (4, 4)) # Bottom-right (rows 5-8, cols 5-8) Intrinsics.extract(tile, Val(map(i -> i - 1, index)), Val(shape)) @inline extract(tile::Tile{T}, ::Val{Index}, ::Val{Shape}) where {T, Index, Shape} = Intrinsics.extract(tile, Val(map(i -> i - 1, Index)), Val(Shape)) - diff --git a/test/codegen.jl b/test/codegen.jl index ae4b42e..c189310 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -387,6 +387,62 @@ end end + # Integer reduce_sum (Int32) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 16)) + @check "reduce" + @check "addi" + sums = ct.reduce_sum(tile, 2) + ct.store(b, pid, sums) + return + end + end + + # Integer reduce_max (Int32) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 16)) + @check "reduce" + @check "maxi" + maxes = ct.reduce_max(tile, 2) + ct.store(b, pid, maxes) + return + end + end + + # Unsigned reduce_sum (UInt32) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 16)) + @check "reduce" + @check "addi" + sums = ct.reduce_sum(tile, 2) + ct.store(b, pid, sums) + return + end + end + + # Unsigned reduce_max (UInt32) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 16)) + @check "reduce" + @check "maxi" + maxes = ct.reduce_max(tile, 2) + ct.store(b, pid, maxes) + return + end + end + @testset "select" begin @test @filecheck begin @check_label "entry" diff --git a/test/execution.jl b/test/execution.jl index a2072f3..073f106 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -1842,6 +1842,107 @@ end end end +# Kernel factory for reduce operations - extendable pattern +function makeReduceKernel(::Type{T}, op::Symbol) where {T} + reduceFunc = if op == :reduce_sum + ct.reduce_sum + elseif op == :reduce_max + ct.reduce_max + # ADD NEW OPERATIONS HERE + # elseif op == :reduce_min + # ct.reduce_min + # elseif op == :reduce_mul + # ct.reduce_mul + end + + @inline function kernel(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, tileSz::ct.Constant{Int}) + ct.store(b, ct.bid(1), reduceFunc(ct.load(a, ct.bid(1), (tileSz[],)), Val(1))) + return nothing + end + return kernel +end + +# CPU reference implementation for reduce operations - extendable pattern +function cpu_reduce(a_reshaped::AbstractArray{T}, op::Symbol) where {T} + if op == :reduce_sum + result = sum(a_reshaped, dims=1)[:] + # For unsigned types, apply mask to handle overflow + if T <: Unsigned + result .= result .& typemax(T) + end + return result + elseif op == :reduce_max + return maximum(a_reshaped, dims=1)[:] + # ADD NEW OPERATIONS HERE + # elseif op == :reduce_min + # return minimum(a_reshaped, dims=1)[:] + # elseif op == :reduce_mul + # return prod(a_reshaped, dims=1)[:] + end +end + +@testset "1D reduce operations (extendable)" begin + # Test parameters - easily extendable + TILE_SIZE = 32 + N = 1024 + + # Supported types - add new types here + TEST_TYPES = [Int8, Int16, Int32, Int64, UInt16, UInt32, UInt64, Float16, Float32, Float64] + + # Supported operations - add new operations here + TEST_OPS = [:reduce_sum, :reduce_max] + + @testset "Type: $elType, Operation: $op" for elType in TEST_TYPES, op in TEST_OPS + # Create kernel using factory + reduceKernel = try + makeReduceKernel(elType, op) + catch e + @test_broken false + rethrow() + end + + # Generate input data with type-appropriate ranges + # Int8: -3 to 3 (32 * 3 = 96, safely within Int8 range -128 to 127) + # Int16: -800 to 800 (32 * 800 = 25,600, safely within Int16 range -32,768 to 32,767) + # UInt16: 1 to 2000 (32 * 2000 = 64,000, safely within UInt16 range 0 to 65,535) + # Larger types: -1000 to 1000 (arbitrary but covers positive/negative) + # Floats: 0 to 1 (CUDA.rand default) + if elType == Int8 + a_gpu = CuArray{Int8}(rand(-3:3, N)) + elseif elType == Int16 + a_gpu = CuArray{Int16}(rand(-800:800, N)) + elseif elType == UInt16 + a_gpu = CuArray{UInt16}(rand(1:2000, N)) + elseif elType <: Integer && elType <: Signed + a_gpu = CuArray{elType}(rand(-1000:1000, N)) + else + a_gpu = CUDA.rand(elType, N) + end + b_gpu = CUDA.zeros(elType, cld(N, TILE_SIZE)) + + # Launch kernel + try + CUDA.@sync ct.launch(reduceKernel, cld(N, TILE_SIZE), a_gpu, b_gpu, ct.Constant(TILE_SIZE)) + catch e + @test_broken false + rethrow() + end + + # Verify results + a_cpu = Array(a_gpu) + b_cpu = Array(b_gpu) + a_reshaped = reshape(a_cpu, TILE_SIZE, :) + cpu_result = cpu_reduce(a_reshaped, op) + + # Use appropriate comparison based on type + if elType <: AbstractFloat + @test b_cpu ≈ cpu_result rtol=1e-3 + else + @test b_cpu == cpu_result + end + end +end + @testset "transpose with hints" begin function transpose_with_hints(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})