From e6ba9731616dc03d723326447ca0cf414e1711a7 Mon Sep 17 00:00:00 2001 From: arhik Date: Sat, 17 Jan 2026 10:51:48 +0000 Subject: [PATCH 1/3] feat: Add integer reduction support for reduce_sum and reduce_max MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit enables reduce_sum and reduce_max operations on all numeric types, extending beyond the previous float-only support. ## Infrastructure Changes ### Bytecode Layer - Added IntegerIdentityOp struct with signed/unsigned handling - Added encode_tagged_int! for integer identity encoding - Added mask_to_width function with zigzag encoding for signed types - Added encode_identity! dispatch for FloatIdentityOp and IntegerIdentityOp - Refactored ReduceIdentity → IdentityOp for extensibility ### Compiler Layer - Refactored emit_reduce! to use dispatch-based approach - Added operation_identity dispatch for add/max operations - Added encode_reduce_body dispatch for float and integer operations - Removed T <: AbstractFloat constraints from intrinsics ### Language Layer - Removed type constraints from reduce_sum and reduce_max in operations.jl ## Test Coverage ### Codegen Tests - Added FileCheck tests for Int32/UInt32 reduce_sum and reduce_max - Verifies correct IR generation (addi, maxi instructions) ### Execution Tests - Factory pattern for easy extension (makeReduceKernel, cpu_reduce) - Tests 10 types: Int8, Int16, Int32, Int64, UInt16, UInt32, UInt64, Float16, Float32, Float64 - Tests 2 operations: reduce_sum, reduce_max - CPU verification for all test cases - Type-appropriate input ranges to prevent overflow ## Files Changed - src/bytecode/encodings.jl: Fix IdentityOp type annotation - src/bytecode/writer.jl: Integer identity infrastructure - src/compiler/intrinsics.jl: Import identity types - src/compiler/intrinsics/core.jl: Dispatch-based reduce implementation - src/cuTile.jl: Export identity types - src/language/operations.jl: Remove type constraints - test/codegen.jl: Add integer reduction codegen tests - test/execution.jl: Add extendable execution tests ## Extensibility The infrastructure is designed for easy extension: - Add new reduce operations by defining operation_identity and encode_reduce_body methods - Add new types by adding to TEST_TYPES array and appropriate data generation --- src/bytecode/encodings.jl | 2 +- src/bytecode/writer.jl | 90 ++++++++++++++++++++++++---- src/compiler/intrinsics.jl | 1 + src/compiler/intrinsics/core.jl | 72 +++++++++++++++++++---- src/cuTile.jl | 3 + src/language/operations.jl | 8 +-- test/codegen.jl | 56 ++++++++++++++++++ test/execution.jl | 101 ++++++++++++++++++++++++++++++++ 8 files changed, 304 insertions(+), 29 deletions(-) diff --git a/src/bytecode/encodings.jl b/src/bytecode/encodings.jl index 9f06415..9d20820 100644 --- a/src/bytecode/encodings.jl +++ b/src/bytecode/encodings.jl @@ -1291,7 +1291,7 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder, result_types::Vector{TypeId}, operands::Vector{Value}, dim::Int, - identities::Vector{<:ReduceIdentity}, + identities::Vector{<:IdentityOp}, body_scalar_types::Vector{TypeId}) encode_varint!(cb.buf, Opcode.ReduceOp) diff --git a/src/bytecode/writer.jl b/src/bytecode/writer.jl index eb87585..52b693e 100644 --- a/src/bytecode/writer.jl +++ b/src/bytecode/writer.jl @@ -234,30 +234,42 @@ end =============================================================================# """ - ReduceIdentity + IdentityOp -Abstract type for reduce identity attributes. +Abstract type for binary operation identity attributes (reduce, scan, etc.). """ -abstract type ReduceIdentity end +abstract type IdentityOp end """ - FloatIdentity(value, type_id, dtype) + FloatIdentityOp(value, type_id, dtype) -Float identity value for reduce operations. +Float identity value for binary operations. """ -struct FloatIdentity <: ReduceIdentity +struct FloatIdentityOp <: IdentityOp value::Float64 type_id::TypeId dtype::Type # Float16, Float32, Float64, etc. end """ - encode_tagged_float!(cb, identity::FloatIdentity) + IntegerIdentityOp(value, type_id, dtype, signed) + +Integer identity value for binary operations. +""" +struct IntegerIdentityOp <: IdentityOp + value::UInt128 # Store as UInt128 to handle all unsigned values up to 64 bits + type_id::TypeId + dtype::Type # Int8, Int16, Int32, Int64, UInt8, etc. + signed::Bool # true for signed, false for unsigned +end + +""" + encode_tagged_float!(cb, identity::FloatIdentityOp) Encode a tagged float attribute for reduce identity. Format: tag(Float=0x02) + typeid + ap_int(value_bits) """ -function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity) +function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentityOp) # Tag for Float attribute push!(cb.buf, 0x02) # Type ID @@ -267,6 +279,53 @@ function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity) encode_signed_varint!(cb.buf, bits) end +""" + encode_tagged_int!(cb, identity::IntegerIdentityOp) + +Encode a tagged integer identity attribute. +Format: tag(Int=0x01) + typeid + ap_int(value) +""" +function encode_tagged_int!(cb::CodeBuilder, identity::IntegerIdentityOp) + # Tag for Int attribute + push!(cb.buf, 0x01) + # Type ID + encode_typeid!(cb.buf, identity.type_id) + # Value: signed uses zigzag varint, unsigned uses plain varint + # Mask value to correct bit width and apply zigzag if signed + masked_value = mask_to_width(identity.value, identity.dtype, identity.signed) + if identity.signed + encode_signed_varint!(cb.buf, masked_value) + else + encode_varint!(cb.buf, masked_value) + end +end + +""" + mask_to_width(value, dtype, signed) + +Mask a UInt128 value to the correct bit width for the given type and apply zigzag if signed. +""" +mask_to_width(value::UInt128, ::Type{Int64}, signed::Bool) = + let masked = UInt64(value & 0xFFFFFFFFFFFFFFFF) + UInt64((masked << 1) ⊻ (masked >>> 63)) + end +mask_to_width(value::UInt128, ::Type{Int32}, signed::Bool) = + let masked = UInt32(value & 0xFFFFFFFF) + UInt32((masked << 1) ⊻ (masked >>> 31)) + end +mask_to_width(value::UInt128, ::Type{Int16}, signed::Bool) = + let masked = UInt16(value & 0xFFFF) + UInt16((masked << 1) ⊻ (masked >>> 15)) + end +mask_to_width(value::UInt128, ::Type{Int8}, signed::Bool) = + let masked = UInt8(value & 0xFF) + UInt8((masked << 1) ⊻ (masked >>> 7)) + end +mask_to_width(value::UInt128, ::Type{UInt64}, signed::Bool) = UInt64(value & 0xFFFFFFFFFFFFFFFF) +mask_to_width(value::UInt128, ::Type{UInt32}, signed::Bool) = UInt32(value & 0xFFFFFFFF) +mask_to_width(value::UInt128, ::Type{UInt16}, signed::Bool) = UInt16(value & 0xFFFF) +mask_to_width(value::UInt128, ::Type{UInt8}, signed::Bool) = UInt8(value & 0xFF) + """ float_to_bits(value, dtype) @@ -304,15 +363,24 @@ end """ encode_identity_array!(cb, identities) -Encode an array of reduce identity attributes. +Encode an array of binary operation identity attributes. +Dispatches on identity type to encode correctly. """ -function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:ReduceIdentity}) +function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:IdentityOp}) encode_varint!(cb.buf, length(identities)) for identity in identities - encode_tagged_float!(cb, identity) + encode_identity!(cb, identity) end end +""" + encode_identity!(cb, identity) + +Encode a single identity attribute, dispatching on type. +""" +encode_identity!(cb::CodeBuilder, identity::FloatIdentityOp) = encode_tagged_float!(cb, identity) +encode_identity!(cb::CodeBuilder, identity::IntegerIdentityOp) = encode_tagged_int!(cb, identity) + """ BytecodeWriter diff --git a/src/compiler/intrinsics.jl b/src/compiler/intrinsics.jl index 16c55da..e522141 100644 --- a/src/compiler/intrinsics.jl +++ b/src/compiler/intrinsics.jl @@ -8,6 +8,7 @@ using Base: compilerbarrier, donotdelete using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual +using ..cuTile: IdentityOp, FloatIdentityOp, IntegerIdentityOp end diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index 7fb2530..61a4a0b 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -512,7 +512,7 @@ end Sum reduction along 0-indexed axis. Compiled to cuda_tile.reduce with ADD. """ - @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} + @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) Tile{T, reduced_shape}() end @@ -523,7 +523,7 @@ end Maximum reduction along 0-indexed axis. Compiled to cuda_tile.reduce with MAX. """ - @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} + @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) Tile{T, reduced_shape}() end @@ -562,28 +562,74 @@ function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol) # Scalar type for reduction body (0D tile) scalar_tile_type = tile_type!(tt, dtype, Int[]) - # Create identity value - use simple dtype (f32), not tile type - identity_val = reduce_fn == :add ? -0.0 : (reduce_fn == :max ? -Inf : 0.0) - identity = FloatIdentity(identity_val, dtype, elem_type) + # Create identity value via dispatch on reduction function and element type + identity = operation_identity(Val(reduce_fn), dtype, elem_type) # Emit ReduceOp results = encode_ReduceOp!(cb, [output_tile_type], [input_tv.v], axis, [identity], [scalar_tile_type]) do block_args acc, elem = block_args[1], block_args[2] - if reduce_fn == :add - res = encode_AddFOp!(cb, scalar_tile_type, acc, elem) - elseif reduce_fn == :max - res = encode_MaxFOp!(cb, scalar_tile_type, acc, elem) - else - error("Unsupported reduction function: $reduce_fn") - end - + res = encode_reduce_body(cb, scalar_tile_type, acc, elem, Val(reduce_fn), elem_type) encode_YieldOp!(cb, [res]) end CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape) end +#=============================================================================# +# Reduce Identity Values via Dispatch +#=============================================================================# + +""" + operation_identity(fn, dtype, elem_type) -> IdentityOp + to_uint128(value) + +Convert an integer value to UInt128 for storage in IntegerIdentityOp. +For signed types, this returns the two's complement bit representation. +""" +# Unsigned types: directly convert +to_uint128(value::UInt64) = UInt128(value) +to_uint128(value::UInt32) = UInt128(value) +to_uint128(value::UInt16) = UInt128(value) +to_uint128(value::UInt8) = UInt128(value) +# Signed types: reinterpret as unsigned first, then convert +to_uint128(value::Int64) = UInt128(reinterpret(UInt64, value)) +to_uint128(value::Int32) = UInt128(reinterpret(UInt32, value)) +to_uint128(value::Int16) = UInt128(reinterpret(UInt16, value)) +to_uint128(value::Int8) = UInt128(reinterpret(UInt8, value)) + +""" + operation_identity(fn, dtype, elem_type) -> IdentityOp + +Return the identity value for a binary operation (reduce, scan, etc.). +Identity must satisfy: identity ⊕ x = x for the operation. +""" + +# Addition identity: 0 + x = x +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(zero(T), dtype, T) +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(to_uint128(zero(T)), dtype, T, T <: Signed) + +# Maximum identity: max(typemin(T), x) = x +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(typemin(T), dtype, T) +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(to_uint128(typemin(T)), dtype, T, T <: Signed) + +#=============================================================================# +# Reduce Body Operations - dispatch on Val{fn} and elem_type +#=============================================================================# + +encode_reduce_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: AbstractFloat = + encode_AddFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: AbstractFloat = + encode_MaxFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: Integer = + encode_AddIOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: Integer = + encode_MaxIOp!(cb, type, acc, elem; signedness=T <: Signed ? SignednessSigned : SignednessUnsigned) + # cuda_tile.reshape @eval Intrinsics begin diff --git a/src/cuTile.jl b/src/cuTile.jl index 375aaa2..a3b019d 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -40,4 +40,7 @@ include("language/atomics.jl") public launch launch() = error("Please import CUDA.jl before using `cuTile.launch`.") +# Export identity types for reduction operations +public IdentityOp, FloatIdentityOp, IntegerIdentityOp + end # module cuTile diff --git a/src/language/operations.jl b/src/language/operations.jl index 463a358..5b00350 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -529,10 +529,10 @@ Returns a tile with the specified dimension removed. sums = ct.reduce_sum(tile, 2) # Returns (128,) tile ``` """ -@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S} +@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T, S} Intrinsics.reduce_sum(tile, Val(axis - 1)) end -@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} +@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} Intrinsics.reduce_sum(tile, Val(axis - 1)) end @@ -546,10 +546,10 @@ Maximum reduction along the specified axis (1-indexed). maxes = ct.reduce_max(tile, 2) # Max along axis 2 ``` """ -@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S} +@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T, S} Intrinsics.reduce_max(tile, Val(axis - 1)) end -@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} +@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} Intrinsics.reduce_max(tile, Val(axis - 1)) end diff --git a/test/codegen.jl b/test/codegen.jl index ae4b42e..c189310 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -387,6 +387,62 @@ end end + # Integer reduce_sum (Int32) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 16)) + @check "reduce" + @check "addi" + sums = ct.reduce_sum(tile, 2) + ct.store(b, pid, sums) + return + end + end + + # Integer reduce_max (Int32) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 16)) + @check "reduce" + @check "maxi" + maxes = ct.reduce_max(tile, 2) + ct.store(b, pid, maxes) + return + end + end + + # Unsigned reduce_sum (UInt32) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 16)) + @check "reduce" + @check "addi" + sums = ct.reduce_sum(tile, 2) + ct.store(b, pid, sums) + return + end + end + + # Unsigned reduce_max (UInt32) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 16)) + @check "reduce" + @check "maxi" + maxes = ct.reduce_max(tile, 2) + ct.store(b, pid, maxes) + return + end + end + @testset "select" begin @test @filecheck begin @check_label "entry" diff --git a/test/execution.jl b/test/execution.jl index 8297a9d..611d35d 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -1842,6 +1842,107 @@ end end end +# Kernel factory for reduce operations - extendable pattern +function makeReduceKernel(::Type{T}, op::Symbol) where {T} + reduceFunc = if op == :reduce_sum + ct.reduce_sum + elseif op == :reduce_max + ct.reduce_max + # ADD NEW OPERATIONS HERE + # elseif op == :reduce_min + # ct.reduce_min + # elseif op == :reduce_mul + # ct.reduce_mul + end + + @inline function kernel(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, tileSz::ct.Constant{Int}) + ct.store(b, ct.bid(1), reduceFunc(ct.load(a, ct.bid(1), (tileSz[],)), Val(1))) + return nothing + end + return kernel +end + +# CPU reference implementation for reduce operations - extendable pattern +function cpu_reduce(a_reshaped::AbstractArray{T}, op::Symbol) where {T} + if op == :reduce_sum + result = sum(a_reshaped, dims=1)[:] + # For unsigned types, apply mask to handle overflow + if T <: Unsigned + result .= result .& typemax(T) + end + return result + elseif op == :reduce_max + return maximum(a_reshaped, dims=1)[:] + # ADD NEW OPERATIONS HERE + # elseif op == :reduce_min + # return minimum(a_reshaped, dims=1)[:] + # elseif op == :reduce_mul + # return prod(a_reshaped, dims=1)[:] + end +end + +@testset "1D reduce operations (extendable)" begin + # Test parameters - easily extendable + TILE_SIZE = 32 + N = 1024 + + # Supported types - add new types here + TEST_TYPES = [Int8, Int16, Int32, Int64, UInt16, UInt32, UInt64, Float16, Float32, Float64] + + # Supported operations - add new operations here + TEST_OPS = [:reduce_sum, :reduce_max] + + @testset "Type: $elType, Operation: $op" for elType in TEST_TYPES, op in TEST_OPS + # Create kernel using factory + reduceKernel = try + makeReduceKernel(elType, op) + catch e + @test_broken false + rethrow() + end + + # Generate input data with type-appropriate ranges + # Int8: -3 to 3 (32 * 3 = 96, safely within Int8 range -128 to 127) + # Int16: -800 to 800 (32 * 800 = 25,600, safely within Int16 range -32,768 to 32,767) + # UInt16: 1 to 2000 (32 * 2000 = 64,000, safely within UInt16 range 0 to 65,535) + # Larger types: -1000 to 1000 (arbitrary but covers positive/negative) + # Floats: 0 to 1 (CUDA.rand default) + if elType == Int8 + a_gpu = CuArray{Int8}(rand(-3:3, N)) + elseif elType == Int16 + a_gpu = CuArray{Int16}(rand(-800:800, N)) + elseif elType == UInt16 + a_gpu = CuArray{UInt16}(rand(1:2000, N)) + elseif elType <: Integer && elType <: Signed + a_gpu = CuArray{elType}(rand(-1000:1000, N)) + else + a_gpu = CUDA.rand(elType, N) + end + b_gpu = CUDA.zeros(elType, cld(N, TILE_SIZE)) + + # Launch kernel + try + CUDA.@sync ct.launch(reduceKernel, cld(N, TILE_SIZE), a_gpu, b_gpu, ct.Constant(TILE_SIZE)) + catch e + @test_broken false + rethrow() + end + + # Verify results + a_cpu = Array(a_gpu) + b_cpu = Array(b_gpu) + a_reshaped = reshape(a_cpu, TILE_SIZE, :) + cpu_result = cpu_reduce(a_reshaped, op) + + # Use appropriate comparison based on type + if elType <: AbstractFloat + @test b_cpu ≈ cpu_result rtol=1e-3 + else + @test b_cpu == cpu_result + end + end +end + @testset "transpose with hints" begin function transpose_with_hints(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2}) From 5088d5411d5402c1a636f16f0e815a00389f7970 Mon Sep 17 00:00:00 2001 From: arhik Date: Sat, 17 Jan 2026 19:33:32 +0000 Subject: [PATCH 2/3] Add Number type constraint to reduce_sum and reduce_max functions - Constrain T <: Number in reduce_sum/reduce_max signatures for type safety - Ensures only numeric types can be used with reduction operations --- src/language/operations.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/language/operations.jl b/src/language/operations.jl index 5b00350..2dcb24f 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -529,10 +529,10 @@ Returns a tile with the specified dimension removed. sums = ct.reduce_sum(tile, 2) # Returns (128,) tile ``` """ -@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T, S} +@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} Intrinsics.reduce_sum(tile, Val(axis - 1)) end -@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} +@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} Intrinsics.reduce_sum(tile, Val(axis - 1)) end @@ -546,10 +546,10 @@ Maximum reduction along the specified axis (1-indexed). maxes = ct.reduce_max(tile, 2) # Max along axis 2 ``` """ -@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T, S} +@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} Intrinsics.reduce_max(tile, Val(axis - 1)) end -@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} +@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} Intrinsics.reduce_max(tile, Val(axis - 1)) end @@ -628,4 +628,3 @@ br = ct.extract(tile, (2, 2), (4, 4)) # Bottom-right (rows 5-8, cols 5-8) Intrinsics.extract(tile, Val(map(i -> i - 1, index)), Val(shape)) @inline extract(tile::Tile{T}, ::Val{Index}, ::Val{Shape}) where {T, Index, Shape} = Intrinsics.extract(tile, Val(map(i -> i - 1, Index)), Val(Shape)) - From 3731b8742253967006c0e72c5d889152e2276263 Mon Sep 17 00:00:00 2001 From: arhik Date: Sun, 18 Jan 2026 04:46:10 +0000 Subject: [PATCH 3/3] Add scan (prefix sum) operations support This commit adds support for scan (parallel prefix sum) operations to cuTile, based on the IntegerReduce branch and commit 0c9ab90. Key changes: - Added encode_ScanOp! to bytecode encodings for generating ScanOp bytecode - Added encode_scan_identity_array! to reuse existing identity encoding - Added scan intrinsic implementation using operation_identity from IntegerReduce - Added scan() and cumsum() public APIs with proper 1-indexed to 0-indexed axis conversion - Added comprehensive codegen tests for scan operations - Added scankernel.jl example demonstrating CSDL scan algorithm Features: - Supports cumulative sum (cumsum) for float and integer types - Supports both forward and reverse scan directions - Reuses FloatIdentityOp and IntegerIdentityOp from IntegerReduce - Uses operation_identity function for cleaner identity value creation - 1-indexed axis parameter (consistent with reduce operations) - Preserves tile shape (scan is an element-wise operation along one dimension) Tests: - All 142 codegen tests pass (including 6 new scan tests) - Scankernel.jl example runs successfully with CSDL algorithm - Clarify that it demonstrates device-side scan operation - Add note that test might occasionally fail (race condition in phase 2 loop) Minor comment improvements in scankernel.jl example - Clarify that it demonstrates device-side scan operation - Add note that test might occasionally fail (race condition in phase 2 loop) --- examples/scankernel.jl | 62 ++++++++++++++++++++++++++ src/bytecode/encodings.jl | 72 ++++++++++++++++++++++++++++++ src/bytecode/writer.jl | 10 ++--- src/compiler/intrinsics/core.jl | 79 ++++++++++++++++++++++++++++++++- src/language/operations.jl | 13 ++++++ test/codegen.jl | 74 +++++++++++++++++++++++++++++- 6 files changed, 303 insertions(+), 7 deletions(-) create mode 100644 examples/scankernel.jl diff --git a/examples/scankernel.jl b/examples/scankernel.jl new file mode 100644 index 0000000..b0764df --- /dev/null +++ b/examples/scankernel.jl @@ -0,0 +1,62 @@ +using Test +using CUDA +using cuTile +import cuTile as ct + +function cumsum_1d_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, + tile_size::ct.Constant{Int}) + bid = ct.bid(1) + tile = ct.load(a, bid, (tile_size[],)) + result = ct.cumsum(tile, Val(1)) # Val(1) means 1st (0th) dimension for 1D tile + ct.store(b, bid, result) + return nothing +end + +sz = 32 +N = 2^15 +a = CUDA.rand(Float32, N) +b = CUDA.zeros(Float32, N) +CUDA.@sync ct.launch(cumsum_1d_kernel, cld(length(a), sz), a, b, ct.Constant(sz)) + +# This is supposed to be a single pass kernel but its simpler version than memory ordering version. +# The idea is to show how device scan operation can be done. + +# CSDL phase 1: Intra-tile scan + store tile sums +function cumsum_csdl_phase1(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, + tile_sums::ct.TileArray{Float32,1}, + tile_size::ct.Constant{Int}) + bid = ct.bid(1) + tile = ct.load(a, bid, (tile_size[],)) + result = ct.cumsum(tile, Val(1)) + ct.store(b, bid, result) + tile_sum = ct.extract(result, (tile_size[],), (1,)) # Extract last element (1 element shape) + ct.store(tile_sums, bid, tile_sum) + return +end + +# CSDL phase 2: Decoupled lookback to accumulate previous tile sums +function cumsum_csdl_phase2(b::ct.TileArray{Float32,1}, + tile_sums::ct.TileArray{Float32,1}, + tile_size::ct.Constant{Int}) + bid = ct.bid(1) + prev_sum = ct.zeros((tile_size[],), Float32) + k = Int32(bid) + while k > 1 + tile_sum_k = ct.load(tile_sums, (k,), (1,)) + prev_sum = prev_sum .+ tile_sum_k + k -= Int32(1) + end + tile = ct.load(b, bid, (tile_size[],)) + result = tile .+ prev_sum + ct.store(b, bid, result) + return nothing +end + +n = length(a) +num_tiles = cld(n, sz) +tile_sums = CUDA.zeros(Float32, num_tiles) +CUDA.@sync ct.launch(cumsum_csdl_phase1, num_tiles, a, b, tile_sums, ct.Constant(sz)) +CUDA.@sync ct.launch(cumsum_csdl_phase2, num_tiles, b, tile_sums, ct.Constant(sz)) + +b_cpu = cumsum(a |> collect, dims=1) +@test isapprox(b |> collect, b_cpu) # This might fail occasionally diff --git a/src/bytecode/encodings.jl b/src/bytecode/encodings.jl index 9d20820..a38f8f4 100644 --- a/src/bytecode/encodings.jl +++ b/src/bytecode/encodings.jl @@ -1331,6 +1331,78 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder, end end + +#============================================================================= + Scan operations +=============================================================================# + +""" + encode_ScanOp!(body::Function, cb::CodeBuilder, + result_types::Vector{TypeId}, + operands::Vector{Value}, + dim::Int, + reverse::Bool, + identities::Vector{<:IdentityOp}, + body_scalar_types::Vector{TypeId}) + +Encode a ScanOp (parallel prefix sum) operation. + +# Arguments +- body: Function that takes block args and yields result(s) +- cb: CodeBuilder for the bytecode +- result_types: Output tile types +- operands: Input tiles to scan +- dim: Dimension to scan along (0-indexed) +- reverse: Whether to scan in reverse order +- identities: Identity values for each operand (reuses IdentityOp from IntegerReduce) +- body_scalar_types: 0D tile types for body arguments +""" +function encode_ScanOp!(body::Function, cb::CodeBuilder, + result_types::Vector{TypeId}, + operands::Vector{Value}, + dim::Int, + reverse::Bool, + identities::Vector{<:IdentityOp}, + body_scalar_types::Vector{TypeId}) + encode_varint!(cb.buf, Opcode.ScanOp) + + # Variadic result types + encode_typeid_seq!(cb.buf, result_types) + + # Attributes: dim (int), reverse (bool), identities (array) + encode_opattr_int!(cb, dim) + encode_opattr_bool!(cb, reverse) + encode_identity_array!(cb, identities) + + # Variadic operands + encode_varint!(cb.buf, length(operands)) + encode_operands!(cb.buf, operands) + + # Number of regions + push!(cb.debug_attrs, cb.cur_debug_attr) + cb.num_ops += 1 + encode_varint!(cb.buf, 1) # 1 region: body + + # Body region - block args are pairs of (acc, elem) for each operand + # The body operates on 0D tiles (scalars) + body_arg_types = TypeId[] + for scalar_type in body_scalar_types + push!(body_arg_types, scalar_type) # accumulator + push!(body_arg_types, scalar_type) # element + end + with_region(body, cb, body_arg_types) + + # Create result values + num_results = length(result_types) + if num_results == 0 + return Value[] + else + vals = [Value(cb.next_value_id + i) for i in 0:num_results-1] + cb.next_value_id += num_results + return vals + end +end + #============================================================================= Comparison and selection operations =============================================================================# diff --git a/src/bytecode/writer.jl b/src/bytecode/writer.jl index 52b693e..e985b5e 100644 --- a/src/bytecode/writer.jl +++ b/src/bytecode/writer.jl @@ -305,19 +305,19 @@ end Mask a UInt128 value to the correct bit width for the given type and apply zigzag if signed. """ -mask_to_width(value::UInt128, ::Type{Int64}, signed::Bool) = +mask_to_width(value::UInt128, ::Type{Int64}, signed::Bool) = let masked = UInt64(value & 0xFFFFFFFFFFFFFFFF) UInt64((masked << 1) ⊻ (masked >>> 63)) end -mask_to_width(value::UInt128, ::Type{Int32}, signed::Bool) = +mask_to_width(value::UInt128, ::Type{Int32}, signed::Bool) = let masked = UInt32(value & 0xFFFFFFFF) UInt32((masked << 1) ⊻ (masked >>> 31)) end -mask_to_width(value::UInt128, ::Type{Int16}, signed::Bool) = +mask_to_width(value::UInt128, ::Type{Int16}, signed::Bool) = let masked = UInt16(value & 0xFFFF) UInt16((masked << 1) ⊻ (masked >>> 15)) end -mask_to_width(value::UInt128, ::Type{Int8}, signed::Bool) = +mask_to_width(value::UInt128, ::Type{Int8}, signed::Bool) = let masked = UInt8(value & 0xFF) UInt8((masked << 1) ⊻ (masked >>> 7)) end @@ -612,7 +612,7 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder, end #============================================================================= - Optimization Hints + Optimization Hints =============================================================================# """ diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index 61a4a0b..089070d 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -700,7 +700,84 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reshape), args) CGVal(current_val, result_type_id, Tile{elem_type, Tuple(target_shape)}, target_shape) end -# TODO: cuda_tile.scan +# cuda_tile.scan +@eval Intrinsics begin + """ + scan(tile, axis_val, fn_type; reverse=false) + + Parallel prefix scan along specified dimension. + fn_type=:add for cumulative sum (only supported operation). + reverse=false for forward scan, true for reverse scan. + Compiled to cuda_tile.scan. + """ + @noinline function scan(tile::Tile{T, S}, ::Val{axis}, fn::Symbol, reverse::Bool=false) where {T, S, axis} + # Scan preserves shape - result has same dimensions as input + Tile{T, S}() + end +end + +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.scan), args) + cb = ctx.cb + tt = ctx.tt + + # Get input tile + input_tv = emit_value!(ctx, args[1]) + input_tv === nothing && error("Cannot resolve input tile for scan") + + # Get scan axis + axis = @something get_constant(ctx, args[2]) error("Scan axis must be a compile-time constant") + + # Get scan function type (only :add is supported) + fn_type = @something get_constant(ctx, args[3]) error("Scan function type must be a compile-time constant") + fn_type == :add || error("Only :add (cumulative sum) is currently supported for scan operations") + + # Get reverse flag (optional, defaults to false) + reverse = false + if length(args) >= 4 + reverse_val = get_constant(ctx, args[4]) + reverse = reverse_val === true + end + + # Get element type and shapes + input_type = unwrap_type(input_tv.jltype) + elem_type = input_type <: Tile ? input_type.parameters[1] : input_type + input_shape = input_tv.shape + + # For scan, output shape is same as input shape + output_shape = copy(input_shape) + + dtype = julia_to_tile_dtype!(tt, elem_type) + + # Output tile type (same shape as input) + output_tile_type = tile_type!(tt, dtype, output_shape) + + # Scalar type for scan body (0D tile) + scalar_tile_type = tile_type!(tt, dtype, Int[]) + + # Create identity value using operation_identity + # Reuses FloatIdentityOp and IntegerIdentityOp from IntegerReduce + identity = operation_identity(Val(fn_type), dtype, elem_type) + + # Emit ScanOp + results = encode_ScanOp!(cb, [output_tile_type], [input_tv.v], axis, reverse, [identity], [scalar_tile_type]) do block_args + acc, elem = block_args[1], block_args[2] + res = encode_scan_body(cb, scalar_tile_type, acc, elem, Val(fn_type), elem_type) + encode_YieldOp!(cb, [res]) + end + + + CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape) +end + +# Dispatch helpers for scan body operations - dispatch on Val{fn} and elem_type +encode_scan_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: AbstractFloat = + encode_AddFOp!(cb, type, acc, elem) +encode_scan_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: Integer = + encode_AddIOp!(cb, type, acc, elem) +encode_scan_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: AbstractFloat = + encode_MaxFOp!(cb, type, acc, elem) +encode_scan_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: Integer = + encode_MaxIOp!(cb, type, acc, elem; signedness=is_signed(T) ? SignednessSigned : SignednessUnsigned) # cuda_tile.select @eval Intrinsics begin diff --git a/src/language/operations.jl b/src/language/operations.jl index 2dcb24f..6c76da0 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -553,6 +553,19 @@ end Intrinsics.reduce_max(tile, Val(axis - 1)) end +# Scan (Prefix Sum) Operations + +@inline function scan(tile::Tile{T, S}, ::Val{axis}, + fn::Symbol=:add, + reverse::Bool=false) where {T<:Number, S, axis} + Intrinsics.scan(tile, Val(axis - 1), fn, reverse) +end + +@inline function cumsum(tile::Tile{T, S}, ::Val{axis}, + reverse::Bool=false) where {T<:Number, S, axis} + scan(tile, Val(axis), :add, reverse) +end + #============================================================================= Matrix multiplication =============================================================================# diff --git a/test/codegen.jl b/test/codegen.jl index c189310..676da71 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -19,7 +19,79 @@ # TODO: mmai - integer matrix multiply-accumulate # TODO: offset - tile offset computation # TODO: pack - pack tiles - # TODO: scan - parallel scan/prefix sum + @testset "scan" begin + # 1D cumulative sum (forward scan) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + result = ct.scan(tile, Val(1), :add, false) + ct.store(b, pid, result) + return + end + end + + # 2D cumulative sum along axis 1 (columns) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 8)) + result = ct.scan(tile, Val(2), :add, false) + ct.store(b, pid, result) + return + end + end + + # 2D cumulative sum along axis 2 (rows) - forward scan + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 8)) + result = ct.scan(tile, Val(1), :add, false) + ct.store(b, pid, result) + return + end + end + + # 2D cumulative sum along axis 2 (rows) - reverse scan + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 8)) + result = ct.scan(tile, Val(1), :add, true) + ct.store(b, pid, result) + return + end + end + + # Integer cumulative sum + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Int32,1,spec1d}, ct.TileArray{Int32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + result = ct.scan(tile, Val(1), :add, false) + ct.store(b, pid, result) + return + end + end + + # cumsum convenience function (forward scan) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,2,spec2d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (4, 8)) + result = ct.cumsum(tile, Val(2), false) + ct.store(b, pid, result) + return + end + end + end # TODO: unpack - unpack tiles @testset "reshape" begin