-
Notifications
You must be signed in to change notification settings - Fork 2
tile scan ops related changes #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
using Test
using CUDA
using cuTile
import cuTile as ct
# intra tile scan
function cumsum_1d_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
tile_size::ct.Constant{Int})
bid = ct.bid(1)
tile = ct.load(a, bid, (tile_size[],))
result = ct.cumsum(tile, ct.axis(1))
ct.store(b, bid, result)
return nothing
end
sz = 32
N = 2^15
a = CUDA.rand(Float32, N)
b = CUDA.zeros(Float32, N)
CUDA.@sync ct.launch(cumsum_1d_kernel, cld(length(a), sz), a, b, ct.Constant(sz))
# This is supposed to be a single pass kernel but its simpler version than memory ordering version.
# The idea is to show scan operation.
# CSDL phase 1: Intra-tile scan + store tile sums
function cumsum_csdl_phase1(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
tile_sums::ct.TileArray{Float32,1},
tile_size::ct.Constant{Int})
bid = ct.bid(1)
tile = ct.load(a, bid, (tile_size[],))
result = ct.cumsum(tile, ct.axis(1))
ct.store(b, bid, result)
tile_sum = ct.extract(result, (tile_size[],), (1,))
ct.store(tile_sums, bid, tile_sum)
return
end
# CSDL phase 2: Decoupled lookback to accumulate previous tile sums
function cumsum_csdl_phase2(b::ct.TileArray{Float32,1},
tile_sums::ct.TileArray{Float32,1},
tile_size::ct.Constant{Int})
bid = ct.bid(1)
prev_sum = ct.zeros((tile_size[],), Float32)
k = Int32(bid)
while k > 1
tile_sum_k = ct.load(tile_sums, (k,), (1,))
prev_sum = prev_sum .+ tile_sum_k
k -= Int32(1)
end
tile = ct.load(b, bid, (tile_size[],))
result = tile .+ prev_sum
ct.store(b, bid, result)
return nothing
end
n = length(a)
num_tiles = cld(n, sz)
tile_sums = CUDA.zeros(Float32, num_tiles)
CUDA.@sync ct.launch(cumsum_csdl_phase1, num_tiles, a, b, tile_sums, ct.Constant(sz))
CUDA.@sync ct.launch(cumsum_csdl_phase2, num_tiles, b, tile_sums, ct.Constant(sz))
b_cpu = cumsum(a |> collect, dims=1)
@test isapprox(b |> collect, b_cpu)
using BenchmarkTools
@benchmark CUDA.@sync begin
CUDA.@sync ct.launch(cumsum_csdl_phase1, num_tiles, a, b, tile_sums, ct.Constant(sz))
CUDA.@sync ct.launch(cumsum_csdl_phase2, num_tiles, b, tile_sums, ct.Constant(sz))
end |
|
Wanted to add modular sum and modular prod. Next PR maybe. |
…both float and integer types - Add reduce_mul, reduce_min, reduce_and, reduce_or, reduce_xor intrinsic functions - Remove AbstractFloat constraint from reduce_sum and reduce_max - Add corresponding emit_intrinsic! handlers for new operations - Add integer encode_reduce_body methods for and, or, xor operations Summary of Additions | Function | Symbol | Types | |----------|--------|-------| | `reduce_sum` | `:add` | Any | | `reduce_max` | `:max` | Any | | `reduce_mul` | `:mul` | Any | | `reduce_min` | `:min` | Any | | `reduce_and` | `:and` | Integer only | | `reduce_or` | `:or` | Integer only | | `reduce_xor` | `:xor` | Integer only |
- Add axis(i::Integer) -> Val{i-1} convenience function
- Use instead of raw Val for self-documenting axis selection
axis convenience is a bit helper function for `Val`. But I see reduce is already one-based. Not sure if we should go with it. It doesn't harm anything. its just a convenience.
…patch - Add wrapper functions in operations.jl for reduce_mul, reduce_min, reduce_and, reduce_or, reduce_xor with appropriate type constraints - Refactor identity value selection to use dispatch instead of if-else chain - Correct identity values: - add: 0.0 - max: -Inf (float) or 0 (int) - mul: 1.0 - min: +Inf (float) or typemax(Int64) (int) - and: 0 (interpreted as -1 bits by backend) - or: 0 - xor: 0
- Add IntIdentity struct to bytecode/writer.jl for proper integer identity encoding - Add encode_tagged_int! function for encoding integer identity attributes (tag 0x01) - Dispatch encode_identity! on identity type for proper encoding - Update reduce_identity to return IntIdentity for integer operations - Import ReduceIdentity, FloatIdentity, IntIdentity in intrinsics.jl Identity values now properly typed: - Float operations → FloatIdentity - Integer operations → IntIdentity
The intrinsics were updated to support all types but the wrapper functions in operations.jl still had T <: AbstractFloat constraint, causing method lookup failures for integer types.
- reduce_sum, reduce_max, reduce_mul, reduce_min now use T <: Number - Provides type safety while supporting all numeric types - More self-documenting than unconstrained T
- IntegerIdentity now has signed::Bool field - encode_tagged_int! encodes signed with zigzag varint, unsigned with plain varint - Add is_signed() helper that checks T <: SignedInteger - Update all reduce_identity calls to pass is_signed(T)
- Abstract type now called OperationIdentity to reflect use by both reduce and scan operations - FloatIdentity and IntegerIdentity now inherit from OperationIdentity - Updated comments and docs to reflect the broader scope - Updated import in intrinsics.jl
- IdentityOp: abstract type for binary operation identities - FloatIdentityOp: concrete type for float identities - IntegerIdentityOp: concrete type for integer identities (with signed field) - Applied consistently across writer.jl, encodings.jl, intrinsics.jl, and core.jl
- T <: Integer && !(T <: Unsigned) correctly identifies: - Int32, Int64, etc. as signed (true) - UInt32, UInt64, etc. as unsigned (false)
Ensures type-consistent encoding for Int8, Int16, etc. intrinsics: use type-dependent identity values for reduce ops - add: zero(T) - max: typemin(T) - mul: one(T) - min: typemax(T) - and: is_signed ? -1 : typemax(T) for proper bit representation - or, xor: zero(T) Fixes encoding error for UInt32 (9223372036854775807 does not fit in 32 bits) Update core.jl fix reduce_min identity to use typemax(T) instead of typemax(Int64) - For UInt32, typemax(UInt32) = 4294967295 fits in 32 bits - typemax(Int64) = 9223372036854775807 does not fit and caused encoding error
test: add comprehensive reduce operations tests - Tests for all reduce ops: add, mul, min, max, and, or, xor - Tests for Float32, Float64, Int32, UInt32, Int8 types - Tests for axis 0 and axis 1 reductions - Compares GPU results against CPU reference implementations - Includes UInt32 and Int8 tests for identity encoding fix
used agent to create tests and hence the wrath.
Prepares for reuse by scan operations. Function is shape-agnostic and depends only on operation type and element type.
Julia's reduce with dims= requires explicit init for &,|,⊻ operators. Use typemax(T) for AND (identity with all bits set).
The original code tried to convert Int64 directly to UInt64, which fails for negative values like typemin(Int32) = -2147483648. Zigzag encoding maps: (n << 1) ⊻ (n >> 63), enabling proper encoding of negative integers in varint format.
Zigzag encoding: (n << 1) ⊻ (n >> 63) properly handles negative values like typemin(Int32) = -2147483648. Unsigned values use plain varint encoding since they don't need zigzag.
The correct implementation is in src/bytecode/basic.jl:
function encode_signed_varint!(buf, x)
x = x << 1
if x < 0
x = ~x
end
encode_varint!(buf, x)
end
The duplicate in writer.jl was shadowing the correct one.
For unsigned integer types like UInt16, UInt32, the comparison must use unsigned signedness, not the default SignednessSigned. This fixes wrong reduction results for unsigned types where signed comparison was causing values to be interpreted incorrectly (e.g., 0xFFFF interpreted as -1).
|
This seems to conflict weirdly with #21? How did you develop this? |
|
I branched out from #21 's last commit ff8b77b and made changes for scan. What conflicts do you see ? I made a silly forced push in tilereduce commit. That must be problematic. I can't think of any other reasons. Edit: Actually I think branched much early and merged the reduce branch into the tile scan branch and continued on that merge. I was working parallelly on those branches. It was not strictly linear work. I am not sure what you mean by "How did you develop this?". So I am trying to answer all possibilities I can think of. In case you are asking about my setup; I don't have a compatible GPU so I was launching 5060-5090 GPU's on avast.ai with CUDA 13.1 filter applied for the development. |
Ah ok. You can change the target branch of this PR to make that clear; right now the diff is massive with significant overlap with that other PR. Some of the conflicting changes I saw was |
- We use same IdentityOp abstraction - scan uses the same encodings as reduce op since they mathematically same ops. - left room for dispatch, lets say scan, to take custom path if needed.
- Added encode_scan_body methods for :min and :max operations - Float types use MinFOp/MaxFOp (no signedness needed) - Integer types use MinIOp/MaxIOp with signedness parameter - Aligns scan implementation with existing reduce pattern
|
#39 will be simpler PR. Closing this |
Scan Operations - Pull Request Summary
Overview
This PR adds comprehensive scan (parallel prefix sum) operation support to the cuTile compiler, building on the existing reduce operations framework with shared abstractions.
Key Changes
1. Commit ca48190: Consolidate Common Identity Operations
Unified
IdentityOpabstraction between reduce and scan operationsReused the same encoding patterns since mathematically both ops are identical
Left room for future dispatch paths if scan needs custom handling
Files modified:
src/bytecode/encodings.jl- Simplified encoding interfacesrc/bytecode/writer.jl- Removed 96 lines of duplicate codesrc/compiler/intrinsics/core.jl- Unified identity handling2. Commit 1840d16: Add Min/Max Support for Scan
Added
encode_scan_bodymethods for:minand:maxoperationsFloat types:
MinFOp/MaxFOp(no signedness needed)Integer types:
MinIOp/MaxIOpwith signedness parameterAligns scan implementation with existing reduce pattern
Files modified:
src/compiler/intrinsics/core.jl- 8 new lines3. Commit 2769c85: Simple Scan Kernel Example
Added
examples/scanKernel.jldemonstrating basic usageShows how to construct and execute scan operations
Files added:
examples/scanKernel.jl- 62 linesSupported Operations (as of this PR)
:addAddFOpAddIOp:mulMulFOpMulIOp:minMinFOpMinIOp(with signedness):maxMaxFOpMaxIOp(with signedness)Architectural Decisions
OperationIdentityfor dispatchis_signed(T)parameterencode_scan_bodypattern