diff --git a/.gitignore b/.gitignore index d2c7fe1..daf5565 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ LocalPreferences.toml CLAUDE.md AGENTS.md TODO.md +__pycache__ diff --git a/examples/batchmatmul.jl b/examples/batchmatmul.jl index 89b2ba4..a9c2186 100644 --- a/examples/batchmatmul.jl +++ b/examples/batchmatmul.jl @@ -57,37 +57,89 @@ function batch_matmul_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.T return nothing end -function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) where T - name = something(name, "batch_matmul ($M x $K x $Batch) @ ($K x $N x $Batch), $T, tiles=$tm x $tn x $tk") - println("--- $name ---") - - # Batch-last ordering for optimal column-major access - A = CUDA.rand(T, M, K, Batch) - B = CUDA.rand(T, K, N, Batch) - C = CUDA.zeros(T, M, N, Batch) +#============================================================================= + Example harness +=============================================================================# + +function prepare(; benchmark::Bool=false, + M::Int=benchmark ? 1024 : 256, + K::Int=benchmark ? 512 : 128, + N::Int=benchmark ? 2048 : 256, + Batch::Int=benchmark ? 8 : 4, + T::DataType=Float32) + return (; + A = CUDA.rand(T, M, K, Batch), + B = CUDA.rand(T, K, N, Batch), + C = CuArray{T}(undef, M, N, Batch), + M, K, N, Batch + ) +end - # 3D grid: (M_tiles, N_tiles, Batch) +function run(data; tm::Int=64, tn::Int=64, tk::Int=64, nruns::Int=1, warmup::Int=0) + (; A, B, C, M, N, Batch) = data grid = (cld(M, tm), cld(N, tn), Batch) - # Launch kernel - ct.launch(batch_matmul_kernel, grid, A, B, C, - ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) + CUDA.@sync for _ in 1:warmup + ct.launch(batch_matmul_kernel, grid, A, B, C, + ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) + end - # Verify result - compute batched matmul on CPU + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(batch_matmul_kernel, grid, A, B, C, + ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) + push!(times, t * 1000) # ms + end + + return (; C, times) +end + +function verify(data, result) + (; A, B, M, N, Batch) = data A_cpu = Array(A) B_cpu = Array(B) expected = similar(A_cpu, M, N, Batch) for b in 1:Batch expected[:, :, b] = A_cpu[:, :, b] * B_cpu[:, :, b] end - result = Array(C) + @assert isapprox(Array(result.C), expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))" +end - if isapprox(result, expected, rtol=1e-2, atol=1e-2) - println(" passed") - else - max_diff = maximum(abs.(result - expected)) - println(" FAILED (max diff: $max_diff)") +#============================================================================= + Reference implementations for benchmarking +=============================================================================# + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; A, B, M, N, Batch) = data + results = Dict{String, Vector{Float64}}() + + C_cublas = similar(A, M, N, Batch) + + # cuBLAS batched gemm via CUBLAS.gemm_strided_batched! + CUDA.@sync for _ in 1:warmup + CUDA.CUBLAS.gemm_strided_batched!('N', 'N', one(eltype(A)), A, B, zero(eltype(A)), C_cublas) end + times_cublas = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed CUDA.CUBLAS.gemm_strided_batched!('N', 'N', one(eltype(A)), A, B, zero(eltype(A)), C_cublas) + push!(times_cublas, t * 1000) + end + results["cuBLAS batched"] = times_cublas + + return results +end + +#============================================================================= + Main +=============================================================================# + +function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) where T + name = something(name, "batch_matmul ($M x $K x $Batch) @ ($K x $N x $Batch), $T, tiles=$tm x $tn x $tk") + println("--- $name ---") + data = prepare(; M, K, N, Batch, T) + result = run(data; tm, tn, tk) + verify(data, result) + println(" passed") end function main() diff --git a/examples/batchmatmul.py b/examples/batchmatmul.py new file mode 100644 index 0000000..07bce7e --- /dev/null +++ b/examples/batchmatmul.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Batch Matrix Multiplication example - cuTile Python +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct +from math import ceil + +@ct.kernel +def batchmatmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int]): + """CuTile kernel for batch matrix multiplication + A has shape (Batch, M, K), B has shape (Batch, K, N) and C has shape (Batch, M, N) + Grid: (Batch, M_tiles, N_tiles) + """ + pid_batch = ct.bid(0) + bidx = ct.bid(1) + bidy = ct.bid(2) + + num_k_tiles = ct.cdiv(A.shape[2], tk) + accumulator = ct.full((tm, tn), 0.0, dtype=ct.float32) + zero_pad = ct.PaddingMode.ZERO + + for k in range(num_k_tiles): + a = ct.load(A, index=(pid_batch, bidx, k), shape=(1, tm, tk), padding_mode=zero_pad) + a = ct.reshape(a, (tm, tk)) + + b = ct.load(B, index=(pid_batch, k, bidy), shape=(1, tk, tn), padding_mode=zero_pad) + b = ct.reshape(b, (tk, tn)) + + accumulator = ct.mma(a, b, acc=accumulator) + + result = ct.astype(accumulator, C.dtype) + result_3d = ct.reshape(result, (1, tm, tn)) + ct.store(C, index=(pid_batch, bidx, bidy), tile=result_3d) + + +#============================================================================= +# Example harness +#============================================================================= + +def prepare(*, benchmark: bool = False, Batch: int = None, M: int = None, K: int = None, N: int = None, dtype=np.float16): + """Allocate and initialize data for batch matmul.""" + if Batch is None: + Batch = 8 if benchmark else 4 + if M is None: + M = 1024 if benchmark else 256 + if K is None: + K = 512 if benchmark else 128 + if N is None: + N = 2048 if benchmark else 256 + return { + "A": cp.random.randn(Batch, M, K).astype(dtype), + "B": cp.random.randn(Batch, K, N).astype(dtype), + "C": cp.empty((Batch, M, N), dtype=dtype), + "Batch": Batch, + "M": M, + "K": K, + "N": N + } + + +def run(data, *, tm: int = 64, tn: int = 64, tk: int = 64, nruns: int = 1, warmup: int = 0): + """Run batch matmul kernel with timing.""" + A, B, C = data["A"], data["B"], data["C"] + Batch, M, N = data["Batch"], data["M"], data["N"] + + grid = (Batch, ceil(M / tm), ceil(N / tn)) + stream = cp.cuda.get_current_stream() + + # Warmup + for _ in range(warmup): + ct.launch(stream, grid, batchmatmul_cutile_kernel, (A, B, C, tm, tn, tk)) + cp.cuda.runtime.deviceSynchronize() + + # Timed runs + times = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + ct.launch(stream, grid, batchmatmul_cutile_kernel, (A, B, C, tm, tn, tk)) + end.record(stream) + end.synchronize() + times.append(cp.cuda.get_elapsed_time(start, end)) # ms + + return {"C": C, "times": times} + + +def verify(data, result): + """Verify batch matmul results.""" + A_np = cp.asnumpy(data["A"]).astype(np.float32) + B_np = cp.asnumpy(data["B"]).astype(np.float32) + C_np = cp.asnumpy(result["C"]).astype(np.float32) + Batch, M, N = data["Batch"], data["M"], data["N"] + + expected = np.zeros((Batch, M, N), dtype=np.float32) + for b in range(Batch): + expected[b] = A_np[b] @ B_np[b] + assert np.allclose(C_np, expected, rtol=1e-1, atol=1e-1), \ + f"batchmatmul incorrect! max diff: {np.max(np.abs(C_np - expected))}" + + +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + import torch + + results = {} + A_cp, B_cp = data["A"], data["B"] + Batch, M, N = data["Batch"], data["M"], data["N"] + + # PyTorch bmm + A_torch = torch.as_tensor(A_cp, device='cuda') + B_torch = torch.as_tensor(B_cp, device='cuda') + C_torch = torch.zeros(Batch, M, N, dtype=A_torch.dtype, device='cuda') + + for _ in range(warmup): + torch.bmm(A_torch, B_torch, out=C_torch) + torch.cuda.synchronize() + + times_torch = [] + for _ in range(nruns): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.bmm(A_torch, B_torch, out=C_torch) + end.record() + torch.cuda.synchronize() + times_torch.append(start.elapsed_time(end)) + results["PyTorch bmm"] = times_torch + + return results + + +#============================================================================= +# Main +#============================================================================= + +def test_batchmatmul(Batch, M, K, N, tm, tn, tk, dtype=np.float16, name=None): + """Test batch matmul with given parameters.""" + name = name or f"batchmatmul ({Batch}x{M}x{K}) @ ({Batch}x{K}x{N}), tiles={tm}x{tn}x{tk}, dtype={dtype.__name__}" + print(f"--- {name} ---") + data = prepare(Batch=Batch, M=M, K=K, N=N, dtype=dtype) + result = run(data, tm=tm, tn=tn, tk=tk) + verify(data, result) + print(" passed") + + +def main(): + print("--- cuTile Batch Matrix Multiplication Examples ---\n") + + test_batchmatmul(4, 256, 128, 256, 32, 32, 32, np.float32) + test_batchmatmul(4, 512, 256, 512, 64, 64, 64, np.float32) + test_batchmatmul(4, 512, 256, 1024, 128, 256, 64, np.float16) + + print("\n--- All batchmatmul examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/benchmarks.jl b/examples/benchmarks.jl index 95071e8..e626cce 100644 --- a/examples/benchmarks.jl +++ b/examples/benchmarks.jl @@ -1,14 +1,9 @@ # EXCLUDE FROM TESTING # -# Comprehensive benchmarks for cuTile.jl -# Compares: GPUArrays (generic), SIMT (CUDA.jl), cuTile -# Kernels: vadd, transpose, matmul +# Generic benchmark runner for cuTile.jl examples +# Discovers and benchmarks all examples in the examples/ directory using CUDA -using LinearAlgebra -using CUDA: GPUArrays -using FFTW -import cuTile as ct #============================================================================= Configuration @@ -17,19 +12,6 @@ import cuTile as ct const NRUNS = 10 const WARMUP = 3 -# Data sizes - large enough to saturate GPU and minimize launch overhead -const VADD_SIZE = 2^27 # 512 MB (128M elements) -const TRANSPOSE_DIM = 8192 # 8192x8192 = 268 MB -const MATMUL_DIM = 4096 # 4096x4096x4096 - -# Tile sizes -const VADD_TILE = 1024 -const TRANSPOSE_TILE_M = 64 -const TRANSPOSE_TILE_N = 64 -const MATMUL_TM = 64 -const MATMUL_TN = 64 -const MATMUL_TK = 64 - #============================================================================= Benchmark Utilities =============================================================================# @@ -40,797 +22,70 @@ struct BenchmarkResult mean_ms::Float64 end -function benchmark_kernel(f, nruns::Int=NRUNS, warmup::Int=WARMUP) - # Warmup - for _ in 1:warmup - f() - end - CUDA.synchronize() - - # Benchmark - times = Float64[] - for _ in 1:nruns - t = CUDA.@elapsed f() - push!(times, t * 1000) # Convert to ms - end - - return minimum(times), sum(times) / length(times) -end - -function print_table(title::String, results::Vector{BenchmarkResult}; extra_col=nothing) +function print_table(title::String, results::Vector{BenchmarkResult}) println() println("=" ^ 60) println(" ", title) println("=" ^ 60) - - if extra_col !== nothing - println(rpad("Implementation", 20), rpad("Min (ms)", 12), rpad("Mean (ms)", 12), extra_col[1]) - println("-" ^ 60) - for (i, r) in enumerate(results) - extra = extra_col[2][i] - println(rpad(r.name, 20), rpad(round(r.min_ms, digits=3), 12), - rpad(round(r.mean_ms, digits=3), 12), extra) - end - else - println(rpad("Implementation", 20), rpad("Min (ms)", 12), "Mean (ms)") - println("-" ^ 60) - for r in results - println(rpad(r.name, 20), rpad(round(r.min_ms, digits=3), 12), - round(r.mean_ms, digits=3)) - end - end + println(rpad("Implementation", 20), rpad("Min (ms)", 12), "Mean (ms)") println("-" ^ 60) -end - -#============================================================================= - Vector Addition -=============================================================================# - -# SIMT kernel -function vadd_simt_kernel!(a, b, c) - i = (blockIdx().x - 1) * blockDim().x + threadIdx().x - if i <= length(c) - @inbounds c[i] = a[i] + b[i] - end - return -end - -# cuTile kernel -function vadd_cutile_kernel(a, b, c, tile_size::ct.Constant{Int}) - pid = ct.bid(1) - tile_a = ct.load(a, pid, (tile_size[],)) - tile_b = ct.load(b, pid, (tile_size[],)) - result = tile_a + tile_b - ct.store(c, pid, result) - return -end - -function benchmark_vadd() - println("\nBenchmarking Vector Addition...") - println(" Size: $VADD_SIZE elements ($(VADD_SIZE * 4 / 1e6) MB)") - - a = CUDA.rand(Float32, VADD_SIZE) - b = CUDA.rand(Float32, VADD_SIZE) - c = similar(a) - expected = Array(a) .+ Array(b) - - results = BenchmarkResult[] - - # GPUArrays (broadcast) - gpuarrays_f = () -> begin - c .= a .+ b - end - gpuarrays_f() - CUDA.synchronize() - @assert Array(c) ≈ expected "GPUArrays incorrect!" - min_t, mean_t = benchmark_kernel(gpuarrays_f) - push!(results, BenchmarkResult("GPUArrays", min_t, mean_t)) - - # SIMT - threads = 1024 - blocks = cld(VADD_SIZE, threads) - simt_f = () -> @cuda threads=threads blocks=blocks vadd_simt_kernel!(a, b, c) - simt_f() - CUDA.synchronize() - @assert Array(c) ≈ expected "SIMT incorrect!" - min_t, mean_t = benchmark_kernel(simt_f) - push!(results, BenchmarkResult("SIMT (CUDA.jl)", min_t, mean_t)) - - # cuTile - grid = (cld(VADD_SIZE, VADD_TILE), 1, 1) - cutile_f = () -> ct.launch(vadd_cutile_kernel, grid, a, b, c, ct.Constant(VADD_TILE)) - cutile_f() - CUDA.synchronize() - @assert Array(c) ≈ expected "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Calculate bandwidth - bytes = 3 * VADD_SIZE * sizeof(Float32) # 2 reads + 1 write - bandwidths = [string(round(bytes / (r.min_ms / 1000) / 1e9, digits=1), " GB/s") for r in results] - - print_table("Vector Addition (Float32)", results; extra_col=("Bandwidth", bandwidths)) - return results -end - -#============================================================================= - Matrix Transpose -=============================================================================# - -# SIMT naive kernel -function transpose_simt_naive_kernel!(input, output, M, N) - i = (blockIdx().x - 1) * blockDim().x + threadIdx().x - j = (blockIdx().y - 1) * blockDim().y + threadIdx().y - if i <= M && j <= N - @inbounds output[j, i] = input[i, j] - end - return -end - -# SIMT shared memory kernel -function transpose_simt_shared_kernel!(input, output, M, N) - TILE = 32 - tile = CuStaticSharedArray(Float32, (TILE+1, TILE)) - - x = (blockIdx().x - 1) * TILE + threadIdx().x - y = (blockIdx().y - 1) * TILE + threadIdx().y - - if x <= M && y <= N - @inbounds tile[threadIdx().x, threadIdx().y] = input[x, y] - end - sync_threads() - - x = (blockIdx().y - 1) * TILE + threadIdx().x - y = (blockIdx().x - 1) * TILE + threadIdx().y - - if x <= N && y <= M - @inbounds output[x, y] = tile[threadIdx().y, threadIdx().x] - end - return -end - -# cuTile kernel -function transpose_cutile_kernel(input, output, tile_m::ct.Constant{Int}, tile_n::ct.Constant{Int}) - pid_m = ct.bid(1) - pid_n = ct.bid(2) - tile = ct.load(input, (pid_m, pid_n), (tile_m[], tile_n[])) - tile_t = ct.transpose(tile) - ct.store(output, (pid_n, pid_m), tile_t) - return -end - -function benchmark_transpose() - println("\nBenchmarking Matrix Transpose...") - M, N = TRANSPOSE_DIM, TRANSPOSE_DIM - println(" Size: $(M)x$(N) ($(M * N * 4 / 1e6) MB)") - - input = CUDA.rand(Float32, M, N) - output = CUDA.zeros(Float32, N, M) - expected = Array(permutedims(input, (2, 1))) - - results = BenchmarkResult[] - - # GPUArrays (permutedims) - gpuarrays_f = () -> permutedims!(output, input, (2, 1)) - gpuarrays_f() - CUDA.synchronize() - @assert Array(output) ≈ expected "GPUArrays incorrect!" - min_t, mean_t = benchmark_kernel(gpuarrays_f) - push!(results, BenchmarkResult("GPUArrays", min_t, mean_t)) - - # SIMT naive - fill!(output, 0) - threads_naive = (16, 16) - blocks_naive = (cld(M, 16), cld(N, 16)) - simt_naive_f = () -> @cuda threads=threads_naive blocks=blocks_naive transpose_simt_naive_kernel!(input, output, M, N) - simt_naive_f() - CUDA.synchronize() - @assert Array(output) ≈ expected "SIMT naive incorrect!" - min_t, mean_t = benchmark_kernel(simt_naive_f) - push!(results, BenchmarkResult("SIMT naive", min_t, mean_t)) - - # SIMT shared - fill!(output, 0) - threads_shared = (32, 32) - blocks_shared = (cld(M, 32), cld(N, 32)) - simt_shared_f = () -> @cuda threads=threads_shared blocks=blocks_shared transpose_simt_shared_kernel!(input, output, M, N) - simt_shared_f() - CUDA.synchronize() - @assert Array(output) ≈ expected "SIMT shared incorrect!" - min_t, mean_t = benchmark_kernel(simt_shared_f) - push!(results, BenchmarkResult("SIMT shared", min_t, mean_t)) - - # cuTile - fill!(output, 0) - grid = (cld(M, TRANSPOSE_TILE_M), cld(N, TRANSPOSE_TILE_N), 1) - cutile_f = () -> ct.launch(transpose_cutile_kernel, grid, input, output, - ct.Constant(TRANSPOSE_TILE_M), ct.Constant(TRANSPOSE_TILE_N)) - cutile_f() - CUDA.synchronize() - @assert Array(output) ≈ expected "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Calculate bandwidth - bytes = 2 * M * N * sizeof(Float32) # read + write - bandwidths = [string(round(bytes / (r.min_ms / 1000) / 1e9, digits=1), " GB/s") for r in results] - - print_table("Matrix Transpose (Float32)", results; extra_col=("Bandwidth", bandwidths)) - return results -end - -#============================================================================= - Matrix Multiplication -=============================================================================# - -# 2D swizzle for better L2 cache locality (using 0-indexed block IDs) -@inline function swizzle_2d(M, N, tm, tn, GROUP_SIZE_M, bid) - num_bid_m = cld(M, Int32(tm)) - num_bid_n = cld(N, Int32(tn)) - num_bid_in_group = Int32(GROUP_SIZE_M) * num_bid_n - group_id = fld(bid, num_bid_in_group) - first_bid_m = group_id * Int32(GROUP_SIZE_M) - group_size_m = min(num_bid_m - first_bid_m, Int32(GROUP_SIZE_M)) - bid_m = first_bid_m + rem(bid, group_size_m) - bid_n = fld(rem(bid, num_bid_in_group), group_size_m) - return bid_m, bid_n -end - -# cuTile matmul kernel with TF32 tensor cores -function matmul_cutile_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArray{T,2}, - tm::ct.Constant{Int}, tn::ct.Constant{Int}, tk::ct.Constant{Int}) where {T} - bid = ct.bid(1) - M = A.sizes[1] - N = B.sizes[2] - bid_m_0, bid_n_0 = swizzle_2d(M, N, tm[], tn[], 8, bid - Int32(1)) - bid_m = bid_m_0 + Int32(1) - bid_n = bid_n_0 + Int32(1) - - num_k = ct.num_tiles(A, 2, (tm[], tk[])) - acc = ct.full((tm[], tn[]), zero(Float32), Float32) - - # Use TF32 for tensor cores - dtype = T === Float32 ? ct.TFloat32 : T - - k = Int32(1) - while k <= num_k - a = ct.astype(ct.load(A, (bid_m, k), (tm[], tk[])), dtype) - b = ct.astype(ct.load(B, (k, bid_n), (tk[], tn[])), dtype) - acc = muladd(a, b, acc) - k += Int32(1) - end - - result = ct.astype(acc, T) - ct.store(C, (bid_m, bid_n), result) - return nothing -end - -function benchmark_matmul() - println("\nBenchmarking Matrix Multiplication...") - M, N, K = MATMUL_DIM, MATMUL_DIM, MATMUL_DIM - println(" Size: $(M)x$(K) * $(K)x$(N)") - - A = CUDA.rand(Float32, M, K) - B = CUDA.rand(Float32, K, N) - C = CUDA.zeros(Float32, M, N) - - # Reference result (cuBLAS) - C_ref = similar(C) - mul!(C_ref, A, B) - CUDA.synchronize() - - results = BenchmarkResult[] - flops = 2.0 * M * N * K - - # GPUArrays (generic matmul) - gpuarrays_f = () -> GPUArrays.generic_matmatmul!(C, A, B, one(Float32), zero(Float32)) - gpuarrays_f() - CUDA.synchronize() - @assert isapprox(Array(C), Array(C_ref), rtol=1e-2, atol=1e-2) "GPUArrays incorrect!" - min_t, mean_t = benchmark_kernel(gpuarrays_f) - push!(results, BenchmarkResult("GPUArrays", min_t, mean_t)) - - # cuBLAS - fill!(C, 0) - cublas_f = () -> mul!(C, A, B) - cublas_f() - CUDA.synchronize() - min_t, mean_t = benchmark_kernel(cublas_f) - push!(results, BenchmarkResult("cuBLAS", min_t, mean_t)) - - # cuTile - fill!(C, 0) - grid_m = cld(M, MATMUL_TM) - grid_n = cld(N, MATMUL_TN) - grid = (grid_m * grid_n, 1, 1) - cutile_f = () -> ct.launch(matmul_cutile_kernel, grid, A, B, C, - ct.Constant(MATMUL_TM), ct.Constant(MATMUL_TN), ct.Constant(MATMUL_TK)) - cutile_f() - CUDA.synchronize() - @assert isapprox(Array(C), Array(C_ref), rtol=1e-2, atol=1e-2) "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Calculate TFLOPS - tflops_vals = [string(round(flops / (r.min_ms * 1e-3) / 1e12, digits=2), " TFLOPS") for r in results] - - print_table("Matrix Multiplication (Float32, TF32 cores)", results; extra_col=("Performance", tflops_vals)) - return results -end - -#============================================================================= - Layer Normalization -=============================================================================# - -const LAYERNORM_M = 4096 -const LAYERNORM_N = 4096 -const LAYERNORM_TILE_N = 1024 -const LAYERNORM_EPS = 1f-5 - -# Batch matmul sizes -const BATCHMATMUL_BATCH = 8 -const BATCHMATMUL_M = 1024 -const BATCHMATMUL_K = 512 -const BATCHMATMUL_N = 2048 -const BATCHMATMUL_TM = 128 -const BATCHMATMUL_TN = 256 -const BATCHMATMUL_TK = 64 - -# FFT sizes -# Tile size is (D, BS, N2D), limited by tileiras compiler. -# Current kernel loads all batches per block, limiting scalability. -const FFT_BATCH = 64 -const FFT_SIZE = 512 -const FFT_FACTORS = (8, 8, 8) -const FFT_ATOM_PACKING_DIM = 2 - -# SIMT naive kernel (2-pass: compute mean/var, then normalize) -function layernorm_simt_kernel!(X, W, B, Y, Mean, Rstd, N, eps) - m = blockIdx().x - - # First pass: compute mean - mean_acc = 0.0f0 - for i in 1:N - @inbounds mean_acc += X[m, i] + for r in results + println(rpad(r.name, 20), rpad(round(r.min_ms, digits=3), 12), + round(r.mean_ms, digits=3)) end - mean = mean_acc / N - @inbounds Mean[m] = mean - - # Second pass: compute variance - var_acc = 0.0f0 - for i in 1:N - @inbounds diff = X[m, i] - mean - var_acc += diff * diff - end - var = var_acc / N - rstd = 1.0f0 / sqrt(var + eps) - @inbounds Rstd[m] = rstd - - # Third pass: normalize and apply affine - for i in 1:N - @inbounds Y[m, i] = (X[m, i] - mean) * rstd * W[i] + B[i] - end - - return -end - -# cuTile kernel (from layernorm.jl) -function layernorm_cutile_kernel(X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1}, - B::ct.TileArray{Float32, 1}, Y::ct.TileArray{Float32, 2}, - Mean::ct.TileArray{Float32, 1}, Rstd::ct.TileArray{Float32, 1}, - eps::ct.Constant{Float32}, TILE_N::ct.Constant{Int}) - bid_m = ct.bid(1) - num_tiles = ct.num_tiles(X, 2, (1, TILE_N[])) - N = X.sizes[2] - - # Compute mean - mean = ct.full((1, TILE_N[]), 0.0f0, Float32) - j = Int32(1) - while j <= num_tiles - tx = ct.load(X, (bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - mean = mean .+ tx - j += Int32(1) - end - mean = ct.reduce_sum(mean, 2) / N - ct.store(Mean, bid_m, mean) - - # Compute variance - var = ct.full((1, TILE_N[]), 0.0f0, Float32) - j = Int32(1) - while j <= num_tiles - tx = ct.load(X, (bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - # Mask for valid elements - mask = ct.broadcast_to(((j - Int32(1)) * Int32(TILE_N[]) .+ ct.arange((TILE_N[],), Int32)) .<= N, (1, TILE_N[])) - centered_tx = ct.where(mask, tx .- mean, ct.full((1, TILE_N[]), 0.0f0, Float32)) - var = var .+ (centered_tx .^ 2.0f0) - j += Int32(1) - end - var = ct.reduce_sum(var, 2) / N - rstd = 1.0f0 ./ sqrt.(var .+ eps[]) - ct.store(Rstd, bid_m, rstd) - - # Normalize and apply affine transformation - j = Int32(1) - while j <= num_tiles - tx = ct.load(X, (bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - tw = ct.load(W, j, (TILE_N[],); padding_mode=ct.PaddingMode.Zero) - tb = ct.load(B, j, (TILE_N[],); padding_mode=ct.PaddingMode.Zero) - ty = (tx .- mean) .* rstd - ty = ty .* tw .+ tb - ct.store(Y, (bid_m, j), ty) - j += Int32(1) - end - - return -end - -function benchmark_layernorm() - println("\nBenchmarking Layer Normalization...") - M, N = LAYERNORM_M, LAYERNORM_N - println(" Size: $(M)x$(N) ($(M * N * 4 / 1e6) MB)") - - X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N) - W = CUDA.randn(Float32, N) - B = CUDA.randn(Float32, N) - Y = CUDA.zeros(Float32, M, N) - Mean = CUDA.zeros(Float32, M) - Rstd = CUDA.zeros(Float32, M) - - # Reference result - X_cpu = Array(X) - W_cpu = Array(W) - B_cpu = Array(B) - expected_mean = vec(sum(X_cpu, dims=2) ./ N) - expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) - expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ LAYERNORM_EPS) - normalized = (X_cpu .- expected_mean) .* expected_rstd - expected_Y = normalized .* W_cpu' .+ B_cpu' - - results = BenchmarkResult[] - - # SIMT naive (single thread per row) - fill!(Y, 0); fill!(Mean, 0); fill!(Rstd, 0) - simt_f = () -> @cuda threads=1 blocks=M layernorm_simt_kernel!(X, W, B, Y, Mean, Rstd, N, LAYERNORM_EPS) - simt_f() - CUDA.synchronize() - @assert isapprox(Array(Y), expected_Y, rtol=1e-2, atol=1e-2) "SIMT incorrect!" - min_t, mean_t = benchmark_kernel(simt_f) - push!(results, BenchmarkResult("SIMT naive", min_t, mean_t)) - - # cuTile - fill!(Y, 0); fill!(Mean, 0); fill!(Rstd, 0) - cutile_f = () -> ct.launch(layernorm_cutile_kernel, M, X, W, B, Y, Mean, Rstd, - ct.Constant(LAYERNORM_EPS), ct.Constant(LAYERNORM_TILE_N)) - cutile_f() - CUDA.synchronize() - @assert isapprox(Array(Y), expected_Y, rtol=1e-2, atol=1e-2) "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Calculate bandwidth (rough estimate: 3 reads of X + W + B, 1 write of Y) - bytes = (3 * M * N + N + N + M * N) * sizeof(Float32) - bandwidths = [string(round(bytes / (r.min_ms / 1000) / 1e9, digits=1), " GB/s") for r in results] - - print_table("Layer Normalization (Float32)", results; extra_col=("Bandwidth", bandwidths)) - return results + println("-" ^ 60) end #============================================================================= - Batch Matrix Multiplication + Benchmark Discovery & Execution =============================================================================# -# Batch matmul kernel (3D arrays with batch-last ordering) -# A: (M, K, Batch), B: (K, N, Batch), C: (M, N, Batch) -function batchmatmul_cutile_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.TileArray{T,3}, - tm::ct.Constant{Int}, tn::ct.Constant{Int}, - tk::ct.Constant{Int}) where {T} - bid_m = ct.bid(1) - bid_n = ct.bid(2) - pid_batch = ct.bid(3) - - K = A.sizes[2] - num_k = cld(K, Int32(tk[])) - - acc = ct.full((tm[], tn[]), zero(Float32), Float32) - - k = Int32(1) - while k <= num_k - a = ct.load(A, (bid_m, k, pid_batch), (tm[], tk[], 1); - padding_mode=ct.PaddingMode.Zero) - b = ct.load(B, (k, bid_n, pid_batch), (tk[], tn[], 1); - padding_mode=ct.PaddingMode.Zero) - - a_2d = ct.reshape(a, (tm[], tk[])) - b_2d = ct.reshape(b, (tk[], tn[])) - - if T === Float32 - a_2d = convert(ct.Tile{ct.TFloat32}, a_2d) - b_2d = convert(ct.Tile{ct.TFloat32}, b_2d) - end - - acc = muladd(a_2d, b_2d, acc) - k += Int32(1) +function discover_benchmarks() + examples = String[] + for file in readdir(@__DIR__) + endswith(file, ".jl") || continue + file == "benchmarks.jl" && continue + name = replace(file, ".jl" => "") + push!(examples, name) end - - result = convert(ct.Tile{T}, acc) - result_3d = ct.reshape(result, (tm[], tn[], 1)) - ct.store(C, (bid_m, bid_n, pid_batch), result_3d) - - return nothing + return sort(examples) end -function benchmark_batchmatmul() - println("\nBenchmarking Batch Matrix Multiplication...") - Batch, M, K, N = BATCHMATMUL_BATCH, BATCHMATMUL_M, BATCHMATMUL_K, BATCHMATMUL_N - println(" Size: ($M x $K x $Batch) @ ($K x $N x $Batch), Float16") +function run_benchmark(name::String) + file = joinpath(@__DIR__, name * ".jl") - # Batch-last ordering for optimal column-major access - A = CUDA.rand(Float16, M, K, Batch) - B = CUDA.rand(Float16, K, N, Batch) - C = CUDA.zeros(Float16, M, N, Batch) + # Include file in anonymous module to avoid polluting namespace + mod = Module() + Base.include(mod, file) - # Reference result (batched matmul on CPU) - A_cpu = Float32.(Array(A)) - B_cpu = Float32.(Array(B)) - C_ref = zeros(Float32, M, N, Batch) - for b in 1:Batch - C_ref[:, :, b] = A_cpu[:, :, b] * B_cpu[:, :, b] - end + # Check required functions exist (unprefixed) + isdefined(mod, :prepare) || return nothing + isdefined(mod, :run) || return nothing - results = BenchmarkResult[] - flops = 2.0 * Batch * M * N * K - - # cuBLAS batched gemm (via loop) - fill!(C, 0) - cublas_f = () -> begin - for b in 1:Batch - mul!(view(C, :, :, b), view(A, :, :, b), view(B, :, :, b)) - end - end - cublas_f() - CUDA.synchronize() - @assert isapprox(Float32.(Array(C)), C_ref, rtol=1e-1, atol=1e-1) "cuBLAS incorrect!" - min_t, mean_t = benchmark_kernel(cublas_f) - push!(results, BenchmarkResult("cuBLAS (loop)", min_t, mean_t)) - - # cuTile - fill!(C, 0) - grid = (cld(M, BATCHMATMUL_TM), cld(N, BATCHMATMUL_TN), Batch) - cutile_f = () -> ct.launch(batchmatmul_cutile_kernel, grid, A, B, C, - ct.Constant(BATCHMATMUL_TM), ct.Constant(BATCHMATMUL_TN), - ct.Constant(BATCHMATMUL_TK)) - cutile_f() - CUDA.synchronize() - @assert isapprox(Float32.(Array(C)), C_ref, rtol=1e-1, atol=1e-1) "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Calculate TFLOPS - tflops_vals = [string(round(flops / (r.min_ms * 1e-3) / 1e12, digits=2), " TFLOPS") for r in results] - - print_table("Batch Matrix Multiplication (Float16)", results; extra_col=("Performance", tflops_vals)) - return results -end - -#============================================================================= - FFT (3-stage Cooley-Tukey) - Column-Major Version -=============================================================================# + # Prepare data with benchmark=true for larger sizes + data = mod.prepare(; benchmark=true) -# FFT kernel - 3-stage Cooley-Tukey decomposition (column-major) -# Uses swapped dimensions and right-multiply for column-major compatibility. -# Input/output layout: (D, BS, N2D) where D=2 for real/imag interleaving. -function fft_kernel( - x_packed_in::ct.TileArray{Float32, 3}, - y_packed_out::ct.TileArray{Float32, 3}, - W0::ct.TileArray{Float32, 3}, - W1::ct.TileArray{Float32, 3}, - W2::ct.TileArray{Float32, 3}, - T0::ct.TileArray{Float32, 3}, - T1::ct.TileArray{Float32, 3}, - n_const::ct.Constant{Int}, - f0_const::ct.Constant{Int}, - f1_const::ct.Constant{Int}, - f2_const::ct.Constant{Int}, - f0f1_const::ct.Constant{Int}, - f1f2_const::ct.Constant{Int}, - f0f2_const::ct.Constant{Int}, - bs_const::ct.Constant{Int}, - d_const::ct.Constant{Int}, - n2d_const::ct.Constant{Int} -) - N = n_const[] - F0 = f0_const[] - F1 = f1_const[] - F2 = f2_const[] - F0F1 = f0f1_const[] - F1F2 = f1f2_const[] - F0F2 = f0f2_const[] - BS = bs_const[] - D = d_const[] - N2D = n2d_const[] + # Run cuTile + result = mod.run(data; nruns=NRUNS, warmup=WARMUP) - bid = ct.bid(1) - - # Load input (D, BS, N2D) and reshape to (2, BS, N) - X_ri = ct.reshape(ct.load(x_packed_in, (1, bid, 1), (D, BS, N2D)), (2, BS, N)) - X_r = ct.reshape(ct.extract(X_ri, (1, 1, 1), (1, BS, N)), (BS, F1F2, F0)) - X_i = ct.reshape(ct.extract(X_ri, (2, 1, 1), (1, BS, N)), (BS, F1F2, F0)) - - # Load DFT matrices - W0_ri = ct.reshape(ct.load(W0, (1, 1, 1), (F0, F0, 2)), (F0, F0, 2)) - W0_r = ct.broadcast_to(ct.reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0)) - W0_i = ct.broadcast_to(ct.reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0)) - - W1_ri = ct.reshape(ct.load(W1, (1, 1, 1), (F1, F1, 2)), (F1, F1, 2)) - W1_r = ct.broadcast_to(ct.reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1)) - W1_i = ct.broadcast_to(ct.reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1)) - - W2_ri = ct.reshape(ct.load(W2, (1, 1, 1), (F2, F2, 2)), (F2, F2, 2)) - W2_r = ct.broadcast_to(ct.reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2)) - W2_i = ct.broadcast_to(ct.reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2)) - - # Load twiddle factors (column-major layout) - T0_ri = ct.reshape(ct.load(T0, (1, 1, 1), (F1F2, F0, 2)), (F1F2, F0, 2)) - T0_r = ct.reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (1, N)) - T0_i = ct.reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (1, N)) - - T1_ri = ct.reshape(ct.load(T1, (1, 1, 1), (F0F2, F1, 2)), (F0F2, F1, 2)) - T1_r = ct.reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (1, F0F2 * F1)) - T1_i = ct.reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (1, F0F2 * F1)) - - # Stage 0: F0-point DFT via right-multiply - X_r_ = X_r * W0_r - X_i * W0_i - X_i_ = X_r * W0_i + X_i * W0_r - - # Twiddle & Permute 0 - X_r_flat = ct.reshape(X_r_, (BS, N)) - X_i_flat = ct.reshape(X_i_, (BS, N)) - X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat - X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat - - X_r3 = ct.reshape(X_r2, (BS, F2, F1, F0)) - X_i3 = ct.reshape(X_i2, (BS, F2, F1, F0)) - X_r4 = ct.permute(X_r3, (1, 2, 4, 3)) - X_i4 = ct.permute(X_i3, (1, 2, 4, 3)) - X_r5 = ct.reshape(X_r4, (BS, F0F2, F1)) - X_i5 = ct.reshape(X_i4, (BS, F0F2, F1)) - - # Stage 1: F1-point DFT - X_r6 = X_r5 * W1_r - X_i5 * W1_i - X_i6 = X_r5 * W1_i + X_i5 * W1_r - - # Twiddle & Permute 1 - X_r_flat2 = ct.reshape(X_r6, (BS, N)) - X_i_flat2 = ct.reshape(X_i6, (BS, N)) - X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2 - X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2 - - X_r8 = ct.reshape(X_r7, (BS, F2, F0, F1)) - X_i8 = ct.reshape(X_i7, (BS, F2, F0, F1)) - X_r9 = ct.permute(X_r8, (1, 3, 4, 2)) - X_i9 = ct.permute(X_i8, (1, 3, 4, 2)) - X_r10 = ct.reshape(X_r9, (BS, F0F1, F2)) - X_i10 = ct.reshape(X_i9, (BS, F0F1, F2)) - - # Stage 2: F2-point DFT - X_r11 = X_r10 * W2_r - X_i10 * W2_i - X_i11 = X_r10 * W2_i + X_i10 * W2_r - - # Final output - X_r_final = ct.reshape(X_r11, (1, BS, N)) - X_i_final = ct.reshape(X_i11, (1, BS, N)) - - # Concatenate and Store - Y_ri = ct.reshape(ct.cat((X_r_final, X_i_final), 1), (D, BS, N2D)) - ct.store(y_packed_out, (1, bid, 1), Y_ri) - - return -end - -# Helper: Generate DFT matrix -function fft_dft_matrix(size::Int) - W = zeros(ComplexF32, size, size) - for i in 0:size-1, j in 0:size-1 - W[i+1, j+1] = exp(-2π * im * i * j / size) + # Extract times (handle times_fwd/times_bwd for layernorm) + if hasproperty(result, :times) + results = Dict{String, Vector{Float64}}("cuTile" => result.times) + elseif hasproperty(result, :times_fwd) + results = Dict{String, Vector{Float64}}( + "cuTile Fwd" => result.times_fwd, + "cuTile Bwd" => result.times_bwd + ) + else + return nothing end - result = zeros(Float32, size, size, 2) - result[:, :, 1] = Float32.(real.(W)) - result[:, :, 2] = Float32.(imag.(W)) - return result -end -# Twiddle factors T0 for column-major layout (F1F2, F0) -function fft_make_twiddles_T0(F0::Int, F1F2::Int, N::Int) - T0 = zeros(Float32, F1F2, F0, 2) - for j in 0:F1F2-1, i in 0:F0-1 - val = exp(-2π * im * i * j / N) - T0[j+1, i+1, 1] = Float32(real(val)) - T0[j+1, i+1, 2] = Float32(imag(val)) + # Run others if available + if isdefined(mod, :run_others) + others = mod.run_others(data; nruns=NRUNS, warmup=WARMUP) + merge!(results, others) end - return T0 -end -# Twiddle factors T1 for column-major layout (F0F2, F1) -function fft_make_twiddles_T1(F0::Int, F1::Int, F2::Int) - F0F2 = F0 * F2 - F1F2 = F1 * F2 - T1 = zeros(Float32, F0F2, F1, 2) - for k in 0:F0F2-1, j in 0:F1-1 - f2 = k % F2 - val = exp(-2π * im * j * f2 / F1F2) - T1[k+1, j+1, 1] = Float32(real(val)) - T1[k+1, j+1, 2] = Float32(imag(val)) - end - return T1 -end - -function fft_make_twiddles(factors::NTuple{3, Int}) - F0, F1, F2 = factors - N = F0 * F1 * F2 - F1F2 = F1 * F2 - W0 = fft_dft_matrix(F0) - W1 = fft_dft_matrix(F1) - W2 = fft_dft_matrix(F2) - T0 = fft_make_twiddles_T0(F0, F1F2, N) - T1 = fft_make_twiddles_T1(F0, F1, F2) - return (W0, W1, W2, T0, T1) -end - -function benchmark_fft() - println("\nBenchmarking FFT...") - BS, N = FFT_BATCH, FFT_SIZE - F0, F1, F2 = FFT_FACTORS - D = FFT_ATOM_PACKING_DIM - println(" Size: $BS batches × $N FFT ($(BS * N * 8 / 1e6) MB)") - - # Create complex input - CUDA.seed!(42) - input = CUDA.randn(ComplexF32, BS, N) - - # Reference result (FFTW) - reference = FFTW.fft(Array(input), 2) - - results = BenchmarkResult[] - - # Pre-compute twiddles (one-time CPU cost) - W0, W1, W2, T0, T1 = fft_make_twiddles(FFT_FACTORS) - W0_gpu, W1_gpu, W2_gpu = CuArray(W0), CuArray(W1), CuArray(W2) - T0_gpu, T1_gpu = CuArray(T0), CuArray(T1) - - # Pre-pack input (zero-copy) - N2D = N * 2 ÷ D - x_packed = reinterpret(reshape, Float32, input) - y_packed = CUDA.zeros(Float32, D, BS, N2D) - - # Kernel launch parameters - F0F1, F1F2, F0F2 = F0 * F1, F1 * F2, F0 * F2 - grid = (BS, 1, 1) - - # Kernel-only timing function - cutile_kernel_f = () -> ct.launch(fft_kernel, grid, - x_packed, y_packed, - W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, - ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), - ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), - ct.Constant(BS), ct.Constant(D), ct.Constant(N2D)) - - # Verify correctness - cutile_kernel_f() - CUDA.synchronize() - y_complex = reinterpret(reshape, ComplexF32, y_packed) - output = copy(y_complex) - @assert isapprox(Array(output), reference, rtol=1e-3) "cuTile FFT incorrect!" - - # Benchmark kernel only - min_t, mean_t = benchmark_kernel(cutile_kernel_f) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Performance metric: GFLOPS (5 * N * log2(N) per complex FFT) - flops_per_fft = 5.0 * N * log2(N) - total_flops = BS * flops_per_fft - gflops = [string(round(total_flops / (r.min_ms * 1e-3) / 1e9, digits=1), " GFLOPS") for r in results] - - print_table("FFT (ComplexF32)", results; extra_col=("Performance", gflops)) return results end @@ -840,20 +95,35 @@ end function main() println("=" ^ 60) - println(" cuTile.jl Comprehensive Benchmarks") + println(" cuTile.jl Benchmarks") println("=" ^ 60) println() println("Configuration:") println(" Runs: $NRUNS (+ $WARMUP warmup)") println(" GPU: ", CUDA.name(CUDA.device())) - println() - vadd_results = benchmark_vadd() - transpose_results = benchmark_transpose() - matmul_results = benchmark_matmul() - layernorm_results = benchmark_layernorm() - batchmatmul_results = benchmark_batchmatmul() - fft_results = benchmark_fft() + for name in discover_benchmarks() + println("\nBenchmarking $name...") + + results = run_benchmark(name) + if results === nothing + println(" (skipped - no prepare/run functions)") + continue + end + + # Convert to BenchmarkResult for printing + benchmark_results = BenchmarkResult[] + for (impl_name, times) in results + min_t = minimum(times) + mean_t = sum(times) / length(times) + push!(benchmark_results, BenchmarkResult(impl_name, min_t, mean_t)) + end + + # Sort by min time + sort!(benchmark_results, by=r -> r.min_ms) + + print_table(name, benchmark_results) + end println() println("=" ^ 60) diff --git a/examples/benchmarks.py b/examples/benchmarks.py index fa8ac43..5ff588a 100644 --- a/examples/benchmarks.py +++ b/examples/benchmarks.py @@ -1,16 +1,12 @@ #!/usr/bin/env python3 -""" -Comprehensive benchmarks for cuTile Python -Compares: CuPy, PyTorch, cuTile -Kernels: vadd, transpose, matmul -""" +# EXCLUDE FROM TESTING +# +# Generic benchmark runner for cuTile Python examples +# Discovers and benchmarks all examples in the examples/ directory +import os +import importlib.util import cupy as cp -import numpy as np -import torch -import cuda.tile as ct -import math -from math import ceil, log2 #============================================================================= # Configuration @@ -19,25 +15,6 @@ NRUNS = 10 WARMUP = 3 -# Data sizes - large enough to saturate GPU and minimize launch overhead -VADD_SIZE = 2**27 # 512 MB (128M elements) -TRANSPOSE_DIM = 8192 # 8192x8192 = 268 MB -MATMUL_DIM = 4096 # 4096x4096x4096 - -# FFT sizes - must match Julia configuration -FFT_BATCH = 64 -FFT_SIZE = 512 -FFT_FACTORS = (8, 8, 8) -FFT_ATOM_PACKING_DIM = 2 - -# Tile sizes -VADD_TILE = 1024 -TRANSPOSE_TILE_M = 64 -TRANSPOSE_TILE_N = 64 -MATMUL_TM = 64 -MATMUL_TN = 64 -MATMUL_TK = 64 - #============================================================================= # Benchmark Utilities #============================================================================= @@ -49,684 +26,76 @@ def __init__(self, name: str, min_ms: float, mean_ms: float): self.mean_ms = mean_ms -def benchmark_cupy(f, nruns: int = NRUNS, warmup: int = WARMUP): - """Benchmark a CuPy function using CUDA events.""" - stream = cp.cuda.get_current_stream() - - # Warmup - for _ in range(warmup): - f() - cp.cuda.runtime.deviceSynchronize() - - # Benchmark - times = [] - for _ in range(nruns): - start = cp.cuda.Event() - end = cp.cuda.Event() - - start.record(stream) - f() - end.record(stream) - end.synchronize() - - elapsed_ms = cp.cuda.get_elapsed_time(start, end) - times.append(elapsed_ms) - - return min(times), sum(times) / len(times) - - -def benchmark_torch(f, nruns: int = NRUNS, warmup: int = WARMUP): - """Benchmark a PyTorch function using CUDA events.""" - # Warmup - for _ in range(warmup): - f() - torch.cuda.synchronize() - - # Benchmark - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - times = [] - for _ in range(nruns): - start_event.record() - f() - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - times.append(elapsed_ms) - - return min(times), sum(times) / len(times) - - -def print_table(title: str, results: list, extra_col=None): +def print_table(title: str, results: list): """Print formatted benchmark results table.""" print() print("=" * 60) print(f" {title}") print("=" * 60) - - if extra_col: - print(f"{'Implementation':<20}{'Min (ms)':<12}{'Mean (ms)':<12}{extra_col[0]}") - print("-" * 60) - for i, r in enumerate(results): - print(f"{r.name:<20}{r.min_ms:<12.3f}{r.mean_ms:<12.3f}{extra_col[1][i]}") - else: - print(f"{'Implementation':<20}{'Min (ms)':<12}Mean (ms)") - print("-" * 60) - for r in results: - print(f"{r.name:<20}{r.min_ms:<12.3f}{r.mean_ms:.3f}") + print(f"{'Implementation':<20}{'Min (ms)':<12}Mean (ms)") + print("-" * 60) + for r in results: + print(f"{r.name:<20}{r.min_ms:<12.3f}{r.mean_ms:.3f}") print("-" * 60) #============================================================================= -# Vector Addition -#============================================================================= - -@ct.kernel -def vadd_cutile_kernel(a, b, c, tile_size: ct.Constant[int]): - pid = ct.bid(0) - tile_a = ct.load(a, index=(pid,), shape=(tile_size,)) - tile_b = ct.load(b, index=(pid,), shape=(tile_size,)) - result = tile_a + tile_b - ct.store(c, index=(pid,), tile=result) - - -def benchmark_vadd(): - print("\nBenchmarking Vector Addition...") - print(f" Size: {VADD_SIZE} elements ({VADD_SIZE * 4 / 1e6} MB)") - - # CuPy arrays - a_cp = cp.random.rand(VADD_SIZE).astype(np.float32) - b_cp = cp.random.rand(VADD_SIZE).astype(np.float32) - c_cp = cp.zeros(VADD_SIZE, dtype=np.float32) - - # PyTorch tensors (from same data) - a_torch = torch.as_tensor(a_cp, device='cuda') - b_torch = torch.as_tensor(b_cp, device='cuda') - c_torch = torch.zeros(VADD_SIZE, dtype=torch.float32, device='cuda') - - expected = cp.asnumpy(a_cp) + cp.asnumpy(b_cp) - results = [] - - # CuPy - def cupy_vadd(): - cp.add(a_cp, b_cp, out=c_cp) - - cupy_vadd() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(c_cp), expected), "CuPy incorrect!" - min_t, mean_t = benchmark_cupy(cupy_vadd) - results.append(BenchmarkResult("CuPy", min_t, mean_t)) - - # PyTorch - def torch_vadd(): - torch.add(a_torch, b_torch, out=c_torch) - - torch_vadd() - torch.cuda.synchronize() - assert np.allclose(c_torch.cpu().numpy(), expected), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_vadd) - results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - - # cuTile - grid = (ct.cdiv(VADD_SIZE, VADD_TILE), 1, 1) - stream = cp.cuda.get_current_stream() - c_cp.fill(0) - - def cutile_vadd(): - ct.launch(stream, grid, vadd_cutile_kernel, (a_cp, b_cp, c_cp, VADD_TILE)) - - cutile_vadd() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(c_cp), expected), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_vadd) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) - - # Calculate bandwidth - bytes_transferred = 3 * VADD_SIZE * 4 # 2 reads + 1 write, float32 - bandwidths = [f"{bytes_transferred / (r.min_ms / 1000) / 1e9:.1f} GB/s" for r in results] - - print_table("Vector Addition (Float32)", results, extra_col=("Bandwidth", bandwidths)) - return results - - -#============================================================================= -# Matrix Transpose -#============================================================================= - -@ct.kernel -def transpose_cutile_kernel(input, output, tile_m: ct.Constant[int], tile_n: ct.Constant[int]): - pid_m = ct.bid(0) - pid_n = ct.bid(1) - tile = ct.load(input, index=(pid_m, pid_n), shape=(tile_m, tile_n)) - tile_t = ct.transpose(tile) - ct.store(output, index=(pid_n, pid_m), tile=tile_t) - - -def benchmark_transpose(): - print("\nBenchmarking Matrix Transpose...") - M, N = TRANSPOSE_DIM, TRANSPOSE_DIM - print(f" Size: {M}x{N} ({M * N * 4 / 1e6} MB)") - - # CuPy arrays - input_cp = cp.random.rand(M, N).astype(np.float32) - output_cp = cp.zeros((N, M), dtype=np.float32) - - # PyTorch tensors - input_torch = torch.as_tensor(input_cp, device='cuda') - output_torch = torch.zeros(N, M, dtype=torch.float32, device='cuda') - - expected = cp.asnumpy(input_cp).T - results = [] - - # CuPy - def cupy_transpose(): - cp.copyto(output_cp, input_cp.T) - - cupy_transpose() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(output_cp), expected), "CuPy incorrect!" - min_t, mean_t = benchmark_cupy(cupy_transpose) - results.append(BenchmarkResult("CuPy", min_t, mean_t)) - - # PyTorch - output_torch.fill_(0) - def torch_transpose(): - output_torch.copy_(input_torch.T) - - torch_transpose() - torch.cuda.synchronize() - assert np.allclose(output_torch.cpu().numpy(), expected), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_transpose) - results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - - # cuTile - output_cp.fill(0) - grid = (ct.cdiv(M, TRANSPOSE_TILE_M), ct.cdiv(N, TRANSPOSE_TILE_N), 1) - stream = cp.cuda.get_current_stream() - - def cutile_transpose(): - ct.launch(stream, grid, transpose_cutile_kernel, - (input_cp, output_cp, TRANSPOSE_TILE_M, TRANSPOSE_TILE_N)) - - cutile_transpose() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(output_cp), expected), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_transpose) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) - - # Calculate bandwidth - bytes_transferred = 2 * M * N * 4 # read + write, float32 - bandwidths = [f"{bytes_transferred / (r.min_ms / 1000) / 1e9:.1f} GB/s" for r in results] - - print_table("Matrix Transpose (Float32)", results, extra_col=("Bandwidth", bandwidths)) - return results - - -#============================================================================= -# Matrix Multiplication -#============================================================================= - -def swizzle_2d(M, N, tm, tn, GROUP_SIZE_M): - """Get the global IDs of the current CUDA block in a 1D grid.""" - bid = ct.bid(0) - num_bid_m = ct.cdiv(M, tm) - num_bid_n = ct.cdiv(N, tn) - num_bid_in_group = GROUP_SIZE_M * num_bid_n - group_id = bid // num_bid_in_group - first_bid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_bid_m - first_bid_m, GROUP_SIZE_M) - bid_m = first_bid_m + (bid % group_size_m) - bid_n = (bid % num_bid_in_group) // group_size_m - return bid_m, bid_n - - -@ct.kernel -def matmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int]): - GROUP_SIZE_M = 8 - M = A.shape[0] - N = B.shape[1] - bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M) - - num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk)) - accumulator = ct.full((tm, tn), 0, dtype=ct.float32) - - # Convert fp32 to tf32 for tensor cores - dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype - - for k in range(num_tiles_k): - a = ct.load(A, index=(bidx, k), shape=(tm, tk)).astype(dtype) - b = ct.load(B, index=(k, bidy), shape=(tk, tn)).astype(dtype) - accumulator = ct.mma(a, b, accumulator) - - accumulator = ct.astype(accumulator, C.dtype) - ct.store(C, index=(bidx, bidy), tile=accumulator) - - -def benchmark_matmul(): - print("\nBenchmarking Matrix Multiplication...") - M, N, K = MATMUL_DIM, MATMUL_DIM, MATMUL_DIM - print(f" Size: {M}x{K} * {K}x{N}") - - # CuPy arrays (used for cuTile and cuBLAS) - A_cp = cp.random.randn(M, K, dtype=np.float32) - B_cp = cp.random.randn(K, N, dtype=np.float32) - C_cp = cp.zeros((M, N), dtype=np.float32) - - # PyTorch tensors (from same data for fair comparison) - torch.set_float32_matmul_precision("high") # Enable TF32 - A_torch = torch.as_tensor(A_cp, device='cuda') - B_torch = torch.as_tensor(B_cp, device='cuda') - C_torch = torch.zeros(M, N, dtype=torch.float32, device='cuda') - - # Compute reference using CuPy (cuBLAS) for correctness checks - # This avoids TF32 precision differences between PyTorch and CuPy - C_ref_cp = cp.matmul(A_cp, B_cp) - cp.cuda.runtime.deviceSynchronize() - C_ref = cp.asnumpy(C_ref_cp) - - results = [] - flops = 2.0 * M * N * K - - # PyTorch - def torch_matmul(): - torch.matmul(A_torch, B_torch, out=C_torch) - - torch_matmul() - torch.cuda.synchronize() - # PyTorch TF32 vs CuPy cuBLAS may differ, use relaxed tolerance - assert np.allclose(C_torch.cpu().numpy(), C_ref, rtol=1e-1, atol=1e-1), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_matmul) - results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - - # CuPy (uses cuBLAS) - this is the reference - def cupy_matmul(): - cp.matmul(A_cp, B_cp, out=C_cp) - - cupy_matmul() - cp.cuda.runtime.deviceSynchronize() - min_t, mean_t = benchmark_cupy(cupy_matmul) - results.append(BenchmarkResult("CuPy (cuBLAS)", min_t, mean_t)) - - # cuTile - C_cp.fill(0) - grid_m = ceil(M / MATMUL_TM) - grid_n = ceil(N / MATMUL_TN) - grid = (grid_m * grid_n, 1, 1) - stream = cp.cuda.get_current_stream() - - def cutile_matmul(): - ct.launch(stream, grid, matmul_cutile_kernel, - (A_cp, B_cp, C_cp, MATMUL_TM, MATMUL_TN, MATMUL_TK)) - - cutile_matmul() - cp.cuda.runtime.deviceSynchronize() - # TF32 has reduced precision compared to FP32 cuBLAS - assert np.allclose(cp.asnumpy(C_cp), C_ref, rtol=1e-1, atol=1e-1), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_matmul) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) - - # Calculate TFLOPS - tflops_vals = [f"{flops / (r.min_ms * 1e-3) / 1e12:.2f} TFLOPS" for r in results] - - print_table("Matrix Multiplication (Float32, TF32 cores)", results, extra_col=("Performance", tflops_vals)) - return results - - -#============================================================================= -# Layer Normalization -#============================================================================= - -LAYERNORM_M = 4096 -LAYERNORM_N = 4096 -LAYERNORM_TILE_N = 1024 -LAYERNORM_EPS = 1e-5 - -# Batch matmul sizes -BATCHMATMUL_BATCH = 8 -BATCHMATMUL_M = 1024 -BATCHMATMUL_K = 512 -BATCHMATMUL_N = 2048 -BATCHMATMUL_TM = 128 -BATCHMATMUL_TN = 256 -BATCHMATMUL_TK = 64 - - -@ct.kernel -def layernorm_cutile_kernel(X, W, B, Y, Mean, Rstd, eps: ct.Constant[float], TILE_N: ct.Constant[int]): - bid_m = ct.bid(0) - num_tiles = ct.num_tiles(X, axis=1, shape=(1, TILE_N)) - N = X.shape[1] - - # Compute mean - mean = ct.full((1, TILE_N), 0, dtype=ct.float32) - for j in range(num_tiles): - tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) - mean += tx - mean = ct.sum(mean, axis=1) / N - ct.store(Mean, index=(bid_m,), tile=mean) - - # Compute variance - var = ct.full((1, TILE_N), 0, dtype=ct.float32) - for j in range(num_tiles): - tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) - mask = (j * TILE_N + ct.arange(TILE_N, dtype=ct.int32)) < N - centered_tx = ct.where(mask, tx - mean, 0) - var += centered_tx ** 2 - var = ct.sum(var, axis=1) / N - rstd = 1 / ct.sqrt(var + eps) - ct.store(Rstd, index=(bid_m,), tile=rstd) - - # Normalize and apply affine transformation - for j in range(num_tiles): - tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) - tw = ct.load(W, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) - tb = ct.load(B, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) - ty = (tx - mean) * rstd - ty = ty * tw + tb - ct.store(Y, index=(bid_m, j), tile=ty.astype(Y.dtype)) - - -def benchmark_layernorm(): - print("\nBenchmarking Layer Normalization...") - M, N = LAYERNORM_M, LAYERNORM_N - print(f" Size: {M}x{N} ({M * N * 4 / 1e6} MB)") - - # CuPy arrays - X_cp = -2.3 + 0.5 * cp.random.randn(M, N).astype(np.float32) - W_cp = cp.random.randn(N).astype(np.float32) - B_cp = cp.random.randn(N).astype(np.float32) - Y_cp = cp.zeros((M, N), dtype=np.float32) - Mean_cp = cp.zeros(M, dtype=np.float32) - Rstd_cp = cp.zeros(M, dtype=np.float32) - - # PyTorch tensors - X_torch = torch.as_tensor(X_cp, device='cuda') - W_torch = torch.as_tensor(W_cp, device='cuda') - B_torch = torch.as_tensor(B_cp, device='cuda') - Y_torch = torch.zeros(M, N, dtype=torch.float32, device='cuda') - - # Reference result - X_np = cp.asnumpy(X_cp) - W_np = cp.asnumpy(W_cp) - B_np = cp.asnumpy(B_cp) - expected_mean = np.mean(X_np, axis=1, keepdims=True) - expected_var = np.mean((X_np - expected_mean) ** 2, axis=1, keepdims=True) - expected_rstd = 1.0 / np.sqrt(expected_var + LAYERNORM_EPS) - normalized = (X_np - expected_mean) * expected_rstd - expected_Y = normalized * W_np + B_np - - results = [] - - # PyTorch F.layer_norm - def torch_layernorm(): - nonlocal Y_torch - Y_torch = torch.nn.functional.layer_norm(X_torch, (N,), W_torch, B_torch, LAYERNORM_EPS) - - torch_layernorm() - torch.cuda.synchronize() - assert np.allclose(Y_torch.cpu().numpy(), expected_Y, rtol=1e-2, atol=1e-2), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_layernorm) - results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - - # cuTile - Y_cp.fill(0) - Mean_cp.fill(0) - Rstd_cp.fill(0) - stream = cp.cuda.get_current_stream() - - def cutile_layernorm(): - ct.launch(stream, (M,), layernorm_cutile_kernel, - (X_cp, W_cp, B_cp, Y_cp, Mean_cp, Rstd_cp, LAYERNORM_EPS, LAYERNORM_TILE_N)) - - cutile_layernorm() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(Y_cp), expected_Y, rtol=1e-2, atol=1e-2), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_layernorm) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) - - # Calculate bandwidth (rough estimate: 3 reads of X + W + B, 1 write of Y) - bytes_transferred = (3 * M * N + N + N + M * N) * 4 - bandwidths = [f"{bytes_transferred / (r.min_ms / 1000) / 1e9:.1f} GB/s" for r in results] - - print_table("Layer Normalization (Float32)", results, extra_col=("Bandwidth", bandwidths)) - return results - - -#============================================================================= -# Batch Matrix Multiplication -#============================================================================= - -@ct.kernel -def batchmatmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int]): - """CuTile kernel for batch matrix multiplication - A has shape (Batch, M, K), B has shape (Batch, K, N) and C has shape (Batch, M, N) - Grid: (Batch, M_tiles, N_tiles) - """ - pid_batch = ct.bid(0) - bidx = ct.bid(1) - bidy = ct.bid(2) - - num_k_tiles = ct.cdiv(A.shape[2], tk) - accumulator = ct.full((tm, tn), 0.0, dtype=ct.float32) - zero_pad = ct.PaddingMode.ZERO - - for k in range(num_k_tiles): - a = ct.load(A, index=(pid_batch, bidx, k), shape=(1, tm, tk), padding_mode=zero_pad) - a = ct.reshape(a, (tm, tk)) - - b = ct.load(B, index=(pid_batch, k, bidy), shape=(1, tk, tn), padding_mode=zero_pad) - b = ct.reshape(b, (tk, tn)) - - accumulator = ct.mma(a, b, acc=accumulator) - - result = ct.astype(accumulator, C.dtype) - result_3d = ct.reshape(result, (1, tm, tn)) - ct.store(C, index=(pid_batch, bidx, bidy), tile=result_3d) - - -def benchmark_batchmatmul(): - print("\nBenchmarking Batch Matrix Multiplication...") - Batch, M, K, N = BATCHMATMUL_BATCH, BATCHMATMUL_M, BATCHMATMUL_K, BATCHMATMUL_N - print(f" Size: ({Batch} x {M} x {K}) @ ({Batch} x {K} x {N}), Float16") - - # PyTorch tensors - A_torch = torch.randn(Batch, M, K, dtype=torch.float16, device='cuda') - B_torch = torch.randn(Batch, K, N, dtype=torch.float16, device='cuda') - C_torch = torch.zeros(Batch, M, N, dtype=torch.float16, device='cuda') - - # CuPy arrays (from same data) - A_cp = cp.asarray(A_torch) - B_cp = cp.asarray(B_torch) - C_cp = cp.zeros((Batch, M, N), dtype=np.float16) - - # Reference result (PyTorch bmm in fp32 for accuracy) - C_ref = torch.bmm(A_torch.float(), B_torch.float()).cpu().numpy() - - results = [] - flops = 2.0 * Batch * M * N * K - - # PyTorch bmm - def torch_bmm(): - torch.bmm(A_torch, B_torch, out=C_torch) - - torch_bmm() - torch.cuda.synchronize() - assert np.allclose(C_torch.float().cpu().numpy(), C_ref, rtol=1e-1, atol=1e-1), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_bmm) - results.append(BenchmarkResult("PyTorch bmm", min_t, mean_t)) - - # cuTile - C_cp.fill(0) - grid = (Batch, ceil(M / BATCHMATMUL_TM), ceil(N / BATCHMATMUL_TN)) - stream = cp.cuda.get_current_stream() - - def cutile_bmm(): - ct.launch(stream, grid, batchmatmul_cutile_kernel, - (A_cp, B_cp, C_cp, BATCHMATMUL_TM, BATCHMATMUL_TN, BATCHMATMUL_TK)) - - cutile_bmm() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(C_cp).astype(np.float32), C_ref, rtol=1e-1, atol=1e-1), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_bmm) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) - - # Calculate TFLOPS - tflops_vals = [f"{flops / (r.min_ms * 1e-3) / 1e12:.2f} TFLOPS" for r in results] - - print_table("Batch Matrix Multiplication (Float16)", results, extra_col=("Performance", tflops_vals)) - return results - - -#============================================================================= -# FFT (3-stage Cooley-Tukey) +# Benchmark Discovery & Execution #============================================================================= -@ct.kernel -def fft_kernel(x_packed_in, y_packed_out, - W0, W1, W2, T0, T1, - N: ct.Constant[int], F0: ct.Constant[int], F1: ct.Constant[int], F2: ct.Constant[int], - BS: ct.Constant[int], D: ct.Constant[int]): - """cuTile kernel for 3-stage Cooley-Tukey FFT.""" - F0F1 = F0 * F1 - F1F2 = F1 * F2 - F0F2 = F0 * F2 - - bid = ct.bid(0) - - # Load input, reshape to separate real/imag - X_ri = ct.reshape(ct.load(x_packed_in, index=(bid, 0, 0), shape=(BS, N * 2 // D, D)), (BS, N, 2)) - X_r = ct.reshape(ct.extract(X_ri, index=(0, 0, 0), shape=(BS, N, 1)), (BS, F0, F1, F2)) - X_i = ct.reshape(ct.extract(X_ri, index=(0, 0, 1), shape=(BS, N, 1)), (BS, F0, F1, F2)) - - # Load W matrices (rotation matrices) - W0_ri = ct.reshape(ct.load(W0, index=(0, 0, 0), shape=(F0, F0, 2)), (F0, F0, 2)) - W0_r = ct.reshape(ct.extract(W0_ri, index=(0, 0, 0), shape=(F0, F0, 1)), (1, F0, F0)) - W0_i = ct.reshape(ct.extract(W0_ri, index=(0, 0, 1), shape=(F0, F0, 1)), (1, F0, F0)) - - W1_ri = ct.reshape(ct.load(W1, index=(0, 0, 0), shape=(F1, F1, 2)), (F1, F1, 2)) - W1_r = ct.reshape(ct.extract(W1_ri, index=(0, 0, 0), shape=(F1, F1, 1)), (1, F1, F1)) - W1_i = ct.reshape(ct.extract(W1_ri, index=(0, 0, 1), shape=(F1, F1, 1)), (1, F1, F1)) - - W2_ri = ct.reshape(ct.load(W2, index=(0, 0, 0), shape=(F2, F2, 2)), (F2, F2, 2)) - W2_r = ct.reshape(ct.extract(W2_ri, index=(0, 0, 0), shape=(F2, F2, 1)), (1, F2, F2)) - W2_i = ct.reshape(ct.extract(W2_ri, index=(0, 0, 1), shape=(F2, F2, 1)), (1, F2, F2)) - - # Load T matrices (twiddle factors) - T0_ri = ct.reshape(ct.load(T0, index=(0, 0, 0), shape=(F0, F1F2, 2)), (F0, F1F2, 2)) - T0_r = ct.reshape(ct.extract(T0_ri, index=(0, 0, 0), shape=(F0, F1F2, 1)), (N, 1)) - T0_i = ct.reshape(ct.extract(T0_ri, index=(0, 0, 1), shape=(F0, F1F2, 1)), (N, 1)) - - T1_ri = ct.reshape(ct.load(T1, index=(0, 0, 0), shape=(F1, F2, 2)), (F1, F2, 2)) - T1_r = ct.reshape(ct.extract(T1_ri, index=(0, 0, 0), shape=(F1, F2, 1)), (F1F2, 1)) - T1_i = ct.reshape(ct.extract(T1_ri, index=(0, 0, 1), shape=(F1, F2, 1)), (F1F2, 1)) +def discover_benchmarks(): + """Discover all benchmark-enabled examples in the examples directory.""" + examples = [] + examples_dir = os.path.dirname(__file__) + for file in sorted(os.listdir(examples_dir)): + if not file.endswith(".py"): + continue + if file == "benchmarks.py": + continue + name = file.replace(".py", "") + examples.append(name) + return examples - # CT0: Contract over F0 dimension - X_r = ct.reshape(X_r, (BS, F0, F1F2)) - X_i = ct.reshape(X_i, (BS, F0, F1F2)) - X_r_ = ct.reshape(ct.matmul(W0_r, X_r) - ct.matmul(W0_i, X_i), (BS, N, 1)) - X_i_ = ct.reshape(ct.matmul(W0_i, X_r) + ct.matmul(W0_r, X_i), (BS, N, 1)) - # Twiddle & Permute 0 - X_r = T0_r * X_r_ - T0_i * X_i_ - X_i = T0_i * X_r_ + T0_r * X_i_ - X_r = ct.permute(ct.reshape(X_r, (BS, F0, F1, F2)), (0, 2, 3, 1)) - X_i = ct.permute(ct.reshape(X_i, (BS, F0, F1, F2)), (0, 2, 3, 1)) +def run_benchmark(name: str): + """Load and run benchmark for a given example.""" + examples_dir = os.path.dirname(__file__) + file_path = os.path.join(examples_dir, f"{name}.py") - # CT1: Contract over F1 dimension - X_r = ct.reshape(X_r, (BS, F1, F0F2)) - X_i = ct.reshape(X_i, (BS, F1, F0F2)) - X_r_ = ct.reshape(ct.matmul(W1_r, X_r) - ct.matmul(W1_i, X_i), (BS, F1F2, F0)) - X_i_ = ct.reshape(ct.matmul(W1_i, X_r) + ct.matmul(W1_r, X_i), (BS, F1F2, F0)) + # Import module dynamically + spec = importlib.util.spec_from_file_location(name, file_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) - # Twiddle & Permute 1 - X_r = T1_r * X_r_ - T1_i * X_i_ - X_i = T1_i * X_r_ + T1_r * X_i_ - X_r = ct.permute(ct.reshape(X_r, (BS, F1, F2, F0)), (0, 2, 3, 1)) - X_i = ct.permute(ct.reshape(X_i, (BS, F1, F2, F0)), (0, 2, 3, 1)) + # Check required functions exist (unprefixed) + prepare_fn = getattr(mod, "prepare", None) + run_fn = getattr(mod, "run", None) + if not prepare_fn or not run_fn: + return None - # CT2: Contract over F2 dimension - X_r = ct.reshape(X_r, (BS, F2, F0F1)) - X_i = ct.reshape(X_i, (BS, F2, F0F1)) - X_r_ = ct.matmul(W2_r, X_r) - ct.matmul(W2_i, X_i) - X_i_ = ct.matmul(W2_i, X_r) + ct.matmul(W2_r, X_i) + # Prepare data with benchmark=True for larger sizes + data = prepare_fn(benchmark=True) - # Final Permutation - X_r = ct.permute(ct.reshape(X_r_, (BS, F2, F0, F1)), (0, 1, 3, 2)) - X_i = ct.permute(ct.reshape(X_i_, (BS, F2, F0, F1)), (0, 1, 3, 2)) - X_r = ct.reshape(X_r, (BS, N, 1)) - X_i = ct.reshape(X_i, (BS, N, 1)) + # Run cuTile + result = run_fn(data, nruns=NRUNS, warmup=WARMUP) - # Concatenate and Store - Y_ri = ct.reshape(ct.cat((X_r, X_i), axis=-1), (BS, N * 2 // D, D)) - ct.store(y_packed_out, index=(bid, 0, 0), tile=Y_ri) - - -def fft_twiddles(rows: int, cols: int, factor: int, device, precision): - """Generate DFT twiddle factors.""" - I, J = torch.meshgrid(torch.arange(rows, device=device), - torch.arange(cols, device=device), indexing='ij') - W_complex = torch.exp(-2 * math.pi * 1j * (I * J) / factor) - return torch.view_as_real(W_complex).to(precision).contiguous() - - -def fft_make_twiddles(factors, precision, device): - """Generate W and T matrices for FFT.""" - F0, F1, F2 = factors - N = F0 * F1 * F2 - F1F2 = F1 * F2 - W0 = fft_twiddles(F0, F0, F0, device, precision) - W1 = fft_twiddles(F1, F1, F1, device, precision) - W2 = fft_twiddles(F2, F2, F2, device, precision) - T0 = fft_twiddles(F0, F1F2, N, device, precision) - T1 = fft_twiddles(F1, F2, F1F2, device, precision) - return (W0, W1, W2, T0, T1) - - -def benchmark_fft(): - print("\nBenchmarking FFT...") - BS, N = FFT_BATCH, FFT_SIZE - F0, F1, F2 = FFT_FACTORS - D = FFT_ATOM_PACKING_DIM - print(f" Size: {BS} batches × {N} FFT ({BS * N * 8 / 1e6} MB)") - - # PyTorch complex input - input_torch = torch.randn(BS, N, dtype=torch.complex64, device='cuda') - - # Reference result - reference = torch.fft.fft(input_torch, dim=-1) - torch.cuda.synchronize() - - results = [] - - # Pre-compute everything outside timing loop - x_ri = torch.view_as_real(input_torch) - x_packed = x_ri.reshape(BS, N * 2 // D, D).contiguous() - W0, W1, W2, T0, T1 = fft_make_twiddles(FFT_FACTORS, input_torch.real.dtype, input_torch.device) - y_packed = torch.empty_like(x_packed) - grid = (BS, 1, 1) - - # Kernel launch function - def fft_launch(): - ct.launch(torch.cuda.current_stream(), grid, fft_kernel, - (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, BS, D)) - - # Verify correctness - fft_launch() - torch.cuda.synchronize() - output = torch.view_as_complex(y_packed.reshape(BS, N, 2)) - assert torch.allclose(output, reference, rtol=1e-3, atol=1e-3), "cuTile FFT incorrect!" - - min_t, mean_t = benchmark_torch(fft_launch) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) + # Extract times (handle times_fwd/times_bwd for layernorm) + if "times" in result: + results = {"cuTile": result["times"]} + elif "times_fwd" in result: + results = { + "cuTile Fwd": result["times_fwd"], + "cuTile Bwd": result["times_bwd"] + } + else: + return None - # Calculate GFLOPS (5 * N * log2(N) ops per complex FFT) - flops_per_fft = 5.0 * N * log2(N) - total_flops = BS * flops_per_fft - gflops = [f"{total_flops / (r.min_ms * 1e-3) / 1e9:.1f} GFLOPS" for r in results] + # Run others if available + run_others_fn = getattr(mod, "run_others", None) + if run_others_fn: + others = run_others_fn(data, nruns=NRUNS, warmup=WARMUP) + results.update(others) - print_table("FFT (ComplexF32)", results, extra_col=("Performance", gflops)) return results @@ -735,21 +104,35 @@ def fft_launch(): #============================================================================= def main(): + import torch # For GPU name + print("=" * 60) - print(" cuTile Python Comprehensive Benchmarks") + print(" cuTile Python Benchmarks") print("=" * 60) print() print("Configuration:") print(f" Runs: {NRUNS} (+ {WARMUP} warmup)") print(f" GPU: {torch.cuda.get_device_name()}") - print() - vadd_results = benchmark_vadd() - transpose_results = benchmark_transpose() - matmul_results = benchmark_matmul() - layernorm_results = benchmark_layernorm() - batchmatmul_results = benchmark_batchmatmul() - fft_results = benchmark_fft() + for name in discover_benchmarks(): + print(f"\nBenchmarking {name}...") + + results = run_benchmark(name) + if results is None: + print(" (skipped - no prepare/run functions)") + continue + + # Convert to BenchmarkResult for printing + benchmark_results = [] + for impl_name, times in results.items(): + min_t = min(times) + mean_t = sum(times) / len(times) + benchmark_results.append(BenchmarkResult(impl_name, min_t, mean_t)) + + # Sort by min time + benchmark_results.sort(key=lambda r: r.min_ms) + + print_table(name, benchmark_results) print() print("=" * 60) diff --git a/examples/fft.jl b/examples/fft.jl index 2e5495a..be1c87b 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -211,53 +211,113 @@ function make_twiddles(factors::NTuple{3, Int}) return (W0, W1, W2, T0, T1) end -# Main FFT function -function cutile_fft(x::CuMatrix{ComplexF32}, factors::NTuple{3, Int}; atom_packing_dim::Int=2) - BS = size(x, 1) - N = size(x, 2) - F0, F1, F2 = factors - - @assert F0 * F1 * F2 == N "Factors must multiply to N" - @assert (N * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim" +#============================================================================= + Example harness +=============================================================================# + +function prepare(; benchmark::Bool=false, + batch::Int=benchmark ? 64 : 2, + n::Int=benchmark ? 512 : 8, + factors::NTuple{3,Int}=benchmark ? (8, 8, 8) : (2, 2, 2), + atom_packing_dim::Int=2) + @assert factors[1] * factors[2] * factors[3] == n "Factors must multiply to N" + @assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim" - D = atom_packing_dim + CUDA.seed!(42) + input = CUDA.randn(ComplexF32, batch, n) - # Generate W and T matrices (CPU, one-time cost) + # Pre-compute twiddles (one-time CPU cost) W0, W1, W2, T0, T1 = make_twiddles(factors) - - # Upload to GPU W0_gpu = CuArray(W0) W1_gpu = CuArray(W1) W2_gpu = CuArray(W2) T0_gpu = CuArray(T0) T1_gpu = CuArray(T1) - # Pack input: complex (BS, N) → real (D, BS, N2D) - zero-copy view - N2D = N * 2 ÷ D - x_packed = reinterpret(reshape, Float32, x) # (2, BS, N) = (D, BS, N2D) + # Pack input + D = atom_packing_dim + N2D = n * 2 ÷ D + x_packed = reinterpret(reshape, Float32, input) + y_packed = CuArray{Float32}(undef, D, batch, N2D) + + return (; + input, x_packed, y_packed, + W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, + factors, batch, n, D, N2D + ) +end - # Allocate output - y_packed = CUDA.zeros(Float32, D, BS, N2D) +function run(data; nruns::Int=1, warmup::Int=0) + (; x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, + factors, batch, n, D, N2D) = data - # Launch kernel + F0, F1, F2 = factors F0F1 = F0 * F1 F1F2 = F1 * F2 F0F2 = F0 * F2 - grid = (BS, 1, 1) - ct.launch(fft_kernel, grid, - x_packed, y_packed, - W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, - ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), - ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), - ct.Constant(BS), ct.Constant(D), ct.Constant(N2D)) - - # Unpack output: real (D, BS, N2D) → complex (BS, N) - zero-copy view + grid = (batch, 1, 1) + + CUDA.@sync for _ in 1:warmup + ct.launch(fft_kernel, grid, + x_packed, y_packed, + W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, + ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), + ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), + ct.Constant(batch), ct.Constant(D), ct.Constant(N2D)) + end + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(fft_kernel, grid, + x_packed, y_packed, + W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, + ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), + ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), + ct.Constant(batch), ct.Constant(D), ct.Constant(N2D)) + push!(times, t * 1000) # ms + end + + # Unpack output y_complex = reinterpret(reshape, ComplexF32, y_packed) + output = copy(y_complex) - return copy(y_complex) + return (; output, times) end -# Validation and example +function verify(data, result) + reference = FFTW.fft(Array(data.input), 2) + @assert isapprox(Array(result.output), reference, rtol=1e-4) +end + +#============================================================================= + Reference implementations for benchmarking +=============================================================================# + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; input, batch, n) = data + results = Dict{String, Vector{Float64}}() + + output_cufft = similar(input) + + # CUFFT via CUDA.CUFFT + CUDA.@sync for _ in 1:warmup + CUDA.CUFFT.fft!(copy(input), 2) + end + times_cufft = Float64[] + for _ in 1:nruns + input_copy = copy(input) + t = CUDA.@elapsed CUDA.CUFFT.fft!(input_copy, 2) + push!(times_cufft, t * 1000) + end + results["cuFFT"] = times_cufft + + return results +end + +#============================================================================= + Main +=============================================================================# + function main() println("--- Running cuTile FFT Example ---") @@ -273,36 +333,17 @@ function main() println(" FFT Factors: $FFT_FACTORS") println(" Atom Packing Dim: $ATOM_PACKING_DIM") - # Create sample input - CUDA.seed!(42) - input_complex = CUDA.randn(ComplexF32, BATCH_SIZE, FFT_SIZE) - - println("\nInput data shape: $(size(input_complex)), dtype: $(eltype(input_complex))") - - # Perform FFT using cuTile kernel - output_cutile = cutile_fft(input_complex, FFT_FACTORS; atom_packing_dim=ATOM_PACKING_DIM) + # Use prepare/run/verify pattern + data = prepare(; batch=BATCH_SIZE, n=FFT_SIZE, factors=FFT_FACTORS, atom_packing_dim=ATOM_PACKING_DIM) + println("\nInput data shape: $(size(data.input)), dtype: $(eltype(data.input))") - println("cuTile FFT Output shape: $(size(output_cutile)), dtype: $(eltype(output_cutile))") + result = run(data) + println("cuTile FFT Output shape: $(size(result.output)), dtype: $(eltype(result.output))") - # Verify against reference (FFTW) - input_cpu = Array(input_complex) - reference_output = FFTW.fft(input_cpu, 2) - - output_cpu = Array(output_cutile) - - if isapprox(output_cpu, reference_output, rtol=1e-4) - println("\n✓ Correctness check PASSED") - else - max_diff = maximum(abs.(output_cpu .- reference_output)) - println("\n✗ Correctness check FAILED - max difference: $max_diff") - println("\nExpected (first 4):") - println(reference_output[1, 1:4]) - println("\nGot (first 4):") - println(output_cpu[1, 1:4]) - end + verify(data, result) + println("\n✓ Correctness check PASSED") println("\n--- cuTile FFT example execution complete ---") end -# Run validation -main() +isinteractive() || main() diff --git a/examples/fft.py b/examples/fft.py new file mode 100644 index 0000000..8c60ab6 --- /dev/null +++ b/examples/fft.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +""" +FFT (3-stage Cooley-Tukey) example - cuTile Python +""" + +import torch +import math +import cuda.tile as ct + +@ct.kernel +def fft_kernel(x_packed_in, y_packed_out, + W0, W1, W2, T0, T1, + N: ct.Constant[int], F0: ct.Constant[int], F1: ct.Constant[int], F2: ct.Constant[int], + BS: ct.Constant[int], D: ct.Constant[int]): + """cuTile kernel for 3-stage Cooley-Tukey FFT.""" + F0F1 = F0 * F1 + F1F2 = F1 * F2 + F0F2 = F0 * F2 + + bid = ct.bid(0) + + # Load input, reshape to separate real/imag + X_ri = ct.reshape(ct.load(x_packed_in, index=(bid, 0, 0), shape=(BS, N * 2 // D, D)), (BS, N, 2)) + X_r = ct.reshape(ct.extract(X_ri, index=(0, 0, 0), shape=(BS, N, 1)), (BS, F0, F1, F2)) + X_i = ct.reshape(ct.extract(X_ri, index=(0, 0, 1), shape=(BS, N, 1)), (BS, F0, F1, F2)) + + # Load W matrices (rotation matrices) + W0_ri = ct.reshape(ct.load(W0, index=(0, 0, 0), shape=(F0, F0, 2)), (F0, F0, 2)) + W0_r = ct.reshape(ct.extract(W0_ri, index=(0, 0, 0), shape=(F0, F0, 1)), (1, F0, F0)) + W0_i = ct.reshape(ct.extract(W0_ri, index=(0, 0, 1), shape=(F0, F0, 1)), (1, F0, F0)) + + W1_ri = ct.reshape(ct.load(W1, index=(0, 0, 0), shape=(F1, F1, 2)), (F1, F1, 2)) + W1_r = ct.reshape(ct.extract(W1_ri, index=(0, 0, 0), shape=(F1, F1, 1)), (1, F1, F1)) + W1_i = ct.reshape(ct.extract(W1_ri, index=(0, 0, 1), shape=(F1, F1, 1)), (1, F1, F1)) + + W2_ri = ct.reshape(ct.load(W2, index=(0, 0, 0), shape=(F2, F2, 2)), (F2, F2, 2)) + W2_r = ct.reshape(ct.extract(W2_ri, index=(0, 0, 0), shape=(F2, F2, 1)), (1, F2, F2)) + W2_i = ct.reshape(ct.extract(W2_ri, index=(0, 0, 1), shape=(F2, F2, 1)), (1, F2, F2)) + + # Load T matrices (twiddle factors) + T0_ri = ct.reshape(ct.load(T0, index=(0, 0, 0), shape=(F0, F1F2, 2)), (F0, F1F2, 2)) + T0_r = ct.reshape(ct.extract(T0_ri, index=(0, 0, 0), shape=(F0, F1F2, 1)), (N, 1)) + T0_i = ct.reshape(ct.extract(T0_ri, index=(0, 0, 1), shape=(F0, F1F2, 1)), (N, 1)) + + T1_ri = ct.reshape(ct.load(T1, index=(0, 0, 0), shape=(F1, F2, 2)), (F1, F2, 2)) + T1_r = ct.reshape(ct.extract(T1_ri, index=(0, 0, 0), shape=(F1, F2, 1)), (F1F2, 1)) + T1_i = ct.reshape(ct.extract(T1_ri, index=(0, 0, 1), shape=(F1, F2, 1)), (F1F2, 1)) + + # CT0: Contract over F0 dimension + X_r = ct.reshape(X_r, (BS, F0, F1F2)) + X_i = ct.reshape(X_i, (BS, F0, F1F2)) + X_r_ = ct.reshape(ct.matmul(W0_r, X_r) - ct.matmul(W0_i, X_i), (BS, N, 1)) + X_i_ = ct.reshape(ct.matmul(W0_i, X_r) + ct.matmul(W0_r, X_i), (BS, N, 1)) + + # Twiddle & Permute 0 + X_r = T0_r * X_r_ - T0_i * X_i_ + X_i = T0_i * X_r_ + T0_r * X_i_ + X_r = ct.permute(ct.reshape(X_r, (BS, F0, F1, F2)), (0, 2, 3, 1)) + X_i = ct.permute(ct.reshape(X_i, (BS, F0, F1, F2)), (0, 2, 3, 1)) + + # CT1: Contract over F1 dimension + X_r = ct.reshape(X_r, (BS, F1, F0F2)) + X_i = ct.reshape(X_i, (BS, F1, F0F2)) + X_r_ = ct.reshape(ct.matmul(W1_r, X_r) - ct.matmul(W1_i, X_i), (BS, F1F2, F0)) + X_i_ = ct.reshape(ct.matmul(W1_i, X_r) + ct.matmul(W1_r, X_i), (BS, F1F2, F0)) + + # Twiddle & Permute 1 + X_r = T1_r * X_r_ - T1_i * X_i_ + X_i = T1_i * X_r_ + T1_r * X_i_ + X_r = ct.permute(ct.reshape(X_r, (BS, F1, F2, F0)), (0, 2, 3, 1)) + X_i = ct.permute(ct.reshape(X_i, (BS, F1, F2, F0)), (0, 2, 3, 1)) + + # CT2: Contract over F2 dimension + X_r = ct.reshape(X_r, (BS, F2, F0F1)) + X_i = ct.reshape(X_i, (BS, F2, F0F1)) + X_r_ = ct.matmul(W2_r, X_r) - ct.matmul(W2_i, X_i) + X_i_ = ct.matmul(W2_i, X_r) + ct.matmul(W2_r, X_i) + + # Final Permutation + X_r = ct.permute(ct.reshape(X_r_, (BS, F2, F0, F1)), (0, 1, 3, 2)) + X_i = ct.permute(ct.reshape(X_i_, (BS, F2, F0, F1)), (0, 1, 3, 2)) + X_r = ct.reshape(X_r, (BS, N, 1)) + X_i = ct.reshape(X_i, (BS, N, 1)) + + # Concatenate and Store + Y_ri = ct.reshape(ct.cat((X_r, X_i), axis=-1), (BS, N * 2 // D, D)) + ct.store(y_packed_out, index=(bid, 0, 0), tile=Y_ri) + +def fft_twiddles(rows: int, cols: int, factor: int, device, precision): + """Generate DFT twiddle factors.""" + I, J = torch.meshgrid(torch.arange(rows, device=device), + torch.arange(cols, device=device), indexing='ij') + W_complex = torch.exp(-2 * math.pi * 1j * (I * J) / factor) + return torch.view_as_real(W_complex).to(precision).contiguous() + + +def fft_make_twiddles(factors, precision, device): + """Generate W and T matrices for FFT.""" + F0, F1, F2 = factors + N = F0 * F1 * F2 + F1F2 = F1 * F2 + W0 = fft_twiddles(F0, F0, F0, device, precision) + W1 = fft_twiddles(F1, F1, F1, device, precision) + W2 = fft_twiddles(F2, F2, F2, device, precision) + T0 = fft_twiddles(F0, F1F2, N, device, precision) + T1 = fft_twiddles(F1, F2, F1F2, device, precision) + return (W0, W1, W2, T0, T1) + + +#============================================================================= +# Example harness +#============================================================================= + +def prepare(*, benchmark: bool = False, batch: int = None, size: int = None, factors: tuple = None, atom_packing_dim: int = 2): + """Allocate and initialize data for FFT.""" + if batch is None: + batch = 64 if benchmark else 2 + if factors is None: + factors = (8, 8, 8) if benchmark else (2, 2, 2) + F0, F1, F2 = factors + N = F0 * F1 * F2 + if size is None: + size = N + assert size == N, f"size ({size}) must equal product of factors ({N})" + D = atom_packing_dim + + input_data = torch.randn(batch, N, dtype=torch.complex64, device='cuda') + + # Pre-compute twiddles + W0, W1, W2, T0, T1 = fft_make_twiddles(factors, input_data.real.dtype, input_data.device) + + # Pack input + x_ri = torch.view_as_real(input_data) + x_packed = x_ri.reshape(batch, N * 2 // D, D).contiguous() + y_packed = torch.empty_like(x_packed) + + return { + "input": input_data, + "x_packed": x_packed, + "y_packed": y_packed, + "W0": W0, "W1": W1, "W2": W2, "T0": T0, "T1": T1, + "factors": factors, + "batch": batch, + "N": N, + "D": D + } + + +def run(data, *, nruns: int = 1, warmup: int = 0): + """Run FFT kernel with timing.""" + x_packed = data["x_packed"] + y_packed = data["y_packed"] + W0, W1, W2, T0, T1 = data["W0"], data["W1"], data["W2"], data["T0"], data["T1"] + F0, F1, F2 = data["factors"] + batch, N, D = data["batch"], data["N"], data["D"] + + grid = (batch, 1, 1) + + # Warmup + for _ in range(warmup): + ct.launch(torch.cuda.current_stream(), grid, fft_kernel, + (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D)) + torch.cuda.synchronize() + + # Timed runs + times = [] + for _ in range(nruns): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + ct.launch(torch.cuda.current_stream(), grid, fft_kernel, + (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D)) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + + output = torch.view_as_complex(y_packed.reshape(batch, N, 2)) + + return {"output": output, "times": times} + + +def verify(data, result): + """Verify FFT results.""" + reference = torch.fft.fft(data["input"], dim=-1) + assert torch.allclose(result["output"], reference, rtol=1e-3, atol=1e-3), \ + f"FFT incorrect! max diff: {torch.max(torch.abs(result['output'] - reference))}" + + +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + results = {} + input_data = data["input"] + + # PyTorch FFT (uses cuFFT) + for _ in range(warmup): + torch.fft.fft(input_data, dim=-1) + torch.cuda.synchronize() + + times_torch = [] + for _ in range(nruns): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.fft.fft(input_data, dim=-1) + end.record() + torch.cuda.synchronize() + times_torch.append(start.elapsed_time(end)) + results["cuFFT"] = times_torch + + return results + + +#============================================================================= +# Main +#============================================================================= + +def test_fft(batch, size, factors, name=None): + """Test FFT with given parameters.""" + name = name or f"fft batch={batch}, size={size}, factors={factors}" + print(f"--- {name} ---") + data = prepare(batch=batch, size=size, factors=factors) + result = run(data) + verify(data, result) + print(" passed") + + +def main(): + print("--- cuTile FFT Examples ---\n") + + test_fft(64, 512, (8, 8, 8)) + test_fft(32, 512, (8, 8, 8)) + + print("\n--- All FFT examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/layernorm.jl b/examples/layernorm.jl index 7e2e11a..659ec6c 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -273,149 +273,131 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa end #============================================================================= - Test / Validation + Example harness =============================================================================# -function main() - println("=== cuTile LayerNorm Sample ===\n") +function prepare(; benchmark::Bool=false, + M::Int=benchmark ? 4096 : 256, + N::Int=benchmark ? 4096 : 256, + eps::Float32=1f-5, GROUP_SIZE_M::Int=64) + return (; + # Forward inputs/outputs + X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N), + W = CUDA.randn(Float32, N), + B = CUDA.randn(Float32, N), + Y = CuArray{Float32}(undef, M, N), + Mean = CuArray{Float32}(undef, M), + Rstd = CuArray{Float32}(undef, M), + # Backward inputs/outputs + DY = 0.1f0 .* CUDA.randn(Float32, M, N), + DX = CuArray{Float32}(undef, M, N), + DW_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N), + DB_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N), + Locks = CuArray{Int}(undef, GROUP_SIZE_M), + FINAL_DW = CuArray{Float32}(undef, N), + FINAL_DB = CuArray{Float32}(undef, N), + # Metadata + M, N, eps, GROUP_SIZE_M + ) +end + +function run(data; TILE_N::Int=1024, TILE_M::Int=32, nruns::Int=1, warmup::Int=0) + (; X, W, B, Y, Mean, Rstd, DY, DX, DW_partial, DB_partial, Locks, FINAL_DW, FINAL_DB, + M, N, eps, GROUP_SIZE_M) = data + + function run_fwd() + ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, + ct.Constant(eps), ct.Constant(TILE_N)) + end + + function run_bwd() + fill!(DW_partial, 0) + fill!(DB_partial, 0) + fill!(Locks, 0) + ct.launch(layer_norm_bwd_dx_partial_dwdb, M, DX, DY, DW_partial, DB_partial, X, W, + Mean, Rstd, Locks, ct.Constant(GROUP_SIZE_M), ct.Constant(TILE_N)) + num_tiles_n = cld(N, TILE_N) + ct.launch(layer_norm_bwd_dwdb, num_tiles_n, DW_partial, DB_partial, FINAL_DW, FINAL_DB, + ct.Constant(TILE_M), ct.Constant(TILE_N)) + end - M, N = 1024, 2048 - TILE_N = 1024 - eps = 1f-5 + # Warmup + CUDA.@sync for _ in 1:warmup + run_fwd() + run_bwd() + end - println("Input shape: ($M, $N), dtype: Float32, eps: $eps") + # Timed forward runs + times_fwd = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed run_fwd() + push!(times_fwd, t * 1000) # ms + end - # Input data - X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N) - W = CUDA.randn(Float32, N) - B = CUDA.randn(Float32, N) + # Timed backward runs + times_bwd = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed run_bwd() + push!(times_bwd, t * 1000) # ms + end - # Output buffers for forward pass - Y = CUDA.zeros(Float32, M, N) - Mean = CUDA.zeros(Float32, M) - Rstd = CUDA.zeros(Float32, M) + return (; Y, Mean, Rstd, DX, FINAL_DW, FINAL_DB, times_fwd, times_bwd) +end - # ========================================================================= - # Forward Pass - # ========================================================================= - println("\n--- Forward Pass ---") - ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, - ct.Constant(eps), ct.Constant(TILE_N)) +function verify(data, result) + (; X, W, B, DY, N, eps) = data - # Compute expected values on CPU X_cpu = Array(X) W_cpu = Array(W) B_cpu = Array(B) + DY_cpu = Array(DY) + # Forward verification expected_mean = vec(sum(X_cpu, dims=2) ./ N) expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ eps) - normalized = (X_cpu .- expected_mean) .* expected_rstd - expected_Y = normalized .* W_cpu' .+ B_cpu' - - # Verify forward pass results - Mean_cpu = Array(Mean) - Rstd_cpu = Array(Rstd) - Y_cpu = Array(Y) + xhat = (X_cpu .- expected_mean) .* expected_rstd + expected_Y = xhat .* W_cpu' .+ B_cpu' atol, rtol = 1f-2, 1f-2 - fwd_ok = isapprox(expected_mean, Mean_cpu; rtol, atol) && - isapprox(expected_rstd, Rstd_cpu; rtol, atol) && - isapprox(expected_Y, Y_cpu; rtol, atol) - - if fwd_ok - println("Forward pass: PASSED") - else - println("Forward pass: FAILED") - isapprox(expected_mean, Mean_cpu; rtol, atol) || println(" Mean max error: $(maximum(abs.(expected_mean .- Mean_cpu)))") - isapprox(expected_rstd, Rstd_cpu; rtol, atol) || println(" Rstd max error: $(maximum(abs.(expected_rstd .- Rstd_cpu)))") - isapprox(expected_Y, Y_cpu; rtol, atol) || println(" Y max error: $(maximum(abs.(expected_Y .- Y_cpu)))") - end - - # ========================================================================= - # Backward Pass (Full: dX, dW, dB) - # ========================================================================= - println("\n--- Backward Pass (Full: dX, dW, dB) ---") - - # Upstream gradient (random for testing) - DY = CUDA.randn(Float32, M, N) - DX = CUDA.zeros(Float32, M, N) + @assert isapprox(expected_Y, Array(result.Y); rtol, atol) "Y mismatch" - # Parameters for partial gradient accumulation - GROUP_SIZE_M = 64 - TILE_M = 32 - - # Partial gradient buffers and locks - DW_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N) - DB_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N) - Locks = CUDA.zeros(Int, GROUP_SIZE_M) - - # Final gradient buffers - FINAL_DW = CUDA.zeros(Float32, N) - FINAL_DB = CUDA.zeros(Float32, N) - - # Launch backward kernels - ct.launch(layer_norm_bwd_dx_partial_dwdb, M, DX, DY, DW_partial, DB_partial, X, W, - Mean, Rstd, Locks, ct.Constant(GROUP_SIZE_M), ct.Constant(TILE_N)) - - num_tiles_n = cld(N, TILE_N) - ct.launch(layer_norm_bwd_dwdb, num_tiles_n, DW_partial, DB_partial, FINAL_DW, FINAL_DB, - ct.Constant(TILE_M), ct.Constant(TILE_N)) - - # Compute expected gradients on CPU - # dX = rstd * (W * dY - c2 - x_hat * c1) - # where c1 = mean(x_hat * W * dY), c2 = mean(W * dY) - DY_cpu = Array(DY) + # Backward verification wdy = W_cpu' .* DY_cpu - xhat = normalized c1 = sum(xhat .* wdy, dims=2) ./ N c2 = sum(wdy, dims=2) ./ N expected_DX = (wdy .- (xhat .* c1 .+ c2)) .* expected_rstd - - # dW = sum(dY * x_hat, dim=0) and dB = sum(dY, dim=0) expected_DW = vec(sum(DY_cpu .* xhat, dims=1)) expected_DB = vec(sum(DY_cpu, dims=1)) - # Verify dX - DX_cpu = Array(DX) - dx_ok = isapprox(expected_DX, DX_cpu; rtol, atol) - if dx_ok - println(" dX: PASSED") - else - max_err = maximum(abs.(expected_DX .- DX_cpu)) - println(" dX: FAILED (max error: $max_err)") - end + @assert isapprox(expected_DX, Array(result.DX); rtol, atol) "dX mismatch" + @assert isapprox(expected_DW, Array(result.FINAL_DW); rtol, atol) "dW mismatch" + @assert isapprox(expected_DB, Array(result.FINAL_DB); rtol, atol) "dB mismatch" +end - # Verify dW - FINAL_DW_cpu = Array(FINAL_DW) - dw_ok = isapprox(expected_DW, FINAL_DW_cpu; rtol, atol) - if dw_ok - println(" dW: PASSED") - else - max_err = maximum(abs.(expected_DW .- FINAL_DW_cpu)) - println(" dW: FAILED (max error: $max_err)") - end +function test_layernorm(M, N, TILE_N; TILE_M::Int=32, eps::Float32=1f-5, name=nothing) + name = something(name, "layernorm ($M x $N), tile_n=$TILE_N, tile_m=$TILE_M") + println("--- $name ---") + data = prepare(; M, N, eps) + result = run(data; TILE_N, TILE_M) + verify(data, result) + println(" fwd passed, bwd passed") +end - # Verify dB - FINAL_DB_cpu = Array(FINAL_DB) - db_ok = isapprox(expected_DB, FINAL_DB_cpu; rtol, atol) - if db_ok - println(" dB: PASSED") - else - max_err = maximum(abs.(expected_DB .- FINAL_DB_cpu)) - println(" dB: FAILED (max error: $max_err)") - end +# No run_others for layernorm - no simple reference implementation to compare against - bwd_ok = dx_ok && dw_ok && db_ok +#============================================================================= + Main +=============================================================================# + +function main() + println("=== cuTile LayerNorm Examples (fwd+bwd) ===\n") - # ========================================================================= - # Summary - # ========================================================================= - println("\n=== Summary ===") - println("Forward pass: $(fwd_ok ? "PASSED" : "FAILED")") - println("Backward (dX/dW/dB): $(bwd_ok ? "PASSED" : "FAILED")") + test_layernorm(256, 256, 256) + test_layernorm(512, 512, 512) + test_layernorm(1024, 2048, 1024) - (fwd_ok && bwd_ok) || error("LayerNorm tests failed") + println("\n=== All layernorm examples completed ===") end isinteractive() || main() diff --git a/examples/layernorm.py b/examples/layernorm.py new file mode 100644 index 0000000..b65e136 --- /dev/null +++ b/examples/layernorm.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +""" +Layer Normalization example - cuTile Python +Forward and backward passes with unified prepare/run/verify pattern. +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct +from math import ceil + +#============================================================================= +# Forward Kernel +#============================================================================= + +@ct.kernel +def layernorm_fwd_kernel(X, W, B, Y, Mean, Rstd, eps: ct.Constant[float], TILE_N: ct.Constant[int]): + """Forward pass: computes mean/var, normalizes input, applies affine transform.""" + bid_m = ct.bid(0) + num_tiles = ct.num_tiles(X, axis=1, shape=(1, TILE_N)) + N = X.shape[1] + + # Compute mean + mean = ct.full((1, TILE_N), 0, dtype=ct.float32) + for j in range(num_tiles): + tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + mean += tx + mean = ct.sum(mean, axis=1) / N + ct.store(Mean, index=(bid_m,), tile=mean) + + # Compute variance + var = ct.full((1, TILE_N), 0, dtype=ct.float32) + for j in range(num_tiles): + tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + mask = (j * TILE_N + ct.arange(TILE_N, dtype=ct.int32)) < N + centered_tx = ct.where(mask, tx - mean, 0) + var += centered_tx ** 2 + var = ct.sum(var, axis=1) / N + rstd = 1 / ct.sqrt(var + eps) + ct.store(Rstd, index=(bid_m,), tile=rstd) + + # Normalize and apply affine transformation + for j in range(num_tiles): + tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + tw = ct.load(W, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) + tb = ct.load(B, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) + ty = (tx - mean) * rstd + ty = ty * tw + tb + ct.store(Y, index=(bid_m, j), tile=ty.astype(Y.dtype)) + + +#============================================================================= +# Backward Kernels +#============================================================================= + +def bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N): + """Helper to load data and compute common backward terms.""" + tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + tw = ct.load(W, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) + tdy = ct.load(DY, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + xhat = (tx - mean) * rstd + wdy = tw * tdy + mask = j * TILE_N + ct.arange(TILE_N, dtype=ct.int32) < N + xhat = ct.where(mask, xhat, 0) + wdy = ct.where(mask, wdy, 0) + return tdy, xhat, wdy + + +@ct.kernel +def layernorm_bwd_dx_partial_dwdb_kernel(DX, DY, DW, DB, X, W, Mean, Rstd, Locks, TILE_N: ct.Constant[int]): + """Backward pass part 1: computes dX and partial dW/dB with atomic accumulation.""" + bid_m = ct.bid(0) + num_tiles = ct.num_tiles(X, axis=1, shape=(1, TILE_N)) + N = X.shape[1] + GROUP_SIZE_M = DW.shape[0] + group_bid_m = bid_m % GROUP_SIZE_M + + mean = ct.load(Mean, index=(bid_m,), shape=(1,)) + rstd = ct.load(Rstd, index=(bid_m,), shape=(1,)) + + c1 = ct.full((1, TILE_N), 0, dtype=ct.float32) + c2 = ct.full((1, TILE_N), 0, dtype=ct.float32) + for j in range(num_tiles): + _, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N) + c1 += xhat * wdy + c2 += wdy + c1 = ct.sum(c1, axis=1) / N + c2 = ct.sum(c2, axis=1) / N + + for j in range(num_tiles): + tdy, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N) + tdx = (wdy - (xhat * c1 + c2)) * rstd + ct.store(DX, index=(bid_m, j), tile=tdx.astype(DX.dtype)) + + partial_dw = (tdy * xhat).astype(DW.dtype) + partial_db = tdy.astype(DB.dtype) + + while ct.atomic_cas(Locks, group_bid_m, 0, 1, memory_order=ct.MemoryOrder.ACQUIRE) == 1: + pass + + partial_dw += ct.load(DW, index=(group_bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + partial_db += ct.load(DB, index=(group_bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + ct.store(DW, index=(group_bid_m, j), tile=partial_dw) + ct.store(DB, index=(group_bid_m, j), tile=partial_db) + + ct.atomic_xchg(Locks, group_bid_m, 0, memory_order=ct.MemoryOrder.RELEASE) + + +@ct.kernel +def layernorm_bwd_dwdb_kernel(DW, DB, FINAL_DW, FINAL_DB, TILE_M: ct.Constant[int], TILE_N: ct.Constant[int]): + """Backward pass part 2: Final reduction for dW and dB.""" + bid_n = ct.bid(0) + num_tiles = ct.num_tiles(DW, axis=0, shape=(TILE_M, TILE_N)) + + dw = ct.zeros((TILE_M, TILE_N), dtype=ct.float32) + db = ct.zeros((TILE_M, TILE_N), dtype=ct.float32) + for i in range(num_tiles): + dw += ct.load(DW, index=(i, bid_n), shape=(TILE_M, TILE_N), padding_mode=ct.PaddingMode.ZERO) + db += ct.load(DB, index=(i, bid_n), shape=(TILE_M, TILE_N), padding_mode=ct.PaddingMode.ZERO) + sum_dw = ct.sum(dw, axis=0) + sum_db = ct.sum(db, axis=0) + + ct.store(FINAL_DW, index=(bid_n,), tile=sum_dw.astype(FINAL_DW.dtype)) + ct.store(FINAL_DB, index=(bid_n,), tile=sum_db.astype(FINAL_DB.dtype)) + + +#============================================================================= +# Example harness +#============================================================================= + +def prepare(*, benchmark: bool = False, M: int = None, N: int = None, eps: float = 1e-5, GROUP_SIZE_M: int = 64, dtype=np.float32): + """Allocate all data for forward and backward passes.""" + if M is None: + M = 4096 if benchmark else 256 + if N is None: + N = 4096 if benchmark else 256 + return { + # Forward inputs/outputs + "X": (-2.3 + 0.5 * cp.random.randn(M, N)).astype(dtype), + "W": cp.random.randn(N).astype(dtype), + "B": cp.random.randn(N).astype(dtype), + "Y": cp.empty((M, N), dtype=dtype), + "Mean": cp.empty(M, dtype=np.float32), + "Rstd": cp.empty(M, dtype=np.float32), + # Backward inputs/outputs + "DY": (0.1 * cp.random.randn(M, N)).astype(dtype), + "DX": cp.empty((M, N), dtype=dtype), + "DW_partial": cp.empty((GROUP_SIZE_M, N), dtype=np.float32), + "DB_partial": cp.empty((GROUP_SIZE_M, N), dtype=np.float32), + "Locks": cp.empty(GROUP_SIZE_M, dtype=np.int32), + "FINAL_DW": cp.empty(N, dtype=dtype), + "FINAL_DB": cp.empty(N, dtype=dtype), + # Metadata + "eps": eps, + "M": M, + "N": N, + "GROUP_SIZE_M": GROUP_SIZE_M + } + + +def run(data, *, tile_n: int = 1024, tile_m: int = 32, nruns: int = 1, warmup: int = 0): + """Run both forward and backward passes with timing.""" + X, W, B, Y = data["X"], data["W"], data["B"], data["Y"] + Mean, Rstd = data["Mean"], data["Rstd"] + DY, DX = data["DY"], data["DX"] + DW_partial, DB_partial = data["DW_partial"], data["DB_partial"] + Locks = data["Locks"] + FINAL_DW, FINAL_DB = data["FINAL_DW"], data["FINAL_DB"] + eps, M, N = data["eps"], data["M"], data["N"] + GROUP_SIZE_M = data["GROUP_SIZE_M"] + + stream = cp.cuda.get_current_stream() + + def run_fwd(): + ct.launch(stream, (M,), layernorm_fwd_kernel, (X, W, B, Y, Mean, Rstd, eps, tile_n)) + + def run_bwd(): + DW_partial.fill(0) + DB_partial.fill(0) + Locks.fill(0) + ct.launch(stream, (M,), layernorm_bwd_dx_partial_dwdb_kernel, + (DX, DY, DW_partial, DB_partial, X, W, Mean, Rstd, Locks, tile_n)) + num_tiles_n = ceil(N / tile_n) + ct.launch(stream, (num_tiles_n,), layernorm_bwd_dwdb_kernel, + (DW_partial, DB_partial, FINAL_DW, FINAL_DB, tile_m, tile_n)) + + # Warmup + for _ in range(warmup): + run_fwd() + run_bwd() + cp.cuda.runtime.deviceSynchronize() + + # Timed forward runs + times_fwd = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + run_fwd() + end.record(stream) + end.synchronize() + times_fwd.append(cp.cuda.get_elapsed_time(start, end)) # ms + + # Timed backward runs + times_bwd = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + run_bwd() + end.record(stream) + end.synchronize() + times_bwd.append(cp.cuda.get_elapsed_time(start, end)) # ms + + return { + "Y": Y, "Mean": Mean, "Rstd": Rstd, + "DX": DX, "DW": FINAL_DW, "DB": FINAL_DB, + "times_fwd": times_fwd, "times_bwd": times_bwd + } + + +def verify(data, result): + """Verify both forward and backward results.""" + X_np = cp.asnumpy(data["X"]) + W_np = cp.asnumpy(data["W"]) + B_np = cp.asnumpy(data["B"]) + DY_np = cp.asnumpy(data["DY"]) + eps = data["eps"] + N = data["N"] + + # Forward verification + expected_mean = np.mean(X_np, axis=1, keepdims=True) + expected_var = np.mean((X_np - expected_mean) ** 2, axis=1, keepdims=True) + expected_rstd = 1.0 / np.sqrt(expected_var + eps) + xhat = (X_np - expected_mean) * expected_rstd + expected_Y = xhat * W_np + B_np + + atol, rtol = 1e-2, 1e-2 + assert np.allclose(cp.asnumpy(result["Y"]), expected_Y, rtol=rtol, atol=atol), \ + f"Y mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['Y']) - expected_Y))}" + + # Backward verification + wdy = W_np * DY_np + c1 = np.sum(xhat * wdy, axis=1, keepdims=True) / N + c2 = np.sum(wdy, axis=1, keepdims=True) / N + expected_DX = (wdy - (xhat * c1 + c2)) * expected_rstd + expected_DW = np.sum(DY_np * xhat, axis=0) + expected_DB = np.sum(DY_np, axis=0) + + assert np.allclose(cp.asnumpy(result["DX"]), expected_DX, rtol=rtol, atol=atol), \ + f"DX mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['DX']) - expected_DX))}" + assert np.allclose(cp.asnumpy(result["DW"]), expected_DW, rtol=rtol, atol=atol), \ + f"DW mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['DW']) - expected_DW))}" + assert np.allclose(cp.asnumpy(result["DB"]), expected_DB, rtol=rtol, atol=atol), \ + f"DB mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['DB']) - expected_DB))}" + +# No run_others for layernorm - no simple reference implementation to compare against + + +#============================================================================= +# Main +#============================================================================= + +def test_layernorm(M, N, tile_n, tile_m=32, eps=1e-5, dtype=np.float32, name=None): + """Test layer normalization (fwd+bwd) with given parameters.""" + name = name or f"layernorm ({M}x{N}), tile_n={tile_n}, tile_m={tile_m}, dtype={dtype.__name__}" + print(f"--- {name} ---") + data = prepare(M=M, N=N, eps=eps, dtype=dtype) + result = run(data, tile_n=tile_n, tile_m=tile_m) + verify(data, result) + print(" fwd passed, bwd passed") + + +def main(): + print("--- cuTile Layer Normalization Examples (fwd+bwd) ---\n") + + test_layernorm(256, 256, 256) + test_layernorm(512, 512, 512) + test_layernorm(1024, 1024, 1024) + + print("\n--- All layernorm examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/matmul.jl b/examples/matmul.jl index 36d41ea..47a24e6 100644 --- a/examples/matmul.jl +++ b/examples/matmul.jl @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 using CUDA +using LinearAlgebra import cuTile as ct # 2D swizzle for better L2 cache locality @@ -62,33 +63,82 @@ function matmul_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArr return nothing end -function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T - name = something(name, "matmul ($M x $K) @ ($K x $N), $T, tiles=$tm x $tn x $tk") - println("--- $name ---") +#============================================================================= + Example harness +=============================================================================# + +function prepare(; benchmark::Bool=false, + M::Int=benchmark ? 4096 : 256, + N::Int=benchmark ? 4096 : 256, + K::Int=benchmark ? 4096 : 256, + T::DataType=Float32) + return (; + A = CUDA.rand(T, M, K), + B = CUDA.rand(T, K, N), + C = CuArray{T}(undef, M, N), + M, N, K + ) +end + +function run(data; tm::Int=64, tn::Int=64, tk::Int=64, nruns::Int=1, warmup::Int=0) + (; A, B, C, M, N, K) = data + grid = cld(M, tm) * cld(N, tn) + + CUDA.@sync for _ in 1:warmup + ct.launch(matmul_kernel, grid, A, B, C, + ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) + end + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(matmul_kernel, grid, A, B, C, + ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) + push!(times, t * 1000) # ms + end + + return (; C, times) +end - A = CUDA.rand(T, M, K) - B = CUDA.rand(T, K, N) - C = CUDA.zeros(T, M, N) +function verify(data, result) + expected = Array(data.A) * Array(data.B) + @assert isapprox(Array(result.C), expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))" +end - # Use 1D grid - swizzle_2d converts to 2D indices - grid_m = cld(M, tm) - grid_n = cld(N, tn) - grid = grid_m * grid_n +#============================================================================= + Reference implementations for benchmarking +=============================================================================# - # Launch kernel - ct.launch(matmul_kernel, grid, A, B, C, - ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) +function run_others(data; nruns::Int=1, warmup::Int=0) + (; A, B) = data + results = Dict{String, Vector{Float64}}() - # Verify result - expected = Array(A) * Array(B) - result = Array(C) + C_gpuarrays = similar(A, size(A, 1), size(B, 2)) - if isapprox(result, expected, rtol=1e-2, atol=1e-2) - println(" passed") - else - max_diff = maximum(abs.(result - expected)) - println(" FAILED (max diff: $max_diff)") + # GPUArrays (uses cuBLAS under the hood via LinearAlgebra.mul!) + CUDA.@sync for _ in 1:warmup + mul!(C_gpuarrays, A, B) + end + times_gpuarrays = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed mul!(C_gpuarrays, A, B) + push!(times_gpuarrays, t * 1000) end + results["cuBLAS"] = times_gpuarrays + + return results +end + +#============================================================================= + Main +=============================================================================# + +function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T + name = something(name, "matmul ($M x $K) @ ($K x $N), $T, tiles=$tm x $tn x $tk") + println("--- $name ---") + data = prepare(; M, N, K, T) + result = run(data; tm, tn, tk) + verify(data, result) + println(" passed") end function main() diff --git a/examples/matmul.py b/examples/matmul.py new file mode 100644 index 0000000..ab367cf --- /dev/null +++ b/examples/matmul.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Matrix Multiplication example - cuTile Python +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct +from math import ceil + +def swizzle_2d(M, N, tm, tn, GROUP_SIZE_M): + """Get the global IDs of the current CUDA block in a 1D grid.""" + bid = ct.bid(0) + num_bid_m = ct.cdiv(M, tm) + num_bid_n = ct.cdiv(N, tn) + num_bid_in_group = GROUP_SIZE_M * num_bid_n + group_id = bid // num_bid_in_group + first_bid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_bid_m - first_bid_m, GROUP_SIZE_M) + bid_m = first_bid_m + (bid % group_size_m) + bid_n = (bid % num_bid_in_group) // group_size_m + return bid_m, bid_n + + +@ct.kernel +def matmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int]): + GROUP_SIZE_M = 8 + M = A.shape[0] + N = B.shape[1] + bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M) + + num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk)) + accumulator = ct.full((tm, tn), 0, dtype=ct.float32) + + # Convert fp32 to tf32 for tensor cores + dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype + + for k in range(num_tiles_k): + a = ct.load(A, index=(bidx, k), shape=(tm, tk)).astype(dtype) + b = ct.load(B, index=(k, bidy), shape=(tk, tn)).astype(dtype) + accumulator = ct.mma(a, b, accumulator) + + accumulator = ct.astype(accumulator, C.dtype) + ct.store(C, index=(bidx, bidy), tile=accumulator) + + +#============================================================================= +# Example harness +#============================================================================= + +def prepare(*, benchmark: bool = False, M: int = None, N: int = None, K: int = None, dtype=np.float32): + """Allocate and initialize data for matmul.""" + if M is None: + M = 4096 if benchmark else 256 + if N is None: + N = 4096 if benchmark else 256 + if K is None: + K = 4096 if benchmark else 256 + return { + "A": cp.random.randn(M, K).astype(dtype), + "B": cp.random.randn(K, N).astype(dtype), + "C": cp.empty((M, N), dtype=dtype), + "M": M, + "N": N, + "K": K + } + + +def run(data, *, tm: int = 64, tn: int = 64, tk: int = 64, nruns: int = 1, warmup: int = 0): + """Run matmul kernel with timing.""" + A, B, C = data["A"], data["B"], data["C"] + M, N = data["M"], data["N"] + + grid_m = ceil(M / tm) + grid_n = ceil(N / tn) + grid = (grid_m * grid_n, 1, 1) + stream = cp.cuda.get_current_stream() + + # Warmup + for _ in range(warmup): + ct.launch(stream, grid, matmul_cutile_kernel, (A, B, C, tm, tn, tk)) + cp.cuda.runtime.deviceSynchronize() + + # Timed runs + times = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + ct.launch(stream, grid, matmul_cutile_kernel, (A, B, C, tm, tn, tk)) + end.record(stream) + end.synchronize() + times.append(cp.cuda.get_elapsed_time(start, end)) # ms + + return {"C": C, "times": times} + + +def verify(data, result): + """Verify matmul results.""" + expected = cp.asnumpy(data["A"]) @ cp.asnumpy(data["B"]) + # TF32 has reduced precision + assert np.allclose(cp.asnumpy(result["C"]), expected, rtol=1e-1, atol=1e-1), \ + f"matmul incorrect! max diff: {np.max(np.abs(cp.asnumpy(result['C']) - expected))}" + + +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + results = {} + A, B = data["A"], data["B"] + M, N = data["M"], data["N"] + C_cupy = cp.zeros((M, N), dtype=A.dtype) + + stream = cp.cuda.get_current_stream() + + # CuPy matmul (uses cuBLAS) + for _ in range(warmup): + cp.matmul(A, B, out=C_cupy) + cp.cuda.runtime.deviceSynchronize() + + times_cupy = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + cp.matmul(A, B, out=C_cupy) + end.record(stream) + end.synchronize() + times_cupy.append(cp.cuda.get_elapsed_time(start, end)) + results["cuBLAS"] = times_cupy + + return results + + +#============================================================================= +# Main +#============================================================================= + +def test_matmul(M, N, K, tm, tn, tk, dtype=np.float32, name=None): + """Test matmul with given parameters.""" + name = name or f"matmul ({M}x{K}) @ ({K}x{N}), tiles={tm}x{tn}x{tk}, dtype={dtype.__name__}" + print(f"--- {name} ---") + data = prepare(M=M, N=N, K=K, dtype=dtype) + result = run(data, tm=tm, tn=tn, tk=tk) + verify(data, result) + print(" passed") + + +def main(): + print("--- cuTile Matrix Multiplication Examples ---\n") + + test_matmul(256, 256, 256, 32, 32, 32) + test_matmul(512, 512, 512, 64, 64, 64) + test_matmul(256, 512, 128, 32, 32, 32) + test_matmul(1024, 1024, 1024, 64, 64, 64) + + print("\n--- All matmul examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/transpose.jl b/examples/transpose.jl index 9c5ec26..773c988 100644 --- a/examples/transpose.jl +++ b/examples/transpose.jl @@ -17,20 +17,103 @@ function transpose_kernel(x::ct.TileArray{T,2}, y::ct.TileArray{T,2}, return end -function test_transpose(::Type{T}, m, n, tm, tn; name=nothing) where T - name = something(name, "transpose ($m x $n, $T, tiles=$tm x $tn)") - println("--- $name ---") +#============================================================================= + Example harness +=============================================================================# + +function prepare(; benchmark::Bool=false, + m::Int=benchmark ? 8192 : 1024, + n::Int=benchmark ? 8192 : 512, + T::DataType=Float32) x = CUDA.rand(T, m, n) - y = CUDA.zeros(T, n, m) + return (; + x, + y = similar(x, n, m), + m, n + ) +end + +function run(data; tm::Int=64, tn::Int=64, nruns::Int=1, warmup::Int=0) + (; x, y, m, n) = data + grid = (cld(m, tm), cld(n, tn)) + + CUDA.@sync for _ in 1:warmup + ct.launch(transpose_kernel, grid, x, y, + ct.Constant(tm), ct.Constant(tn)) + end + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(transpose_kernel, grid, x, y, + ct.Constant(tm), ct.Constant(tn)) + push!(times, t * 1000) # ms + end + + return (; y, times) +end + +function verify(data, result) + @assert Array(result.y) ≈ transpose(Array(data.x)) +end - grid_x = cld(m, tm) - grid_y = cld(n, tn) +#============================================================================= + Reference implementations for benchmarking +=============================================================================# - # Launch with ct.launch - CuArrays are auto-converted to TileArray - ct.launch(transpose_kernel, (grid_x, grid_y), x, y, - ct.Constant(tm), ct.Constant(tn)) +# Simple SIMT transpose kernel (naive, no shared memory) +function simt_naive_kernel(x, y, m, n) + i = (blockIdx().x - 1) * blockDim().x + threadIdx().x + j = (blockIdx().y - 1) * blockDim().y + threadIdx().y + if i <= m && j <= n + @inbounds y[j, i] = x[i, j] + end + return +end + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; x, m, n) = data + results = Dict{String, Vector{Float64}}() + + y_gpuarrays = similar(x, n, m) + y_simt = similar(x, n, m) + + # GPUArrays (permutedims) + CUDA.@sync for _ in 1:warmup + permutedims!(y_gpuarrays, x, (2, 1)) + end + times_gpuarrays = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed permutedims!(y_gpuarrays, x, (2, 1)) + push!(times_gpuarrays, t * 1000) + end + results["GPUArrays"] = times_gpuarrays + + # SIMT naive kernel + threads = (16, 16) + blocks = (cld(m, threads[1]), cld(n, threads[2])) + CUDA.@sync for _ in 1:warmup + @cuda threads=threads blocks=blocks simt_naive_kernel(x, y_simt, m, n) + end + times_simt = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed @cuda threads=threads blocks=blocks simt_naive_kernel(x, y_simt, m, n) + push!(times_simt, t * 1000) + end + results["SIMT naive"] = times_simt - @assert Array(y) ≈ transpose(Array(x)) + return results +end + +#============================================================================= + Main +=============================================================================# + +function test_transpose(::Type{T}, m, n, tm, tn; name=nothing) where T + name = something(name, "transpose ($m x $n, $T, tiles=$tm x $tn)") + println("--- $name ---") + data = prepare(; m, n, T) + result = run(data; tm, tn) + verify(data, result) println("✓ passed") end diff --git a/examples/transpose.py b/examples/transpose.py new file mode 100644 index 0000000..1996a3b --- /dev/null +++ b/examples/transpose.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +Matrix Transpose example - cuTile Python +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct + +@ct.kernel +def transpose_cutile_kernel(input, output, tile_m: ct.Constant[int], tile_n: ct.Constant[int]): + pid_m = ct.bid(0) + pid_n = ct.bid(1) + tile = ct.load(input, index=(pid_m, pid_n), shape=(tile_m, tile_n)) + tile_t = ct.transpose(tile) + ct.store(output, index=(pid_n, pid_m), tile=tile_t) + + +#============================================================================= +# Example harness +#============================================================================= + +def prepare(*, benchmark: bool = False, M: int = None, N: int = None, dtype=np.float32): + """Allocate and initialize data for transpose.""" + if M is None: + M = 8192 if benchmark else 1024 + if N is None: + N = 8192 if benchmark else 512 + return { + "input": cp.random.rand(M, N).astype(dtype), + "output": cp.empty((N, M), dtype=dtype), + "M": M, + "N": N + } + + +def run(data, *, tile_m: int = 64, tile_n: int = 64, nruns: int = 1, warmup: int = 0): + """Run transpose kernel with timing.""" + input_arr = data["input"] + output_arr = data["output"] + M, N = data["M"], data["N"] + + grid = (ct.cdiv(M, tile_m), ct.cdiv(N, tile_n), 1) + stream = cp.cuda.get_current_stream() + + # Warmup + for _ in range(warmup): + ct.launch(stream, grid, transpose_cutile_kernel, (input_arr, output_arr, tile_m, tile_n)) + cp.cuda.runtime.deviceSynchronize() + + # Timed runs + times = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + ct.launch(stream, grid, transpose_cutile_kernel, (input_arr, output_arr, tile_m, tile_n)) + end.record(stream) + end.synchronize() + times.append(cp.cuda.get_elapsed_time(start, end)) # ms + + return {"output": output_arr, "times": times} + + +def verify(data, result): + """Verify transpose results.""" + expected = cp.asnumpy(data["input"]).T + assert np.allclose(cp.asnumpy(result["output"]), expected), "transpose incorrect!" + + +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + results = {} + input_arr = data["input"] + M, N = data["M"], data["N"] + output_cupy = cp.zeros((N, M), dtype=input_arr.dtype) + + stream = cp.cuda.get_current_stream() + + # CuPy transpose + for _ in range(warmup): + cp.copyto(output_cupy, input_arr.T) + cp.cuda.runtime.deviceSynchronize() + + times_cupy = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + cp.copyto(output_cupy, input_arr.T) + end.record(stream) + end.synchronize() + times_cupy.append(cp.cuda.get_elapsed_time(start, end)) + results["CuPy"] = times_cupy + + return results + + +#============================================================================= +# Main +#============================================================================= + +def test_transpose(M, N, tile_m, tile_n, dtype=np.float32, name=None): + """Test transpose with given parameters.""" + name = name or f"transpose ({M}x{N}), tiles={tile_m}x{tile_n}, dtype={dtype.__name__}" + print(f"--- {name} ---") + data = prepare(M=M, N=N, dtype=dtype) + result = run(data, tile_m=tile_m, tile_n=tile_n) + verify(data, result) + print(" passed") + + +def main(): + print("--- cuTile Matrix Transpose Examples ---\n") + + test_transpose(256, 256, 32, 32) + test_transpose(512, 512, 64, 64) + test_transpose(256, 512, 32, 64) + test_transpose(1024, 1024, 64, 64) + + print("\n--- All transpose examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/vadd.jl b/examples/vadd.jl index 9689fc9..14444e0 100644 --- a/examples/vadd.jl +++ b/examples/vadd.jl @@ -5,8 +5,7 @@ using CUDA import cuTile as ct -# 1D kernel with TileArray and constant tile size -# TileArray carries size/stride metadata, Constant is a ghost type +# 1D kernel function vec_add_kernel_1d(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c::ct.TileArray{T,1}, tile::ct.Constant{Int}) where {T} bid = ct.bid(1) @@ -16,7 +15,7 @@ function vec_add_kernel_1d(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c::ct.Til return end -# 2D kernel with TileArray and constant tile sizes +# 2D kernel function vec_add_kernel_2d(a::ct.TileArray{T,2}, b::ct.TileArray{T,2}, c::ct.TileArray{T,2}, tile_x::ct.Constant{Int}, tile_y::ct.Constant{Int}) where {T} bid_x = ct.bid(1) @@ -27,37 +26,7 @@ function vec_add_kernel_2d(a::ct.TileArray{T,2}, b::ct.TileArray{T,2}, c::ct.Til return end -function test_add_1d(::Type{T}, n, tile; name=nothing) where T - name = something(name, "1D vec_add ($n elements, $T, tile=$tile)") - println("--- $name ---") - a, b = CUDA.rand(T, n), CUDA.rand(T, n) - c = CUDA.zeros(T, n) - - # Launch with ct.launch - CuArrays are auto-converted to TileArray - # Constant parameters are ghost types - filtered out at launch time - ct.launch(vec_add_kernel_1d, cld(n, tile), a, b, c, ct.Constant(tile)) - - @assert Array(c) ≈ Array(a) + Array(b) - println("✓ passed") -end - -function test_add_2d(::Type{T}, m, n, tile_x, tile_y; name=nothing) where T - name = something(name, "2D vec_add ($m x $n, $T, tiles=$tile_x x $tile_y)") - println("--- $name ---") - a, b = CUDA.rand(T, m, n), CUDA.rand(T, m, n) - c = CUDA.zeros(T, m, n) - - # Launch with ct.launch - CuArrays are auto-converted to TileArray - ct.launch(vec_add_kernel_2d, (cld(m, tile_x), cld(n, tile_y)), a, b, c, - ct.Constant(tile_x), ct.Constant(tile_y)) - - @assert Array(c) ≈ Array(a) + Array(b) - println("✓ passed") -end - -# 1D kernel using gather/scatter (explicit index-based memory access) -# This demonstrates the gather/scatter API for cases where you need -# explicit control over indices (e.g., for non-contiguous access patterns) +# 1D kernel using gather/scatter function vec_add_kernel_1d_gather(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c::ct.TileArray{T,1}, tile::ct.Constant{Int}) where {T} bid = ct.bid(1) @@ -74,45 +43,169 @@ function vec_add_kernel_1d_gather(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c: return end -function test_add_1d_gather(::Type{T}, n, tile; name=nothing) where T - name = something(name, "1D vec_add gather ($n elements, $T, tile=$tile)") - println("--- $name ---") - a, b = CUDA.rand(T, n), CUDA.rand(T, n) - c = CUDA.zeros(T, n) - ct.launch(vec_add_kernel_1d_gather, cld(n, tile), a, b, c, ct.Constant(tile)) +#============================================================================= +# Example harness +=============================================================================# + +function prepare(; benchmark::Bool=false, + shape::Tuple=benchmark ? (2^27,) : (1_024_000,), + use_gather::Bool=false, T::DataType=Float32) + a = CUDA.rand(T, shape...) + return (; + a, + b = CUDA.rand(T, shape...), + c = similar(a), + shape, + use_gather + ) +end - @assert Array(c) ≈ Array(a) + Array(b) - println("✓ passed") +function run(data; tile::Union{Int, Tuple{Int,Int}}=1024, nruns::Int=1, warmup::Int=0) + (; a, b, c, shape, use_gather) = data + + if length(shape) == 2 + # 2D case + m, n = shape + tile_x, tile_y = tile isa Tuple ? tile : (tile, tile) + grid = (cld(m, tile_x), cld(n, tile_y)) + + CUDA.@sync for _ in 1:warmup + ct.launch(vec_add_kernel_2d, grid, a, b, c, + ct.Constant(tile_x), ct.Constant(tile_y)) + end + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(vec_add_kernel_2d, grid, a, b, c, + ct.Constant(tile_x), ct.Constant(tile_y)) + push!(times, t * 1000) # ms + end + else + # 1D case + n = shape[1] + tile_val = tile isa Tuple ? tile[1] : tile + grid = cld(n, tile_val) + kernel = use_gather ? vec_add_kernel_1d_gather : vec_add_kernel_1d + + CUDA.@sync for _ in 1:warmup + ct.launch(kernel, grid, a, b, c, ct.Constant(tile_val)) + end + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(kernel, grid, a, b, c, ct.Constant(tile_val)) + push!(times, t * 1000) # ms + end + end + + return (; c, times) +end + +function verify(data, result) + @assert Array(result.c) ≈ Array(data.a) + Array(data.b) +end + + +#============================================================================= +# Reference implementations for benchmarking +=============================================================================# + +# Simple SIMT kernel for comparison +function simt_kernel(a, b, c, n) + i = (blockIdx().x - 1) * blockDim().x + threadIdx().x + if i <= n + @inbounds c[i] = a[i] + b[i] + end + return +end + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; a, b, c, shape) = data + results = Dict{String, Vector{Float64}}() + + if length(shape) == 1 + n = shape[1] + c_gpuarrays = similar(c) + c_simt = similar(c) + + # GPUArrays (broadcasting) + CUDA.@sync for _ in 1:warmup + c_gpuarrays .= a .+ b + end + times_gpuarrays = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed c_gpuarrays .= a .+ b + push!(times_gpuarrays, t * 1000) + end + results["GPUArrays"] = times_gpuarrays + + # SIMT kernel + threads = 256 + blocks = cld(n, threads) + CUDA.@sync for _ in 1:warmup + @cuda threads=threads blocks=blocks simt_kernel(a, b, c_simt, n) + end + times_simt = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed @cuda threads=threads blocks=blocks simt_kernel(a, b, c_simt, n) + push!(times_simt, t * 1000) + end + results["SIMT"] = times_simt + end + + return results +end + + +#============================================================================= +# Main +=============================================================================# + +function test_vadd(shape, tile; use_gather::Bool=false, T::DataType=Float32, name=nothing) + if name === nothing + if length(shape) == 2 + name = "2D vec_add ($(shape[1]) x $(shape[2]), $T, tile=$tile)" + elseif use_gather + name = "1D vec_add gather ($(shape[1]) elements, $T, tile=$tile)" + else + name = "1D vec_add ($(shape[1]) elements, $T, tile=$tile)" + end + end + println("--- $name ---") + data = prepare(; shape, use_gather, T) + result = run(data; tile) + verify(data, result) + println(" passed") end function main() println("--- cuTile Vector/Matrix Addition Examples ---\n") # 1D tests with Float32 - test_add_1d(Float32, 1_024_000, 1024) - test_add_1d(Float32, 2^20, 512) + test_vadd((1_024_000,), 1024) + test_vadd((2^20,), 512) # 1D tests with Float64 - test_add_1d(Float64, 2^18, 512) + test_vadd((2^18,), 512; T=Float64) # 1D tests with Float16 - test_add_1d(Float16, 1_024_000, 1024) + test_vadd((1_024_000,), 1024; T=Float16) # 2D tests with Float32 - test_add_2d(Float32, 2048, 1024, 32, 32) - test_add_2d(Float32, 1024, 2048, 64, 64) + test_vadd((2048, 1024), (32, 32)) + test_vadd((1024, 2048), (64, 64)) # 2D tests with Float64 - test_add_2d(Float64, 1024, 512, 32, 32) + test_vadd((1024, 512), (32, 32); T=Float64) # 2D tests with Float16 - test_add_2d(Float16, 1024, 1024, 64, 64) + test_vadd((1024, 1024), (64, 64); T=Float16) # 1D gather/scatter tests with Float32 # Uses explicit index-based memory access instead of tiled loads/stores - test_add_1d_gather(Float32, 1_024_000, 1024) - test_add_1d_gather(Float32, 2^20, 512) + test_vadd((1_024_000,), 1024; use_gather=true) + test_vadd((2^20,), 512; use_gather=true) println("\n--- All addition examples completed ---") end diff --git a/examples/vadd.py b/examples/vadd.py new file mode 100644 index 0000000..566ba94 --- /dev/null +++ b/examples/vadd.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Vector Addition example - cuTile Python +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct + +# 1D kernel +@ct.kernel +def vadd_kernel_1d(a, b, c, tile_size: ct.Constant[int]): + pid = ct.bid(0) + tile_a = ct.load(a, index=(pid,), shape=(tile_size,)) + tile_b = ct.load(b, index=(pid,), shape=(tile_size,)) + result = tile_a + tile_b + ct.store(c, index=(pid,), tile=result) + + +# 2D kernel +@ct.kernel +def vadd_kernel_2d(a, b, c, tile_x: ct.Constant[int], tile_y: ct.Constant[int]): + pid_x = ct.bid(0) + pid_y = ct.bid(1) + tile_a = ct.load(a, index=(pid_x, pid_y), shape=(tile_x, tile_y)) + tile_b = ct.load(b, index=(pid_x, pid_y), shape=(tile_x, tile_y)) + result = tile_a + tile_b + ct.store(c, index=(pid_x, pid_y), tile=result) + + +# 1D kernel with gather/scatter +@ct.kernel +def vadd_kernel_1d_gather(a, b, c, tile_size: ct.Constant[int]): + pid = ct.bid(0) + # Create index tile for this block's elements + offsets = ct.arange(tile_size, dtype=ct.int32) + base = pid * tile_size + indices = base + offsets + + # Gather, add, scatter + tile_a = ct.gather(a, indices) + tile_b = ct.gather(b, indices) + result = tile_a + tile_b + ct.scatter(c, indices, result) + + +#============================================================================= +# Example harness +#============================================================================= + +def prepare(*, benchmark: bool = False, shape: tuple = None, use_gather: bool = False, dtype=np.float32): + """Allocate and initialize data for vector addition.""" + if shape is None: + shape = (2**27,) if benchmark else (1_024_000,) + a = cp.random.rand(*shape).astype(dtype) + return { + "a": a, + "b": cp.random.rand(*shape).astype(dtype), + "c": cp.empty_like(a), + "shape": shape, + "use_gather": use_gather + } + + +def run(data, *, tile=1024, nruns: int = 1, warmup: int = 0): + """Run vector addition kernel with timing.""" + a, b, c = data["a"], data["b"], data["c"] + shape = data["shape"] + use_gather = data["use_gather"] + + stream = cp.cuda.get_current_stream() + + if len(shape) == 2: + # 2D case + m, n = shape + tile_x, tile_y = tile if isinstance(tile, tuple) else (tile, tile) + grid = (ct.cdiv(m, tile_x), ct.cdiv(n, tile_y), 1) + + def run_kernel(): + ct.launch(stream, grid, vadd_kernel_2d, (a, b, c, tile_x, tile_y)) + else: + # 1D case + n = shape[0] + tile_val = tile[0] if isinstance(tile, tuple) else tile + grid = (ct.cdiv(n, tile_val), 1, 1) + + if use_gather: + def run_kernel(): + ct.launch(stream, grid, vadd_kernel_1d_gather, (a, b, c, tile_val)) + else: + def run_kernel(): + ct.launch(stream, grid, vadd_kernel_1d, (a, b, c, tile_val)) + + # Warmup + for _ in range(warmup): + run_kernel() + cp.cuda.runtime.deviceSynchronize() + + # Timed runs + times = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + run_kernel() + end.record(stream) + end.synchronize() + times.append(cp.cuda.get_elapsed_time(start, end)) # ms + + return {"c": c, "times": times} + + +def verify(data, result): + """Verify vector addition results.""" + expected = cp.asnumpy(data["a"]) + cp.asnumpy(data["b"]) + assert np.allclose(cp.asnumpy(result["c"]), expected), "vadd incorrect!" + + +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + results = {} + shape = data["shape"] + + if len(shape) == 1: + a, b = data["a"], data["b"] + c_cupy = cp.zeros_like(a) + + stream = cp.cuda.get_current_stream() + + # CuPy (broadcasting) + for _ in range(warmup): + cp.add(a, b, out=c_cupy) + cp.cuda.runtime.deviceSynchronize() + + times_cupy = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + cp.add(a, b, out=c_cupy) + end.record(stream) + end.synchronize() + times_cupy.append(cp.cuda.get_elapsed_time(start, end)) + results["CuPy"] = times_cupy + + return results + + +#============================================================================= +# Main +#============================================================================= + +def test_vadd(shape, tile, use_gather=False, dtype=np.float32, name=None): + """Test vector addition with given parameters.""" + if name is None: + if len(shape) == 2: + name = f"2D vadd ({shape[0]}x{shape[1]}), tile={tile}, dtype={dtype.__name__}" + elif use_gather: + name = f"1D vadd gather size={shape[0]}, tile={tile}, dtype={dtype.__name__}" + else: + name = f"1D vadd size={shape[0]}, tile={tile}, dtype={dtype.__name__}" + print(f"--- {name} ---") + data = prepare(shape=shape, use_gather=use_gather, dtype=dtype) + result = run(data, tile=tile) + verify(data, result) + print(" passed") + + +def main(): + print("--- cuTile Vector Addition Examples ---\n") + + # 1D tests with float32 + test_vadd((1_024_000,), 1024) + test_vadd((2**20,), 512) + + # 1D tests with float64 + test_vadd((2**18,), 512, dtype=np.float64) + + # 1D tests with float16 + test_vadd((1_024_000,), 1024, dtype=np.float16) + + # 2D tests with float32 + test_vadd((2048, 1024), (32, 32)) + test_vadd((1024, 2048), (64, 64)) + + # 2D tests with float64 + test_vadd((1024, 512), (32, 32), dtype=np.float64) + + # 2D tests with float16 + test_vadd((1024, 1024), (64, 64), dtype=np.float16) + + # 1D gather/scatter tests + test_vadd((1_024_000,), 1024, use_gather=True) + test_vadd((2**20,), 512, use_gather=True) + + print("\n--- All vadd examples completed ---") + + +if __name__ == "__main__": + main()