diff --git a/README.md b/README.md index 1b381ba..45336a7 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,12 @@ conservative token threading in the compiler (see https://github.com/JuliaGPU/cu | Operation | Description | |-----------|-------------| | `reduce_sum(tile, axis)` | Sum along axis | +| `reduce_mul(tile, axis)` | Product along axis | | `reduce_max(tile, axis)` | Maximum along axis | +| `reduce_min(tile, axis)` | Minimum along axis | +| `reduce_and(tile, axis)` | Bitwise AND along axis (integer) | +| `reduce_or(tile, axis)` | Bitwise OR along axis (integer) | +| `reduce_xor(tile, axis)` | Bitwise XOR along axis (integer) | ### Math | Operation | Description | @@ -274,6 +279,11 @@ ct.permute(tile, (3, 1, 2)) This applies to `bid`, `num_blocks`, `permute`, `reshape`, dimension arguments, etc. +### axis convenience + +| `axis(i)` | Convert 1-based axis to 0-based (helper) | + + ### `Val`-like constants CuTile.jl uses `ct.Constant{T}` to encode compile-time constant values in the type domain, similar to how `Val` works. An explicit `[]` is needed to extract the value at runtime: diff --git a/examples/reducekernel.jl b/examples/reducekernel.jl new file mode 100644 index 0000000..7bb485b --- /dev/null +++ b/examples/reducekernel.jl @@ -0,0 +1,20 @@ +using Test +using CUDA +using cuTile +import cuTile as ct + +elType = UInt16 +function reduceKernel(a::ct.TileArray{elType,1}, b::ct.TileArray{elType,1}, tileSz::ct.Constant{Int}) + bid = ct.bid(1) + tile = ct.load(a, bid, (tileSz[],)) + result = ct.reduce_min(tile, Val(1)) + ct.store(b, bid, result) + return nothing +end + +sz = 32 +N = 2^15 +a = CUDA.rand(elType, N) +b = CUDA.zeros(elType, cld(N, sz)) +CUDA.@sync ct.launch(reduceKernel, cld(length(a), sz), a, b, ct.Constant(sz)) +res = Array(b) diff --git a/examples/scanKernel.jl b/examples/scanKernel.jl new file mode 100644 index 0000000..b5dd8ea --- /dev/null +++ b/examples/scanKernel.jl @@ -0,0 +1,62 @@ +using Test +using CUDA +using cuTile +import cuTile as ct + +function cumsum_1d_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, + tile_size::ct.Constant{Int}) + bid = ct.bid(1) + tile = ct.load(a, bid, (tile_size[],)) + result = ct.cumsum(tile, ct.axis(1)) + ct.store(b, bid, result) + return nothing +end + +sz = 32 +N = 2^15 +a = CUDA.rand(Float32, N) +b = CUDA.zeros(Float32, N) +CUDA.@sync ct.launch(cumsum_1d_kernel, cld(length(a), sz), a, b, ct.Constant(sz)) + +# This is supposed to be a single pass kernel but its simpler version than memory ordering version. +# The idea is to show scan operation. + +# CSDL phase 1: Intra-tile scan + store tile sums +function cumsum_csdl_phase1(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, + tile_sums::ct.TileArray{Float32,1}, + tile_size::ct.Constant{Int}) + bid = ct.bid(1) + tile = ct.load(a, bid, (tile_size[],)) + result = ct.cumsum(tile, ct.axis(1)) + ct.store(b, bid, result) + tile_sum = ct.extract(result, (tile_size[],), (1,)) + ct.store(tile_sums, bid, tile_sum) + return +end + +# CSDL phase 2: Decoupled lookback to accumulate previous tile sums +function cumsum_csdl_phase2(b::ct.TileArray{Float32,1}, + tile_sums::ct.TileArray{Float32,1}, + tile_size::ct.Constant{Int}) + bid = ct.bid(1) + prev_sum = ct.zeros((tile_size[],), Float32) + k = Int32(bid) + while k > 1 + tile_sum_k = ct.load(tile_sums, (k,), (1,)) + prev_sum = prev_sum .+ tile_sum_k + k -= Int32(1) + end + tile = ct.load(b, bid, (tile_size[],)) + result = tile .+ prev_sum + ct.store(b, bid, result) + return nothing +end + +n = length(a) +num_tiles = cld(n, sz) +tile_sums = CUDA.zeros(Float32, num_tiles) +CUDA.@sync ct.launch(cumsum_csdl_phase1, num_tiles, a, b, tile_sums, ct.Constant(sz)) +CUDA.@sync ct.launch(cumsum_csdl_phase2, num_tiles, b, tile_sums, ct.Constant(sz)) + +b_cpu = cumsum(a |> collect, dims=1) +@test isapprox(b |> collect, b_cpu) diff --git a/src/bytecode/encodings.jl b/src/bytecode/encodings.jl index 9305bee..a301257 100644 --- a/src/bytecode/encodings.jl +++ b/src/bytecode/encodings.jl @@ -1291,7 +1291,7 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder, result_types::Vector{TypeId}, operands::Vector{Value}, dim::Int, - identities::Vector{<:ReduceIdentity}, + identities::Vector{<:IdentityOp}, body_scalar_types::Vector{TypeId}) encode_varint!(cb.buf, Opcode.ReduceOp) @@ -1331,6 +1331,78 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder, end end + +#============================================================================= + Scan operations +=============================================================================# + +""" + encode_ScanOp!(body::Function, cb::CodeBuilder, + result_types::Vector{TypeId}, + operands::Vector{Value}, + dim::Int, + reverse::Bool, + identities::Vector{<:IdentityOp}, + body_scalar_types::Vector{TypeId}) + +Encode a ScanOp (parallel prefix sum) operation. + +# Arguments +- body: Function that takes block args and yields result(s) +- cb: CodeBuilder for the bytecode +- result_types: Output tile types +- operands: Input tiles to scan +- dim: Dimension to scan along (0-indexed) +- reverse: Whether to scan in reverse order +- identities: Identity values for each operand +- body_scalar_types: 0D tile types for body arguments +""" +function encode_ScanOp!(body::Function, cb::CodeBuilder, + result_types::Vector{TypeId}, + operands::Vector{Value}, + dim::Int, + reverse::Bool, + identities::Vector{<:IdentityOp}, + body_scalar_types::Vector{TypeId}) + encode_varint!(cb.buf, Opcode.ScanOp) + + # Variadic result types + encode_typeid_seq!(cb.buf, result_types) + + # Attributes: dim (int), reverse (bool), identities (array) + encode_opattr_int!(cb, dim) + encode_opattr_bool!(cb, reverse) + encode_identity_array!(cb, identities) + + # Variadic operands + encode_varint!(cb.buf, length(operands)) + encode_operands!(cb.buf, operands) + + # Number of regions + push!(cb.debug_attrs, cb.cur_debug_attr) + cb.num_ops += 1 + encode_varint!(cb.buf, 1) # 1 region: body + + # Body region - block args are pairs of (acc, elem) for each operand + # The body operates on 0D tiles (scalars) + body_arg_types = TypeId[] + for scalar_type in body_scalar_types + push!(body_arg_types, scalar_type) # accumulator + push!(body_arg_types, scalar_type) # element + end + with_region(body, cb, body_arg_types) + + # Create result values + num_results = length(result_types) + if num_results == 0 + return Value[] + else + vals = [Value(cb.next_value_id + i) for i in 0:num_results-1] + cb.next_value_id += num_results + return vals + end +end + #============================================================================= Comparison and selection operations =============================================================================# diff --git a/src/bytecode/writer.jl b/src/bytecode/writer.jl index c7a2ac3..cf37cbc 100644 --- a/src/bytecode/writer.jl +++ b/src/bytecode/writer.jl @@ -234,30 +234,42 @@ end =============================================================================# """ - ReduceIdentity + IdentityOp -Abstract type for reduce identity attributes. +Abstract type for binary operation identity attributes (reduce, scan, etc.). """ -abstract type ReduceIdentity end +abstract type IdentityOp end """ - FloatIdentity(value, type_id, dtype) + FloatIdentityOp(value, type_id, dtype) -Float identity value for reduce operations. +Float identity value for binary operations. """ -struct FloatIdentity <: ReduceIdentity +struct FloatIdentityOp <: IdentityOp value::Float64 type_id::TypeId dtype::Type # Float16, Float32, Float64, etc. end """ - encode_tagged_float!(cb, identity::FloatIdentity) + IntegerIdentityOp(value, type_id, dtype, signed) + +Integer identity value for binary operations. +""" +struct IntegerIdentityOp <: IdentityOp + value::Int64 # Store as signed Int64, will be reinterpreted as unsigned + type_id::TypeId + dtype::Type # Int8, Int16, Int32, Int64, UInt8, etc. + signed::Bool # true for signed, false for unsigned +end + +""" + encode_tagged_float!(cb, identity::FloatIdentityOp) Encode a tagged float attribute for reduce identity. Format: tag(Float=0x02) + typeid + ap_int(value_bits) """ -function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity) +function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentityOp) # Tag for Float attribute push!(cb.buf, 0x02) # Type ID @@ -267,6 +279,25 @@ function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity) encode_signed_varint!(cb.buf, bits) end +""" + encode_tagged_integer!(cb, identity::IntegerIdentityOp) + +Encode a tagged integer identity attribute. +Format: tag(Int=0x01) + typeid + ap_int(value) +""" +function encode_tagged_integer!(cb::CodeBuilder, identity::IntegerIdentityOp) + # Tag for Int attribute + push!(cb.buf, 0x01) + # Type ID + encode_typeid!(cb.buf, identity.type_id) + # Value: signed uses zigzag varint, unsigned uses plain varint + if identity.signed + encode_signed_varint!(cb.buf, identity.value) + else + encode_varint!(cb.buf, UInt64(identity.value)) + end +end + """ float_to_bits(value, dtype) @@ -296,6 +327,7 @@ end Encode a signed integer as a variable-length integer. Uses zigzag encoding for signed values. """ + function encode_signed_varint!(buf::Vector{UInt8}, value::Union{UInt16, UInt32, UInt64, Int64}) # For float bits, encode as unsigned varint encode_varint!(buf, UInt64(value)) @@ -304,15 +336,24 @@ end """ encode_identity_array!(cb, identities) -Encode an array of reduce identity attributes. +Encode an array of binary operation identity attributes. +Dispatches on identity type to encode correctly. """ -function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:ReduceIdentity}) +function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:IdentityOp}) encode_varint!(cb.buf, length(identities)) for identity in identities - encode_tagged_float!(cb, identity) + encode_identity!(cb, identity) end end +""" + encode_identity!(cb, identity) + +Encode a single identity attribute, dispatching on type. +""" +encode_identity!(cb::CodeBuilder, identity::FloatIdentityOp) = encode_tagged_float!(cb, identity) +encode_identity!(cb::CodeBuilder, identity::IntegerIdentityOp) = encode_tagged_integer!(cb, identity) + """ BytecodeWriter diff --git a/src/compiler/intrinsics.jl b/src/compiler/intrinsics.jl index 06a8f40..66791bd 100644 --- a/src/compiler/intrinsics.jl +++ b/src/compiler/intrinsics.jl @@ -8,6 +8,7 @@ using Base: compilerbarrier, donotdelete using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual +using ..cuTile: IdentityOp, FloatIdentityOp, IntegerIdentityOp end diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index c4ce4f1..71e6f65 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -525,7 +525,7 @@ end Sum reduction along 0-indexed axis. Compiled to cuda_tile.reduce with ADD. """ - @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} + @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) Tile{T, reduced_shape}() end @@ -536,7 +536,65 @@ end Maximum reduction along 0-indexed axis. Compiled to cuda_tile.reduce with MAX. """ - @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} + @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_mul(tile, axis_val) + + Product reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with MUL. + """ + @noinline function reduce_mul(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_min(tile, axis_val) + + Minimum reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with MIN. + """ + @noinline function reduce_min(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_and(tile, axis_val) + + Bitwise AND reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with AND. + Integer types only. + """ + @noinline function reduce_and(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_or(tile, axis_val) + + Bitwise OR reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with OR. + Integer types only. + """ + @noinline function reduce_or(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} + reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) + Tile{T, reduced_shape}() + end + + """ + reduce_xor(tile, axis_val) + + Bitwise XOR reduction along 0-indexed axis. + Compiled to cuda_tile.reduce with XOR. + Integer types only. + """ + @noinline function reduce_xor(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) Tile{T, reduced_shape}() end @@ -547,6 +605,22 @@ end function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_max), args) emit_reduce!(ctx, args, :max) end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_mul), args) + emit_reduce!(ctx, args, :mul) +end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_min), args) + emit_reduce!(ctx, args, :min) +end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_and), args) + emit_reduce!(ctx, args, :and) +end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_or), args) + emit_reduce!(ctx, args, :or) +end +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reduce_xor), args) + emit_reduce!(ctx, args, :xor) +end + function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol) cb = ctx.cb tt = ctx.tt @@ -571,32 +645,109 @@ function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol) # Output tile type output_tile_type = tile_type!(tt, dtype, output_shape) - # Scalar type for reduction body (0D tile) scalar_tile_type = tile_type!(tt, dtype, Int[]) - # Create identity value - use simple dtype (f32), not tile type - identity_val = reduce_fn == :add ? -0.0 : (reduce_fn == :max ? -Inf : 0.0) - identity = FloatIdentity(identity_val, dtype, elem_type) + # Create identity value via dispatch on reduction function and element type + identity = operation_identity(Val(reduce_fn), dtype, elem_type) # Emit ReduceOp results = encode_ReduceOp!(cb, [output_tile_type], [input_tv.v], axis, [identity], [scalar_tile_type]) do block_args acc, elem = block_args[1], block_args[2] - if reduce_fn == :add - res = encode_AddFOp!(cb, scalar_tile_type, acc, elem) - elseif reduce_fn == :max - res = encode_MaxFOp!(cb, scalar_tile_type, acc, elem) - else - error("Unsupported reduction function: $reduce_fn") - end - + res = encode_reduce_body(cb, scalar_tile_type, acc, elem, Val(reduce_fn), elem_type) encode_YieldOp!(cb, [res]) end CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape) end +#============================================================================= + Reduce Identity Values via Dispatch +=============================================================================# + +""" + is_signed(::Type{T}) -> Bool + +Return true if type T is signed, false for unsigned types. +""" +is_signed(::Type{T}) where T <: Integer = T <: Integer && !(T <: Unsigned) +is_signed(::Type{T}) where T <: AbstractFloat = false + +""" + operation_identity(fn, dtype, elem_type) -> IdentityOp + +Return the identity value for a binary operation (reduce, scan, etc.). +Identity must satisfy: identity ⊕ x = x for the operation. +""" +# Addition identity: 0 + x = x +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(zero(T), dtype, T) +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(zero(T), dtype, T, is_signed(T)) + +# Maximum identity: max(typemin(T), x) = x +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(typemin(T), dtype, T) +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(typemin(T), dtype, T, is_signed(T)) + +# Multiplication identity: 1 * x = x +operation_identity(::Val{:mul}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(one(T), dtype, T) +operation_identity(::Val{:mul}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(one(T), dtype, T, is_signed(T)) + +# Minimum identity: min(typemax(T), x) = x +operation_identity(::Val{:min}, dtype, ::Type{T}) where T <: AbstractFloat = + FloatIdentityOp(typemax(T), dtype, T) +operation_identity(::Val{:min}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(typemax(T), dtype, T, is_signed(T)) + +# AND identity: all bits set (x & identity == x) +# For signed: -one(T) has all bits set in two's complement +# For unsigned: typemax(T) has all bits set +operation_identity(::Val{:and}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(is_signed(T) ? -one(T) : typemax(T), dtype, T, is_signed(T)) + +# OR identity: 0 | x = x +operation_identity(::Val{:or}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(zero(T), dtype, T, is_signed(T)) + +# XOR identity: 0 ⊕ x = x +operation_identity(::Val{:xor}, dtype, ::Type{T}) where T <: Integer = + IntegerIdentityOp(zero(T), dtype, T, is_signed(T)) + +#============================================================================= + Reduce Body Operations - dispatch on Val{fn} and elem_type +=============================================================================# + +encode_reduce_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: AbstractFloat = + encode_AddFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: AbstractFloat = + encode_MaxFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:mul}, ::Type{T}) where T <: AbstractFloat = + encode_MulFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:min}, ::Type{T}) where T <: AbstractFloat = + encode_MinFOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: Integer = + encode_AddIOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: Integer = + encode_MaxIOp!(cb, type, acc, elem; signedness=is_signed(T) ? SignednessSigned : SignednessUnsigned) +encode_reduce_body(cb, type, acc, elem, ::Val{:mul}, ::Type{T}) where T <: Integer = + encode_MulIOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:min}, ::Type{T}) where T <: Integer = + encode_MinIOp!(cb, type, acc, elem; signedness=is_signed(T) ? SignednessSigned : SignednessUnsigned) + + +# less likely commutative/associative ops can be reduced too for whatever reason. +# eg: and, or, xor. +encode_reduce_body(cb, type, acc, elem, ::Val{:and}, ::Type{T}) where T <: Integer = + encode_AndIOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:or}, ::Type{T}) where T <: Integer = + encode_OrIOp!(cb, type, acc, elem) +encode_reduce_body(cb, type, acc, elem, ::Val{:xor}, ::Type{T}) where T <: Integer = + encode_XOrIOp!(cb, type, acc, elem) # cuda_tile.reshape @eval Intrinsics begin @@ -668,6 +819,109 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reshape), args) end # TODO: cuda_tile.scan +@eval Intrinsics begin + """ + scan(tile, axis_val, fn_type; reverse=false) + + Parallel prefix scan along specified dimension. + fn_type=:add for cumulative sum, :mul for cumulative product. + reverse=false for forward scan, true for reverse scan. + Compiled to cuda_tile.scan. + """ + @noinline function scan(tile::Tile{T, S}, ::Val{axis}, fn::Symbol, reverse::Bool=false) where {T, S, axis} + # Scan preserves shape - result has same dimensions as input + Tile{T, S}() + end +end + +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.scan), args) + cb = ctx.cb + tt = ctx.tt + + # Get input tile + input_tv = emit_value!(ctx, args[1]) + input_tv === nothing && error("Cannot resolve input tile for scan") + + # Get scan axis + axis = @something get_constant(ctx, args[2]) error("Scan axis must be a compile-time constant") + + # Get scan function type + fn_type = @something get_constant(ctx, args[3]) error("Scan function type must be a compile-time constant") + fn_type == :add || fn_type == :mul || error("Scan function must be :add or :mul") + + # Get reverse flag (optional, defaults to false) + reverse = false + if length(args) >= 4 + reverse_val = get_constant(ctx, args[4]) + reverse = reverse_val === true + end + + # Get element type and shapes + input_type = unwrap_type(input_tv.jltype) + elem_type = input_type <: Tile ? input_type.parameters[1] : input_type + input_shape = input_tv.shape + + # For scan, output shape is same as input shape + output_shape = copy(input_shape) + + dtype = julia_to_tile_dtype!(tt, elem_type) + + # Output tile type (same shape as input) + output_tile_type = tile_type!(tt, dtype, output_shape) + + # Scalar type for scan body (0D tile) + scalar_tile_type = tile_type!(tt, dtype, Int[]) + + # Create identity value + # For cumsum: identity is 0.0 (represented as -0.0 for float) + # For cumprod: identity is 1.0 + if fn_type == :add + identity_val = -0.0 # Negative zero works as additive identity + else # :mul + identity_val = 1.0 + end + + # Choose identity type based on element type + if elem_type <: AbstractFloat + # Use float identity for float types + identity = FloatIdentityOp(identity_val, dtype, elem_type) + elseif elem_type <: Integer + # Use integer identity for integer types + identity_val_int = fn_type == :add ? Int64(0) : Int64(1) + is_signed = elem_type <: Signed + identity = IntegerIdentityOp(identity_val_int, dtype, elem_type, is_signed) + else + error("Unsupported element type for scan: $elem_type") + end + + # Emit ScanOp + results = encode_ScanOp!(cb, [output_tile_type], [input_tv.v], axis, reverse, [identity], [scalar_tile_type]) do block_args + acc, elem = block_args[1], block_args[2] + res = encode_scan_body(cb, scalar_tile_type, acc, elem, Val(fn_type), elem_type) + encode_YieldOp!(cb, [res]) + end + + CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape) +end + +# Dispatch helpers for scan body operations - dispatch on Val{fn} and elem_type +encode_scan_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: AbstractFloat = + encode_AddFOp!(cb, type, acc, elem) +encode_scan_body(cb, type, acc, elem, ::Val{:add}, ::Type{T}) where T <: Integer = + encode_AddIOp!(cb, type, acc, elem) +encode_scan_body(cb, type, acc, elem, ::Val{:mul}, ::Type{T}) where T <: AbstractFloat = + encode_MulFOp!(cb, type, acc, elem) +encode_scan_body(cb, type, acc, elem, ::Val{:mul}, ::Type{T}) where T <: Integer = + encode_MulIOp!(cb, type, acc, elem) +encode_scan_body(cb, type, acc, elem, ::Val{:min}, ::Type{T}) where T <: AbstractFloat = + encode_MinFOp!(cb, type, acc, elem) +encode_scan_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: AbstractFloat = + encode_MaxFOp!(cb, type, acc, elem) +encode_scan_body(cb, type, acc, elem, ::Val{:min}, ::Type{T}) where T <: Integer = + encode_MinIOp!(cb, type, acc, elem; signedness=is_signed(T) ? SignednessSigned : SignednessUnsigned) +encode_scan_body(cb, type, acc, elem, ::Val{:max}, ::Type{T}) where T <: Integer = + encode_MaxIOp!(cb, type, acc, elem; signedness=is_signed(T) ? SignednessSigned : SignednessUnsigned) + # cuda_tile.select @eval Intrinsics begin diff --git a/src/language/operations.jl b/src/language/operations.jl index bf20bb2..846a1f7 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -10,7 +10,7 @@ Load/Store =============================================================================# -public bid, num_blocks, num_tiles, load, store, gather, scatter +public bid, num_blocks, num_tiles, axis, load, store, gather, scatter """ Padding mode for load operations. @@ -59,6 +59,24 @@ Axis is 1-indexed. Equivalent to cld(arr.sizes[axis], shape[axis]). Intrinsics.get_index_space_shape(pv, axis - One()) # convert to 0-indexed end +""" + axis(i::Integer) -> Val{i-1} + +Return a compile-time axis selector for tile operations. +Axis indices are 1-based (axis(1) = first dimension, axis(2) = second, etc.). +Internally converts to 0-based for Tile IR. + +Use this instead of raw `Val` for self-documenting code. + +# Examples +```julia +ct.cumsum(tile, ct.axis(1)) # Scan along first axis +ct.cumsum(tile, ct.axis(2)) # Scan along second axis +ct.scan(tile, ct.axis(1), :add) +``` +""" +@inline axis(i::Integer) = Val(i - One()) + """ load(arr::TileArray, index, shape; padding_mode=PaddingMode.Undetermined) -> Tile @@ -473,7 +491,7 @@ result = ct.astype(acc, ct.TFloat32) # Convert to TF32 for tensor cores Reduction =============================================================================# -public reduce_sum, reduce_max +public reduce_sum, reduce_max, reduce_mul, reduce_min, reduce_and, reduce_or, reduce_xor """ reduce_sum(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} @@ -481,16 +499,18 @@ public reduce_sum, reduce_max Sum reduction along the specified axis (1-indexed). Returns a tile with the specified dimension removed. +Supports any numeric type (Float16, Float32, Float64, and integer types). + # Example ```julia # For a (128, 64) tile, reducing along axis 2: sums = ct.reduce_sum(tile, 2) # Returns (128,) tile ``` """ -@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S} +@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} Intrinsics.reduce_sum(tile, Val(axis - 1)) end -@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} +@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} Intrinsics.reduce_sum(tile, Val(axis - 1)) end @@ -498,19 +518,132 @@ end reduce_max(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} Maximum reduction along the specified axis (1-indexed). +Supports any numeric type (Float16, Float32, Float64, and integer types). # Example ```julia maxes = ct.reduce_max(tile, 2) # Max along axis 2 ``` """ -@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S} +@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} Intrinsics.reduce_max(tile, Val(axis - 1)) end -@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} +@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} Intrinsics.reduce_max(tile, Val(axis - 1)) end +<<<<<<< HEAD +""" + reduce_mul(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Product reduction along the specified axis (1-indexed). +Returns a tile with the specified dimension removed. + +# Example +```julia +# For a (128, 64) tile, reducing along axis 2: +products = ct.reduce_mul(tile, 2) # Returns (128,) tile +``` +""" +@inline function reduce_mul(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} + Intrinsics.reduce_mul(tile, Val(axis - 1)) +end +@inline function reduce_mul(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} + Intrinsics.reduce_mul(tile, Val(axis - 1)) +end + +""" + reduce_min(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Minimum reduction along the specified axis (1-indexed). + +# Example +```julia +mins = ct.reduce_min(tile, 2) # Min along axis 2 +``` +""" +@inline function reduce_min(tile::Tile{T, S}, axis::Integer) where {T <: Number, S} + Intrinsics.reduce_min(tile, Val(axis - 1)) +end +@inline function reduce_min(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis} + Intrinsics.reduce_min(tile, Val(axis - 1)) +end + +""" + reduce_and(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Bitwise AND reduction along the specified axis (1-indexed). +Integer types only. + +# Example +```julia +# For an Int32 tile, reducing along axis 2: +result = ct.reduce_and(tile, 2) # Returns (128,) tile of Int32 +``` +""" +@inline function reduce_and(tile::Tile{T, S}, axis::Integer) where {T <: Integer, S} + Intrinsics.reduce_and(tile, Val(axis - 1)) +end +@inline function reduce_and(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} + Intrinsics.reduce_and(tile, Val(axis - 1)) +end + +""" + reduce_or(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Bitwise OR reduction along the specified axis (1-indexed). +Integer types only. + +# Example +```julia +# For an Int32 tile, reducing along axis 2: +result = ct.reduce_or(tile, 2) # Returns (128,) tile of Int32 +``` +""" +@inline function reduce_or(tile::Tile{T, S}, axis::Integer) where {T <: Integer, S} + Intrinsics.reduce_or(tile, Val(axis - 1)) +end +@inline function reduce_or(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} + Intrinsics.reduce_or(tile, Val(axis - 1)) +end + +""" + reduce_xor(tile::Tile{T, S}, axis::Integer) -> Tile{T, reduced_shape} + +Bitwise XOR reduction along the specified axis (1-indexed). +Integer types only. + +# Example +```julia +# For an Int32 tile, reducing along axis 2: +result = ct.reduce_xor(tile, 2) # Returns (128,) tile of Int32 +``` +""" +@inline function reduce_xor(tile::Tile{T, S}, axis::Integer) where {T <: Integer, S} + Intrinsics.reduce_xor(tile, Val(axis - 1)) +end +@inline function reduce_xor(tile::Tile{T, S}, ::Val{axis}) where {T <: Integer, S, axis} +Intrinsics.reduce_xor(tile, Val(axis - 1)) +end + +# Scan (Prefix Sum) Operations + +@inline function scan(tile::Tile{T, S}, ::Val{axis}, + fn::Symbol=:add, + reverse::Bool=false) where {T, S, axis} + Intrinsics.scan(tile, Val(axis), fn, reverse) +end + +@inline function cumsum(tile::Tile{T, S}, ::Val{axis}, + reverse::Bool=false) where {T, S, axis} + scan(tile, Val(axis), :add, reverse) +end + +@inline function cumprod(tile::Tile{T, S}, ::Val{axis}, + reverse::Bool=false) where {T, S, axis} + scan(tile, Val(axis), :mul, reverse) +end + #============================================================================= Matmul =============================================================================# diff --git a/test/reduce_ops.jl b/test/reduce_ops.jl new file mode 100644 index 0000000..36dc9ee --- /dev/null +++ b/test/reduce_ops.jl @@ -0,0 +1,750 @@ +using cuTile +import cuTile as ct +using CUDA +using Test + +@testset "reduce operations" begin + +#======================================================================# +# CPU reference implementations +# =====================================================================# + +cpu_reduce_add(a::AbstractArray, dims::Integer) = sum(a, dims=dims) +cpu_reduce_mul(a::AbstractArray, dims::Integer) = prod(a, dims=dims) +cpu_reduce_max(a::AbstractArray, dims::Integer) = maximum(a, dims=dims) +cpu_reduce_min(a::AbstractArray, dims::Integer) = minimum(a, dims=dims) + +cpu_reduce_and(a::AbstractArray{<:Unsigned}, dims::Integer) = reduce((x, y) -> x & y, a, init=typemax(eltype(a)), dims=dims) +cpu_reduce_and(a::AbstractArray{<:Signed}, dims::Integer) = reduce((x, y) -> x & y, a, init=Int64(-1), dims=dims) +cpu_reduce_or(a::AbstractArray{<:Integer}, dims::Integer) = reduce((x, y) -> x | y, a, init=zero(eltype(a)), dims=dims) +cpu_reduce_xor(a::AbstractArray{<:Integer}, dims::Integer) = reduce((x, y) -> x ⊻ y, a, init=zero(eltype(a)), dims=dims) + +#======================================================================# +# Float32 operations +#======================================================================# + +@testset "Float32 reduce_add" begin + function reduce_add_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 128)) + sums = ct.reduce_sum(tile, 2) + ct.store(b, pid, sums) + return + end + + m, n = 64, 128 + a = CUDA.rand(Float32, m, n) + b = CUDA.zeros(Float32, m) + + ct.launch(reduce_add_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ cpu_reduce_add(a_cpu[i:i, :], 2)[1] rtol=1e-3 + end +end + +@testset "Float32 reduce_mul" begin + function reduce_mul_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + products = ct.reduce_mul(tile, 2) + ct.store(b, pid, products) + return + end + + m, n = 32, 64 + a = CUDA.rand(Float32, m, n) .+ 0.1f0 + b = CUDA.ones(Float32, m) + + ct.launch(reduce_mul_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ cpu_reduce_mul(a_cpu[i:i, :], 2)[1] rtol=1e-2 + end +end + +@testset "Float32 reduce_max" begin + function reduce_max_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 128)) + maxes = ct.reduce_max(tile, 2) + ct.store(b, pid, maxes) + return + end + + m, n = 64, 128 + a = CUDA.rand(Float32, m, n) + b = CUDA.zeros(Float32, m) + + ct.launch(reduce_max_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ cpu_reduce_max(a_cpu[i:i, :], 2)[1] rtol=1e-5 + end +end + +@testset "Float32 reduce_min" begin + function reduce_min_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 128)) + mins = ct.reduce_min(tile, 2) + ct.store(b, pid, mins) + return + end + + m, n = 64, 128 + a = CUDA.rand(Float32, m, n) + b = CUDA.zeros(Float32, m) + + ct.launch(reduce_min_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ cpu_reduce_min(a_cpu[i:i, :], 2)[1] rtol=1e-5 + end +end + +#======================================================================# +# Float64 operations +#======================================================================# + +@testset "Float64 reduce_add" begin + function reduce_add_f64_kernel(a::ct.TileArray{Float64,2}, b::ct.TileArray{Float64,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 64)) + sums = ct.reduce_sum(tile, 2) + ct.store(b, pid, sums) + return + end + + m, n = 32, 64 + a = CUDA.rand(Float64, m, n) + b = CUDA.zeros(Float64, m) + + ct.launch(reduce_add_f64_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ cpu_reduce_add(a_cpu[i:i, :], 2)[1] rtol=1e-5 + end +end + +@testset "Float64 reduce_max" begin + function reduce_max_f64_kernel(a::ct.TileArray{Float64,2}, b::ct.TileArray{Float64,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 64)) + maxes = ct.reduce_max(tile, 2) + ct.store(b, pid, maxes) + return + end + + m, n = 32, 64 + a = CUDA.rand(Float64, m, n) + b = CUDA.zeros(Float64, m) + + ct.launch(reduce_max_f64_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ cpu_reduce_max(a_cpu[i:i, :], 2)[1] rtol=1e-5 + end +end + +@testset "Float64 reduce_min" begin + function reduce_min_f64_kernel(a::ct.TileArray{Float64,2}, b::ct.TileArray{Float64,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 64)) + mins = ct.reduce_min(tile, 2) + ct.store(b, pid, mins) + return + end + + m, n = 32, 64 + a = CUDA.rand(Float64, m, n) + b = CUDA.zeros(Float64, m) + + ct.launch(reduce_min_f64_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ cpu_reduce_min(a_cpu[i:i, :], 2)[1] rtol=1e-5 + end +end + +@testset "Float64 reduce_mul" begin + function reduce_mul_f64_kernel(a::ct.TileArray{Float64,2}, b::ct.TileArray{Float64,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + products = ct.reduce_mul(tile, 2) + ct.store(b, pid, products) + return + end + + m, n = 16, 32 + a = CUDA.rand(Float64, m, n) .+ 0.1 + b = CUDA.ones(Float64, m) + + ct.launch(reduce_mul_f64_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] ≈ cpu_reduce_mul(a_cpu[i:i, :], 2)[1] rtol=1e-2 + end +end + +#======================================================================# +# Int32 operations +#======================================================================# + +@testset "Int32 reduce_add" begin + function reduce_add_i32_kernel(a::ct.TileArray{Int32,2}, b::ct.TileArray{Int32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 64)) + sums = ct.reduce_sum(tile, 2) + ct.store(b, pid, sums) + return + end + + m, n = 32, 64 + a = CUDA.rand(Int32, m, n) .+ 1 + b = CUDA.zeros(Int32, m) + + ct.launch(reduce_add_i32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_add(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int32 reduce_mul" begin + function reduce_mul_i32_kernel(a::ct.TileArray{Int32,2}, b::ct.TileArray{Int32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 16)) + products = ct.reduce_mul(tile, 2) + ct.store(b, pid, products) + return + end + + m, n = 8, 16 + a = CUDA.rand(Int32, m, n) .% 10 .+ 2 + b = CUDA.ones(Int32, m) + + ct.launch(reduce_mul_i32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_mul(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int32 reduce_max" begin + function reduce_max_i32_kernel(a::ct.TileArray{Int32,2}, b::ct.TileArray{Int32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 64)) + maxes = ct.reduce_max(tile, 2) + ct.store(b, pid, maxes) + return + end + + m, n = 32, 64 + a = CUDA.rand(Int32, m, n) + b = CUDA.fill(typemin(Int32), m) + + ct.launch(reduce_max_i32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_max(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int32 reduce_min" begin + function reduce_min_i32_kernel(a::ct.TileArray{Int32,2}, b::ct.TileArray{Int32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 64)) + mins = ct.reduce_min(tile, 2) + ct.store(b, pid, mins) + return + end + + m, n = 32, 64 + a = CUDA.rand(Int32, m, n) + b = CUDA.fill(typemax(Int32), m) + + ct.launch(reduce_min_i32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_min(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int32 reduce_and" begin + function reduce_and_i32_kernel(a::ct.TileArray{Int32,2}, b::ct.TileArray{Int32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + result = ct.reduce_and(tile, 2) + ct.store(b, pid, result) + return + end + + m, n = 16, 32 + a = CUDA.rand(Int32, m, n) + b = CUDA.zeros(Int32, m) + + ct.launch(reduce_and_i32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_and(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int32 reduce_or" begin + function reduce_or_i32_kernel(a::ct.TileArray{Int32,2}, b::ct.TileArray{Int32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + result = ct.reduce_or(tile, 2) + ct.store(b, pid, result) + return + end + + m, n = 16, 32 + a = CUDA.rand(Int32, m, n) + b = CUDA.zeros(Int32, m) + + ct.launch(reduce_or_i32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_or(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int32 reduce_xor" begin + function reduce_xor_i32_kernel(a::ct.TileArray{Int32,2}, b::ct.TileArray{Int32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + result = ct.reduce_xor(tile, 2) + ct.store(b, pid, result) + return + end + + m, n = 16, 32 + a = CUDA.rand(Int32, m, n) + b = CUDA.zeros(Int32, m) + + ct.launch(reduce_xor_i32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_xor(a_cpu[i:i, :], 2)[1] + end +end + +#======================================================================# +# UInt32 operations - tests AND identity encoding fix +#======================================================================# + +@testset "UInt32 reduce_add" begin + function reduce_add_u32_kernel(a::ct.TileArray{UInt32,2}, b::ct.TileArray{UInt32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 64)) + sums = ct.reduce_sum(tile, 2) + ct.store(b, pid, sums) + return + end + + m, n = 32, 64 + a = CUDA.rand(UInt32, m, n) + b = CUDA.zeros(UInt32, m) + + ct.launch(reduce_add_u32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_add(a_cpu[i:i, :], 2)[1] + end +end + +@testset "UInt32 reduce_mul" begin + function reduce_mul_u32_kernel(a::ct.TileArray{UInt32,2}, b::ct.TileArray{UInt32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 16)) + products = ct.reduce_mul(tile, 2) + ct.store(b, pid, products) + return + end + + m, n = 8, 16 + a = CUDA.rand(UInt32, m, n) .% 10 .+ 2 + b = CUDA.ones(UInt32, m) + + ct.launch(reduce_mul_u32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_mul(a_cpu[i:i, :], 2)[1] + end +end + +@testset "UInt32 reduce_max" begin + function reduce_max_u32_kernel(a::ct.TileArray{UInt32,2}, b::ct.TileArray{UInt32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 64)) + maxes = ct.reduce_max(tile, 2) + ct.store(b, pid, maxes) + return + end + + m, n = 32, 64 + a = CUDA.rand(UInt32, m, n) + b = CUDA.zeros(UInt32, m) + + ct.launch(reduce_max_u32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_max(a_cpu[i:i, :], 2)[1] + end +end + +@testset "UInt32 reduce_min" begin + function reduce_min_u32_kernel(a::ct.TileArray{UInt32,2}, b::ct.TileArray{UInt32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 64)) + mins = ct.reduce_min(tile, 2) + ct.store(b, pid, mins) + return + end + + m, n = 32, 64 + a = CUDA.rand(UInt32, m, n) + b = CUDA.fill(typemax(UInt32), m) + + ct.launch(reduce_min_u32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_min(a_cpu[i:i, :], 2)[1] + end +end + +@testset "UInt32 reduce_and" begin + function reduce_and_u32_kernel(a::ct.TileArray{UInt32,2}, b::ct.TileArray{UInt32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + result = ct.reduce_and(tile, 2) + ct.store(b, pid, result) + return + end + + m, n = 16, 32 + a = CUDA.rand(UInt32, m, n) + b = CUDA.zeros(UInt32, m) + + ct.launch(reduce_and_u32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_and(a_cpu[i:i, :], 2)[1] + end +end + +@testset "UInt32 reduce_or" begin + function reduce_or_u32_kernel(a::ct.TileArray{UInt32,2}, b::ct.TileArray{UInt32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + result = ct.reduce_or(tile, 2) + ct.store(b, pid, result) + return + end + + m, n = 16, 32 + a = CUDA.rand(UInt32, m, n) + b = CUDA.zeros(UInt32, m) + + ct.launch(reduce_or_u32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_or(a_cpu[i:i, :], 2)[1] + end +end + +@testset "UInt32 reduce_xor" begin + function reduce_xor_u32_kernel(a::ct.TileArray{UInt32,2}, b::ct.TileArray{UInt32,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + result = ct.reduce_xor(tile, 2) + ct.store(b, pid, result) + return + end + + m, n = 16, 32 + a = CUDA.rand(UInt32, m, n) + b = CUDA.zeros(UInt32, m) + + ct.launch(reduce_xor_u32_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_xor(a_cpu[i:i, :], 2)[1] + end +end + +#======================================================================# +# Int8 operations - smaller integer type for encoding tests +#======================================================================# + +@testset "Int8 reduce_add" begin + function reduce_add_i8_kernel(a::ct.TileArray{Int8,2}, b::ct.TileArray{Int8,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + sums = ct.reduce_sum(tile, 2) + ct.store(b, pid, sums) + return + end + + m, n = 16, 32 + a = CUDA.rand(Int8, m, n) + b = CUDA.zeros(Int8, m) + + ct.launch(reduce_add_i8_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test Int32(b_cpu[i]) == cpu_reduce_add(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int8 reduce_max" begin + function reduce_max_i8_kernel(a::ct.TileArray{Int8,2}, b::ct.TileArray{Int8,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + maxes = ct.reduce_max(tile, 2) + ct.store(b, pid, maxes) + return + end + + m, n = 16, 32 + a = CUDA.rand(Int8, m, n) + b = CUDA.fill(typemin(Int8), m) + + ct.launch(reduce_max_i8_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_max(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int8 reduce_min" begin + function reduce_min_i8_kernel(a::ct.TileArray{Int8,2}, b::ct.TileArray{Int8,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 32)) + mins = ct.reduce_min(tile, 2) + ct.store(b, pid, mins) + return + end + + m, n = 16, 32 + a = CUDA.rand(Int8, m, n) + b = CUDA.fill(typemax(Int8), m) + + ct.launch(reduce_min_i8_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test b_cpu[i] == cpu_reduce_min(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int8 reduce_and" begin + function reduce_and_i8_kernel(a::ct.TileArray{Int8,2}, b::ct.TileArray{Int8,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 16)) + result = ct.reduce_and(tile, 2) + ct.store(b, pid, result) + return + end + + m, n = 8, 16 + a = CUDA.rand(Int8, m, n) + b = CUDA.zeros(Int8, m) + + ct.launch(reduce_and_i8_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test Int32(b_cpu[i]) == cpu_reduce_and(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int8 reduce_or" begin + function reduce_or_i8_kernel(a::ct.TileArray{Int8,2}, b::ct.TileArray{Int8,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 16)) + result = ct.reduce_or(tile, 2) + ct.store(b, pid, result) + return + end + + m, n = 8, 16 + a = CUDA.rand(Int8, m, n) + b = CUDA.zeros(Int8, m) + + ct.launch(reduce_or_i8_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test Int32(b_cpu[i]) == cpu_reduce_or(a_cpu[i:i, :], 2)[1] + end +end + +@testset "Int8 reduce_xor" begin + function reduce_xor_i8_kernel(a::ct.TileArray{Int8,2}, b::ct.TileArray{Int8,1}) + pid = ct.bid(1) + tile = ct.load(a, (pid, 1), (1, 16)) + result = ct.reduce_xor(tile, 2) + ct.store(b, pid, result) + return + end + + m, n = 8, 16 + a = CUDA.rand(Int8, m, n) + b = CUDA.zeros(Int8, m) + + ct.launch(reduce_xor_i8_kernel, m, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for i in 1:m + @test Int32(b_cpu[i]) == cpu_reduce_xor(a_cpu[i:i, :], 2)[1] + end +end + +#======================================================================# +# Axis 0 reductions - verify both axes work +#======================================================================# + +@testset "axis 0 reduce_sum Float32" begin + function reduce_sum_axis0_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, (1, pid), (64, 1)) + sums = ct.reduce_sum(tile, 1) + ct.store(b, pid, sums) + return + end + + m, n = 64, 128 + a = CUDA.rand(Float32, m, n) + b = CUDA.zeros(Float32, n) + + ct.launch(reduce_sum_axis0_kernel, n, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for j in 1:n + @test b_cpu[j] ≈ cpu_reduce_add(a_cpu[:, j:j], 1)[1] rtol=1e-3 + end +end + +@testset "axis 0 reduce_min Int32" begin + function reduce_min_axis0_i32_kernel(a::ct.TileArray{Int32,2}, b::ct.TileArray{Int32,1}) + pid = ct.bid(1) + tile = ct.load(a, (1, pid), (32, 1)) + mins = ct.reduce_min(tile, 1) + ct.store(b, pid, mins) + return + end + + m, n = 32, 64 + a = CUDA.rand(Int32, m, n) + b = CUDA.fill(typemax(Int32), n) + + ct.launch(reduce_min_axis0_i32_kernel, n, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for j in 1:n + @test b_cpu[j] == cpu_reduce_min(a_cpu[:, j:j], 1)[1] + end +end + +@testset "axis 0 reduce_max UInt32" begin + function reduce_max_axis0_u32_kernel(a::ct.TileArray{UInt32,2}, b::ct.TileArray{UInt32,1}) + pid = ct.bid(1) + tile = ct.load(a, (1, pid), (32, 1)) + maxes = ct.reduce_max(tile, 1) + ct.store(b, pid, maxes) + return + end + + m, n = 32, 64 + a = CUDA.rand(UInt32, m, n) + b = CUDA.zeros(UInt32, n) + + ct.launch(reduce_max_axis0_u32_kernel, n, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for j in 1:n + @test b_cpu[j] == cpu_reduce_max(a_cpu[:, j:j], 1)[1] + end +end + +@testset "axis 0 reduce_and UInt32" begin + function reduce_and_axis0_u32_kernel(a::ct.TileArray{UInt32,2}, b::ct.TileArray{UInt32,1}) + pid = ct.bid(1) + tile = ct.load(a, (1, pid), (16, 1)) + result = ct.reduce_and(tile, 1) + ct.store(b, pid, result) + return + end + + m, n = 16, 32 + a = CUDA.rand(UInt32, m, n) + b = CUDA.fill(typemax(UInt32), n) + + ct.launch(reduce_and_axis0_u32_kernel, n, a, b) + + a_cpu = Array(a) + b_cpu = Array(b) + for j in 1:n + @test b_cpu[j] == cpu_reduce_and(a_cpu[:, j:j], 1)[1] + end +end + +end # @testset "reduce operations" diff --git a/test/runtests.jl b/test/runtests.jl index 6b9aed9..a52163a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,7 +45,7 @@ if filter_tests!(testsuite, args) cuda_functional = CUDA.functional() filter!(testsuite) do (test, _) - if in(test, ["execution"]) || startswith(test, "examples/") + if in(test, ["execution", "reduce_ops"]) || startswith(test, "examples/") return cuda_functional else return true