diff --git a/README.md b/README.md index 117267d..140e106 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,12 @@ conservative token threading in the compiler (see https://github.com/JuliaGPU/cu | Operation | Description | |-----------|-------------| | `reduce_sum(tile, axis)` | Sum along axis | +| `reduce_mul(tile, axis)` | Product along axis | | `reduce_max(tile, axis)` | Maximum along axis | +| `reduce_min(tile, axis)` | Minimum along axis | +| `reduce_and(tile, axis)` | Bitwise AND along axis (integer) | +| `reduce_or(tile, axis)` | Bitwise OR along axis (integer) | +| `reduce_xor(tile, axis)` | Bitwise XOR along axis (integer) | ### Math | Operation | Description | @@ -275,6 +280,11 @@ ct.permute(tile, (3, 1, 2)) This applies to `bid`, `num_blocks`, `permute`, `reshape`, dimension arguments, etc. +### axis convenience + +| `axis(i)` | Convert 1-based axis to 0-based (helper) | + + ### `Val`-like constants CuTile.jl uses `ct.Constant{T}` to encode compile-time constant values in the type domain, similar to how `Val` works. An explicit `[]` is needed to extract the value at runtime: diff --git a/examples/reducekernel.jl b/examples/reducekernel.jl new file mode 100644 index 0000000..f93c094 --- /dev/null +++ b/examples/reducekernel.jl @@ -0,0 +1,153 @@ +using Test +using CUDA +using cuTile +import cuTile as ct + +# Kernel factory to properly capture element type and operation +function makeReduceKernel(::Type{T}, op::Symbol) where {T} + reduceFunc = if op == :reduce_min + ct.reduce_min + elseif op == :reduce_max + ct.reduce_max + elseif op == :reduce_sum + ct.reduce_sum + elseif op == :reduce_xor + ct.reduce_xor + elseif op == :reduce_or + ct.reduce_or + elseif op == :reduce_and + ct.reduce_and + end + + @inline function kernel(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, tileSz::ct.Constant{Int}) + ct.store(b, ct.bid(1), reduceFunc(ct.load(a, ct.bid(1), (tileSz[],)), Val(1))) + return nothing + end + return kernel +end + +# Test with UInt types +@testset for elType in [UInt16, UInt32, UInt64] + @testset for op in [:reduce_min, :reduce_max, :reduce_sum, :reduce_xor, :reduce_or, :reduce_and] + sz = 32 + N = 2^15 + + # Create kernel using factory + reduceKernel = try + makeReduceKernel(elType, op) + catch e + @test_broken false + rethrow() + end + + # Create data and run kernel + a_gpu = CUDA.rand(elType, N) + b_gpu = CUDA.zeros(elType, cld(N, sz)) + try + CUDA.@sync ct.launch(reduceKernel, cld(length(a_gpu), sz), a_gpu, b_gpu, ct.Constant(sz)) + catch e + @test_broken false + rethrow() + end + res = Array(b_gpu) + + # CPU computation + a_cpu = Array(a_gpu) + a_reshaped = reshape(a_cpu, sz, :) + + if op == :reduce_min + cpu_result = minimum(a_reshaped, dims=1)[:] + elseif op == :reduce_max + cpu_result = maximum(a_reshaped, dims=1)[:] + elseif op == :reduce_sum + raw_sum = sum(a_reshaped, dims=1)[:] + cpu_result = raw_sum .& typemax(elType) + elseif op == :reduce_xor + cpu_result = mapslices(x -> reduce(⊻, x), a_reshaped, dims=1)[:] + elseif op == :reduce_or + cpu_result = mapslices(x -> reduce(|, x), a_reshaped, dims=1)[:] + elseif op == :reduce_and + cpu_result = mapslices(x -> reduce(&, x), a_reshaped, dims=1)[:] + end + + @test cpu_result == res + end +end + +# Test with signed Int types +@testset for elType in [Int16, Int32, Int64] + @testset for op in [:reduce_min, :reduce_max, :reduce_sum, :reduce_xor, :reduce_or, :reduce_and] + sz = 32 + N = 2^15 + + # Create kernel using factory + reduceKernel = try + makeReduceKernel(elType, op) + catch e + @test_broken false + rethrow() + end + + # Create data and run kernel - use range to get negative values too + a_gpu = CuArray{elType}(rand(-1000:1000, N)) + b_gpu = CUDA.zeros(elType, cld(N, sz)) + try + CUDA.@sync ct.launch(reduceKernel, cld(length(a_gpu), sz), a_gpu, b_gpu, ct.Constant(sz)) + catch e + @test_broken false + rethrow() + end + res = Array(b_gpu) + + # CPU computation + a_cpu = Array(a_gpu) + a_reshaped = reshape(a_cpu, sz, :) + + if op == :reduce_min + cpu_result = minimum(a_reshaped, dims=1)[:] + elseif op == :reduce_max + cpu_result = maximum(a_reshaped, dims=1)[:] + elseif op == :reduce_sum + cpu_result = sum(a_reshaped, dims=1)[:] + elseif op == :reduce_xor + cpu_result = mapslices(x -> reduce(⊻, x), a_reshaped, dims=1)[:] + elseif op == :reduce_or + cpu_result = mapslices(x -> reduce(|, x), a_reshaped, dims=1)[:] + elseif op == :reduce_and + cpu_result = mapslices(x -> reduce(&, x), a_reshaped, dims=1)[:] + end + + @test cpu_result == res + end +end + +# Test with Float types +@testset for elType in [Float16, Float32, Float64] + @testset for op in [:reduce_min, :reduce_max, :reduce_sum] + sz = 32 + N = 2^15 + + # Create kernel using factory + reduceKernel = makeReduceKernel(elType, op) + + # Create data and run kernel + a_gpu = CUDA.rand(elType, N) + b_gpu = CUDA.zeros(elType, cld(N, sz)) + CUDA.@sync ct.launch(reduceKernel, cld(length(a_gpu), sz), a_gpu, b_gpu, ct.Constant(sz)) + res = Array(b_gpu) + + # CPU computation + a_cpu = Array(a_gpu) + a_reshaped = reshape(a_cpu, sz, :) + + if op == :reduce_min + cpu_result = minimum(a_reshaped, dims=1)[:] + elseif op == :reduce_max + cpu_result = maximum(a_reshaped, dims=1)[:] + elseif op == :reduce_sum + cpu_result = sum(a_reshaped, dims=1)[:] + end + + @test isapprox(cpu_result, res) + end +end \ No newline at end of file diff --git a/src/bytecode/encodings.jl b/src/bytecode/encodings.jl index 9f06415..9d20820 100644 --- a/src/bytecode/encodings.jl +++ b/src/bytecode/encodings.jl @@ -1291,7 +1291,7 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder, result_types::Vector{TypeId}, operands::Vector{Value}, dim::Int, - identities::Vector{<:ReduceIdentity}, + identities::Vector{<:IdentityOp}, body_scalar_types::Vector{TypeId}) encode_varint!(cb.buf, Opcode.ReduceOp) diff --git a/src/bytecode/writer.jl b/src/bytecode/writer.jl index eb87585..a6f34c1 100644 --- a/src/bytecode/writer.jl +++ b/src/bytecode/writer.jl @@ -234,30 +234,42 @@ end =============================================================================# """ - ReduceIdentity + IdentityOp -Abstract type for reduce identity attributes. +Abstract type for binary operation identity attributes (reduce, scan, etc.). """ -abstract type ReduceIdentity end +abstract type IdentityOp end """ - FloatIdentity(value, type_id, dtype) + FloatIdentityOp(value, type_id, dtype) -Float identity value for reduce operations. +Float identity value for binary operations. """ -struct FloatIdentity <: ReduceIdentity +struct FloatIdentityOp <: IdentityOp value::Float64 type_id::TypeId dtype::Type # Float16, Float32, Float64, etc. end """ - encode_tagged_float!(cb, identity::FloatIdentity) + IntegerIdentityOp(value, type_id, dtype, signed) + +Integer identity value for binary operations. +""" +struct IntegerIdentityOp <: IdentityOp + value::UInt128 # Store as UInt128 to handle all unsigned values up to 64 bits + type_id::TypeId + dtype::Type # Int8, Int16, Int32, Int64, UInt8, etc. + signed::Bool # true for signed, false for unsigned +end + +""" + encode_tagged_float!(cb, identity::FloatIdentityOp) Encode a tagged float attribute for reduce identity. Format: tag(Float=0x02) + typeid + ap_int(value_bits) """ -function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity) +function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentityOp) # Tag for Float attribute push!(cb.buf, 0x02) # Type ID @@ -267,6 +279,59 @@ function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity) encode_signed_varint!(cb.buf, bits) end +""" + encode_tagged_int!(cb, identity::IntegerIdentityOp) + +Encode a tagged integer identity attribute. +Format: tag(Int=0x01) + typeid + ap_int(value) +""" +function encode_tagged_int!(cb::CodeBuilder, identity::IntegerIdentityOp) + # Tag for Int attribute + push!(cb.buf, 0x01) + # Type ID + encode_typeid!(cb.buf, identity.type_id) + # Value: signed uses zigzag varint, unsigned uses plain varint + # Mask value to correct bit width and apply zigzag for signed types + masked_value = mask_to_width(identity.value, identity.dtype, identity.signed) + if identity.signed + encode_signed_varint!(cb.buf, masked_value) + else + encode_varint!(cb.buf, masked_value) + end +end + +""" + mask_to_width(value, dtype, signed) + +Mask a UInt128 value to the correct bit width for the given type and apply zigzag if signed. +For signed types, this masks first, then applies zigzag encoding. +""" +# Signed Int64: mask to 64 bits first, then zigzag encode +mask_to_width(value::UInt128, ::Type{Int64}, signed::Bool) = + let masked = UInt64(value & 0xFFFFFFFFFFFFFFFF) + UInt64((masked << 1) ⊻ (masked >>> 63)) + end +# Signed Int32: mask to 32 bits first, then zigzag encode +mask_to_width(value::UInt128, ::Type{Int32}, signed::Bool) = + let masked = UInt32(value & 0xFFFFFFFF) + UInt32((masked << 1) ⊻ (masked >>> 31)) + end +# Signed Int16: mask to 16 bits first, then zigzag encode +mask_to_width(value::UInt128, ::Type{Int16}, signed::Bool) = + let masked = UInt16(value & 0xFFFF) + UInt16((masked << 1) ⊻ (masked >>> 15)) + end +# Signed Int8: mask to 8 bits first, then zigzag encode +mask_to_width(value::UInt128, ::Type{Int8}, signed::Bool) = + let masked = UInt8(value & 0xFF) + UInt8((masked << 1) ⊻ (masked >>> 7)) + end +# Unsigned types: just mask to bit width, no zigzag +mask_to_width(value::UInt128, ::Type{UInt64}, signed::Bool) = UInt64(value & 0xFFFFFFFFFFFFFFFF) +mask_to_width(value::UInt128, ::Type{UInt32}, signed::Bool) = UInt32(value & 0xFFFFFFFF) +mask_to_width(value::UInt128, ::Type{UInt16}, signed::Bool) = UInt16(value & 0xFFFF) +mask_to_width(value::UInt128, ::Type{UInt8}, signed::Bool) = UInt8(value & 0xFF) + """ float_to_bits(value, dtype) @@ -296,6 +361,7 @@ end Encode a signed integer as a variable-length integer. Uses zigzag encoding for signed values. """ + function encode_signed_varint!(buf::Vector{UInt8}, value::Union{UInt16, UInt32, UInt64, Int64}) # For float bits, encode as unsigned varint encode_varint!(buf, UInt64(value)) @@ -304,15 +370,24 @@ end """ encode_identity_array!(cb, identities) -Encode an array of reduce identity attributes. +Encode an array of binary operation identity attributes. +Dispatches on identity type to encode correctly. """ -function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:ReduceIdentity}) +function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:IdentityOp}) encode_varint!(cb.buf, length(identities)) for identity in identities - encode_tagged_float!(cb, identity) + encode_identity!(cb, identity) end end +""" + encode_identity!(cb, identity) + +Encode a single identity attribute, dispatching on type. +""" +encode_identity!(cb::CodeBuilder, identity::FloatIdentityOp) = encode_tagged_float!(cb, identity) +encode_identity!(cb::CodeBuilder, identity::IntegerIdentityOp) = encode_tagged_int!(cb, identity) + """ BytecodeWriter @@ -544,7 +619,7 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder, end #============================================================================= - Optimization Hints + Optimization Hints =============================================================================# """ diff --git a/src/compiler/intrinsics.jl b/src/compiler/intrinsics.jl index 16c55da..e522141 100644 --- a/src/compiler/intrinsics.jl +++ b/src/compiler/intrinsics.jl @@ -8,6 +8,7 @@ using Base: compilerbarrier, donotdelete using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual +using ..cuTile: IdentityOp, FloatIdentityOp, IntegerIdentityOp end diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index 7fb2530..53d35af 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -512,7 +512,7 @@ end Sum reduction along 0-indexed axis. Compiled to cuda_tile.reduce with ADD. """ - @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} + @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) Tile{T, reduced_shape}() end @@ -523,7 +523,65 @@ end Maximum reduction along 0-indexed axis. Compiled to cuda_tile.reduce with MAX. """ - @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} + @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_mul(tile, axis_val) + + Product reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with MUL. + """ + @noinline function reduce_mul(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_min(tile, axis_val) + + Minimum reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with MIN. + """ + @noinline function reduce_min(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_and(tile, axis_val) + + Bitwise AND reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with AND. + Integer types only. + """ + @noinline function reduce_and(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_or(tile, axis_val) + + Bitwise OR reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with OR. + Integer types only. + """ + @noinline function reduce_or(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_xor(tile, axis_val) + + Bitwise XOR reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with XOR. + Integer types only. + """ + @noinline function reduce_xor(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) Tile{T, reduced_shape}() end @@ -534,6 +592,22 @@ end function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_max), args) emit_reduce!(ctx, args, :max) end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_mul), args) + emit_reduce!(ctx, args, :mul) +end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_min), args) + emit_reduce!(ctx, args, :min) +end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_and), args) + emit_reduce!(ctx, args, :and) +end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_or), args) + emit_reduce!(ctx, args, :or) +end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_xor), args) + emit_reduce!(ctx, args, :xor) +end + function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol) cb = ctx.cb tt = ctx.tt @@ -558,32 +632,127 @@ function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol) # Output tile type output_tile_type = tile_type!(tt, dtype, output_shape) - # Scalar type for reduction body (0D tile) scalar_tile_type = tile_type!(tt, dtype, Int[]) - # Create identity value - use simple dtype (f32), not tile type - identity_val = reduce_fn == :add ? -0.0 : (reduce_fn == :max ? -Inf : 0.0) - identity = FloatIdentity(identity_val, dtype, elem_type) + # Create identity value via dispatch on reduction function and element type + identity = operation_identity(Val(reduce_fn), dtype, elem_type) # Emit ReduceOp results = encode_ReduceOp!(cb, [output_tile_type], [input_tv.v], axis, [identity], [scalar_tile_type]) do block_args acc, elem = block_args[1], block_args[2] - if reduce_fn == :add - res = encode_AddFOp!(cb, scalar_tile_type, acc, elem) - elseif reduce_fn == :max - res = encode_MaxFOp!(cb, scalar_tile_type, acc, elem) - else - error("Unsupported reduction function: $reduce_fn") - end - + res = encode_reduce_body(cb, scalar_tile_type, acc, elem, Val(reduce_fn), elem_type) encode_YieldOp!(cb, [res]) end CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape) end +#============================================================================= + Reduce Identity Values via Dispatch +=============================================================================# + +""" + is_signed(::Type{T}) -> Bool + +Return true if type T is signed, false for unsigned types. +""" +is_signed(::Type{T}) where T <: Integer = T <: Integer && !(T <: Unsigned) +is_signed(::Type{T}) where T <: AbstractFloat = false + +""" + operation_identity(fn, dtype, elem_type) -> IdentityOp + +Return the identity value for a binary operation (reduce, scan, etc.). +Identity must satisfy: identity ⊕ x = x for the operation. +""" + +""" + to_uint128(value, dtype) + +Convert an integer value to UInt128 for storage in IntegerIdentityOp. +For signed types, this returns the two's complement bit representation. +""" +# Unsigned types: directly convert +to_uint128(value::UInt64) = UInt128(value) +to_uint128(value::UInt32) = UInt128(value) +to_uint128(value::UInt16) = UInt128(value) +to_uint128(value::UInt8) = UInt128(value) +# Signed types: reinterpret as unsigned first, then convert +to_uint128(value::Int64) = UInt128(reinterpret(UInt64, value)) +to_uint128(value::Int32) = UInt128(reinterpret(UInt32, value)) +to_uint128(value::Int16) = UInt128(reinterpret(UInt16, value)) +to_uint128(value::Int8) = UInt128(reinterpret(UInt8, value)) + +# Addition identity: 0 + x = x +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(zero(T), dtype, T) +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(to_uint128(zero(T)), dtype, T, is_signed(T)) + +# Maximum identity: max(typemin(T), x) = x +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(typemin(T), dtype, T) +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(to_uint128(typemin(T)), dtype, T, is_signed(T)) + +# Multiplication identity: 1 * x = x +operation_identity(::Val{:mul}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(one(T), dtype, T) +operation_identity(::Val{:mul}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(to_uint128(one(T)), dtype, T, is_signed(T)) + +# Minimum identity: min(typemax(T), x) = x +operation_identity(::Val{:min}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(typemax(T), dtype, T) +operation_identity(::Val{:min}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(to_uint128(typemax(T)), dtype, T, is_signed(T)) + +# AND identity: all bits set (x & identity == x) +# For signed: -one(T) has all bits set in two's complement +# For unsigned: typemax(T) has all bits set +operation_identity(::Val{:and}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(to_uint128(is_signed(T) ? -one(T) : typemax(T)), dtype, T, is_signed(T)) + +# OR identity: 0 | x = x +operation_identity(::Val{:or}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(to_uint128(zero(T)), dtype, T, is_signed(T)) + +# XOR identity: 0 ⊕ x = x +operation_identity(::Val{:xor}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(to_uint128(zero(T)), dtype, T, is_signed(T)) + +#============================================================================= + Reduce Body Operations - dispatch on Val{fn} and elem_type +=============================================================================# + +encode_reduce_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: AbstractFloat = + encode_AddFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: AbstractFloat = + encode_MaxFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:mul}, ::Type{T}) where T <: AbstractFloat = + encode_MulFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:min}, ::Type{T}) where T <: AbstractFloat = + encode_MinFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: Integer = + encode_AddIOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: Integer = + encode_MaxIOp!(cb, type, acc, elem; signedness=is_signed(T) ? SignednessSigned : SignednessUnsigned) +encode_reduce_body(cb, type, acc, elem, ::Val{:mul}, ::Type{T}) where T <: Integer = + encode_MulIOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:min}, ::Type{T}) where T <: Integer = + encode_MinIOp!(cb, type, acc, elem; signedness=is_signed(T) ? SignednessSigned : SignednessUnsigned) + + +# less likely commutative/associative ops can be reduced too for whatever reason. +# eg: and, or, xor. +encode_reduce_body(cb, type, acc, elem, ::Val{:and}, ::Type{T}) where T <: Integer = + encode_AndIOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:or}, ::Type{T}) where T <: Integer = + encode_OrIOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:xor}, ::Type{T}) where T <: Integer = + encode_XOrIOp!(cb, type, acc, elem) # cuda_tile.reshape @eval Intrinsics begin diff --git a/src/language/operations.jl b/src/language/operations.jl index 463a358..4e1de57 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -10,7 +10,7 @@ Load/Store =============================================================================# -public bid, num_blocks, num_tiles, load, store, gather, scatter +public bid, num_blocks, num_tiles, axis, load, store, gather, scatter """ Padding mode for load operations. @@ -59,6 +59,24 @@ Axis is 1-indexed. Equivalent to cld(arr.sizes[axis], shape[axis]). Intrinsics.get_index_space_shape(pv, axis - One()) # convert to 0-indexed end +""" + axis(i::Integer) -> Val{i-1} + +Return a compile-time axis selector for tile operations. +Axis indices are 1-based (axis(1) = first dimension, axis(2) = second, etc.). +Internally converts to 0-based for Tile IR. + +Use this instead of raw `Val` for self-documenting code. + +# Examples +```julia +ct.cumsum(tile, ct.axis(1)) # Scan along first axis +ct.cumsum(tile, ct.axis(2)) # Scan along second axis +ct.scan(tile, ct.axis(1), :add) +``` +""" +@inline axis(i::Integer) = Val(i - One()) + """ load(arr::TileArray, index, shape; padding_mode=PaddingMode.Undetermined, latency=nothing, allow_tma=true) -> Tile @@ -515,7 +533,7 @@ result = ct.astype(acc, ct.TFloat32) # Convert to TF32 for tensor cores Reduction =============================================================================# -public reduce_sum, reduce_max +public reduce_sum, reduce_max, reduce_mul, reduce_min, reduce_and, reduce_or, reduce_xor """ reduce_sum(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} @@ -523,16 +541,18 @@ public reduce_sum, reduce_max Sum reduction along the specified axis (1-indexed). Returns a tile with the specified dimension removed. +Supports any numeric type (Float16, Float32, Float64, and integer types). + # Example ```julia # For a (128, 64) tile, reducing along axis 2: sums = ct.reduce_sum(tile, 2) # Returns (128,) tile ``` """ -@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S} +@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} Intrinsics.reduce_sum(tile, Val(axis - 1)) end -@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} +@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} Intrinsics.reduce_sum(tile, Val(axis - 1)) end @@ -540,19 +560,113 @@ end reduce_max(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} Maximum reduction along the specified axis (1-indexed). +Supports any numeric type (Float16, Float32, Float64, and integer types). # Example ```julia maxes = ct.reduce_max(tile, 2) # Max along axis 2 ``` """ -@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S} +@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} Intrinsics.reduce_max(tile, Val(axis - 1)) end -@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} +@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} Intrinsics.reduce_max(tile, Val(axis - 1)) end +""" + reduce_mul(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Product reduction along the specified axis (1-indexed). +Returns a tile with the specified dimension removed. + +# Example +```julia +# For a (128, 64) tile, reducing along axis 2: +products = ct.reduce_mul(tile, 2) # Returns (128,) tile +``` +""" +@inline function reduce_mul(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} + Intrinsics.reduce_mul(tile, Val(axis - 1)) +end +@inline function reduce_mul(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} + Intrinsics.reduce_mul(tile, Val(axis - 1)) +end + +""" + reduce_min(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Minimum reduction along the specified axis (1-indexed). + +# Example +```julia +mins = ct.reduce_min(tile, 2) # Min along axis 2 +``` +""" +@inline function reduce_min(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} + Intrinsics.reduce_min(tile, Val(axis - 1)) +end +@inline function reduce_min(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} + Intrinsics.reduce_min(tile, Val(axis - 1)) +end + +""" + reduce_and(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Bitwise AND reduction along the specified axis (1-indexed). +Integer types only. + +# Example +```julia +# For an Int32 tile, reducing along axis 2: +result = ct.reduce_and(tile, 2) # Returns (128,) tile of Int32 +``` +""" +@inline function reduce_and(tile::Tile{T, S}, axis::Integer) where {T <: Integer, S} + Intrinsics.reduce_and(tile, Val(axis - 1)) +end +@inline function reduce_and(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} + Intrinsics.reduce_and(tile, Val(axis - 1)) +end + +""" + reduce_or(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Bitwise OR reduction along the specified axis (1-indexed). +Integer types only. + +# Example +```julia +# For an Int32 tile, reducing along axis 2: +result = ct.reduce_or(tile, 2) # Returns (128,) tile of Int32 +``` +""" +@inline function reduce_or(tile::Tile{T, S}, axis::Integer) where {T <: Integer, S} + Intrinsics.reduce_or(tile, Val(axis - 1)) +end +@inline function reduce_or(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} + Intrinsics.reduce_or(tile, Val(axis - 1)) +end + +""" + reduce_xor(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Bitwise XOR reduction along the specified axis (1-indexed). +Integer types only. + +# Example +```julia +# For an Int32 tile, reducing along axis 2: +result = ct.reduce_xor(tile, 2) # Returns (128,) tile of Int32 +``` +""" +@inline function reduce_xor(tile::Tile{T, S}, axis::Integer) where {T <: Integer, S} + Intrinsics.reduce_xor(tile, Val(axis - 1)) +end +@inline function reduce_xor(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} + Intrinsics.reduce_xor(tile, Val(axis - 1)) +end + #============================================================================= Matrix multiplication =============================================================================# @@ -628,4 +742,3 @@ br = ct.extract(tile, (2, 2), (4, 4)) # Bottom-right (rows 5-8, cols 5-8) Intrinsics.extract(tile, Val(map(i -> i - 1, index)), Val(shape)) @inline extract(tile::Tile{T}, ::Val{Index}, ::Val{Shape}) where {T, Index, Shape} = Intrinsics.extract(tile, Val(map(i -> i - 1, Index)), Val(Shape)) - diff --git a/test/reduce_ops.jl b/test/reduce_ops.jl new file mode 100644 index 0000000..4250904 --- /dev/null +++ b/test/reduce_ops.jl @@ -0,0 +1,154 @@ + +using cuTile +import cuTile as ct +using CUDA +using Test + +# Kernel factory to properly capture element type and operation +function makeReduceKernel(::Type{T}, op::Symbol) where {T} + reduceFunc = if op == :reduce_min + ct.reduce_min + elseif op == :reduce_max + ct.reduce_max + elseif op == :reduce_sum + ct.reduce_sum + elseif op == :reduce_xor + ct.reduce_xor + elseif op == :reduce_or + ct.reduce_or + elseif op == :reduce_and + ct.reduce_and + end + + @inline function kernel(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, tileSz::ct.Constant{Int}) + ct.store(b, ct.bid(1), reduceFunc(ct.load(a, ct.bid(1), (tileSz[],)), Val(1))) + return nothing + end + return kernel +end + +# Test with UInt types +@testset for elType in [UInt16, UInt32, UInt64] + @testset for op in [:reduce_min, :reduce_max, :reduce_sum, :reduce_xor, :reduce_or, :reduce_and] + sz = 32 + N = 2^15 + + # Create kernel using factory + reduceKernel = try + makeReduceKernel(elType, op) + catch e + @test_broken false + rethrow() + end + + # Create data and run kernel + a_gpu = CUDA.rand(elType, N) + b_gpu = CUDA.zeros(elType, cld(N, sz)) + try + CUDA.@sync ct.launch(reduceKernel, cld(length(a_gpu), sz), a_gpu, b_gpu, ct.Constant(sz)) + catch e + @test_broken false + rethrow() + end + res = Array(b_gpu) + + # CPU computation + a_cpu = Array(a_gpu) + a_reshaped = reshape(a_cpu, sz, :) + + if op == :reduce_min + cpu_result = minimum(a_reshaped, dims=1)[:] + elseif op == :reduce_max + cpu_result = maximum(a_reshaped, dims=1)[:] + elseif op == :reduce_sum + raw_sum = sum(a_reshaped, dims=1)[:] + cpu_result = raw_sum .& typemax(elType) + elseif op == :reduce_xor + cpu_result = mapslices(x -> reduce(⊻, x), a_reshaped, dims=1)[:] + elseif op == :reduce_or + cpu_result = mapslices(x -> reduce(|, x), a_reshaped, dims=1)[:] + elseif op == :reduce_and + cpu_result = mapslices(x -> reduce(&, x), a_reshaped, dims=1)[:] + end + + @test cpu_result == res + end +end + +# Test with signed Int types +@testset for elType in [Int16, Int32, Int64] + @testset for op in [:reduce_min, :reduce_max, :reduce_sum, :reduce_xor, :reduce_or, :reduce_and] + sz = 32 + N = 2^15 + + # Create kernel using factory + reduceKernel = try + makeReduceKernel(elType, op) + catch e + @test_broken false + rethrow() + end + + # Create data and run kernel - use range to get negative values too + a_gpu = CuArray{elType}(rand(-1000:1000, N)) + b_gpu = CUDA.zeros(elType, cld(N, sz)) + try + CUDA.@sync ct.launch(reduceKernel, cld(length(a_gpu), sz), a_gpu, b_gpu, ct.Constant(sz)) + catch e + @test_broken false + rethrow() + end + res = Array(b_gpu) + + # CPU computation + a_cpu = Array(a_gpu) + a_reshaped = reshape(a_cpu, sz, :) + + if op == :reduce_min + cpu_result = minimum(a_reshaped, dims=1)[:] + elseif op == :reduce_max + cpu_result = maximum(a_reshaped, dims=1)[:] + elseif op == :reduce_sum + cpu_result = sum(a_reshaped, dims=1)[:] + elseif op == :reduce_xor + cpu_result = mapslices(x -> reduce(⊻, x), a_reshaped, dims=1)[:] + elseif op == :reduce_or + cpu_result = mapslices(x -> reduce(|, x), a_reshaped, dims=1)[:] + elseif op == :reduce_and + cpu_result = mapslices(x -> reduce(&, x), a_reshaped, dims=1)[:] + end + + @test cpu_result == res + end +end + +# Test with Float types +@testset for elType in [Float16, Float32, Float64] + @testset for op in [:reduce_min, :reduce_max, :reduce_sum] + sz = 32 + N = 2^15 + + # Create kernel using factory + reduceKernel = makeReduceKernel(elType, op) + + # Create data and run kernel + a_gpu = CUDA.rand(elType, N) + b_gpu = CUDA.zeros(elType, cld(N, sz)) + CUDA.@sync ct.launch(reduceKernel, cld(length(a_gpu), sz), a_gpu, b_gpu, ct.Constant(sz)) + res = Array(b_gpu) + + # CPU computation + a_cpu = Array(a_gpu) + a_reshaped = reshape(a_cpu, sz, :) + + if op == :reduce_min + cpu_result = minimum(a_reshaped, dims=1)[:] + elseif op == :reduce_max + cpu_result = maximum(a_reshaped, dims=1)[:] + elseif op == :reduce_sum + cpu_result = sum(a_reshaped, dims=1)[:] + end + + @test isapprox(cpu_result, res) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6b9aed9..a52163a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,7 +45,7 @@ if filter_tests!(testsuite, args) cuda_functional = CUDA.functional() filter!(testsuite) do (test, _) - if in(test, ["execution"]) || startswith(test, "examples/") + if in(test, ["execution", "reduce_ops"]) || startswith(test, "examples/") return cuda_functional else return true