Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3bc8521
intrinsics: extend reduce operations with mul, min, and, or, xor for …
arhik Jan 11, 2026
2b21f73
operations: add axis(i) helper for 1-based to 0-based axis conversion
arhik Jan 16, 2026
3f9584f
make axis public
arhik Jan 11, 2026
43cd64b
make new reduce_{ops} public
arhik Jan 11, 2026
12ebf37
reduce ops update and axis convenience
arhik Jan 11, 2026
d1c977a
reduce ops: add wrapper functions and correct identity values via dis…
arhik Jan 11, 2026
31282c4
add IntIdentity type for integer reduce operations
arhik Jan 11, 2026
e689672
rename IntIdentity to IntegerIdentity for clarity
arhik Jan 11, 2026
aa67f22
fix: remove AbstractFloat constraint from reduce_sum and reduce_max
arhik Jan 11, 2026
eb98ec1
use Number constraint for numeric reduce operations
arhik Jan 11, 2026
e097fa7
add signed field to IntegerIdentity for proper signed/unsigned encoding
arhik Jan 11, 2026
6652087
remove unused is_reduce kwarg from encode_tagged_int
arhik Jan 11, 2026
8b62221
rename ReduceIdentity to OperationIdentity
arhik Jan 11, 2026
6ddecaa
rename identity types to IdentityOp hierarchy
arhik Jan 11, 2026
9fc8d0a
fix is_signed to use proper Julia type hierarchy check
arhik Jan 11, 2026
00f0de9
intrinsics: use -one(T) instead of -1 for signed AND identity
arhik Jan 11, 2026
55b601e
test: restore codegen and types tests, fix reduce_ops reference
arhik Jan 11, 2026
959daa4
multiline comment mess in reduce_ops.jl
arhik Jan 11, 2026
9b0418f
intrinsics: rename reduce_identity -> operation_identity
arhik Jan 11, 2026
47a30bc
test: fix CPU reference functions for bitwise ops
arhik Jan 11, 2026
baa5e65
bytecode: fix zigzag encoding for signed varint
arhik Jan 11, 2026
90fca21
reverting zigzag encoding
arhik Jan 11, 2026
2eb3171
bytecode: fix zigzag encoding for signed varint
arhik Jan 11, 2026
5dea6f3
bytecode: remove duplicate encode_signed_varint!
arhik Jan 11, 2026
6dc89b4
intrinsics: pass signedness to encode_MinIOp! and encode_MaxIOp!
arhik Jan 11, 2026
db685ae
reverting original encode_signed_varint
arhik Jan 12, 2026
6a1b21d
revert comment inside encode_signed_varint
arhik Jan 12, 2026
be4265f
Fix SLEB128 zigzag encoding for 64-bit and small integer types
arhik Jan 16, 2026
eb34ca3
Simplify reduce_ops.jl tests with broadcasting
arhik Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,12 @@ conservative token threading in the compiler (see https://github.com/JuliaGPU/cu
| Operation | Description |
|-----------|-------------|
| `reduce_sum(tile, axis)` | Sum along axis |
| `reduce_mul(tile, axis)` | Product along axis |
| `reduce_max(tile, axis)` | Maximum along axis |
| `reduce_min(tile, axis)` | Minimum along axis |
| `reduce_and(tile, axis)` | Bitwise AND along axis (integer) |
| `reduce_or(tile, axis)` | Bitwise OR along axis (integer) |
| `reduce_xor(tile, axis)` | Bitwise XOR along axis (integer) |

### Math
| Operation | Description |
Expand Down Expand Up @@ -275,6 +280,11 @@ ct.permute(tile, (3, 1, 2))

This applies to `bid`, `num_blocks`, `permute`, `reshape`, dimension arguments, etc.

### axis convenience

| `axis(i)` | Convert 1-based axis to 0-based (helper) |


### `Val`-like constants

CuTile.jl uses `ct.Constant{T}` to encode compile-time constant values in the type domain, similar to how `Val` works. An explicit `[]` is needed to extract the value at runtime:
Expand Down
153 changes: 153 additions & 0 deletions examples/reducekernel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
using Test
using CUDA
using cuTile
import cuTile as ct

# Kernel factory to properly capture element type and operation
function makeReduceKernel(::Type{T}, op::Symbol) where {T}
reduceFunc = if op == :reduce_min
ct.reduce_min
elseif op == :reduce_max
ct.reduce_max
elseif op == :reduce_sum
ct.reduce_sum
elseif op == :reduce_xor
ct.reduce_xor
elseif op == :reduce_or
ct.reduce_or
elseif op == :reduce_and
ct.reduce_and
end

@inline function kernel(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, tileSz::ct.Constant{Int})
ct.store(b, ct.bid(1), reduceFunc(ct.load(a, ct.bid(1), (tileSz[],)), Val(1)))
return nothing
end
return kernel
end

# Test with UInt types
@testset for elType in [UInt16, UInt32, UInt64]
@testset for op in [:reduce_min, :reduce_max, :reduce_sum, :reduce_xor, :reduce_or, :reduce_and]
sz = 32
N = 2^15

# Create kernel using factory
reduceKernel = try
makeReduceKernel(elType, op)
catch e
@test_broken false
rethrow()
end

# Create data and run kernel
a_gpu = CUDA.rand(elType, N)
b_gpu = CUDA.zeros(elType, cld(N, sz))
try
CUDA.@sync ct.launch(reduceKernel, cld(length(a_gpu), sz), a_gpu, b_gpu, ct.Constant(sz))
catch e
@test_broken false
rethrow()
end
res = Array(b_gpu)

# CPU computation
a_cpu = Array(a_gpu)
a_reshaped = reshape(a_cpu, sz, :)

if op == :reduce_min
cpu_result = minimum(a_reshaped, dims=1)[:]
elseif op == :reduce_max
cpu_result = maximum(a_reshaped, dims=1)[:]
elseif op == :reduce_sum
raw_sum = sum(a_reshaped, dims=1)[:]
cpu_result = raw_sum .& typemax(elType)
elseif op == :reduce_xor
cpu_result = mapslices(x -> reduce(⊻, x), a_reshaped, dims=1)[:]
elseif op == :reduce_or
cpu_result = mapslices(x -> reduce(|, x), a_reshaped, dims=1)[:]
elseif op == :reduce_and
cpu_result = mapslices(x -> reduce(&, x), a_reshaped, dims=1)[:]
end

@test cpu_result == res
end
end

# Test with signed Int types
@testset for elType in [Int16, Int32, Int64]
@testset for op in [:reduce_min, :reduce_max, :reduce_sum, :reduce_xor, :reduce_or, :reduce_and]
sz = 32
N = 2^15

# Create kernel using factory
reduceKernel = try
makeReduceKernel(elType, op)
catch e
@test_broken false
rethrow()
end

# Create data and run kernel - use range to get negative values too
a_gpu = CuArray{elType}(rand(-1000:1000, N))
b_gpu = CUDA.zeros(elType, cld(N, sz))
try
CUDA.@sync ct.launch(reduceKernel, cld(length(a_gpu), sz), a_gpu, b_gpu, ct.Constant(sz))
catch e
@test_broken false
rethrow()
end
res = Array(b_gpu)

# CPU computation
a_cpu = Array(a_gpu)
a_reshaped = reshape(a_cpu, sz, :)

if op == :reduce_min
cpu_result = minimum(a_reshaped, dims=1)[:]
elseif op == :reduce_max
cpu_result = maximum(a_reshaped, dims=1)[:]
elseif op == :reduce_sum
cpu_result = sum(a_reshaped, dims=1)[:]
elseif op == :reduce_xor
cpu_result = mapslices(x -> reduce(⊻, x), a_reshaped, dims=1)[:]
elseif op == :reduce_or
cpu_result = mapslices(x -> reduce(|, x), a_reshaped, dims=1)[:]
elseif op == :reduce_and
cpu_result = mapslices(x -> reduce(&, x), a_reshaped, dims=1)[:]
end

@test cpu_result == res
end
end

# Test with Float types
@testset for elType in [Float16, Float32, Float64]
@testset for op in [:reduce_min, :reduce_max, :reduce_sum]
sz = 32
N = 2^15

# Create kernel using factory
reduceKernel = makeReduceKernel(elType, op)

# Create data and run kernel
a_gpu = CUDA.rand(elType, N)
b_gpu = CUDA.zeros(elType, cld(N, sz))
CUDA.@sync ct.launch(reduceKernel, cld(length(a_gpu), sz), a_gpu, b_gpu, ct.Constant(sz))
res = Array(b_gpu)

# CPU computation
a_cpu = Array(a_gpu)
a_reshaped = reshape(a_cpu, sz, :)

if op == :reduce_min
cpu_result = minimum(a_reshaped, dims=1)[:]
elseif op == :reduce_max
cpu_result = maximum(a_reshaped, dims=1)[:]
elseif op == :reduce_sum
cpu_result = sum(a_reshaped, dims=1)[:]
end

@test isapprox(cpu_result, res)
end
end
2 changes: 1 addition & 1 deletion src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
identities::Vector{<:ReduceIdentity},
identities::Vector{<:IdentityOp},
body_scalar_types::Vector{TypeId})
encode_varint!(cb.buf, Opcode.ReduceOp)

Expand Down
99 changes: 87 additions & 12 deletions src/bytecode/writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,30 +234,42 @@ end
=============================================================================#

"""
ReduceIdentity
IdentityOp

Abstract type for reduce identity attributes.
Abstract type for binary operation identity attributes (reduce, scan, etc.).
"""
abstract type ReduceIdentity end
abstract type IdentityOp end

"""
FloatIdentity(value, type_id, dtype)
FloatIdentityOp(value, type_id, dtype)

Float identity value for reduce operations.
Float identity value for binary operations.
"""
struct FloatIdentity <: ReduceIdentity
struct FloatIdentityOp <: IdentityOp
value::Float64
type_id::TypeId
dtype::Type # Float16, Float32, Float64, etc.
end

"""
encode_tagged_float!(cb, identity::FloatIdentity)
IntegerIdentityOp(value, type_id, dtype, signed)

Integer identity value for binary operations.
"""
struct IntegerIdentityOp <: IdentityOp
value::UInt128 # Store as UInt128 to handle all unsigned values up to 64 bits
type_id::TypeId
dtype::Type # Int8, Int16, Int32, Int64, UInt8, etc.
signed::Bool # true for signed, false for unsigned
end

"""
encode_tagged_float!(cb, identity::FloatIdentityOp)

Encode a tagged float attribute for reduce identity.
Format: tag(Float=0x02) + typeid + ap_int(value_bits)
"""
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentityOp)
# Tag for Float attribute
push!(cb.buf, 0x02)
# Type ID
Expand All @@ -267,6 +279,59 @@ function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
encode_signed_varint!(cb.buf, bits)
end

"""
encode_tagged_int!(cb, identity::IntegerIdentityOp)

Encode a tagged integer identity attribute.
Format: tag(Int=0x01) + typeid + ap_int(value)
"""
function encode_tagged_int!(cb::CodeBuilder, identity::IntegerIdentityOp)
# Tag for Int attribute
push!(cb.buf, 0x01)
# Type ID
encode_typeid!(cb.buf, identity.type_id)
# Value: signed uses zigzag varint, unsigned uses plain varint
# Mask value to correct bit width and apply zigzag for signed types
masked_value = mask_to_width(identity.value, identity.dtype, identity.signed)
if identity.signed
encode_signed_varint!(cb.buf, masked_value)
else
encode_varint!(cb.buf, masked_value)
end
end

"""
mask_to_width(value, dtype, signed)

Mask a UInt128 value to the correct bit width for the given type and apply zigzag if signed.
For signed types, this masks first, then applies zigzag encoding.
"""
# Signed Int64: mask to 64 bits first, then zigzag encode
mask_to_width(value::UInt128, ::Type{Int64}, signed::Bool) =
let masked = UInt64(value & 0xFFFFFFFFFFFFFFFF)
UInt64((masked << 1) ⊻ (masked >>> 63))
end
# Signed Int32: mask to 32 bits first, then zigzag encode
mask_to_width(value::UInt128, ::Type{Int32}, signed::Bool) =
let masked = UInt32(value & 0xFFFFFFFF)
UInt32((masked << 1) ⊻ (masked >>> 31))
end
# Signed Int16: mask to 16 bits first, then zigzag encode
mask_to_width(value::UInt128, ::Type{Int16}, signed::Bool) =
let masked = UInt16(value & 0xFFFF)
UInt16((masked << 1) ⊻ (masked >>> 15))
end
# Signed Int8: mask to 8 bits first, then zigzag encode
mask_to_width(value::UInt128, ::Type{Int8}, signed::Bool) =
let masked = UInt8(value & 0xFF)
UInt8((masked << 1) ⊻ (masked >>> 7))
end
# Unsigned types: just mask to bit width, no zigzag
mask_to_width(value::UInt128, ::Type{UInt64}, signed::Bool) = UInt64(value & 0xFFFFFFFFFFFFFFFF)
mask_to_width(value::UInt128, ::Type{UInt32}, signed::Bool) = UInt32(value & 0xFFFFFFFF)
mask_to_width(value::UInt128, ::Type{UInt16}, signed::Bool) = UInt16(value & 0xFFFF)
mask_to_width(value::UInt128, ::Type{UInt8}, signed::Bool) = UInt8(value & 0xFF)

"""
float_to_bits(value, dtype)

Expand Down Expand Up @@ -296,6 +361,7 @@ end
Encode a signed integer as a variable-length integer.
Uses zigzag encoding for signed values.
"""

function encode_signed_varint!(buf::Vector{UInt8}, value::Union{UInt16, UInt32, UInt64, Int64})
# For float bits, encode as unsigned varint
encode_varint!(buf, UInt64(value))
Expand All @@ -304,15 +370,24 @@ end
"""
encode_identity_array!(cb, identities)

Encode an array of reduce identity attributes.
Encode an array of binary operation identity attributes.
Dispatches on identity type to encode correctly.
"""
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:ReduceIdentity})
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:IdentityOp})
encode_varint!(cb.buf, length(identities))
for identity in identities
encode_tagged_float!(cb, identity)
encode_identity!(cb, identity)
end
end

"""
encode_identity!(cb, identity)

Encode a single identity attribute, dispatching on type.
"""
encode_identity!(cb::CodeBuilder, identity::FloatIdentityOp) = encode_tagged_float!(cb, identity)
encode_identity!(cb::CodeBuilder, identity::IntegerIdentityOp) = encode_tagged_int!(cb, identity)

"""
BytecodeWriter

Expand Down Expand Up @@ -544,7 +619,7 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder,
end

#=============================================================================
Optimization Hints
Optimization Hints
=============================================================================#

"""
Expand Down
1 change: 1 addition & 0 deletions src/compiler/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Base: compilerbarrier, donotdelete
using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView
using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned
using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual
using ..cuTile: IdentityOp, FloatIdentityOp, IntegerIdentityOp

end

Expand Down
Loading