Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7caaa16
intrinsics: extend reduce operations with mul, min, and, or, xor for …
arhik Jan 11, 2026
5d39500
operations: add axis(i) helper for 1-based to 0-based axis conversion
arhik Jan 11, 2026
554d2b0
make axis public
arhik Jan 11, 2026
a10c086
make new reduce_{ops} public
arhik Jan 11, 2026
abd8299
reduce ops update and axis convenience
arhik Jan 11, 2026
fc48338
reduce ops: add wrapper functions and correct identity values via dis…
arhik Jan 11, 2026
dd25158
add IntIdentity type for integer reduce operations
arhik Jan 11, 2026
4401d37
rename IntIdentity to IntegerIdentity for clarity
arhik Jan 11, 2026
a70f7b7
fix: remove AbstractFloat constraint from reduce_sum and reduce_max
arhik Jan 11, 2026
bbda0f1
use Number constraint for numeric reduce operations
arhik Jan 11, 2026
33d93ae
add signed field to IntegerIdentity for proper signed/unsigned encoding
arhik Jan 11, 2026
175fc9a
remove unused is_reduce kwarg from encode_tagged_int
arhik Jan 11, 2026
b60a327
rename ReduceIdentity to OperationIdentity
arhik Jan 11, 2026
28a8875
rename identity types to IdentityOp hierarchy
arhik Jan 11, 2026
56f3376
fix is_signed to use proper Julia type hierarchy check
arhik Jan 11, 2026
0327a2e
intrinsics: use -one(T) instead of -1 for signed AND identity
arhik Jan 11, 2026
d1c9c0a
test: restore codegen and types tests, fix reduce_ops reference
arhik Jan 11, 2026
2c3642a
multiline comment mess in reduce_ops.jl
arhik Jan 11, 2026
212786c
intrinsics: rename reduce_identity -> operation_identity
arhik Jan 11, 2026
5dfde9e
test: fix CPU reference functions for bitwise ops
arhik Jan 11, 2026
b256403
bytecode: fix zigzag encoding for signed varint
arhik Jan 11, 2026
a50055f
reverting zigzag encoding
arhik Jan 11, 2026
7ba27c4
bytecode: fix zigzag encoding for signed varint
arhik Jan 11, 2026
95555c6
bytecode: remove duplicate encode_signed_varint!
arhik Jan 11, 2026
55efe4b
intrinsics: pass signedness to encode_MinIOp! and encode_MaxIOp!
arhik Jan 11, 2026
77118eb
reverting original encode_signed_varint
arhik Jan 12, 2026
dfffe98
revert comment inside encode_signed_varint
arhik Jan 12, 2026
eeb4d24
scanops related changes
arhik Jan 11, 2026
6773baf
consolidating common identity ops with scanops branch
arhik Jan 12, 2026
9b04ce7
scan: add min/max support for float and integer types
arhik Jan 12, 2026
5e6e1a1
simple scan kernel example.
arhik Jan 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,12 @@ conservative token threading in the compiler (see https://github.com/JuliaGPU/cu
| Operation | Description |
|-----------|-------------|
| `reduce_sum(tile, axis)` | Sum along axis |
| `reduce_mul(tile, axis)` | Product along axis |
| `reduce_max(tile, axis)` | Maximum along axis |
| `reduce_min(tile, axis)` | Minimum along axis |
| `reduce_and(tile, axis)` | Bitwise AND along axis (integer) |
| `reduce_or(tile, axis)` | Bitwise OR along axis (integer) |
| `reduce_xor(tile, axis)` | Bitwise XOR along axis (integer) |

### Math
| Operation | Description |
Expand Down Expand Up @@ -274,6 +279,11 @@ ct.permute(tile, (3, 1, 2))

This applies to `bid`, `num_blocks`, `permute`, `reshape`, dimension arguments, etc.

### axis convenience

| `axis(i)` | Convert 1-based axis to 0-based (helper) |


### `Val`-like constants

CuTile.jl uses `ct.Constant{T}` to encode compile-time constant values in the type domain, similar to how `Val` works. An explicit `[]` is needed to extract the value at runtime:
Expand Down
20 changes: 20 additions & 0 deletions examples/reducekernel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using Test
using CUDA
using cuTile
import cuTile as ct

elType = UInt16
function reduceKernel(a::ct.TileArray{elType,1}, b::ct.TileArray{elType,1}, tileSz::ct.Constant{Int})
bid = ct.bid(1)
tile = ct.load(a, bid, (tileSz[],))
result = ct.reduce_min(tile, Val(1))
ct.store(b, bid, result)
return nothing
end

sz = 32
N = 2^15
a = CUDA.rand(elType, N)
b = CUDA.zeros(elType, cld(N, sz))
CUDA.@sync ct.launch(reduceKernel, cld(length(a), sz), a, b, ct.Constant(sz))
res = Array(b)
62 changes: 62 additions & 0 deletions examples/scanKernel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Test
using CUDA
using cuTile
import cuTile as ct

function cumsum_1d_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
tile_size::ct.Constant{Int})
bid = ct.bid(1)
tile = ct.load(a, bid, (tile_size[],))
result = ct.cumsum(tile, ct.axis(1))
ct.store(b, bid, result)
return nothing
end

sz = 32
N = 2^15
a = CUDA.rand(Float32, N)
b = CUDA.zeros(Float32, N)
CUDA.@sync ct.launch(cumsum_1d_kernel, cld(length(a), sz), a, b, ct.Constant(sz))

# This is supposed to be a single pass kernel but its simpler version than memory ordering version.
# The idea is to show scan operation.

# CSDL phase 1: Intra-tile scan + store tile sums
function cumsum_csdl_phase1(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
tile_sums::ct.TileArray{Float32,1},
tile_size::ct.Constant{Int})
bid = ct.bid(1)
tile = ct.load(a, bid, (tile_size[],))
result = ct.cumsum(tile, ct.axis(1))
ct.store(b, bid, result)
tile_sum = ct.extract(result, (tile_size[],), (1,))
ct.store(tile_sums, bid, tile_sum)
return
end

# CSDL phase 2: Decoupled lookback to accumulate previous tile sums
function cumsum_csdl_phase2(b::ct.TileArray{Float32,1},
tile_sums::ct.TileArray{Float32,1},
tile_size::ct.Constant{Int})
bid = ct.bid(1)
prev_sum = ct.zeros((tile_size[],), Float32)
k = Int32(bid)
while k > 1
tile_sum_k = ct.load(tile_sums, (k,), (1,))
prev_sum = prev_sum .+ tile_sum_k
k -= Int32(1)
end
tile = ct.load(b, bid, (tile_size[],))
result = tile .+ prev_sum
ct.store(b, bid, result)
return nothing
end

n = length(a)
num_tiles = cld(n, sz)
tile_sums = CUDA.zeros(Float32, num_tiles)
CUDA.@sync ct.launch(cumsum_csdl_phase1, num_tiles, a, b, tile_sums, ct.Constant(sz))
CUDA.@sync ct.launch(cumsum_csdl_phase2, num_tiles, b, tile_sums, ct.Constant(sz))

b_cpu = cumsum(a |> collect, dims=1)
@test isapprox(b |> collect, b_cpu)
74 changes: 73 additions & 1 deletion src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
identities::Vector{<:ReduceIdentity},
identities::Vector{<:IdentityOp},
body_scalar_types::Vector{TypeId})
encode_varint!(cb.buf, Opcode.ReduceOp)

Expand Down Expand Up @@ -1331,6 +1331,78 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder,
end
end


#=============================================================================
Scan operations
=============================================================================#

"""
encode_ScanOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
reverse::Bool,
identities::Vector{<:IdentityOp},
body_scalar_types::Vector{TypeId})

Encode a ScanOp (parallel prefix sum) operation.

# Arguments
- body: Function that takes block args and yields result(s)
- cb: CodeBuilder for the bytecode
- result_types: Output tile types
- operands: Input tiles to scan
- dim: Dimension to scan along (0-indexed)
- reverse: Whether to scan in reverse order
- identities: Identity values for each operand
- body_scalar_types: 0D tile types for body arguments
"""
function encode_ScanOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId},
operands::Vector{Value},
dim::Int,
reverse::Bool,
identities::Vector{<:IdentityOp},
body_scalar_types::Vector{TypeId})
encode_varint!(cb.buf, Opcode.ScanOp)

# Variadic result types
encode_typeid_seq!(cb.buf, result_types)

# Attributes: dim (int), reverse (bool), identities (array)
encode_opattr_int!(cb, dim)
encode_opattr_bool!(cb, reverse)
encode_identity_array!(cb, identities)

# Variadic operands
encode_varint!(cb.buf, length(operands))
encode_operands!(cb.buf, operands)

# Number of regions
push!(cb.debug_attrs, cb.cur_debug_attr)
cb.num_ops += 1
encode_varint!(cb.buf, 1) # 1 region: body

# Body region - block args are pairs of (acc, elem) for each operand
# The body operates on 0D tiles (scalars)
body_arg_types = TypeId[]
for scalar_type in body_scalar_types
push!(body_arg_types, scalar_type) # accumulator
push!(body_arg_types, scalar_type) # element
end
with_region(body, cb, body_arg_types)

# Create result values
num_results = length(result_types)
if num_results == 0
return Value[]
else
vals = [Value(cb.next_value_id + i) for i in 0:num_results-1]
cb.next_value_id += num_results
return vals
end
end

#=============================================================================
Comparison and selection operations
=============================================================================#
Expand Down
63 changes: 52 additions & 11 deletions src/bytecode/writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,30 +234,42 @@ end
=============================================================================#

"""
ReduceIdentity
IdentityOp

Abstract type for reduce identity attributes.
Abstract type for binary operation identity attributes (reduce, scan, etc.).
"""
abstract type ReduceIdentity end
abstract type IdentityOp end

"""
FloatIdentity(value, type_id, dtype)
FloatIdentityOp(value, type_id, dtype)

Float identity value for reduce operations.
Float identity value for binary operations.
"""
struct FloatIdentity <: ReduceIdentity
struct FloatIdentityOp <: IdentityOp
value::Float64
type_id::TypeId
dtype::Type # Float16, Float32, Float64, etc.
end

"""
encode_tagged_float!(cb, identity::FloatIdentity)
IntegerIdentityOp(value, type_id, dtype, signed)

Integer identity value for binary operations.
"""
struct IntegerIdentityOp <: IdentityOp
value::Int64 # Store as signed Int64, will be reinterpreted as unsigned
type_id::TypeId
dtype::Type # Int8, Int16, Int32, Int64, UInt8, etc.
signed::Bool # true for signed, false for unsigned
end

"""
encode_tagged_float!(cb, identity::FloatIdentityOp)

Encode a tagged float attribute for reduce identity.
Format: tag(Float=0x02) + typeid + ap_int(value_bits)
"""
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentityOp)
# Tag for Float attribute
push!(cb.buf, 0x02)
# Type ID
Expand All @@ -267,6 +279,25 @@ function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
encode_signed_varint!(cb.buf, bits)
end

"""
encode_tagged_integer!(cb, identity::IntegerIdentityOp)

Encode a tagged integer identity attribute.
Format: tag(Int=0x01) + typeid + ap_int(value)
"""
function encode_tagged_integer!(cb::CodeBuilder, identity::IntegerIdentityOp)
# Tag for Int attribute
push!(cb.buf, 0x01)
# Type ID
encode_typeid!(cb.buf, identity.type_id)
# Value: signed uses zigzag varint, unsigned uses plain varint
if identity.signed
encode_signed_varint!(cb.buf, identity.value)
else
encode_varint!(cb.buf, UInt64(identity.value))
end
end

"""
float_to_bits(value, dtype)

Expand Down Expand Up @@ -296,6 +327,7 @@ end
Encode a signed integer as a variable-length integer.
Uses zigzag encoding for signed values.
"""

function encode_signed_varint!(buf::Vector{UInt8}, value::Union{UInt16, UInt32, UInt64, Int64})
# For float bits, encode as unsigned varint
encode_varint!(buf, UInt64(value))
Expand All @@ -304,15 +336,24 @@ end
"""
encode_identity_array!(cb, identities)

Encode an array of reduce identity attributes.
Encode an array of binary operation identity attributes.
Dispatches on identity type to encode correctly.
"""
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:ReduceIdentity})
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:IdentityOp})
encode_varint!(cb.buf, length(identities))
for identity in identities
encode_tagged_float!(cb, identity)
encode_identity!(cb, identity)
end
end

"""
encode_identity!(cb, identity)

Encode a single identity attribute, dispatching on type.
"""
encode_identity!(cb::CodeBuilder, identity::FloatIdentityOp) = encode_tagged_float!(cb, identity)
encode_identity!(cb::CodeBuilder, identity::IntegerIdentityOp) = encode_tagged_integer!(cb, identity)

"""
BytecodeWriter

Expand Down
1 change: 1 addition & 0 deletions src/compiler/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Base: compilerbarrier, donotdelete
using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView
using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned
using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual
using ..cuTile: IdentityOp, FloatIdentityOp, IntegerIdentityOp

end

Expand Down
Loading