From d12113f638e0c5d6809098ebf8a0aa3d9c9d0b08 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 13 Jan 2026 11:08:02 -0500 Subject: [PATCH 1/7] Make benchmarks reuse examples. --- .gitignore | 1 + examples/batchmatmul.jl | 52 +++--- examples/batchmatmul.py | 91 +++++++++++ examples/benchmarks.jl | 353 ++++------------------------------------ examples/benchmarks.py | 227 ++------------------------ examples/fft.jl | 3 +- examples/fft.py | 166 +++++++++++++++++++ examples/layernorm.jl | 41 +++++ examples/layernorm.py | 102 ++++++++++++ examples/matmul.jl | 42 ++--- examples/matmul.py | 98 +++++++++++ examples/transpose.jl | 25 ++- examples/transpose.py | 61 +++++++ examples/vadd.jl | 74 ++++++--- examples/vadd.py | 56 +++++++ 15 files changed, 775 insertions(+), 617 deletions(-) create mode 100644 examples/batchmatmul.py create mode 100644 examples/fft.py create mode 100644 examples/layernorm.py create mode 100644 examples/matmul.py create mode 100644 examples/transpose.py create mode 100644 examples/vadd.py diff --git a/.gitignore b/.gitignore index d2c7fe1..daf5565 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ LocalPreferences.toml CLAUDE.md AGENTS.md TODO.md +__pycache__ diff --git a/examples/batchmatmul.jl b/examples/batchmatmul.jl index 89b2ba4..0ac108d 100644 --- a/examples/batchmatmul.jl +++ b/examples/batchmatmul.jl @@ -57,37 +57,37 @@ function batch_matmul_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.T return nothing end -function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) where T - name = something(name, "batch_matmul ($M x $K x $Batch) @ ($K x $N x $Batch), $T, tiles=$tm x $tn x $tk") - println("--- $name ---") - - # Batch-last ordering for optimal column-major access - A = CUDA.rand(T, M, K, Batch) - B = CUDA.rand(T, K, N, Batch) - C = CUDA.zeros(T, M, N, Batch) - - # 3D grid: (M_tiles, N_tiles, Batch) +# Run batch matmul - for benchmarking or programmatic use +function run_batchmatmul(; M::Int, K::Int, N::Int, Batch::Int, tm::Int, tn::Int, tk::Int, + T::DataType=Float32, + A::Union{CuArray,Nothing}=nothing, + B::Union{CuArray,Nothing}=nothing, + C::Union{CuArray,Nothing}=nothing, + validate::Bool=false) + A = something(A, CUDA.rand(T, M, K, Batch)) + B = something(B, CUDA.rand(T, K, N, Batch)) + C = something(C, CUDA.zeros(T, M, N, Batch)) grid = (cld(M, tm), cld(N, tn), Batch) - - # Launch kernel ct.launch(batch_matmul_kernel, grid, A, B, C, ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) - - # Verify result - compute batched matmul on CPU - A_cpu = Array(A) - B_cpu = Array(B) - expected = similar(A_cpu, M, N, Batch) - for b in 1:Batch - expected[:, :, b] = A_cpu[:, :, b] * B_cpu[:, :, b] + if validate + A_cpu = Array(A) + B_cpu = Array(B) + expected = similar(A_cpu, M, N, Batch) + for b in 1:Batch + expected[:, :, b] = A_cpu[:, :, b] * B_cpu[:, :, b] + end + result = Array(C) + @assert isapprox(result, expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(result - expected)))" end - result = Array(C) + return (; A, B, C) +end - if isapprox(result, expected, rtol=1e-2, atol=1e-2) - println(" passed") - else - max_diff = maximum(abs.(result - expected)) - println(" FAILED (max diff: $max_diff)") - end +function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) where T + name = something(name, "batch_matmul ($M x $K x $Batch) @ ($K x $N x $Batch), $T, tiles=$tm x $tn x $tk") + println("--- $name ---") + run_batchmatmul(; M, K, N, Batch, tm, tn, tk, T, validate=true) + println(" passed") end function main() diff --git a/examples/batchmatmul.py b/examples/batchmatmul.py new file mode 100644 index 0000000..3723bd7 --- /dev/null +++ b/examples/batchmatmul.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +""" +Batch Matrix Multiplication example - cuTile Python +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct +from math import ceil + +@ct.kernel +def batchmatmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int]): + """CuTile kernel for batch matrix multiplication + A has shape (Batch, M, K), B has shape (Batch, K, N) and C has shape (Batch, M, N) + Grid: (Batch, M_tiles, N_tiles) + """ + pid_batch = ct.bid(0) + bidx = ct.bid(1) + bidy = ct.bid(2) + + num_k_tiles = ct.cdiv(A.shape[2], tk) + accumulator = ct.full((tm, tn), 0.0, dtype=ct.float32) + zero_pad = ct.PaddingMode.ZERO + + for k in range(num_k_tiles): + a = ct.load(A, index=(pid_batch, bidx, k), shape=(1, tm, tk), padding_mode=zero_pad) + a = ct.reshape(a, (tm, tk)) + + b = ct.load(B, index=(pid_batch, k, bidy), shape=(1, tk, tn), padding_mode=zero_pad) + b = ct.reshape(b, (tk, tn)) + + accumulator = ct.mma(a, b, acc=accumulator) + + result = ct.astype(accumulator, C.dtype) + result_3d = ct.reshape(result, (1, tm, tn)) + ct.store(C, index=(pid_batch, bidx, bidy), tile=result_3d) + + +def run_batchmatmul(*, Batch: int = 4, M: int = 256, K: int = 128, N: int = 256, + tm: int = 128, tn: int = 256, tk: int = 64, + dtype=np.float16, A=None, B=None, C=None, validate: bool = False): + """Run batch matrix multiplication. Returns (A, B, C) arrays for benchmarking.""" + if A is None: + A = cp.random.randn(Batch, M, K).astype(dtype) + else: + Batch, M, K = A.shape + if B is None: + B = cp.random.randn(Batch, K, N).astype(dtype) + else: + Batch, K, N = B.shape + if C is None: + C = cp.zeros((Batch, M, N), dtype=dtype) + + grid = (Batch, ceil(M / tm), ceil(N / tn)) + stream = cp.cuda.get_current_stream() + ct.launch(stream, grid, batchmatmul_cutile_kernel, (A, B, C, tm, tn, tk)) + + if validate: + cp.cuda.runtime.deviceSynchronize() + A_np = cp.asnumpy(A).astype(np.float32) + B_np = cp.asnumpy(B).astype(np.float32) + C_np = cp.asnumpy(C).astype(np.float32) + expected = np.zeros((Batch, M, N), dtype=np.float32) + for b in range(Batch): + expected[b] = A_np[b] @ B_np[b] + assert np.allclose(C_np, expected, rtol=1e-1, atol=1e-1), \ + f"batchmatmul incorrect! max diff: {np.max(np.abs(C_np - expected))}" + + return A, B, C + + +def test_batchmatmul(Batch, M, K, N, tm, tn, tk, dtype=np.float16, name=None): + """Test batch matmul with given parameters.""" + name = name or f"batchmatmul ({Batch}x{M}x{K}) @ ({Batch}x{K}x{N}), tiles={tm}x{tn}x{tk}, dtype={dtype.__name__}" + print(f"--- {name} ---") + run_batchmatmul(Batch=Batch, M=M, K=K, N=N, tm=tm, tn=tn, tk=tk, dtype=dtype, validate=True) + print(" passed") + + +def main(): + print("--- cuTile Batch Matrix Multiplication Examples ---\n") + + test_batchmatmul(4, 256, 128, 256, 32, 32, 32, np.float32) + test_batchmatmul(4, 512, 256, 512, 64, 64, 64, np.float32) + test_batchmatmul(4, 512, 256, 1024, 128, 256, 64, np.float16) + + print("\n--- All batchmatmul examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/benchmarks.jl b/examples/benchmarks.jl index 95071e8..2a7854c 100644 --- a/examples/benchmarks.jl +++ b/examples/benchmarks.jl @@ -4,11 +4,16 @@ # Compares: GPUArrays (generic), SIMT (CUDA.jl), cuTile # Kernels: vadd, transpose, matmul -using CUDA +# Include example files to reuse their kernels +include("vadd.jl") +include("transpose.jl") +include("matmul.jl") +include("batchmatmul.jl") +include("layernorm.jl") +include("fft.jl") + using LinearAlgebra using CUDA: GPUArrays -using FFTW -import cuTile as ct #============================================================================= Configuration @@ -86,7 +91,7 @@ end Vector Addition =============================================================================# -# SIMT kernel +# SIMT kernel (benchmark-specific) function vadd_simt_kernel!(a, b, c) i = (blockIdx().x - 1) * blockDim().x + threadIdx().x if i <= length(c) @@ -95,15 +100,7 @@ function vadd_simt_kernel!(a, b, c) return end -# cuTile kernel -function vadd_cutile_kernel(a, b, c, tile_size::ct.Constant{Int}) - pid = ct.bid(1) - tile_a = ct.load(a, pid, (tile_size[],)) - tile_b = ct.load(b, pid, (tile_size[],)) - result = tile_a + tile_b - ct.store(c, pid, result) - return -end +# cuTile kernel: use vec_add_kernel_1d from vadd.jl function benchmark_vadd() println("\nBenchmarking Vector Addition...") @@ -136,9 +133,9 @@ function benchmark_vadd() min_t, mean_t = benchmark_kernel(simt_f) push!(results, BenchmarkResult("SIMT (CUDA.jl)", min_t, mean_t)) - # cuTile + # cuTile (uses vec_add_kernel_1d from vadd.jl) grid = (cld(VADD_SIZE, VADD_TILE), 1, 1) - cutile_f = () -> ct.launch(vadd_cutile_kernel, grid, a, b, c, ct.Constant(VADD_TILE)) + cutile_f = () -> ct.launch(vec_add_kernel_1d, grid, a, b, c, ct.Constant(VADD_TILE)) cutile_f() CUDA.synchronize() @assert Array(c) ≈ expected "cuTile incorrect!" @@ -167,7 +164,7 @@ function transpose_simt_naive_kernel!(input, output, M, N) return end -# SIMT shared memory kernel +# SIMT shared memory kernel (benchmark-specific) function transpose_simt_shared_kernel!(input, output, M, N) TILE = 32 tile = CuStaticSharedArray(Float32, (TILE+1, TILE)) @@ -189,15 +186,7 @@ function transpose_simt_shared_kernel!(input, output, M, N) return end -# cuTile kernel -function transpose_cutile_kernel(input, output, tile_m::ct.Constant{Int}, tile_n::ct.Constant{Int}) - pid_m = ct.bid(1) - pid_n = ct.bid(2) - tile = ct.load(input, (pid_m, pid_n), (tile_m[], tile_n[])) - tile_t = ct.transpose(tile) - ct.store(output, (pid_n, pid_m), tile_t) - return -end +# cuTile kernel: use transpose_kernel from transpose.jl function benchmark_transpose() println("\nBenchmarking Matrix Transpose...") @@ -240,10 +229,10 @@ function benchmark_transpose() min_t, mean_t = benchmark_kernel(simt_shared_f) push!(results, BenchmarkResult("SIMT shared", min_t, mean_t)) - # cuTile + # cuTile (uses transpose_kernel from transpose.jl) fill!(output, 0) grid = (cld(M, TRANSPOSE_TILE_M), cld(N, TRANSPOSE_TILE_N), 1) - cutile_f = () -> ct.launch(transpose_cutile_kernel, grid, input, output, + cutile_f = () -> ct.launch(transpose_kernel, grid, input, output, ct.Constant(TRANSPOSE_TILE_M), ct.Constant(TRANSPOSE_TILE_N)) cutile_f() CUDA.synchronize() @@ -263,47 +252,7 @@ end Matrix Multiplication =============================================================================# -# 2D swizzle for better L2 cache locality (using 0-indexed block IDs) -@inline function swizzle_2d(M, N, tm, tn, GROUP_SIZE_M, bid) - num_bid_m = cld(M, Int32(tm)) - num_bid_n = cld(N, Int32(tn)) - num_bid_in_group = Int32(GROUP_SIZE_M) * num_bid_n - group_id = fld(bid, num_bid_in_group) - first_bid_m = group_id * Int32(GROUP_SIZE_M) - group_size_m = min(num_bid_m - first_bid_m, Int32(GROUP_SIZE_M)) - bid_m = first_bid_m + rem(bid, group_size_m) - bid_n = fld(rem(bid, num_bid_in_group), group_size_m) - return bid_m, bid_n -end - -# cuTile matmul kernel with TF32 tensor cores -function matmul_cutile_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArray{T,2}, - tm::ct.Constant{Int}, tn::ct.Constant{Int}, tk::ct.Constant{Int}) where {T} - bid = ct.bid(1) - M = A.sizes[1] - N = B.sizes[2] - bid_m_0, bid_n_0 = swizzle_2d(M, N, tm[], tn[], 8, bid - Int32(1)) - bid_m = bid_m_0 + Int32(1) - bid_n = bid_n_0 + Int32(1) - - num_k = ct.num_tiles(A, 2, (tm[], tk[])) - acc = ct.full((tm[], tn[]), zero(Float32), Float32) - - # Use TF32 for tensor cores - dtype = T === Float32 ? ct.TFloat32 : T - - k = Int32(1) - while k <= num_k - a = ct.astype(ct.load(A, (bid_m, k), (tm[], tk[])), dtype) - b = ct.astype(ct.load(B, (k, bid_n), (tk[], tn[])), dtype) - acc = muladd(a, b, acc) - k += Int32(1) - end - - result = ct.astype(acc, T) - ct.store(C, (bid_m, bid_n), result) - return nothing -end +# cuTile kernel: use matmul_kernel and swizzle_2d from matmul.jl function benchmark_matmul() println("\nBenchmarking Matrix Multiplication...") @@ -338,12 +287,12 @@ function benchmark_matmul() min_t, mean_t = benchmark_kernel(cublas_f) push!(results, BenchmarkResult("cuBLAS", min_t, mean_t)) - # cuTile + # cuTile (uses matmul_kernel from matmul.jl) fill!(C, 0) grid_m = cld(M, MATMUL_TM) grid_n = cld(N, MATMUL_TN) grid = (grid_m * grid_n, 1, 1) - cutile_f = () -> ct.launch(matmul_cutile_kernel, grid, A, B, C, + cutile_f = () -> ct.launch(matmul_kernel, grid, A, B, C, ct.Constant(MATMUL_TM), ct.Constant(MATMUL_TN), ct.Constant(MATMUL_TK)) cutile_f() CUDA.synchronize() @@ -384,7 +333,7 @@ const FFT_SIZE = 512 const FFT_FACTORS = (8, 8, 8) const FFT_ATOM_PACKING_DIM = 2 -# SIMT naive kernel (2-pass: compute mean/var, then normalize) +# SIMT naive kernel (benchmark-specific, 2-pass: compute mean/var, then normalize) function layernorm_simt_kernel!(X, W, B, Y, Mean, Rstd, N, eps) m = blockIdx().x @@ -414,55 +363,7 @@ function layernorm_simt_kernel!(X, W, B, Y, Mean, Rstd, N, eps) return end -# cuTile kernel (from layernorm.jl) -function layernorm_cutile_kernel(X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1}, - B::ct.TileArray{Float32, 1}, Y::ct.TileArray{Float32, 2}, - Mean::ct.TileArray{Float32, 1}, Rstd::ct.TileArray{Float32, 1}, - eps::ct.Constant{Float32}, TILE_N::ct.Constant{Int}) - bid_m = ct.bid(1) - num_tiles = ct.num_tiles(X, 2, (1, TILE_N[])) - N = X.sizes[2] - - # Compute mean - mean = ct.full((1, TILE_N[]), 0.0f0, Float32) - j = Int32(1) - while j <= num_tiles - tx = ct.load(X, (bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - mean = mean .+ tx - j += Int32(1) - end - mean = ct.reduce_sum(mean, 2) / N - ct.store(Mean, bid_m, mean) - - # Compute variance - var = ct.full((1, TILE_N[]), 0.0f0, Float32) - j = Int32(1) - while j <= num_tiles - tx = ct.load(X, (bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - # Mask for valid elements - mask = ct.broadcast_to(((j - Int32(1)) * Int32(TILE_N[]) .+ ct.arange((TILE_N[],), Int32)) .<= N, (1, TILE_N[])) - centered_tx = ct.where(mask, tx .- mean, ct.full((1, TILE_N[]), 0.0f0, Float32)) - var = var .+ (centered_tx .^ 2.0f0) - j += Int32(1) - end - var = ct.reduce_sum(var, 2) / N - rstd = 1.0f0 ./ sqrt.(var .+ eps[]) - ct.store(Rstd, bid_m, rstd) - - # Normalize and apply affine transformation - j = Int32(1) - while j <= num_tiles - tx = ct.load(X, (bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero) - tw = ct.load(W, j, (TILE_N[],); padding_mode=ct.PaddingMode.Zero) - tb = ct.load(B, j, (TILE_N[],); padding_mode=ct.PaddingMode.Zero) - ty = (tx .- mean) .* rstd - ty = ty .* tw .+ tb - ct.store(Y, (bid_m, j), ty) - j += Int32(1) - end - - return -end +# cuTile kernel: use layer_norm_fwd from layernorm.jl function benchmark_layernorm() println("\nBenchmarking Layer Normalization...") @@ -497,9 +398,9 @@ function benchmark_layernorm() min_t, mean_t = benchmark_kernel(simt_f) push!(results, BenchmarkResult("SIMT naive", min_t, mean_t)) - # cuTile + # cuTile (uses layer_norm_fwd from layernorm.jl) fill!(Y, 0); fill!(Mean, 0); fill!(Rstd, 0) - cutile_f = () -> ct.launch(layernorm_cutile_kernel, M, X, W, B, Y, Mean, Rstd, + cutile_f = () -> ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, ct.Constant(LAYERNORM_EPS), ct.Constant(LAYERNORM_TILE_N)) cutile_f() CUDA.synchronize() @@ -519,45 +420,7 @@ end Batch Matrix Multiplication =============================================================================# -# Batch matmul kernel (3D arrays with batch-last ordering) -# A: (M, K, Batch), B: (K, N, Batch), C: (M, N, Batch) -function batchmatmul_cutile_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.TileArray{T,3}, - tm::ct.Constant{Int}, tn::ct.Constant{Int}, - tk::ct.Constant{Int}) where {T} - bid_m = ct.bid(1) - bid_n = ct.bid(2) - pid_batch = ct.bid(3) - - K = A.sizes[2] - num_k = cld(K, Int32(tk[])) - - acc = ct.full((tm[], tn[]), zero(Float32), Float32) - - k = Int32(1) - while k <= num_k - a = ct.load(A, (bid_m, k, pid_batch), (tm[], tk[], 1); - padding_mode=ct.PaddingMode.Zero) - b = ct.load(B, (k, bid_n, pid_batch), (tk[], tn[], 1); - padding_mode=ct.PaddingMode.Zero) - - a_2d = ct.reshape(a, (tm[], tk[])) - b_2d = ct.reshape(b, (tk[], tn[])) - - if T === Float32 - a_2d = convert(ct.Tile{ct.TFloat32}, a_2d) - b_2d = convert(ct.Tile{ct.TFloat32}, b_2d) - end - - acc = muladd(a_2d, b_2d, acc) - k += Int32(1) - end - - result = convert(ct.Tile{T}, acc) - result_3d = ct.reshape(result, (tm[], tn[], 1)) - ct.store(C, (bid_m, bid_n, pid_batch), result_3d) - - return nothing -end +# cuTile kernel: use batch_matmul_kernel from batchmatmul.jl function benchmark_batchmatmul() println("\nBenchmarking Batch Matrix Multiplication...") @@ -593,10 +456,10 @@ function benchmark_batchmatmul() min_t, mean_t = benchmark_kernel(cublas_f) push!(results, BenchmarkResult("cuBLAS (loop)", min_t, mean_t)) - # cuTile + # cuTile (uses batch_matmul_kernel from batchmatmul.jl) fill!(C, 0) grid = (cld(M, BATCHMATMUL_TM), cld(N, BATCHMATMUL_TN), Batch) - cutile_f = () -> ct.launch(batchmatmul_cutile_kernel, grid, A, B, C, + cutile_f = () -> ct.launch(batch_matmul_kernel, grid, A, B, C, ct.Constant(BATCHMATMUL_TM), ct.Constant(BATCHMATMUL_TN), ct.Constant(BATCHMATMUL_TK)) cutile_f() @@ -616,165 +479,7 @@ end FFT (3-stage Cooley-Tukey) - Column-Major Version =============================================================================# -# FFT kernel - 3-stage Cooley-Tukey decomposition (column-major) -# Uses swapped dimensions and right-multiply for column-major compatibility. -# Input/output layout: (D, BS, N2D) where D=2 for real/imag interleaving. -function fft_kernel( - x_packed_in::ct.TileArray{Float32, 3}, - y_packed_out::ct.TileArray{Float32, 3}, - W0::ct.TileArray{Float32, 3}, - W1::ct.TileArray{Float32, 3}, - W2::ct.TileArray{Float32, 3}, - T0::ct.TileArray{Float32, 3}, - T1::ct.TileArray{Float32, 3}, - n_const::ct.Constant{Int}, - f0_const::ct.Constant{Int}, - f1_const::ct.Constant{Int}, - f2_const::ct.Constant{Int}, - f0f1_const::ct.Constant{Int}, - f1f2_const::ct.Constant{Int}, - f0f2_const::ct.Constant{Int}, - bs_const::ct.Constant{Int}, - d_const::ct.Constant{Int}, - n2d_const::ct.Constant{Int} -) - N = n_const[] - F0 = f0_const[] - F1 = f1_const[] - F2 = f2_const[] - F0F1 = f0f1_const[] - F1F2 = f1f2_const[] - F0F2 = f0f2_const[] - BS = bs_const[] - D = d_const[] - N2D = n2d_const[] - - bid = ct.bid(1) - - # Load input (D, BS, N2D) and reshape to (2, BS, N) - X_ri = ct.reshape(ct.load(x_packed_in, (1, bid, 1), (D, BS, N2D)), (2, BS, N)) - X_r = ct.reshape(ct.extract(X_ri, (1, 1, 1), (1, BS, N)), (BS, F1F2, F0)) - X_i = ct.reshape(ct.extract(X_ri, (2, 1, 1), (1, BS, N)), (BS, F1F2, F0)) - - # Load DFT matrices - W0_ri = ct.reshape(ct.load(W0, (1, 1, 1), (F0, F0, 2)), (F0, F0, 2)) - W0_r = ct.broadcast_to(ct.reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0)) - W0_i = ct.broadcast_to(ct.reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0)) - - W1_ri = ct.reshape(ct.load(W1, (1, 1, 1), (F1, F1, 2)), (F1, F1, 2)) - W1_r = ct.broadcast_to(ct.reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1)) - W1_i = ct.broadcast_to(ct.reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1)) - - W2_ri = ct.reshape(ct.load(W2, (1, 1, 1), (F2, F2, 2)), (F2, F2, 2)) - W2_r = ct.broadcast_to(ct.reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2)) - W2_i = ct.broadcast_to(ct.reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2)) - - # Load twiddle factors (column-major layout) - T0_ri = ct.reshape(ct.load(T0, (1, 1, 1), (F1F2, F0, 2)), (F1F2, F0, 2)) - T0_r = ct.reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (1, N)) - T0_i = ct.reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (1, N)) - - T1_ri = ct.reshape(ct.load(T1, (1, 1, 1), (F0F2, F1, 2)), (F0F2, F1, 2)) - T1_r = ct.reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (1, F0F2 * F1)) - T1_i = ct.reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (1, F0F2 * F1)) - - # Stage 0: F0-point DFT via right-multiply - X_r_ = X_r * W0_r - X_i * W0_i - X_i_ = X_r * W0_i + X_i * W0_r - - # Twiddle & Permute 0 - X_r_flat = ct.reshape(X_r_, (BS, N)) - X_i_flat = ct.reshape(X_i_, (BS, N)) - X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat - X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat - - X_r3 = ct.reshape(X_r2, (BS, F2, F1, F0)) - X_i3 = ct.reshape(X_i2, (BS, F2, F1, F0)) - X_r4 = ct.permute(X_r3, (1, 2, 4, 3)) - X_i4 = ct.permute(X_i3, (1, 2, 4, 3)) - X_r5 = ct.reshape(X_r4, (BS, F0F2, F1)) - X_i5 = ct.reshape(X_i4, (BS, F0F2, F1)) - - # Stage 1: F1-point DFT - X_r6 = X_r5 * W1_r - X_i5 * W1_i - X_i6 = X_r5 * W1_i + X_i5 * W1_r - - # Twiddle & Permute 1 - X_r_flat2 = ct.reshape(X_r6, (BS, N)) - X_i_flat2 = ct.reshape(X_i6, (BS, N)) - X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2 - X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2 - - X_r8 = ct.reshape(X_r7, (BS, F2, F0, F1)) - X_i8 = ct.reshape(X_i7, (BS, F2, F0, F1)) - X_r9 = ct.permute(X_r8, (1, 3, 4, 2)) - X_i9 = ct.permute(X_i8, (1, 3, 4, 2)) - X_r10 = ct.reshape(X_r9, (BS, F0F1, F2)) - X_i10 = ct.reshape(X_i9, (BS, F0F1, F2)) - - # Stage 2: F2-point DFT - X_r11 = X_r10 * W2_r - X_i10 * W2_i - X_i11 = X_r10 * W2_i + X_i10 * W2_r - - # Final output - X_r_final = ct.reshape(X_r11, (1, BS, N)) - X_i_final = ct.reshape(X_i11, (1, BS, N)) - - # Concatenate and Store - Y_ri = ct.reshape(ct.cat((X_r_final, X_i_final), 1), (D, BS, N2D)) - ct.store(y_packed_out, (1, bid, 1), Y_ri) - - return -end - -# Helper: Generate DFT matrix -function fft_dft_matrix(size::Int) - W = zeros(ComplexF32, size, size) - for i in 0:size-1, j in 0:size-1 - W[i+1, j+1] = exp(-2π * im * i * j / size) - end - result = zeros(Float32, size, size, 2) - result[:, :, 1] = Float32.(real.(W)) - result[:, :, 2] = Float32.(imag.(W)) - return result -end - -# Twiddle factors T0 for column-major layout (F1F2, F0) -function fft_make_twiddles_T0(F0::Int, F1F2::Int, N::Int) - T0 = zeros(Float32, F1F2, F0, 2) - for j in 0:F1F2-1, i in 0:F0-1 - val = exp(-2π * im * i * j / N) - T0[j+1, i+1, 1] = Float32(real(val)) - T0[j+1, i+1, 2] = Float32(imag(val)) - end - return T0 -end - -# Twiddle factors T1 for column-major layout (F0F2, F1) -function fft_make_twiddles_T1(F0::Int, F1::Int, F2::Int) - F0F2 = F0 * F2 - F1F2 = F1 * F2 - T1 = zeros(Float32, F0F2, F1, 2) - for k in 0:F0F2-1, j in 0:F1-1 - f2 = k % F2 - val = exp(-2π * im * j * f2 / F1F2) - T1[k+1, j+1, 1] = Float32(real(val)) - T1[k+1, j+1, 2] = Float32(imag(val)) - end - return T1 -end - -function fft_make_twiddles(factors::NTuple{3, Int}) - F0, F1, F2 = factors - N = F0 * F1 * F2 - F1F2 = F1 * F2 - W0 = fft_dft_matrix(F0) - W1 = fft_dft_matrix(F1) - W2 = fft_dft_matrix(F2) - T0 = fft_make_twiddles_T0(F0, F1F2, N) - T1 = fft_make_twiddles_T1(F0, F1, F2) - return (W0, W1, W2, T0, T1) -end +# cuTile kernel: use fft_kernel and make_twiddles from fft.jl function benchmark_fft() println("\nBenchmarking FFT...") @@ -792,8 +497,8 @@ function benchmark_fft() results = BenchmarkResult[] - # Pre-compute twiddles (one-time CPU cost) - W0, W1, W2, T0, T1 = fft_make_twiddles(FFT_FACTORS) + # Pre-compute twiddles (one-time CPU cost, uses make_twiddles from fft.jl) + W0, W1, W2, T0, T1 = make_twiddles(FFT_FACTORS) W0_gpu, W1_gpu, W2_gpu = CuArray(W0), CuArray(W1), CuArray(W2) T0_gpu, T1_gpu = CuArray(T0), CuArray(T1) diff --git a/examples/benchmarks.py b/examples/benchmarks.py index fa8ac43..482fd18 100644 --- a/examples/benchmarks.py +++ b/examples/benchmarks.py @@ -12,6 +12,14 @@ import math from math import ceil, log2 +# Import kernels from example files +from vadd import vadd_cutile_kernel +from transpose import transpose_cutile_kernel +from matmul import matmul_cutile_kernel, swizzle_2d +from layernorm import layernorm_cutile_kernel +from batchmatmul import batchmatmul_cutile_kernel +from fft import fft_kernel, fft_make_twiddles + #============================================================================= # Configuration #============================================================================= @@ -123,14 +131,7 @@ def print_table(title: str, results: list, extra_col=None): # Vector Addition #============================================================================= -@ct.kernel -def vadd_cutile_kernel(a, b, c, tile_size: ct.Constant[int]): - pid = ct.bid(0) - tile_a = ct.load(a, index=(pid,), shape=(tile_size,)) - tile_b = ct.load(b, index=(pid,), shape=(tile_size,)) - result = tile_a + tile_b - ct.store(c, index=(pid,), tile=result) - +# cuTile kernel: use vadd_cutile_kernel from vadd.py def benchmark_vadd(): print("\nBenchmarking Vector Addition...") @@ -195,14 +196,7 @@ def cutile_vadd(): # Matrix Transpose #============================================================================= -@ct.kernel -def transpose_cutile_kernel(input, output, tile_m: ct.Constant[int], tile_n: ct.Constant[int]): - pid_m = ct.bid(0) - pid_n = ct.bid(1) - tile = ct.load(input, index=(pid_m, pid_n), shape=(tile_m, tile_n)) - tile_t = ct.transpose(tile) - ct.store(output, index=(pid_n, pid_m), tile=tile_t) - +# cuTile kernel: use transpose_cutile_kernel from transpose.py def benchmark_transpose(): print("\nBenchmarking Matrix Transpose...") @@ -268,41 +262,7 @@ def cutile_transpose(): # Matrix Multiplication #============================================================================= -def swizzle_2d(M, N, tm, tn, GROUP_SIZE_M): - """Get the global IDs of the current CUDA block in a 1D grid.""" - bid = ct.bid(0) - num_bid_m = ct.cdiv(M, tm) - num_bid_n = ct.cdiv(N, tn) - num_bid_in_group = GROUP_SIZE_M * num_bid_n - group_id = bid // num_bid_in_group - first_bid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_bid_m - first_bid_m, GROUP_SIZE_M) - bid_m = first_bid_m + (bid % group_size_m) - bid_n = (bid % num_bid_in_group) // group_size_m - return bid_m, bid_n - - -@ct.kernel -def matmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int]): - GROUP_SIZE_M = 8 - M = A.shape[0] - N = B.shape[1] - bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M) - - num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk)) - accumulator = ct.full((tm, tn), 0, dtype=ct.float32) - - # Convert fp32 to tf32 for tensor cores - dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype - - for k in range(num_tiles_k): - a = ct.load(A, index=(bidx, k), shape=(tm, tk)).astype(dtype) - b = ct.load(B, index=(k, bidy), shape=(tk, tn)).astype(dtype) - accumulator = ct.mma(a, b, accumulator) - - accumulator = ct.astype(accumulator, C.dtype) - ct.store(C, index=(bidx, bidy), tile=accumulator) - +# cuTile kernel: use matmul_cutile_kernel and swizzle_2d from matmul.py def benchmark_matmul(): print("\nBenchmarking Matrix Multiplication...") @@ -392,41 +352,7 @@ def cutile_matmul(): BATCHMATMUL_TN = 256 BATCHMATMUL_TK = 64 - -@ct.kernel -def layernorm_cutile_kernel(X, W, B, Y, Mean, Rstd, eps: ct.Constant[float], TILE_N: ct.Constant[int]): - bid_m = ct.bid(0) - num_tiles = ct.num_tiles(X, axis=1, shape=(1, TILE_N)) - N = X.shape[1] - - # Compute mean - mean = ct.full((1, TILE_N), 0, dtype=ct.float32) - for j in range(num_tiles): - tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) - mean += tx - mean = ct.sum(mean, axis=1) / N - ct.store(Mean, index=(bid_m,), tile=mean) - - # Compute variance - var = ct.full((1, TILE_N), 0, dtype=ct.float32) - for j in range(num_tiles): - tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) - mask = (j * TILE_N + ct.arange(TILE_N, dtype=ct.int32)) < N - centered_tx = ct.where(mask, tx - mean, 0) - var += centered_tx ** 2 - var = ct.sum(var, axis=1) / N - rstd = 1 / ct.sqrt(var + eps) - ct.store(Rstd, index=(bid_m,), tile=rstd) - - # Normalize and apply affine transformation - for j in range(num_tiles): - tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) - tw = ct.load(W, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) - tb = ct.load(B, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) - ty = (tx - mean) * rstd - ty = ty * tw + tb - ct.store(Y, index=(bid_m, j), tile=ty.astype(Y.dtype)) - +# cuTile kernel: use layernorm_cutile_kernel from layernorm.py def benchmark_layernorm(): print("\nBenchmarking Layer Normalization...") @@ -498,33 +424,7 @@ def cutile_layernorm(): # Batch Matrix Multiplication #============================================================================= -@ct.kernel -def batchmatmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int]): - """CuTile kernel for batch matrix multiplication - A has shape (Batch, M, K), B has shape (Batch, K, N) and C has shape (Batch, M, N) - Grid: (Batch, M_tiles, N_tiles) - """ - pid_batch = ct.bid(0) - bidx = ct.bid(1) - bidy = ct.bid(2) - - num_k_tiles = ct.cdiv(A.shape[2], tk) - accumulator = ct.full((tm, tn), 0.0, dtype=ct.float32) - zero_pad = ct.PaddingMode.ZERO - - for k in range(num_k_tiles): - a = ct.load(A, index=(pid_batch, bidx, k), shape=(1, tm, tk), padding_mode=zero_pad) - a = ct.reshape(a, (tm, tk)) - - b = ct.load(B, index=(pid_batch, k, bidy), shape=(1, tk, tn), padding_mode=zero_pad) - b = ct.reshape(b, (tk, tn)) - - accumulator = ct.mma(a, b, acc=accumulator) - - result = ct.astype(accumulator, C.dtype) - result_3d = ct.reshape(result, (1, tm, tn)) - ct.store(C, index=(pid_batch, bidx, bidy), tile=result_3d) - +# cuTile kernel: use batchmatmul_cutile_kernel from batchmatmul.py def benchmark_batchmatmul(): print("\nBenchmarking Batch Matrix Multiplication...") @@ -583,106 +483,7 @@ def cutile_bmm(): # FFT (3-stage Cooley-Tukey) #============================================================================= -@ct.kernel -def fft_kernel(x_packed_in, y_packed_out, - W0, W1, W2, T0, T1, - N: ct.Constant[int], F0: ct.Constant[int], F1: ct.Constant[int], F2: ct.Constant[int], - BS: ct.Constant[int], D: ct.Constant[int]): - """cuTile kernel for 3-stage Cooley-Tukey FFT.""" - F0F1 = F0 * F1 - F1F2 = F1 * F2 - F0F2 = F0 * F2 - - bid = ct.bid(0) - - # Load input, reshape to separate real/imag - X_ri = ct.reshape(ct.load(x_packed_in, index=(bid, 0, 0), shape=(BS, N * 2 // D, D)), (BS, N, 2)) - X_r = ct.reshape(ct.extract(X_ri, index=(0, 0, 0), shape=(BS, N, 1)), (BS, F0, F1, F2)) - X_i = ct.reshape(ct.extract(X_ri, index=(0, 0, 1), shape=(BS, N, 1)), (BS, F0, F1, F2)) - - # Load W matrices (rotation matrices) - W0_ri = ct.reshape(ct.load(W0, index=(0, 0, 0), shape=(F0, F0, 2)), (F0, F0, 2)) - W0_r = ct.reshape(ct.extract(W0_ri, index=(0, 0, 0), shape=(F0, F0, 1)), (1, F0, F0)) - W0_i = ct.reshape(ct.extract(W0_ri, index=(0, 0, 1), shape=(F0, F0, 1)), (1, F0, F0)) - - W1_ri = ct.reshape(ct.load(W1, index=(0, 0, 0), shape=(F1, F1, 2)), (F1, F1, 2)) - W1_r = ct.reshape(ct.extract(W1_ri, index=(0, 0, 0), shape=(F1, F1, 1)), (1, F1, F1)) - W1_i = ct.reshape(ct.extract(W1_ri, index=(0, 0, 1), shape=(F1, F1, 1)), (1, F1, F1)) - - W2_ri = ct.reshape(ct.load(W2, index=(0, 0, 0), shape=(F2, F2, 2)), (F2, F2, 2)) - W2_r = ct.reshape(ct.extract(W2_ri, index=(0, 0, 0), shape=(F2, F2, 1)), (1, F2, F2)) - W2_i = ct.reshape(ct.extract(W2_ri, index=(0, 0, 1), shape=(F2, F2, 1)), (1, F2, F2)) - - # Load T matrices (twiddle factors) - T0_ri = ct.reshape(ct.load(T0, index=(0, 0, 0), shape=(F0, F1F2, 2)), (F0, F1F2, 2)) - T0_r = ct.reshape(ct.extract(T0_ri, index=(0, 0, 0), shape=(F0, F1F2, 1)), (N, 1)) - T0_i = ct.reshape(ct.extract(T0_ri, index=(0, 0, 1), shape=(F0, F1F2, 1)), (N, 1)) - - T1_ri = ct.reshape(ct.load(T1, index=(0, 0, 0), shape=(F1, F2, 2)), (F1, F2, 2)) - T1_r = ct.reshape(ct.extract(T1_ri, index=(0, 0, 0), shape=(F1, F2, 1)), (F1F2, 1)) - T1_i = ct.reshape(ct.extract(T1_ri, index=(0, 0, 1), shape=(F1, F2, 1)), (F1F2, 1)) - - # CT0: Contract over F0 dimension - X_r = ct.reshape(X_r, (BS, F0, F1F2)) - X_i = ct.reshape(X_i, (BS, F0, F1F2)) - X_r_ = ct.reshape(ct.matmul(W0_r, X_r) - ct.matmul(W0_i, X_i), (BS, N, 1)) - X_i_ = ct.reshape(ct.matmul(W0_i, X_r) + ct.matmul(W0_r, X_i), (BS, N, 1)) - - # Twiddle & Permute 0 - X_r = T0_r * X_r_ - T0_i * X_i_ - X_i = T0_i * X_r_ + T0_r * X_i_ - X_r = ct.permute(ct.reshape(X_r, (BS, F0, F1, F2)), (0, 2, 3, 1)) - X_i = ct.permute(ct.reshape(X_i, (BS, F0, F1, F2)), (0, 2, 3, 1)) - - # CT1: Contract over F1 dimension - X_r = ct.reshape(X_r, (BS, F1, F0F2)) - X_i = ct.reshape(X_i, (BS, F1, F0F2)) - X_r_ = ct.reshape(ct.matmul(W1_r, X_r) - ct.matmul(W1_i, X_i), (BS, F1F2, F0)) - X_i_ = ct.reshape(ct.matmul(W1_i, X_r) + ct.matmul(W1_r, X_i), (BS, F1F2, F0)) - - # Twiddle & Permute 1 - X_r = T1_r * X_r_ - T1_i * X_i_ - X_i = T1_i * X_r_ + T1_r * X_i_ - X_r = ct.permute(ct.reshape(X_r, (BS, F1, F2, F0)), (0, 2, 3, 1)) - X_i = ct.permute(ct.reshape(X_i, (BS, F1, F2, F0)), (0, 2, 3, 1)) - - # CT2: Contract over F2 dimension - X_r = ct.reshape(X_r, (BS, F2, F0F1)) - X_i = ct.reshape(X_i, (BS, F2, F0F1)) - X_r_ = ct.matmul(W2_r, X_r) - ct.matmul(W2_i, X_i) - X_i_ = ct.matmul(W2_i, X_r) + ct.matmul(W2_r, X_i) - - # Final Permutation - X_r = ct.permute(ct.reshape(X_r_, (BS, F2, F0, F1)), (0, 1, 3, 2)) - X_i = ct.permute(ct.reshape(X_i_, (BS, F2, F0, F1)), (0, 1, 3, 2)) - X_r = ct.reshape(X_r, (BS, N, 1)) - X_i = ct.reshape(X_i, (BS, N, 1)) - - # Concatenate and Store - Y_ri = ct.reshape(ct.cat((X_r, X_i), axis=-1), (BS, N * 2 // D, D)) - ct.store(y_packed_out, index=(bid, 0, 0), tile=Y_ri) - - -def fft_twiddles(rows: int, cols: int, factor: int, device, precision): - """Generate DFT twiddle factors.""" - I, J = torch.meshgrid(torch.arange(rows, device=device), - torch.arange(cols, device=device), indexing='ij') - W_complex = torch.exp(-2 * math.pi * 1j * (I * J) / factor) - return torch.view_as_real(W_complex).to(precision).contiguous() - - -def fft_make_twiddles(factors, precision, device): - """Generate W and T matrices for FFT.""" - F0, F1, F2 = factors - N = F0 * F1 * F2 - F1F2 = F1 * F2 - W0 = fft_twiddles(F0, F0, F0, device, precision) - W1 = fft_twiddles(F1, F1, F1, device, precision) - W2 = fft_twiddles(F2, F2, F2, device, precision) - T0 = fft_twiddles(F0, F1F2, N, device, precision) - T1 = fft_twiddles(F1, F2, F1F2, device, precision) - return (W0, W1, W2, T0, T1) - +# cuTile kernel: use fft_kernel and fft_make_twiddles from fft.py def benchmark_fft(): print("\nBenchmarking FFT...") diff --git a/examples/fft.jl b/examples/fft.jl index 2e5495a..e7712d5 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -304,5 +304,4 @@ function main() println("\n--- cuTile FFT example execution complete ---") end -# Run validation -main() +isinteractive() || main() diff --git a/examples/fft.py b/examples/fft.py new file mode 100644 index 0000000..7977f7b --- /dev/null +++ b/examples/fft.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +""" +FFT (3-stage Cooley-Tukey) example - cuTile Python +""" + +import torch +import math +import cuda.tile as ct + +@ct.kernel +def fft_kernel(x_packed_in, y_packed_out, + W0, W1, W2, T0, T1, + N: ct.Constant[int], F0: ct.Constant[int], F1: ct.Constant[int], F2: ct.Constant[int], + BS: ct.Constant[int], D: ct.Constant[int]): + """cuTile kernel for 3-stage Cooley-Tukey FFT.""" + F0F1 = F0 * F1 + F1F2 = F1 * F2 + F0F2 = F0 * F2 + + bid = ct.bid(0) + + # Load input, reshape to separate real/imag + X_ri = ct.reshape(ct.load(x_packed_in, index=(bid, 0, 0), shape=(BS, N * 2 // D, D)), (BS, N, 2)) + X_r = ct.reshape(ct.extract(X_ri, index=(0, 0, 0), shape=(BS, N, 1)), (BS, F0, F1, F2)) + X_i = ct.reshape(ct.extract(X_ri, index=(0, 0, 1), shape=(BS, N, 1)), (BS, F0, F1, F2)) + + # Load W matrices (rotation matrices) + W0_ri = ct.reshape(ct.load(W0, index=(0, 0, 0), shape=(F0, F0, 2)), (F0, F0, 2)) + W0_r = ct.reshape(ct.extract(W0_ri, index=(0, 0, 0), shape=(F0, F0, 1)), (1, F0, F0)) + W0_i = ct.reshape(ct.extract(W0_ri, index=(0, 0, 1), shape=(F0, F0, 1)), (1, F0, F0)) + + W1_ri = ct.reshape(ct.load(W1, index=(0, 0, 0), shape=(F1, F1, 2)), (F1, F1, 2)) + W1_r = ct.reshape(ct.extract(W1_ri, index=(0, 0, 0), shape=(F1, F1, 1)), (1, F1, F1)) + W1_i = ct.reshape(ct.extract(W1_ri, index=(0, 0, 1), shape=(F1, F1, 1)), (1, F1, F1)) + + W2_ri = ct.reshape(ct.load(W2, index=(0, 0, 0), shape=(F2, F2, 2)), (F2, F2, 2)) + W2_r = ct.reshape(ct.extract(W2_ri, index=(0, 0, 0), shape=(F2, F2, 1)), (1, F2, F2)) + W2_i = ct.reshape(ct.extract(W2_ri, index=(0, 0, 1), shape=(F2, F2, 1)), (1, F2, F2)) + + # Load T matrices (twiddle factors) + T0_ri = ct.reshape(ct.load(T0, index=(0, 0, 0), shape=(F0, F1F2, 2)), (F0, F1F2, 2)) + T0_r = ct.reshape(ct.extract(T0_ri, index=(0, 0, 0), shape=(F0, F1F2, 1)), (N, 1)) + T0_i = ct.reshape(ct.extract(T0_ri, index=(0, 0, 1), shape=(F0, F1F2, 1)), (N, 1)) + + T1_ri = ct.reshape(ct.load(T1, index=(0, 0, 0), shape=(F1, F2, 2)), (F1, F2, 2)) + T1_r = ct.reshape(ct.extract(T1_ri, index=(0, 0, 0), shape=(F1, F2, 1)), (F1F2, 1)) + T1_i = ct.reshape(ct.extract(T1_ri, index=(0, 0, 1), shape=(F1, F2, 1)), (F1F2, 1)) + + # CT0: Contract over F0 dimension + X_r = ct.reshape(X_r, (BS, F0, F1F2)) + X_i = ct.reshape(X_i, (BS, F0, F1F2)) + X_r_ = ct.reshape(ct.matmul(W0_r, X_r) - ct.matmul(W0_i, X_i), (BS, N, 1)) + X_i_ = ct.reshape(ct.matmul(W0_i, X_r) + ct.matmul(W0_r, X_i), (BS, N, 1)) + + # Twiddle & Permute 0 + X_r = T0_r * X_r_ - T0_i * X_i_ + X_i = T0_i * X_r_ + T0_r * X_i_ + X_r = ct.permute(ct.reshape(X_r, (BS, F0, F1, F2)), (0, 2, 3, 1)) + X_i = ct.permute(ct.reshape(X_i, (BS, F0, F1, F2)), (0, 2, 3, 1)) + + # CT1: Contract over F1 dimension + X_r = ct.reshape(X_r, (BS, F1, F0F2)) + X_i = ct.reshape(X_i, (BS, F1, F0F2)) + X_r_ = ct.reshape(ct.matmul(W1_r, X_r) - ct.matmul(W1_i, X_i), (BS, F1F2, F0)) + X_i_ = ct.reshape(ct.matmul(W1_i, X_r) + ct.matmul(W1_r, X_i), (BS, F1F2, F0)) + + # Twiddle & Permute 1 + X_r = T1_r * X_r_ - T1_i * X_i_ + X_i = T1_i * X_r_ + T1_r * X_i_ + X_r = ct.permute(ct.reshape(X_r, (BS, F1, F2, F0)), (0, 2, 3, 1)) + X_i = ct.permute(ct.reshape(X_i, (BS, F1, F2, F0)), (0, 2, 3, 1)) + + # CT2: Contract over F2 dimension + X_r = ct.reshape(X_r, (BS, F2, F0F1)) + X_i = ct.reshape(X_i, (BS, F2, F0F1)) + X_r_ = ct.matmul(W2_r, X_r) - ct.matmul(W2_i, X_i) + X_i_ = ct.matmul(W2_i, X_r) + ct.matmul(W2_r, X_i) + + # Final Permutation + X_r = ct.permute(ct.reshape(X_r_, (BS, F2, F0, F1)), (0, 1, 3, 2)) + X_i = ct.permute(ct.reshape(X_i_, (BS, F2, F0, F1)), (0, 1, 3, 2)) + X_r = ct.reshape(X_r, (BS, N, 1)) + X_i = ct.reshape(X_i, (BS, N, 1)) + + # Concatenate and Store + Y_ri = ct.reshape(ct.cat((X_r, X_i), axis=-1), (BS, N * 2 // D, D)) + ct.store(y_packed_out, index=(bid, 0, 0), tile=Y_ri) + + +def fft_twiddles(rows: int, cols: int, factor: int, device, precision): + """Generate DFT twiddle factors.""" + I, J = torch.meshgrid(torch.arange(rows, device=device), + torch.arange(cols, device=device), indexing='ij') + W_complex = torch.exp(-2 * math.pi * 1j * (I * J) / factor) + return torch.view_as_real(W_complex).to(precision).contiguous() + + +def fft_make_twiddles(factors, precision, device): + """Generate W and T matrices for FFT.""" + F0, F1, F2 = factors + N = F0 * F1 * F2 + F1F2 = F1 * F2 + W0 = fft_twiddles(F0, F0, F0, device, precision) + W1 = fft_twiddles(F1, F1, F1, device, precision) + W2 = fft_twiddles(F2, F2, F2, device, precision) + T0 = fft_twiddles(F0, F1F2, N, device, precision) + T1 = fft_twiddles(F1, F2, F1F2, device, precision) + return (W0, W1, W2, T0, T1) + + +def run_fft(*, batch: int = 64, size: int = 512, factors: tuple = (8, 8, 8), + atom_packing_dim: int = 2, input=None, validate: bool = False): + """Run FFT. Returns (input, output) tensors for benchmarking.""" + F0, F1, F2 = factors + N = F0 * F1 * F2 + assert size == N, f"size ({size}) must equal product of factors ({N})" + D = atom_packing_dim + + if input is None: + input = torch.randn(batch, N, dtype=torch.complex64, device='cuda') + else: + batch = input.shape[0] + N = input.shape[1] + + # Pre-compute twiddles + W0, W1, W2, T0, T1 = fft_make_twiddles(factors, input.real.dtype, input.device) + + # Pack input + x_ri = torch.view_as_real(input) + x_packed = x_ri.reshape(batch, N * 2 // D, D).contiguous() + y_packed = torch.empty_like(x_packed) + + grid = (batch, 1, 1) + ct.launch(torch.cuda.current_stream(), grid, fft_kernel, + (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D)) + + output = torch.view_as_complex(y_packed.reshape(batch, N, 2)) + + if validate: + torch.cuda.synchronize() + reference = torch.fft.fft(input, dim=-1) + assert torch.allclose(output, reference, rtol=1e-3, atol=1e-3), \ + f"FFT incorrect! max diff: {torch.max(torch.abs(output - reference))}" + + return input, output + + +def test_fft(batch, size, factors, name=None): + """Test FFT with given parameters.""" + name = name or f"fft batch={batch}, size={size}, factors={factors}" + print(f"--- {name} ---") + run_fft(batch=batch, size=size, factors=factors, validate=True) + print(" passed") + + +def main(): + print("--- cuTile FFT Examples ---\n") + + test_fft(64, 512, (8, 8, 8)) + test_fft(32, 512, (8, 8, 8)) + + print("\n--- All FFT examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/layernorm.jl b/examples/layernorm.jl index 7e2e11a..be5e1fa 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -272,6 +272,47 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa return end +#============================================================================= + Run Functions (for benchmarking or programmatic use) +=============================================================================# + +# Run layernorm forward pass +function run_layernorm_fwd(; M::Int, N::Int, TILE_N::Int, eps::Float32=1f-5, + X::Union{CuArray,Nothing}=nothing, + W::Union{CuArray,Nothing}=nothing, + B::Union{CuArray,Nothing}=nothing, + Y::Union{CuArray,Nothing}=nothing, + Mean::Union{CuArray,Nothing}=nothing, + Rstd::Union{CuArray,Nothing}=nothing, + validate::Bool=false) + X = something(X, -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N)) + W = something(W, CUDA.randn(Float32, N)) + B = something(B, CUDA.randn(Float32, N)) + Y = something(Y, CUDA.zeros(Float32, M, N)) + Mean = something(Mean, CUDA.zeros(Float32, M)) + Rstd = something(Rstd, CUDA.zeros(Float32, M)) + + ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, + ct.Constant(eps), ct.Constant(TILE_N)) + + if validate + X_cpu = Array(X) + W_cpu = Array(W) + B_cpu = Array(B) + expected_mean = vec(sum(X_cpu, dims=2) ./ N) + expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) + expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ eps) + normalized = (X_cpu .- expected_mean) .* expected_rstd + expected_Y = normalized .* W_cpu' .+ B_cpu' + + atol, rtol = 1f-2, 1f-2 + @assert isapprox(expected_mean, Array(Mean); rtol, atol) "Mean mismatch" + @assert isapprox(expected_rstd, Array(Rstd); rtol, atol) "Rstd mismatch" + @assert isapprox(expected_Y, Array(Y); rtol, atol) "Y mismatch" + end + return (; X, W, B, Y, Mean, Rstd) +end + #============================================================================= Test / Validation =============================================================================# diff --git a/examples/layernorm.py b/examples/layernorm.py new file mode 100644 index 0000000..2007eb9 --- /dev/null +++ b/examples/layernorm.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Layer Normalization example - cuTile Python +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct + +@ct.kernel +def layernorm_cutile_kernel(X, W, B, Y, Mean, Rstd, eps: ct.Constant[float], TILE_N: ct.Constant[int]): + bid_m = ct.bid(0) + num_tiles = ct.num_tiles(X, axis=1, shape=(1, TILE_N)) + N = X.shape[1] + + # Compute mean + mean = ct.full((1, TILE_N), 0, dtype=ct.float32) + for j in range(num_tiles): + tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + mean += tx + mean = ct.sum(mean, axis=1) / N + ct.store(Mean, index=(bid_m,), tile=mean) + + # Compute variance + var = ct.full((1, TILE_N), 0, dtype=ct.float32) + for j in range(num_tiles): + tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + mask = (j * TILE_N + ct.arange(TILE_N, dtype=ct.int32)) < N + centered_tx = ct.where(mask, tx - mean, 0) + var += centered_tx ** 2 + var = ct.sum(var, axis=1) / N + rstd = 1 / ct.sqrt(var + eps) + ct.store(Rstd, index=(bid_m,), tile=rstd) + + # Normalize and apply affine transformation + for j in range(num_tiles): + tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + tw = ct.load(W, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) + tb = ct.load(B, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) + ty = (tx - mean) * rstd + ty = ty * tw + tb + ct.store(Y, index=(bid_m, j), tile=ty.astype(Y.dtype)) + + +def run_layernorm(*, M: int = 1024, N: int = 1024, tile_n: int = 1024, eps: float = 1e-5, + X=None, W=None, B=None, Y=None, Mean=None, Rstd=None, + dtype=np.float32, validate: bool = False): + """Run layer normalization. Returns (X, W, B, Y, Mean, Rstd) arrays for benchmarking.""" + if X is None: + X = (-2.3 + 0.5 * cp.random.randn(M, N)).astype(dtype) + else: + M, N = X.shape + if W is None: + W = cp.random.randn(N).astype(dtype) + if B is None: + B = cp.random.randn(N).astype(dtype) + if Y is None: + Y = cp.zeros((M, N), dtype=dtype) + if Mean is None: + Mean = cp.zeros(M, dtype=np.float32) + if Rstd is None: + Rstd = cp.zeros(M, dtype=np.float32) + + stream = cp.cuda.get_current_stream() + ct.launch(stream, (M,), layernorm_cutile_kernel, (X, W, B, Y, Mean, Rstd, eps, tile_n)) + + if validate: + cp.cuda.runtime.deviceSynchronize() + X_np = cp.asnumpy(X) + W_np = cp.asnumpy(W) + B_np = cp.asnumpy(B) + expected_mean = np.mean(X_np, axis=1, keepdims=True) + expected_var = np.mean((X_np - expected_mean) ** 2, axis=1, keepdims=True) + expected_rstd = 1.0 / np.sqrt(expected_var + eps) + normalized = (X_np - expected_mean) * expected_rstd + expected_Y = normalized * W_np + B_np + assert np.allclose(cp.asnumpy(Y), expected_Y, rtol=1e-2, atol=1e-2), \ + f"layernorm incorrect! max diff: {np.max(np.abs(cp.asnumpy(Y) - expected_Y))}" + + return X, W, B, Y, Mean, Rstd + + +def test_layernorm(M, N, tile_n, eps=1e-5, dtype=np.float32, name=None): + """Test layer normalization with given parameters.""" + name = name or f"layernorm ({M}x{N}), tile={tile_n}, dtype={dtype.__name__}" + print(f"--- {name} ---") + run_layernorm(M=M, N=N, tile_n=tile_n, eps=eps, dtype=dtype, validate=True) + print(" passed") + + +def main(): + print("--- cuTile Layer Normalization Examples ---\n") + + test_layernorm(256, 256, 256) + test_layernorm(512, 512, 512) + test_layernorm(1024, 1024, 1024) + + print("\n--- All layernorm examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/matmul.jl b/examples/matmul.jl index 36d41ea..247c929 100644 --- a/examples/matmul.jl +++ b/examples/matmul.jl @@ -62,33 +62,33 @@ function matmul_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArr return nothing end -function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T - name = something(name, "matmul ($M x $K) @ ($K x $N), $T, tiles=$tm x $tn x $tk") - println("--- $name ---") - - A = CUDA.rand(T, M, K) - B = CUDA.rand(T, K, N) - C = CUDA.zeros(T, M, N) - - # Use 1D grid - swizzle_2d converts to 2D indices +# Run matmul - for benchmarking or programmatic use +function run_matmul(; M::Int, N::Int, K::Int, tm::Int, tn::Int, tk::Int, T::DataType=Float32, + A::Union{CuArray,Nothing}=nothing, + B::Union{CuArray,Nothing}=nothing, + C::Union{CuArray,Nothing}=nothing, + validate::Bool=false) + A = something(A, CUDA.rand(T, M, K)) + B = something(B, CUDA.rand(T, K, N)) + C = something(C, CUDA.zeros(T, M, N)) grid_m = cld(M, tm) grid_n = cld(N, tn) grid = grid_m * grid_n - - # Launch kernel ct.launch(matmul_kernel, grid, A, B, C, ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) - - # Verify result - expected = Array(A) * Array(B) - result = Array(C) - - if isapprox(result, expected, rtol=1e-2, atol=1e-2) - println(" passed") - else - max_diff = maximum(abs.(result - expected)) - println(" FAILED (max diff: $max_diff)") + if validate + expected = Array(A) * Array(B) + result = Array(C) + @assert isapprox(result, expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(result - expected)))" end + return (; A, B, C) +end + +function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T + name = something(name, "matmul ($M x $K) @ ($K x $N), $T, tiles=$tm x $tn x $tk") + println("--- $name ---") + run_matmul(; M, N, K, tm, tn, tk, T, validate=true) + println(" passed") end function main() diff --git a/examples/matmul.py b/examples/matmul.py new file mode 100644 index 0000000..9eab93c --- /dev/null +++ b/examples/matmul.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +""" +Matrix Multiplication example - cuTile Python +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct +from math import ceil + +def swizzle_2d(M, N, tm, tn, GROUP_SIZE_M): + """Get the global IDs of the current CUDA block in a 1D grid.""" + bid = ct.bid(0) + num_bid_m = ct.cdiv(M, tm) + num_bid_n = ct.cdiv(N, tn) + num_bid_in_group = GROUP_SIZE_M * num_bid_n + group_id = bid // num_bid_in_group + first_bid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_bid_m - first_bid_m, GROUP_SIZE_M) + bid_m = first_bid_m + (bid % group_size_m) + bid_n = (bid % num_bid_in_group) // group_size_m + return bid_m, bid_n + + +@ct.kernel +def matmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int]): + GROUP_SIZE_M = 8 + M = A.shape[0] + N = B.shape[1] + bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M) + + num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk)) + accumulator = ct.full((tm, tn), 0, dtype=ct.float32) + + # Convert fp32 to tf32 for tensor cores + dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype + + for k in range(num_tiles_k): + a = ct.load(A, index=(bidx, k), shape=(tm, tk)).astype(dtype) + b = ct.load(B, index=(k, bidy), shape=(tk, tn)).astype(dtype) + accumulator = ct.mma(a, b, accumulator) + + accumulator = ct.astype(accumulator, C.dtype) + ct.store(C, index=(bidx, bidy), tile=accumulator) + + +def run_matmul(*, M: int = 1024, N: int = 1024, K: int = 1024, + tm: int = 64, tn: int = 64, tk: int = 64, + dtype=np.float32, A=None, B=None, C=None, validate: bool = False): + """Run matrix multiplication. Returns (A, B, C) arrays for benchmarking.""" + if A is None: + A = cp.random.randn(M, K).astype(dtype) + else: + M, K = A.shape + if B is None: + B = cp.random.randn(K, N).astype(dtype) + else: + K, N = B.shape + if C is None: + C = cp.zeros((M, N), dtype=dtype) + + grid_m = ceil(M / tm) + grid_n = ceil(N / tn) + grid = (grid_m * grid_n, 1, 1) + stream = cp.cuda.get_current_stream() + ct.launch(stream, grid, matmul_cutile_kernel, (A, B, C, tm, tn, tk)) + + if validate: + cp.cuda.runtime.deviceSynchronize() + expected = cp.asnumpy(A) @ cp.asnumpy(B) + # TF32 has reduced precision + assert np.allclose(cp.asnumpy(C), expected, rtol=1e-1, atol=1e-1), \ + f"matmul incorrect! max diff: {np.max(np.abs(cp.asnumpy(C) - expected))}" + + return A, B, C + + +def test_matmul(M, N, K, tm, tn, tk, dtype=np.float32, name=None): + """Test matmul with given parameters.""" + name = name or f"matmul ({M}x{K}) @ ({K}x{N}), tiles={tm}x{tn}x{tk}, dtype={dtype.__name__}" + print(f"--- {name} ---") + run_matmul(M=M, N=N, K=K, tm=tm, tn=tn, tk=tk, dtype=dtype, validate=True) + print(" passed") + + +def main(): + print("--- cuTile Matrix Multiplication Examples ---\n") + + test_matmul(256, 256, 256, 32, 32, 32) + test_matmul(512, 512, 512, 64, 64, 64) + test_matmul(256, 512, 128, 32, 32, 32) + test_matmul(1024, 1024, 1024, 64, 64, 64) + + print("\n--- All matmul examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/transpose.jl b/examples/transpose.jl index 9c5ec26..02a03b9 100644 --- a/examples/transpose.jl +++ b/examples/transpose.jl @@ -17,20 +17,27 @@ function transpose_kernel(x::ct.TileArray{T,2}, y::ct.TileArray{T,2}, return end -function test_transpose(::Type{T}, m, n, tm, tn; name=nothing) where T - name = something(name, "transpose ($m x $n, $T, tiles=$tm x $tn)") - println("--- $name ---") - x = CUDA.rand(T, m, n) - y = CUDA.zeros(T, n, m) - +# Run transpose - for benchmarking or programmatic use +function run_transpose(; m::Int, n::Int, tm::Int, tn::Int, T::DataType=Float32, + x::Union{CuArray,Nothing}=nothing, + y::Union{CuArray,Nothing}=nothing, + validate::Bool=false) + x = something(x, CUDA.rand(T, m, n)) + y = something(y, CUDA.zeros(T, n, m)) grid_x = cld(m, tm) grid_y = cld(n, tn) - - # Launch with ct.launch - CuArrays are auto-converted to TileArray ct.launch(transpose_kernel, (grid_x, grid_y), x, y, ct.Constant(tm), ct.Constant(tn)) + if validate + @assert Array(y) ≈ transpose(Array(x)) + end + return (; x, y) +end - @assert Array(y) ≈ transpose(Array(x)) +function test_transpose(::Type{T}, m, n, tm, tn; name=nothing) where T + name = something(name, "transpose ($m x $n, $T, tiles=$tm x $tn)") + println("--- $name ---") + run_transpose(; m, n, tm, tn, T, validate=true) println("✓ passed") end diff --git a/examples/transpose.py b/examples/transpose.py new file mode 100644 index 0000000..db0f7ab --- /dev/null +++ b/examples/transpose.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +""" +Matrix Transpose example - cuTile Python +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct + +@ct.kernel +def transpose_cutile_kernel(input, output, tile_m: ct.Constant[int], tile_n: ct.Constant[int]): + pid_m = ct.bid(0) + pid_n = ct.bid(1) + tile = ct.load(input, index=(pid_m, pid_n), shape=(tile_m, tile_n)) + tile_t = ct.transpose(tile) + ct.store(output, index=(pid_n, pid_m), tile=tile_t) + + +def run_transpose(*, M: int = 1024, N: int = 1024, tile_m: int = 64, tile_n: int = 64, + dtype=np.float32, input=None, output=None, validate: bool = False): + """Run matrix transpose. Returns (input, output) arrays for benchmarking.""" + if input is None: + input = cp.random.rand(M, N).astype(dtype) + else: + M, N = input.shape + if output is None: + output = cp.zeros((N, M), dtype=dtype) + + grid = (ct.cdiv(M, tile_m), ct.cdiv(N, tile_n), 1) + stream = cp.cuda.get_current_stream() + ct.launch(stream, grid, transpose_cutile_kernel, (input, output, tile_m, tile_n)) + + if validate: + cp.cuda.runtime.deviceSynchronize() + expected = cp.asnumpy(input).T + assert np.allclose(cp.asnumpy(output), expected), "transpose incorrect!" + + return input, output + + +def test_transpose(M, N, tile_m, tile_n, dtype=np.float32, name=None): + """Test transpose with given parameters.""" + name = name or f"transpose ({M}x{N}), tiles={tile_m}x{tile_n}, dtype={dtype.__name__}" + print(f"--- {name} ---") + run_transpose(M=M, N=N, tile_m=tile_m, tile_n=tile_n, dtype=dtype, validate=True) + print(" passed") + + +def main(): + print("--- cuTile Matrix Transpose Examples ---\n") + + test_transpose(256, 256, 32, 32) + test_transpose(512, 512, 64, 64) + test_transpose(256, 512, 32, 64) + test_transpose(1024, 1024, 64, 64) + + print("\n--- All transpose examples completed ---") + + +if __name__ == "__main__": + main() diff --git a/examples/vadd.jl b/examples/vadd.jl index 9689fc9..fcc04b3 100644 --- a/examples/vadd.jl +++ b/examples/vadd.jl @@ -27,31 +27,50 @@ function vec_add_kernel_2d(a::ct.TileArray{T,2}, b::ct.TileArray{T,2}, c::ct.Til return end +# Run 1D vector addition - for benchmarking or programmatic use +function run_vadd_1d(; n::Int, tile::Int, T::DataType=Float32, + a::Union{CuArray,Nothing}=nothing, + b::Union{CuArray,Nothing}=nothing, + c::Union{CuArray,Nothing}=nothing, + validate::Bool=false) + a = something(a, CUDA.rand(T, n)) + b = something(b, CUDA.rand(T, n)) + c = something(c, CUDA.zeros(T, n)) + ct.launch(vec_add_kernel_1d, cld(n, tile), a, b, c, ct.Constant(tile)) + if validate + @assert Array(c) ≈ Array(a) + Array(b) + end + return (; a, b, c) +end + function test_add_1d(::Type{T}, n, tile; name=nothing) where T name = something(name, "1D vec_add ($n elements, $T, tile=$tile)") println("--- $name ---") - a, b = CUDA.rand(T, n), CUDA.rand(T, n) - c = CUDA.zeros(T, n) - - # Launch with ct.launch - CuArrays are auto-converted to TileArray - # Constant parameters are ghost types - filtered out at launch time - ct.launch(vec_add_kernel_1d, cld(n, tile), a, b, c, ct.Constant(tile)) - - @assert Array(c) ≈ Array(a) + Array(b) + run_vadd_1d(; n, tile, T, validate=true) println("✓ passed") end -function test_add_2d(::Type{T}, m, n, tile_x, tile_y; name=nothing) where T - name = something(name, "2D vec_add ($m x $n, $T, tiles=$tile_x x $tile_y)") - println("--- $name ---") - a, b = CUDA.rand(T, m, n), CUDA.rand(T, m, n) - c = CUDA.zeros(T, m, n) - - # Launch with ct.launch - CuArrays are auto-converted to TileArray +# Run 2D matrix addition - for benchmarking or programmatic use +function run_vadd_2d(; m::Int, n::Int, tile_x::Int, tile_y::Int, T::DataType=Float32, + a::Union{CuArray,Nothing}=nothing, + b::Union{CuArray,Nothing}=nothing, + c::Union{CuArray,Nothing}=nothing, + validate::Bool=false) + a = something(a, CUDA.rand(T, m, n)) + b = something(b, CUDA.rand(T, m, n)) + c = something(c, CUDA.zeros(T, m, n)) ct.launch(vec_add_kernel_2d, (cld(m, tile_x), cld(n, tile_y)), a, b, c, ct.Constant(tile_x), ct.Constant(tile_y)) + if validate + @assert Array(c) ≈ Array(a) + Array(b) + end + return (; a, b, c) +end - @assert Array(c) ≈ Array(a) + Array(b) +function test_add_2d(::Type{T}, m, n, tile_x, tile_y; name=nothing) where T + name = something(name, "2D vec_add ($m x $n, $T, tiles=$tile_x x $tile_y)") + println("--- $name ---") + run_vadd_2d(; m, n, tile_x, tile_y, T, validate=true) println("✓ passed") end @@ -74,15 +93,26 @@ function vec_add_kernel_1d_gather(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c: return end +# Run 1D gather/scatter vector addition - for benchmarking or programmatic use +function run_vadd_1d_gather(; n::Int, tile::Int, T::DataType=Float32, + a::Union{CuArray,Nothing}=nothing, + b::Union{CuArray,Nothing}=nothing, + c::Union{CuArray,Nothing}=nothing, + validate::Bool=false) + a = something(a, CUDA.rand(T, n)) + b = something(b, CUDA.rand(T, n)) + c = something(c, CUDA.zeros(T, n)) + ct.launch(vec_add_kernel_1d_gather, cld(n, tile), a, b, c, ct.Constant(tile)) + if validate + @assert Array(c) ≈ Array(a) + Array(b) + end + return (; a, b, c) +end + function test_add_1d_gather(::Type{T}, n, tile; name=nothing) where T name = something(name, "1D vec_add gather ($n elements, $T, tile=$tile)") println("--- $name ---") - a, b = CUDA.rand(T, n), CUDA.rand(T, n) - c = CUDA.zeros(T, n) - - ct.launch(vec_add_kernel_1d_gather, cld(n, tile), a, b, c, ct.Constant(tile)) - - @assert Array(c) ≈ Array(a) + Array(b) + run_vadd_1d_gather(; n, tile, T, validate=true) println("✓ passed") end diff --git a/examples/vadd.py b/examples/vadd.py new file mode 100644 index 0000000..8d997bb --- /dev/null +++ b/examples/vadd.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +""" +Vector Addition example - cuTile Python +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct + +@ct.kernel +def vadd_cutile_kernel(a, b, c, tile_size: ct.Constant[int]): + pid = ct.bid(0) + tile_a = ct.load(a, index=(pid,), shape=(tile_size,)) + tile_b = ct.load(b, index=(pid,), shape=(tile_size,)) + result = tile_a + tile_b + ct.store(c, index=(pid,), tile=result) + + +def run_vadd(*, size: int = 2**20, tile: int = 1024, dtype=np.float32, validate: bool = False): + """Run vector addition. Returns (a, b, c) arrays for benchmarking.""" + a = cp.random.rand(size).astype(dtype) + b = cp.random.rand(size).astype(dtype) + c = cp.zeros(size, dtype=dtype) + + grid = (ct.cdiv(size, tile), 1, 1) + stream = cp.cuda.get_current_stream() + ct.launch(stream, grid, vadd_cutile_kernel, (a, b, c, tile)) + + if validate: + cp.cuda.runtime.deviceSynchronize() + expected = cp.asnumpy(a) + cp.asnumpy(b) + assert np.allclose(cp.asnumpy(c), expected), "vadd incorrect!" + + return a, b, c + + +def test_vadd(size, tile, dtype=np.float32, name=None): + """Test vector addition with given parameters.""" + name = name or f"vadd size={size}, tile={tile}, dtype={dtype.__name__}" + print(f"--- {name} ---") + run_vadd(size=size, tile=tile, dtype=dtype, validate=True) + print(" passed") + + +def main(): + print("--- cuTile Vector Addition Examples ---\n") + + test_vadd(1_024_000, 1024) + test_vadd(2**20, 512) + test_vadd(2**20, 1024) + + print("\n--- All vadd examples completed ---") + + +if __name__ == "__main__": + main() From 9d2c7c85637447766393a2d7777a89cab050c4ad Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 14 Jan 2026 08:59:17 -0500 Subject: [PATCH 2/7] Use unified structure. --- examples/batchmatmul.jl | 70 ++++++---- examples/batchmatmul.py | 85 +++++++++---- examples/benchmarks.jl | 174 ++++++++----------------- examples/benchmarks.py | 217 ++++++++++++------------------- examples/fft.jl | 151 ++++++++++++++-------- examples/fft.py | 84 +++++++++--- examples/layernorm.jl | 276 +++++++++++++++++++--------------------- examples/layernorm.py | 97 +++++++++----- examples/matmul.jl | 62 ++++++--- examples/matmul.py | 74 +++++++---- examples/transpose.jl | 55 +++++--- examples/transpose.py | 64 +++++++--- examples/vadd.jl | 188 ++++++++++++++++++--------- examples/vadd.py | 65 +++++++--- 14 files changed, 957 insertions(+), 705 deletions(-) diff --git a/examples/batchmatmul.jl b/examples/batchmatmul.jl index 0ac108d..f7ac0f0 100644 --- a/examples/batchmatmul.jl +++ b/examples/batchmatmul.jl @@ -57,39 +57,63 @@ function batch_matmul_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.T return nothing end -# Run batch matmul - for benchmarking or programmatic use -function run_batchmatmul(; M::Int, K::Int, N::Int, Batch::Int, tm::Int, tn::Int, tk::Int, - T::DataType=Float32, - A::Union{CuArray,Nothing}=nothing, - B::Union{CuArray,Nothing}=nothing, - C::Union{CuArray,Nothing}=nothing, - validate::Bool=false) - A = something(A, CUDA.rand(T, M, K, Batch)) - B = something(B, CUDA.rand(T, K, N, Batch)) - C = something(C, CUDA.zeros(T, M, N, Batch)) +#============================================================================= + Batch Matmul - prepare/run/verify pattern +=============================================================================# + +function batchmatmul_prepare(; M::Int, K::Int, N::Int, Batch::Int, T::DataType=Float32) + return (; + A = CUDA.rand(T, M, K, Batch), + B = CUDA.rand(T, K, N, Batch), + C = CUDA.zeros(T, M, N, Batch), + M, K, N, Batch + ) +end + +function batchmatmul_run(data; tm::Int, tn::Int, tk::Int, nruns::Int=1, warmup::Int=0) + (; A, B, C, M, N, Batch) = data grid = (cld(M, tm), cld(N, tn), Batch) - ct.launch(batch_matmul_kernel, grid, A, B, C, - ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) - if validate - A_cpu = Array(A) - B_cpu = Array(B) - expected = similar(A_cpu, M, N, Batch) - for b in 1:Batch - expected[:, :, b] = A_cpu[:, :, b] * B_cpu[:, :, b] - end - result = Array(C) - @assert isapprox(result, expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(result - expected)))" + + for _ in 1:warmup + ct.launch(batch_matmul_kernel, grid, A, B, C, + ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) + end + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(batch_matmul_kernel, grid, A, B, C, + ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) + push!(times, t * 1000) # ms end - return (; A, B, C) + + return (; C, times) +end + +function batchmatmul_verify(data, result) + (; A, B, M, N, Batch) = data + A_cpu = Array(A) + B_cpu = Array(B) + expected = similar(A_cpu, M, N, Batch) + for b in 1:Batch + expected[:, :, b] = A_cpu[:, :, b] * B_cpu[:, :, b] + end + @assert isapprox(Array(result.C), expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))" end function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) where T name = something(name, "batch_matmul ($M x $K x $Batch) @ ($K x $N x $Batch), $T, tiles=$tm x $tn x $tk") println("--- $name ---") - run_batchmatmul(; M, K, N, Batch, tm, tn, tk, T, validate=true) + data = batchmatmul_prepare(; M, K, N, Batch, T) + result = batchmatmul_run(data; tm, tn, tk) + batchmatmul_verify(data, result) println(" passed") end +#============================================================================= + Main +=============================================================================# + function main() println("--- cuTile Batch Matrix Multiplication Examples ---\n") diff --git a/examples/batchmatmul.py b/examples/batchmatmul.py index 3723bd7..01a8d8a 100644 --- a/examples/batchmatmul.py +++ b/examples/batchmatmul.py @@ -36,44 +36,75 @@ def batchmatmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int ct.store(C, index=(pid_batch, bidx, bidy), tile=result_3d) -def run_batchmatmul(*, Batch: int = 4, M: int = 256, K: int = 128, N: int = 256, - tm: int = 128, tn: int = 256, tk: int = 64, - dtype=np.float16, A=None, B=None, C=None, validate: bool = False): - """Run batch matrix multiplication. Returns (A, B, C) arrays for benchmarking.""" - if A is None: - A = cp.random.randn(Batch, M, K).astype(dtype) - else: - Batch, M, K = A.shape - if B is None: - B = cp.random.randn(Batch, K, N).astype(dtype) - else: - Batch, K, N = B.shape - if C is None: - C = cp.zeros((Batch, M, N), dtype=dtype) +#============================================================================= +# prepare/run/verify pattern +#============================================================================= + +def batchmatmul_prepare(*, Batch: int, M: int, K: int, N: int, dtype=np.float16): + """Allocate and initialize data for batch matmul.""" + return { + "A": cp.random.randn(Batch, M, K).astype(dtype), + "B": cp.random.randn(Batch, K, N).astype(dtype), + "C": cp.zeros((Batch, M, N), dtype=dtype), + "Batch": Batch, + "M": M, + "K": K, + "N": N + } + + +def batchmatmul_run(data, *, tm: int, tn: int, tk: int, nruns: int = 1, warmup: int = 0): + """Run batch matmul kernel with timing.""" + A, B, C = data["A"], data["B"], data["C"] + Batch, M, N = data["Batch"], data["M"], data["N"] grid = (Batch, ceil(M / tm), ceil(N / tn)) stream = cp.cuda.get_current_stream() - ct.launch(stream, grid, batchmatmul_cutile_kernel, (A, B, C, tm, tn, tk)) - if validate: - cp.cuda.runtime.deviceSynchronize() - A_np = cp.asnumpy(A).astype(np.float32) - B_np = cp.asnumpy(B).astype(np.float32) - C_np = cp.asnumpy(C).astype(np.float32) - expected = np.zeros((Batch, M, N), dtype=np.float32) - for b in range(Batch): - expected[b] = A_np[b] @ B_np[b] - assert np.allclose(C_np, expected, rtol=1e-1, atol=1e-1), \ - f"batchmatmul incorrect! max diff: {np.max(np.abs(C_np - expected))}" + # Warmup + for _ in range(warmup): + ct.launch(stream, grid, batchmatmul_cutile_kernel, (A, B, C, tm, tn, tk)) + cp.cuda.runtime.deviceSynchronize() - return A, B, C + # Timed runs + times = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + ct.launch(stream, grid, batchmatmul_cutile_kernel, (A, B, C, tm, tn, tk)) + end.record(stream) + end.synchronize() + times.append(cp.cuda.get_elapsed_time(start, end)) # ms + return {"C": C, "times": times} + + +def batchmatmul_verify(data, result): + """Verify batch matmul results.""" + A_np = cp.asnumpy(data["A"]).astype(np.float32) + B_np = cp.asnumpy(data["B"]).astype(np.float32) + C_np = cp.asnumpy(result["C"]).astype(np.float32) + Batch, M, N = data["Batch"], data["M"], data["N"] + + expected = np.zeros((Batch, M, N), dtype=np.float32) + for b in range(Batch): + expected[b] = A_np[b] @ B_np[b] + assert np.allclose(C_np, expected, rtol=1e-1, atol=1e-1), \ + f"batchmatmul incorrect! max diff: {np.max(np.abs(C_np - expected))}" + + +#============================================================================= +# Test function +#============================================================================= def test_batchmatmul(Batch, M, K, N, tm, tn, tk, dtype=np.float16, name=None): """Test batch matmul with given parameters.""" name = name or f"batchmatmul ({Batch}x{M}x{K}) @ ({Batch}x{K}x{N}), tiles={tm}x{tn}x{tk}, dtype={dtype.__name__}" print(f"--- {name} ---") - run_batchmatmul(Batch=Batch, M=M, K=K, N=N, tm=tm, tn=tn, tk=tk, dtype=dtype, validate=True) + data = batchmatmul_prepare(Batch=Batch, M=M, K=K, N=N, dtype=dtype) + result = batchmatmul_run(data, tm=tm, tn=tn, tk=tk) + batchmatmul_verify(data, result) print(" passed") diff --git a/examples/benchmarks.jl b/examples/benchmarks.jl index 2a7854c..ca9840a 100644 --- a/examples/benchmarks.jl +++ b/examples/benchmarks.jl @@ -100,15 +100,13 @@ function vadd_simt_kernel!(a, b, c) return end -# cuTile kernel: use vec_add_kernel_1d from vadd.jl - function benchmark_vadd() println("\nBenchmarking Vector Addition...") println(" Size: $VADD_SIZE elements ($(VADD_SIZE * 4 / 1e6) MB)") - a = CUDA.rand(Float32, VADD_SIZE) - b = CUDA.rand(Float32, VADD_SIZE) - c = similar(a) + # Prepare data once (using vadd.jl's prepare function) + data = vadd_1d_prepare(; n=VADD_SIZE, T=Float32) + (; a, b, c) = data expected = Array(a) .+ Array(b) results = BenchmarkResult[] @@ -133,13 +131,10 @@ function benchmark_vadd() min_t, mean_t = benchmark_kernel(simt_f) push!(results, BenchmarkResult("SIMT (CUDA.jl)", min_t, mean_t)) - # cuTile (uses vec_add_kernel_1d from vadd.jl) - grid = (cld(VADD_SIZE, VADD_TILE), 1, 1) - cutile_f = () -> ct.launch(vec_add_kernel_1d, grid, a, b, c, ct.Constant(VADD_TILE)) - cutile_f() - CUDA.synchronize() - @assert Array(c) ≈ expected "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) + # cuTile (using vadd.jl's run/verify functions) + result = vadd_1d_run(data; tile=VADD_TILE, nruns=NRUNS, warmup=WARMUP) + vadd_1d_verify(data, result) + min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) # Calculate bandwidth @@ -186,58 +181,52 @@ function transpose_simt_shared_kernel!(input, output, M, N) return end -# cuTile kernel: use transpose_kernel from transpose.jl - function benchmark_transpose() println("\nBenchmarking Matrix Transpose...") M, N = TRANSPOSE_DIM, TRANSPOSE_DIM println(" Size: $(M)x$(N) ($(M * N * 4 / 1e6) MB)") - input = CUDA.rand(Float32, M, N) - output = CUDA.zeros(Float32, N, M) - expected = Array(permutedims(input, (2, 1))) + # Prepare data once (using transpose.jl's prepare function) + data = transpose_prepare(; m=M, n=N, T=Float32) + (; x, y) = data + expected = Array(permutedims(x, (2, 1))) results = BenchmarkResult[] # GPUArrays (permutedims) - gpuarrays_f = () -> permutedims!(output, input, (2, 1)) + gpuarrays_f = () -> permutedims!(y, x, (2, 1)) gpuarrays_f() CUDA.synchronize() - @assert Array(output) ≈ expected "GPUArrays incorrect!" + @assert Array(y) ≈ expected "GPUArrays incorrect!" min_t, mean_t = benchmark_kernel(gpuarrays_f) push!(results, BenchmarkResult("GPUArrays", min_t, mean_t)) # SIMT naive - fill!(output, 0) + fill!(y, 0) threads_naive = (16, 16) blocks_naive = (cld(M, 16), cld(N, 16)) - simt_naive_f = () -> @cuda threads=threads_naive blocks=blocks_naive transpose_simt_naive_kernel!(input, output, M, N) + simt_naive_f = () -> @cuda threads=threads_naive blocks=blocks_naive transpose_simt_naive_kernel!(x, y, M, N) simt_naive_f() CUDA.synchronize() - @assert Array(output) ≈ expected "SIMT naive incorrect!" + @assert Array(y) ≈ expected "SIMT naive incorrect!" min_t, mean_t = benchmark_kernel(simt_naive_f) push!(results, BenchmarkResult("SIMT naive", min_t, mean_t)) # SIMT shared - fill!(output, 0) + fill!(y, 0) threads_shared = (32, 32) blocks_shared = (cld(M, 32), cld(N, 32)) - simt_shared_f = () -> @cuda threads=threads_shared blocks=blocks_shared transpose_simt_shared_kernel!(input, output, M, N) + simt_shared_f = () -> @cuda threads=threads_shared blocks=blocks_shared transpose_simt_shared_kernel!(x, y, M, N) simt_shared_f() CUDA.synchronize() - @assert Array(output) ≈ expected "SIMT shared incorrect!" + @assert Array(y) ≈ expected "SIMT shared incorrect!" min_t, mean_t = benchmark_kernel(simt_shared_f) push!(results, BenchmarkResult("SIMT shared", min_t, mean_t)) - # cuTile (uses transpose_kernel from transpose.jl) - fill!(output, 0) - grid = (cld(M, TRANSPOSE_TILE_M), cld(N, TRANSPOSE_TILE_N), 1) - cutile_f = () -> ct.launch(transpose_kernel, grid, input, output, - ct.Constant(TRANSPOSE_TILE_M), ct.Constant(TRANSPOSE_TILE_N)) - cutile_f() - CUDA.synchronize() - @assert Array(output) ≈ expected "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) + # cuTile (using transpose.jl's run/verify functions) + result = transpose_run(data; tm=TRANSPOSE_TILE_M, tn=TRANSPOSE_TILE_N, nruns=NRUNS, warmup=WARMUP) + transpose_verify(data, result) + min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) # Calculate bandwidth @@ -252,16 +241,14 @@ end Matrix Multiplication =============================================================================# -# cuTile kernel: use matmul_kernel and swizzle_2d from matmul.jl - function benchmark_matmul() println("\nBenchmarking Matrix Multiplication...") M, N, K = MATMUL_DIM, MATMUL_DIM, MATMUL_DIM println(" Size: $(M)x$(K) * $(K)x$(N)") - A = CUDA.rand(Float32, M, K) - B = CUDA.rand(Float32, K, N) - C = CUDA.zeros(Float32, M, N) + # Prepare data once (using matmul.jl's prepare function) + data = matmul_prepare(; M, K, N, T=Float32) + (; A, B, C) = data # Reference result (cuBLAS) C_ref = similar(C) @@ -287,17 +274,10 @@ function benchmark_matmul() min_t, mean_t = benchmark_kernel(cublas_f) push!(results, BenchmarkResult("cuBLAS", min_t, mean_t)) - # cuTile (uses matmul_kernel from matmul.jl) - fill!(C, 0) - grid_m = cld(M, MATMUL_TM) - grid_n = cld(N, MATMUL_TN) - grid = (grid_m * grid_n, 1, 1) - cutile_f = () -> ct.launch(matmul_kernel, grid, A, B, C, - ct.Constant(MATMUL_TM), ct.Constant(MATMUL_TN), ct.Constant(MATMUL_TK)) - cutile_f() - CUDA.synchronize() - @assert isapprox(Array(C), Array(C_ref), rtol=1e-2, atol=1e-2) "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) + # cuTile (using matmul.jl's run/verify functions) + result = matmul_run(data; tm=MATMUL_TM, tn=MATMUL_TN, tk=MATMUL_TK, nruns=NRUNS, warmup=WARMUP) + matmul_verify(data, result) + min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) # Calculate TFLOPS @@ -363,19 +343,14 @@ function layernorm_simt_kernel!(X, W, B, Y, Mean, Rstd, N, eps) return end -# cuTile kernel: use layer_norm_fwd from layernorm.jl - function benchmark_layernorm() println("\nBenchmarking Layer Normalization...") M, N = LAYERNORM_M, LAYERNORM_N println(" Size: $(M)x$(N) ($(M * N * 4 / 1e6) MB)") - X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N) - W = CUDA.randn(Float32, N) - B = CUDA.randn(Float32, N) - Y = CUDA.zeros(Float32, M, N) - Mean = CUDA.zeros(Float32, M) - Rstd = CUDA.zeros(Float32, M) + # Prepare data once (using layernorm.jl's prepare function) + data = layernorm_fwd_prepare(; M, N, T=Float32, eps=LAYERNORM_EPS) + (; X, W, B, Y, Mean, Rstd) = data # Reference result X_cpu = Array(X) @@ -398,14 +373,10 @@ function benchmark_layernorm() min_t, mean_t = benchmark_kernel(simt_f) push!(results, BenchmarkResult("SIMT naive", min_t, mean_t)) - # cuTile (uses layer_norm_fwd from layernorm.jl) - fill!(Y, 0); fill!(Mean, 0); fill!(Rstd, 0) - cutile_f = () -> ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, - ct.Constant(LAYERNORM_EPS), ct.Constant(LAYERNORM_TILE_N)) - cutile_f() - CUDA.synchronize() - @assert isapprox(Array(Y), expected_Y, rtol=1e-2, atol=1e-2) "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) + # cuTile (using layernorm.jl's run/verify functions) + result = layernorm_fwd_run(data; tile_n=LAYERNORM_TILE_N, nruns=NRUNS, warmup=WARMUP) + layernorm_fwd_verify(data, result) + min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) # Calculate bandwidth (rough estimate: 3 reads of X + W + B, 1 write of Y) @@ -420,17 +391,14 @@ end Batch Matrix Multiplication =============================================================================# -# cuTile kernel: use batch_matmul_kernel from batchmatmul.jl - function benchmark_batchmatmul() println("\nBenchmarking Batch Matrix Multiplication...") Batch, M, K, N = BATCHMATMUL_BATCH, BATCHMATMUL_M, BATCHMATMUL_K, BATCHMATMUL_N println(" Size: ($M x $K x $Batch) @ ($K x $N x $Batch), Float16") - # Batch-last ordering for optimal column-major access - A = CUDA.rand(Float16, M, K, Batch) - B = CUDA.rand(Float16, K, N, Batch) - C = CUDA.zeros(Float16, M, N, Batch) + # Prepare data once (using batchmatmul.jl's prepare function) + data = batchmatmul_prepare(; M, K, N, Batch, T=Float16) + (; A, B, C) = data # Reference result (batched matmul on CPU) A_cpu = Float32.(Array(A)) @@ -456,16 +424,11 @@ function benchmark_batchmatmul() min_t, mean_t = benchmark_kernel(cublas_f) push!(results, BenchmarkResult("cuBLAS (loop)", min_t, mean_t)) - # cuTile (uses batch_matmul_kernel from batchmatmul.jl) - fill!(C, 0) - grid = (cld(M, BATCHMATMUL_TM), cld(N, BATCHMATMUL_TN), Batch) - cutile_f = () -> ct.launch(batch_matmul_kernel, grid, A, B, C, - ct.Constant(BATCHMATMUL_TM), ct.Constant(BATCHMATMUL_TN), - ct.Constant(BATCHMATMUL_TK)) - cutile_f() - CUDA.synchronize() - @assert isapprox(Float32.(Array(C)), C_ref, rtol=1e-1, atol=1e-1) "cuTile incorrect!" - min_t, mean_t = benchmark_kernel(cutile_f) + # cuTile (using batchmatmul.jl's run/verify functions) + result = batchmatmul_run(data; tm=BATCHMATMUL_TM, tn=BATCHMATMUL_TN, tk=BATCHMATMUL_TK, + nruns=NRUNS, warmup=WARMUP) + batchmatmul_verify(data, result) + min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) # Calculate TFLOPS @@ -479,55 +442,20 @@ end FFT (3-stage Cooley-Tukey) - Column-Major Version =============================================================================# -# cuTile kernel: use fft_kernel and make_twiddles from fft.jl - function benchmark_fft() println("\nBenchmarking FFT...") BS, N = FFT_BATCH, FFT_SIZE - F0, F1, F2 = FFT_FACTORS - D = FFT_ATOM_PACKING_DIM println(" Size: $BS batches × $N FFT ($(BS * N * 8 / 1e6) MB)") - # Create complex input - CUDA.seed!(42) - input = CUDA.randn(ComplexF32, BS, N) - - # Reference result (FFTW) - reference = FFTW.fft(Array(input), 2) + # Prepare data once (using fft.jl's prepare function) + data = fft_prepare(; batch=BS, n=N, factors=FFT_FACTORS, atom_packing_dim=FFT_ATOM_PACKING_DIM) results = BenchmarkResult[] - # Pre-compute twiddles (one-time CPU cost, uses make_twiddles from fft.jl) - W0, W1, W2, T0, T1 = make_twiddles(FFT_FACTORS) - W0_gpu, W1_gpu, W2_gpu = CuArray(W0), CuArray(W1), CuArray(W2) - T0_gpu, T1_gpu = CuArray(T0), CuArray(T1) - - # Pre-pack input (zero-copy) - N2D = N * 2 ÷ D - x_packed = reinterpret(reshape, Float32, input) - y_packed = CUDA.zeros(Float32, D, BS, N2D) - - # Kernel launch parameters - F0F1, F1F2, F0F2 = F0 * F1, F1 * F2, F0 * F2 - grid = (BS, 1, 1) - - # Kernel-only timing function - cutile_kernel_f = () -> ct.launch(fft_kernel, grid, - x_packed, y_packed, - W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, - ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), - ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), - ct.Constant(BS), ct.Constant(D), ct.Constant(N2D)) - - # Verify correctness - cutile_kernel_f() - CUDA.synchronize() - y_complex = reinterpret(reshape, ComplexF32, y_packed) - output = copy(y_complex) - @assert isapprox(Array(output), reference, rtol=1e-3) "cuTile FFT incorrect!" - - # Benchmark kernel only - min_t, mean_t = benchmark_kernel(cutile_kernel_f) + # cuTile (using fft.jl's run/verify functions) + result = fft_run(data; nruns=NRUNS, warmup=WARMUP) + fft_verify(data, result) + min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) # Performance metric: GFLOPS (5 * N * log2(N) per complex FFT) diff --git a/examples/benchmarks.py b/examples/benchmarks.py index 482fd18..b27addd 100644 --- a/examples/benchmarks.py +++ b/examples/benchmarks.py @@ -9,16 +9,15 @@ import numpy as np import torch import cuda.tile as ct -import math from math import ceil, log2 -# Import kernels from example files -from vadd import vadd_cutile_kernel -from transpose import transpose_cutile_kernel -from matmul import matmul_cutile_kernel, swizzle_2d -from layernorm import layernorm_cutile_kernel -from batchmatmul import batchmatmul_cutile_kernel -from fft import fft_kernel, fft_make_twiddles +# Import prepare/run/verify functions from example files +from vadd import vadd_prepare, vadd_run, vadd_verify +from transpose import transpose_prepare, transpose_run, transpose_verify +from matmul import matmul_prepare, matmul_run, matmul_verify +from layernorm import layernorm_prepare, layernorm_run, layernorm_verify +from batchmatmul import batchmatmul_prepare, batchmatmul_run, batchmatmul_verify +from fft import fft_prepare, fft_run, fft_verify #============================================================================= # Configuration @@ -36,7 +35,6 @@ FFT_BATCH = 64 FFT_SIZE = 512 FFT_FACTORS = (8, 8, 8) -FFT_ATOM_PACKING_DIM = 2 # Tile sizes VADD_TILE = 1024 @@ -46,6 +44,21 @@ MATMUL_TN = 64 MATMUL_TK = 64 +# Layer norm sizes +LAYERNORM_M = 4096 +LAYERNORM_N = 4096 +LAYERNORM_TILE_N = 1024 +LAYERNORM_EPS = 1e-5 + +# Batch matmul sizes +BATCHMATMUL_BATCH = 8 +BATCHMATMUL_M = 1024 +BATCHMATMUL_K = 512 +BATCHMATMUL_N = 2048 +BATCHMATMUL_TM = 128 +BATCHMATMUL_TN = 256 +BATCHMATMUL_TK = 64 + #============================================================================= # Benchmark Utilities #============================================================================= @@ -131,13 +144,11 @@ def print_table(title: str, results: list, extra_col=None): # Vector Addition #============================================================================= -# cuTile kernel: use vadd_cutile_kernel from vadd.py - def benchmark_vadd(): print("\nBenchmarking Vector Addition...") print(f" Size: {VADD_SIZE} elements ({VADD_SIZE * 4 / 1e6} MB)") - # CuPy arrays + # CuPy arrays for CuPy/PyTorch benchmarks a_cp = cp.random.rand(VADD_SIZE).astype(np.float32) b_cp = cp.random.rand(VADD_SIZE).astype(np.float32) c_cp = cp.zeros(VADD_SIZE, dtype=np.float32) @@ -170,18 +181,16 @@ def torch_vadd(): min_t, mean_t = benchmark_torch(torch_vadd) results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - # cuTile - grid = (ct.cdiv(VADD_SIZE, VADD_TILE), 1, 1) - stream = cp.cuda.get_current_stream() - c_cp.fill(0) - - def cutile_vadd(): - ct.launch(stream, grid, vadd_cutile_kernel, (a_cp, b_cp, c_cp, VADD_TILE)) + # cuTile - use prepare/run/verify pattern + data = vadd_prepare(n=VADD_SIZE, dtype=np.float32) + # Copy expected data for apples-to-apples comparison + data["a"] = a_cp + data["b"] = b_cp + data["c"] = cp.zeros(VADD_SIZE, dtype=np.float32) - cutile_vadd() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(c_cp), expected), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_vadd) + result = vadd_run(data, tile=VADD_TILE, nruns=NRUNS, warmup=WARMUP) + vadd_verify(data, result) + min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) # Calculate bandwidth @@ -196,8 +205,6 @@ def cutile_vadd(): # Matrix Transpose #============================================================================= -# cuTile kernel: use transpose_cutile_kernel from transpose.py - def benchmark_transpose(): print("\nBenchmarking Matrix Transpose...") M, N = TRANSPOSE_DIM, TRANSPOSE_DIM @@ -235,19 +242,16 @@ def torch_transpose(): min_t, mean_t = benchmark_torch(torch_transpose) results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - # cuTile - output_cp.fill(0) - grid = (ct.cdiv(M, TRANSPOSE_TILE_M), ct.cdiv(N, TRANSPOSE_TILE_N), 1) - stream = cp.cuda.get_current_stream() - - def cutile_transpose(): - ct.launch(stream, grid, transpose_cutile_kernel, - (input_cp, output_cp, TRANSPOSE_TILE_M, TRANSPOSE_TILE_N)) + # cuTile - use prepare/run/verify pattern + data = transpose_prepare(M=M, N=N, dtype=np.float32) + # Copy input for apples-to-apples comparison + data["input"] = input_cp + data["output"] = cp.zeros((N, M), dtype=np.float32) - cutile_transpose() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(output_cp), expected), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_transpose) + result = transpose_run(data, tile_m=TRANSPOSE_TILE_M, tile_n=TRANSPOSE_TILE_N, + nruns=NRUNS, warmup=WARMUP) + transpose_verify(data, result) + min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) # Calculate bandwidth @@ -262,8 +266,6 @@ def cutile_transpose(): # Matrix Multiplication #============================================================================= -# cuTile kernel: use matmul_cutile_kernel and swizzle_2d from matmul.py - def benchmark_matmul(): print("\nBenchmarking Matrix Multiplication...") M, N, K = MATMUL_DIM, MATMUL_DIM, MATMUL_DIM @@ -309,22 +311,17 @@ def cupy_matmul(): min_t, mean_t = benchmark_cupy(cupy_matmul) results.append(BenchmarkResult("CuPy (cuBLAS)", min_t, mean_t)) - # cuTile - C_cp.fill(0) - grid_m = ceil(M / MATMUL_TM) - grid_n = ceil(N / MATMUL_TN) - grid = (grid_m * grid_n, 1, 1) - stream = cp.cuda.get_current_stream() - - def cutile_matmul(): - ct.launch(stream, grid, matmul_cutile_kernel, - (A_cp, B_cp, C_cp, MATMUL_TM, MATMUL_TN, MATMUL_TK)) - - cutile_matmul() - cp.cuda.runtime.deviceSynchronize() - # TF32 has reduced precision compared to FP32 cuBLAS - assert np.allclose(cp.asnumpy(C_cp), C_ref, rtol=1e-1, atol=1e-1), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_matmul) + # cuTile - use prepare/run/verify pattern + data = matmul_prepare(M=M, N=N, K=K, dtype=np.float32) + # Copy input for apples-to-apples comparison + data["A"] = A_cp + data["B"] = B_cp + data["C"] = cp.zeros((M, N), dtype=np.float32) + + result = matmul_run(data, tm=MATMUL_TM, tn=MATMUL_TN, tk=MATMUL_TK, + nruns=NRUNS, warmup=WARMUP) + matmul_verify(data, result) + min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) # Calculate TFLOPS @@ -338,34 +335,18 @@ def cutile_matmul(): # Layer Normalization #============================================================================= -LAYERNORM_M = 4096 -LAYERNORM_N = 4096 -LAYERNORM_TILE_N = 1024 -LAYERNORM_EPS = 1e-5 - -# Batch matmul sizes -BATCHMATMUL_BATCH = 8 -BATCHMATMUL_M = 1024 -BATCHMATMUL_K = 512 -BATCHMATMUL_N = 2048 -BATCHMATMUL_TM = 128 -BATCHMATMUL_TN = 256 -BATCHMATMUL_TK = 64 - -# cuTile kernel: use layernorm_cutile_kernel from layernorm.py - def benchmark_layernorm(): print("\nBenchmarking Layer Normalization...") M, N = LAYERNORM_M, LAYERNORM_N print(f" Size: {M}x{N} ({M * N * 4 / 1e6} MB)") - # CuPy arrays - X_cp = -2.3 + 0.5 * cp.random.randn(M, N).astype(np.float32) - W_cp = cp.random.randn(N).astype(np.float32) - B_cp = cp.random.randn(N).astype(np.float32) - Y_cp = cp.zeros((M, N), dtype=np.float32) - Mean_cp = cp.zeros(M, dtype=np.float32) - Rstd_cp = cp.zeros(M, dtype=np.float32) + # cuTile - prepare data + data = layernorm_prepare(M=M, N=N, eps=LAYERNORM_EPS, dtype=np.float32) + + # Extract CuPy/NumPy arrays for other benchmarks + X_cp = data["X"] + W_cp = data["W"] + B_cp = data["B"] # PyTorch tensors X_torch = torch.as_tensor(X_cp, device='cuda') @@ -396,20 +377,10 @@ def torch_layernorm(): min_t, mean_t = benchmark_torch(torch_layernorm) results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - # cuTile - Y_cp.fill(0) - Mean_cp.fill(0) - Rstd_cp.fill(0) - stream = cp.cuda.get_current_stream() - - def cutile_layernorm(): - ct.launch(stream, (M,), layernorm_cutile_kernel, - (X_cp, W_cp, B_cp, Y_cp, Mean_cp, Rstd_cp, LAYERNORM_EPS, LAYERNORM_TILE_N)) - - cutile_layernorm() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(Y_cp), expected_Y, rtol=1e-2, atol=1e-2), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_layernorm) + # cuTile - use prepare/run/verify pattern + result = layernorm_run(data, tile_n=LAYERNORM_TILE_N, nruns=NRUNS, warmup=WARMUP) + layernorm_verify(data, result) + min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) # Calculate bandwidth (rough estimate: 3 reads of X + W + B, 1 write of Y) @@ -424,8 +395,6 @@ def cutile_layernorm(): # Batch Matrix Multiplication #============================================================================= -# cuTile kernel: use batchmatmul_cutile_kernel from batchmatmul.py - def benchmark_batchmatmul(): print("\nBenchmarking Batch Matrix Multiplication...") Batch, M, K, N = BATCHMATMUL_BATCH, BATCHMATMUL_M, BATCHMATMUL_K, BATCHMATMUL_N @@ -439,7 +408,6 @@ def benchmark_batchmatmul(): # CuPy arrays (from same data) A_cp = cp.asarray(A_torch) B_cp = cp.asarray(B_torch) - C_cp = cp.zeros((Batch, M, N), dtype=np.float16) # Reference result (PyTorch bmm in fp32 for accuracy) C_ref = torch.bmm(A_torch.float(), B_torch.float()).cpu().numpy() @@ -457,19 +425,17 @@ def torch_bmm(): min_t, mean_t = benchmark_torch(torch_bmm) results.append(BenchmarkResult("PyTorch bmm", min_t, mean_t)) - # cuTile - C_cp.fill(0) - grid = (Batch, ceil(M / BATCHMATMUL_TM), ceil(N / BATCHMATMUL_TN)) - stream = cp.cuda.get_current_stream() - - def cutile_bmm(): - ct.launch(stream, grid, batchmatmul_cutile_kernel, - (A_cp, B_cp, C_cp, BATCHMATMUL_TM, BATCHMATMUL_TN, BATCHMATMUL_TK)) - - cutile_bmm() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(C_cp).astype(np.float32), C_ref, rtol=1e-1, atol=1e-1), "cuTile incorrect!" - min_t, mean_t = benchmark_cupy(cutile_bmm) + # cuTile - use prepare/run/verify pattern + data = batchmatmul_prepare(Batch=Batch, M=M, K=K, N=N, dtype=np.float16) + # Copy input for apples-to-apples comparison + data["A"] = A_cp + data["B"] = B_cp + data["C"] = cp.zeros((Batch, M, N), dtype=np.float16) + + result = batchmatmul_run(data, tm=BATCHMATMUL_TM, tn=BATCHMATMUL_TN, tk=BATCHMATMUL_TK, + nruns=NRUNS, warmup=WARMUP) + batchmatmul_verify(data, result) + min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) # Calculate TFLOPS @@ -483,43 +449,24 @@ def cutile_bmm(): # FFT (3-stage Cooley-Tukey) #============================================================================= -# cuTile kernel: use fft_kernel and fft_make_twiddles from fft.py - def benchmark_fft(): print("\nBenchmarking FFT...") BS, N = FFT_BATCH, FFT_SIZE - F0, F1, F2 = FFT_FACTORS - D = FFT_ATOM_PACKING_DIM print(f" Size: {BS} batches × {N} FFT ({BS * N * 8 / 1e6} MB)") - # PyTorch complex input - input_torch = torch.randn(BS, N, dtype=torch.complex64, device='cuda') + # cuTile - use prepare/run/verify pattern + data = fft_prepare(batch=BS, size=N, factors=FFT_FACTORS) - # Reference result - reference = torch.fft.fft(input_torch, dim=-1) + # Reference result using torch + reference = torch.fft.fft(data["input"], dim=-1) torch.cuda.synchronize() results = [] - # Pre-compute everything outside timing loop - x_ri = torch.view_as_real(input_torch) - x_packed = x_ri.reshape(BS, N * 2 // D, D).contiguous() - W0, W1, W2, T0, T1 = fft_make_twiddles(FFT_FACTORS, input_torch.real.dtype, input_torch.device) - y_packed = torch.empty_like(x_packed) - grid = (BS, 1, 1) - - # Kernel launch function - def fft_launch(): - ct.launch(torch.cuda.current_stream(), grid, fft_kernel, - (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, BS, D)) - - # Verify correctness - fft_launch() - torch.cuda.synchronize() - output = torch.view_as_complex(y_packed.reshape(BS, N, 2)) - assert torch.allclose(output, reference, rtol=1e-3, atol=1e-3), "cuTile FFT incorrect!" - - min_t, mean_t = benchmark_torch(fft_launch) + # cuTile FFT + result = fft_run(data, nruns=NRUNS, warmup=WARMUP) + fft_verify(data, result) + min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) # Calculate GFLOPS (5 * N * log2(N) ops per complex FFT) diff --git a/examples/fft.jl b/examples/fft.jl index e7712d5..899f62b 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -211,53 +211,116 @@ function make_twiddles(factors::NTuple{3, Int}) return (W0, W1, W2, T0, T1) end -# Main FFT function -function cutile_fft(x::CuMatrix{ComplexF32}, factors::NTuple{3, Int}; atom_packing_dim::Int=2) - BS = size(x, 1) - N = size(x, 2) - F0, F1, F2 = factors +#============================================================================= + FFT - prepare/run/verify pattern +=============================================================================# - @assert F0 * F1 * F2 == N "Factors must multiply to N" - @assert (N * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim" +function fft_prepare(; batch::Int, n::Int, factors::NTuple{3,Int}, atom_packing_dim::Int=2) + @assert factors[1] * factors[2] * factors[3] == n "Factors must multiply to N" + @assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim" - D = atom_packing_dim + CUDA.seed!(42) + input = CUDA.randn(ComplexF32, batch, n) - # Generate W and T matrices (CPU, one-time cost) + # Pre-compute twiddles (one-time CPU cost) W0, W1, W2, T0, T1 = make_twiddles(factors) - - # Upload to GPU W0_gpu = CuArray(W0) W1_gpu = CuArray(W1) W2_gpu = CuArray(W2) T0_gpu = CuArray(T0) T1_gpu = CuArray(T1) - # Pack input: complex (BS, N) → real (D, BS, N2D) - zero-copy view - N2D = N * 2 ÷ D - x_packed = reinterpret(reshape, Float32, x) # (2, BS, N) = (D, BS, N2D) + # Pack input + D = atom_packing_dim + N2D = n * 2 ÷ D + x_packed = reinterpret(reshape, Float32, input) + y_packed = CUDA.zeros(Float32, D, batch, N2D) + + return (; + input, x_packed, y_packed, + W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, + factors, batch, n, D, N2D + ) +end - # Allocate output - y_packed = CUDA.zeros(Float32, D, BS, N2D) +function fft_run(data; nruns::Int=1, warmup::Int=0) + (; x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, + factors, batch, n, D, N2D) = data - # Launch kernel + F0, F1, F2 = factors F0F1 = F0 * F1 F1F2 = F1 * F2 F0F2 = F0 * F2 - grid = (BS, 1, 1) - ct.launch(fft_kernel, grid, - x_packed, y_packed, - W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, - ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), - ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), - ct.Constant(BS), ct.Constant(D), ct.Constant(N2D)) - - # Unpack output: real (D, BS, N2D) → complex (BS, N) - zero-copy view + grid = (batch, 1, 1) + + for _ in 1:warmup + ct.launch(fft_kernel, grid, + x_packed, y_packed, + W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, + ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), + ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), + ct.Constant(batch), ct.Constant(D), ct.Constant(N2D)) + end + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(fft_kernel, grid, + x_packed, y_packed, + W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, + ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), + ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), + ct.Constant(batch), ct.Constant(D), ct.Constant(N2D)) + push!(times, t * 1000) # ms + end + + # Unpack output y_complex = reinterpret(reshape, ComplexF32, y_packed) + output = copy(y_complex) + + return (; output, times) +end - return copy(y_complex) +function fft_verify(data, result) + reference = FFTW.fft(Array(data.input), 2) + @assert isapprox(Array(result.output), reference, rtol=1e-4) end -# Validation and example +#============================================================================= + Legacy wrapper for backward compatibility +=============================================================================# + +function cutile_fft(x::CuMatrix{ComplexF32}, factors::NTuple{3, Int}; atom_packing_dim::Int=2) + BS = size(x, 1) + N = size(x, 2) + + # Create data structure similar to prepare + D = atom_packing_dim + N2D = N * 2 ÷ D + W0, W1, W2, T0, T1 = make_twiddles(factors) + W0_gpu = CuArray(W0) + W1_gpu = CuArray(W1) + W2_gpu = CuArray(W2) + T0_gpu = CuArray(T0) + T1_gpu = CuArray(T1) + + x_packed = reinterpret(reshape, Float32, x) + y_packed = CUDA.zeros(Float32, D, BS, N2D) + + data = (; + input=x, x_packed, y_packed, + W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, + factors, batch=BS, n=N, D, N2D + ) + + result = fft_run(data) + return result.output +end + +#============================================================================= + Main +=============================================================================# + function main() println("--- Running cuTile FFT Example ---") @@ -273,33 +336,15 @@ function main() println(" FFT Factors: $FFT_FACTORS") println(" Atom Packing Dim: $ATOM_PACKING_DIM") - # Create sample input - CUDA.seed!(42) - input_complex = CUDA.randn(ComplexF32, BATCH_SIZE, FFT_SIZE) - - println("\nInput data shape: $(size(input_complex)), dtype: $(eltype(input_complex))") - - # Perform FFT using cuTile kernel - output_cutile = cutile_fft(input_complex, FFT_FACTORS; atom_packing_dim=ATOM_PACKING_DIM) - - println("cuTile FFT Output shape: $(size(output_cutile)), dtype: $(eltype(output_cutile))") + # Use prepare/run/verify pattern + data = fft_prepare(; batch=BATCH_SIZE, n=FFT_SIZE, factors=FFT_FACTORS, atom_packing_dim=ATOM_PACKING_DIM) + println("\nInput data shape: $(size(data.input)), dtype: $(eltype(data.input))") - # Verify against reference (FFTW) - input_cpu = Array(input_complex) - reference_output = FFTW.fft(input_cpu, 2) + result = fft_run(data) + println("cuTile FFT Output shape: $(size(result.output)), dtype: $(eltype(result.output))") - output_cpu = Array(output_cutile) - - if isapprox(output_cpu, reference_output, rtol=1e-4) - println("\n✓ Correctness check PASSED") - else - max_diff = maximum(abs.(output_cpu .- reference_output)) - println("\n✗ Correctness check FAILED - max difference: $max_diff") - println("\nExpected (first 4):") - println(reference_output[1, 1:4]) - println("\nGot (first 4):") - println(output_cpu[1, 1:4]) - end + fft_verify(data, result) + println("\n✓ Correctness check PASSED") println("\n--- cuTile FFT example execution complete ---") end diff --git a/examples/fft.py b/examples/fft.py index 7977f7b..a1719ab 100644 --- a/examples/fft.py +++ b/examples/fft.py @@ -87,6 +87,10 @@ def fft_kernel(x_packed_in, y_packed_out, ct.store(y_packed_out, index=(bid, 0, 0), tile=Y_ri) +#============================================================================= +# Helper functions +#============================================================================= + def fft_twiddles(rows: int, cols: int, factor: int, device, precision): """Generate DFT twiddle factors.""" I, J = torch.meshgrid(torch.arange(rows, device=device), @@ -108,48 +112,90 @@ def fft_make_twiddles(factors, precision, device): return (W0, W1, W2, T0, T1) -def run_fft(*, batch: int = 64, size: int = 512, factors: tuple = (8, 8, 8), - atom_packing_dim: int = 2, input=None, validate: bool = False): - """Run FFT. Returns (input, output) tensors for benchmarking.""" +#============================================================================= +# prepare/run/verify pattern +#============================================================================= + +def fft_prepare(*, batch: int, size: int, factors: tuple, atom_packing_dim: int = 2): + """Allocate and initialize data for FFT.""" F0, F1, F2 = factors N = F0 * F1 * F2 assert size == N, f"size ({size}) must equal product of factors ({N})" D = atom_packing_dim - if input is None: - input = torch.randn(batch, N, dtype=torch.complex64, device='cuda') - else: - batch = input.shape[0] - N = input.shape[1] + input_data = torch.randn(batch, N, dtype=torch.complex64, device='cuda') # Pre-compute twiddles - W0, W1, W2, T0, T1 = fft_make_twiddles(factors, input.real.dtype, input.device) + W0, W1, W2, T0, T1 = fft_make_twiddles(factors, input_data.real.dtype, input_data.device) # Pack input - x_ri = torch.view_as_real(input) + x_ri = torch.view_as_real(input_data) x_packed = x_ri.reshape(batch, N * 2 // D, D).contiguous() y_packed = torch.empty_like(x_packed) + return { + "input": input_data, + "x_packed": x_packed, + "y_packed": y_packed, + "W0": W0, "W1": W1, "W2": W2, "T0": T0, "T1": T1, + "factors": factors, + "batch": batch, + "N": N, + "D": D + } + + +def fft_run(data, *, nruns: int = 1, warmup: int = 0): + """Run FFT kernel with timing.""" + x_packed = data["x_packed"] + y_packed = data["y_packed"] + W0, W1, W2, T0, T1 = data["W0"], data["W1"], data["W2"], data["T0"], data["T1"] + F0, F1, F2 = data["factors"] + batch, N, D = data["batch"], data["N"], data["D"] + grid = (batch, 1, 1) - ct.launch(torch.cuda.current_stream(), grid, fft_kernel, - (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D)) + + # Warmup + for _ in range(warmup): + ct.launch(torch.cuda.current_stream(), grid, fft_kernel, + (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D)) + torch.cuda.synchronize() + + # Timed runs + times = [] + for _ in range(nruns): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + ct.launch(torch.cuda.current_stream(), grid, fft_kernel, + (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D)) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms output = torch.view_as_complex(y_packed.reshape(batch, N, 2)) - if validate: - torch.cuda.synchronize() - reference = torch.fft.fft(input, dim=-1) - assert torch.allclose(output, reference, rtol=1e-3, atol=1e-3), \ - f"FFT incorrect! max diff: {torch.max(torch.abs(output - reference))}" + return {"output": output, "times": times} + + +def fft_verify(data, result): + """Verify FFT results.""" + reference = torch.fft.fft(data["input"], dim=-1) + assert torch.allclose(result["output"], reference, rtol=1e-3, atol=1e-3), \ + f"FFT incorrect! max diff: {torch.max(torch.abs(result['output'] - reference))}" - return input, output +#============================================================================= +# Test function +#============================================================================= def test_fft(batch, size, factors, name=None): """Test FFT with given parameters.""" name = name or f"fft batch={batch}, size={size}, factors={factors}" print(f"--- {name} ---") - run_fft(batch=batch, size=size, factors=factors, validate=True) + data = fft_prepare(batch=batch, size=size, factors=factors) + result = fft_run(data) + fft_verify(data, result) print(" passed") diff --git a/examples/layernorm.jl b/examples/layernorm.jl index be5e1fa..29aa2d6 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -273,190 +273,180 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa end #============================================================================= - Run Functions (for benchmarking or programmatic use) + Forward Pass - prepare/run/verify pattern =============================================================================# -# Run layernorm forward pass -function run_layernorm_fwd(; M::Int, N::Int, TILE_N::Int, eps::Float32=1f-5, - X::Union{CuArray,Nothing}=nothing, - W::Union{CuArray,Nothing}=nothing, - B::Union{CuArray,Nothing}=nothing, - Y::Union{CuArray,Nothing}=nothing, - Mean::Union{CuArray,Nothing}=nothing, - Rstd::Union{CuArray,Nothing}=nothing, - validate::Bool=false) - X = something(X, -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N)) - W = something(W, CUDA.randn(Float32, N)) - B = something(B, CUDA.randn(Float32, N)) - Y = something(Y, CUDA.zeros(Float32, M, N)) - Mean = something(Mean, CUDA.zeros(Float32, M)) - Rstd = something(Rstd, CUDA.zeros(Float32, M)) - - ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, - ct.Constant(eps), ct.Constant(TILE_N)) - - if validate - X_cpu = Array(X) - W_cpu = Array(W) - B_cpu = Array(B) - expected_mean = vec(sum(X_cpu, dims=2) ./ N) - expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) - expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ eps) - normalized = (X_cpu .- expected_mean) .* expected_rstd - expected_Y = normalized .* W_cpu' .+ B_cpu' - - atol, rtol = 1f-2, 1f-2 - @assert isapprox(expected_mean, Array(Mean); rtol, atol) "Mean mismatch" - @assert isapprox(expected_rstd, Array(Rstd); rtol, atol) "Rstd mismatch" - @assert isapprox(expected_Y, Array(Y); rtol, atol) "Y mismatch" - end - return (; X, W, B, Y, Mean, Rstd) +function layernorm_fwd_prepare(; M::Int, N::Int, eps::Float32=1f-5) + return (; + X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N), + W = CUDA.randn(Float32, N), + B = CUDA.randn(Float32, N), + Y = CUDA.zeros(Float32, M, N), + Mean = CUDA.zeros(Float32, M), + Rstd = CUDA.zeros(Float32, M), + M, N, eps + ) end -#============================================================================= - Test / Validation -=============================================================================# - -function main() - println("=== cuTile LayerNorm Sample ===\n") - - M, N = 1024, 2048 - TILE_N = 1024 - eps = 1f-5 - - println("Input shape: ($M, $N), dtype: Float32, eps: $eps") +function layernorm_fwd_run(data; TILE_N::Int, nruns::Int=1, warmup::Int=0) + (; X, W, B, Y, Mean, Rstd, M, eps) = data - # Input data - X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N) - W = CUDA.randn(Float32, N) - B = CUDA.randn(Float32, N) + for _ in 1:warmup + ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, + ct.Constant(eps), ct.Constant(TILE_N)) + end + CUDA.synchronize() - # Output buffers for forward pass - Y = CUDA.zeros(Float32, M, N) - Mean = CUDA.zeros(Float32, M) - Rstd = CUDA.zeros(Float32, M) + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, + ct.Constant(eps), ct.Constant(TILE_N)) + push!(times, t * 1000) # ms + end - # ========================================================================= - # Forward Pass - # ========================================================================= - println("\n--- Forward Pass ---") - ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, - ct.Constant(eps), ct.Constant(TILE_N)) + return (; Y, Mean, Rstd, times) +end - # Compute expected values on CPU +function layernorm_fwd_verify(data, result) + (; X, W, B, N, eps) = data X_cpu = Array(X) W_cpu = Array(W) B_cpu = Array(B) - expected_mean = vec(sum(X_cpu, dims=2) ./ N) expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ eps) normalized = (X_cpu .- expected_mean) .* expected_rstd expected_Y = normalized .* W_cpu' .+ B_cpu' - # Verify forward pass results - Mean_cpu = Array(Mean) - Rstd_cpu = Array(Rstd) - Y_cpu = Array(Y) - atol, rtol = 1f-2, 1f-2 - fwd_ok = isapprox(expected_mean, Mean_cpu; rtol, atol) && - isapprox(expected_rstd, Rstd_cpu; rtol, atol) && - isapprox(expected_Y, Y_cpu; rtol, atol) - - if fwd_ok - println("Forward pass: PASSED") - else - println("Forward pass: FAILED") - isapprox(expected_mean, Mean_cpu; rtol, atol) || println(" Mean max error: $(maximum(abs.(expected_mean .- Mean_cpu)))") - isapprox(expected_rstd, Rstd_cpu; rtol, atol) || println(" Rstd max error: $(maximum(abs.(expected_rstd .- Rstd_cpu)))") - isapprox(expected_Y, Y_cpu; rtol, atol) || println(" Y max error: $(maximum(abs.(expected_Y .- Y_cpu)))") - end + @assert isapprox(expected_mean, Array(result.Mean); rtol, atol) "Mean mismatch" + @assert isapprox(expected_rstd, Array(result.Rstd); rtol, atol) "Rstd mismatch" + @assert isapprox(expected_Y, Array(result.Y); rtol, atol) "Y mismatch" +end - # ========================================================================= - # Backward Pass (Full: dX, dW, dB) - # ========================================================================= - println("\n--- Backward Pass (Full: dX, dW, dB) ---") +#============================================================================= + Backward Pass - prepare/run/verify pattern +=============================================================================# - # Upstream gradient (random for testing) - DY = CUDA.randn(Float32, M, N) - DX = CUDA.zeros(Float32, M, N) +function layernorm_bwd_prepare(fwd_data, fwd_result; GROUP_SIZE_M::Int=64) + (; X, W, M, N) = fwd_data + (; Mean, Rstd) = fwd_result + return (; + X, W, Mean, Rstd, + DY = CUDA.randn(Float32, M, N), + DX = CUDA.zeros(Float32, M, N), + DW_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N), + DB_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N), + Locks = CUDA.zeros(Int, GROUP_SIZE_M), + FINAL_DW = CUDA.zeros(Float32, N), + FINAL_DB = CUDA.zeros(Float32, N), + M, N, GROUP_SIZE_M + ) +end - # Parameters for partial gradient accumulation - GROUP_SIZE_M = 64 - TILE_M = 32 +function layernorm_bwd_run(data; TILE_N::Int, TILE_M::Int=32, nruns::Int=1, warmup::Int=0) + (; X, W, Mean, Rstd, DY, DX, DW_partial, DB_partial, Locks, FINAL_DW, FINAL_DB, M, N, GROUP_SIZE_M) = data + + for _ in 1:warmup + # Reset partial buffers + fill!(DW_partial, 0) + fill!(DB_partial, 0) + fill!(Locks, 0) + ct.launch(layer_norm_bwd_dx_partial_dwdb, M, DX, DY, DW_partial, DB_partial, X, W, + Mean, Rstd, Locks, ct.Constant(GROUP_SIZE_M), ct.Constant(TILE_N)) + num_tiles_n = cld(N, TILE_N) + ct.launch(layer_norm_bwd_dwdb, num_tiles_n, DW_partial, DB_partial, FINAL_DW, FINAL_DB, + ct.Constant(TILE_M), ct.Constant(TILE_N)) + end + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + # Reset partial buffers + fill!(DW_partial, 0) + fill!(DB_partial, 0) + fill!(Locks, 0) + t = CUDA.@elapsed begin + ct.launch(layer_norm_bwd_dx_partial_dwdb, M, DX, DY, DW_partial, DB_partial, X, W, + Mean, Rstd, Locks, ct.Constant(GROUP_SIZE_M), ct.Constant(TILE_N)) + num_tiles_n = cld(N, TILE_N) + ct.launch(layer_norm_bwd_dwdb, num_tiles_n, DW_partial, DB_partial, FINAL_DW, FINAL_DB, + ct.Constant(TILE_M), ct.Constant(TILE_N)) + end + push!(times, t * 1000) # ms + end - # Partial gradient buffers and locks - DW_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N) - DB_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N) - Locks = CUDA.zeros(Int, GROUP_SIZE_M) + return (; DX, FINAL_DW, FINAL_DB, times) +end - # Final gradient buffers - FINAL_DW = CUDA.zeros(Float32, N) - FINAL_DB = CUDA.zeros(Float32, N) +function layernorm_bwd_verify(fwd_data, bwd_data, bwd_result) + (; X, W, N, eps) = fwd_data + (; DY, Mean, Rstd) = bwd_data - # Launch backward kernels - ct.launch(layer_norm_bwd_dx_partial_dwdb, M, DX, DY, DW_partial, DB_partial, X, W, - Mean, Rstd, Locks, ct.Constant(GROUP_SIZE_M), ct.Constant(TILE_N)) + X_cpu = Array(X) + W_cpu = Array(W) + DY_cpu = Array(DY) - num_tiles_n = cld(N, TILE_N) - ct.launch(layer_norm_bwd_dwdb, num_tiles_n, DW_partial, DB_partial, FINAL_DW, FINAL_DB, - ct.Constant(TILE_M), ct.Constant(TILE_N)) + # Compute expected values + expected_mean = vec(sum(X_cpu, dims=2) ./ N) + expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) + expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ eps) + xhat = (X_cpu .- expected_mean) .* expected_rstd - # Compute expected gradients on CPU - # dX = rstd * (W * dY - c2 - x_hat * c1) - # where c1 = mean(x_hat * W * dY), c2 = mean(W * dY) - DY_cpu = Array(DY) wdy = W_cpu' .* DY_cpu - xhat = normalized c1 = sum(xhat .* wdy, dims=2) ./ N c2 = sum(wdy, dims=2) ./ N expected_DX = (wdy .- (xhat .* c1 .+ c2)) .* expected_rstd - - # dW = sum(dY * x_hat, dim=0) and dB = sum(dY, dim=0) expected_DW = vec(sum(DY_cpu .* xhat, dims=1)) expected_DB = vec(sum(DY_cpu, dims=1)) - # Verify dX - DX_cpu = Array(DX) - dx_ok = isapprox(expected_DX, DX_cpu; rtol, atol) - if dx_ok - println(" dX: PASSED") - else - max_err = maximum(abs.(expected_DX .- DX_cpu)) - println(" dX: FAILED (max error: $max_err)") - end + atol, rtol = 1f-2, 1f-2 + @assert isapprox(expected_DX, Array(bwd_result.DX); rtol, atol) "dX mismatch" + @assert isapprox(expected_DW, Array(bwd_result.FINAL_DW); rtol, atol) "dW mismatch" + @assert isapprox(expected_DB, Array(bwd_result.FINAL_DB); rtol, atol) "dB mismatch" +end - # Verify dW - FINAL_DW_cpu = Array(FINAL_DW) - dw_ok = isapprox(expected_DW, FINAL_DW_cpu; rtol, atol) - if dw_ok - println(" dW: PASSED") - else - max_err = maximum(abs.(expected_DW .- FINAL_DW_cpu)) - println(" dW: FAILED (max error: $max_err)") - end +#============================================================================= + Main +=============================================================================# - # Verify dB - FINAL_DB_cpu = Array(FINAL_DB) - db_ok = isapprox(expected_DB, FINAL_DB_cpu; rtol, atol) - if db_ok - println(" dB: PASSED") - else - max_err = maximum(abs.(expected_DB .- FINAL_DB_cpu)) - println(" dB: FAILED (max error: $max_err)") - end +function main() + println("=== cuTile LayerNorm Sample ===\n") - bwd_ok = dx_ok && dw_ok && db_ok + M, N = 1024, 2048 + TILE_N = 1024 + eps = 1f-5 + + println("Input shape: ($M, $N), dtype: Float32, eps: $eps") + + # ========================================================================= + # Forward Pass + # ========================================================================= + println("\n--- Forward Pass ---") + fwd_data = layernorm_fwd_prepare(; M, N, eps) + fwd_result = layernorm_fwd_run(fwd_data; TILE_N) + layernorm_fwd_verify(fwd_data, fwd_result) + println("Forward pass: PASSED") + + # ========================================================================= + # Backward Pass (Full: dX, dW, dB) + # ========================================================================= + println("\n--- Backward Pass (Full: dX, dW, dB) ---") + GROUP_SIZE_M = 64 + TILE_M = 32 + bwd_data = layernorm_bwd_prepare(fwd_data, fwd_result; GROUP_SIZE_M) + bwd_result = layernorm_bwd_run(bwd_data; TILE_N, TILE_M) + layernorm_bwd_verify(fwd_data, bwd_data, bwd_result) + println(" dX: PASSED") + println(" dW: PASSED") + println(" dB: PASSED") # ========================================================================= # Summary # ========================================================================= println("\n=== Summary ===") - println("Forward pass: $(fwd_ok ? "PASSED" : "FAILED")") - println("Backward (dX/dW/dB): $(bwd_ok ? "PASSED" : "FAILED")") - - (fwd_ok && bwd_ok) || error("LayerNorm tests failed") + println("Forward pass: PASSED") + println("Backward (dX/dW/dB): PASSED") end isinteractive() || main() diff --git a/examples/layernorm.py b/examples/layernorm.py index 2007eb9..f833f4c 100644 --- a/examples/layernorm.py +++ b/examples/layernorm.py @@ -42,49 +42,80 @@ def layernorm_cutile_kernel(X, W, B, Y, Mean, Rstd, eps: ct.Constant[float], TIL ct.store(Y, index=(bid_m, j), tile=ty.astype(Y.dtype)) -def run_layernorm(*, M: int = 1024, N: int = 1024, tile_n: int = 1024, eps: float = 1e-5, - X=None, W=None, B=None, Y=None, Mean=None, Rstd=None, - dtype=np.float32, validate: bool = False): - """Run layer normalization. Returns (X, W, B, Y, Mean, Rstd) arrays for benchmarking.""" - if X is None: - X = (-2.3 + 0.5 * cp.random.randn(M, N)).astype(dtype) - else: - M, N = X.shape - if W is None: - W = cp.random.randn(N).astype(dtype) - if B is None: - B = cp.random.randn(N).astype(dtype) - if Y is None: - Y = cp.zeros((M, N), dtype=dtype) - if Mean is None: - Mean = cp.zeros(M, dtype=np.float32) - if Rstd is None: - Rstd = cp.zeros(M, dtype=np.float32) +#============================================================================= +# prepare/run/verify pattern +#============================================================================= + +def layernorm_prepare(*, M: int, N: int, eps: float = 1e-5, dtype=np.float32): + """Allocate and initialize data for layer normalization.""" + return { + "X": (-2.3 + 0.5 * cp.random.randn(M, N)).astype(dtype), + "W": cp.random.randn(N).astype(dtype), + "B": cp.random.randn(N).astype(dtype), + "Y": cp.zeros((M, N), dtype=dtype), + "Mean": cp.zeros(M, dtype=np.float32), + "Rstd": cp.zeros(M, dtype=np.float32), + "eps": eps, + "M": M, + "N": N + } + + +def layernorm_run(data, *, tile_n: int, nruns: int = 1, warmup: int = 0): + """Run layer normalization kernel with timing.""" + X, W, B, Y = data["X"], data["W"], data["B"], data["Y"] + Mean, Rstd = data["Mean"], data["Rstd"] + eps, M = data["eps"], data["M"] stream = cp.cuda.get_current_stream() - ct.launch(stream, (M,), layernorm_cutile_kernel, (X, W, B, Y, Mean, Rstd, eps, tile_n)) - if validate: - cp.cuda.runtime.deviceSynchronize() - X_np = cp.asnumpy(X) - W_np = cp.asnumpy(W) - B_np = cp.asnumpy(B) - expected_mean = np.mean(X_np, axis=1, keepdims=True) - expected_var = np.mean((X_np - expected_mean) ** 2, axis=1, keepdims=True) - expected_rstd = 1.0 / np.sqrt(expected_var + eps) - normalized = (X_np - expected_mean) * expected_rstd - expected_Y = normalized * W_np + B_np - assert np.allclose(cp.asnumpy(Y), expected_Y, rtol=1e-2, atol=1e-2), \ - f"layernorm incorrect! max diff: {np.max(np.abs(cp.asnumpy(Y) - expected_Y))}" + # Warmup + for _ in range(warmup): + ct.launch(stream, (M,), layernorm_cutile_kernel, (X, W, B, Y, Mean, Rstd, eps, tile_n)) + cp.cuda.runtime.deviceSynchronize() - return X, W, B, Y, Mean, Rstd + # Timed runs + times = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + ct.launch(stream, (M,), layernorm_cutile_kernel, (X, W, B, Y, Mean, Rstd, eps, tile_n)) + end.record(stream) + end.synchronize() + times.append(cp.cuda.get_elapsed_time(start, end)) # ms + return {"Y": Y, "Mean": Mean, "Rstd": Rstd, "times": times} + + +def layernorm_verify(data, result): + """Verify layer normalization results.""" + X_np = cp.asnumpy(data["X"]) + W_np = cp.asnumpy(data["W"]) + B_np = cp.asnumpy(data["B"]) + eps = data["eps"] + + expected_mean = np.mean(X_np, axis=1, keepdims=True) + expected_var = np.mean((X_np - expected_mean) ** 2, axis=1, keepdims=True) + expected_rstd = 1.0 / np.sqrt(expected_var + eps) + normalized = (X_np - expected_mean) * expected_rstd + expected_Y = normalized * W_np + B_np + + assert np.allclose(cp.asnumpy(result["Y"]), expected_Y, rtol=1e-2, atol=1e-2), \ + f"layernorm incorrect! max diff: {np.max(np.abs(cp.asnumpy(result['Y']) - expected_Y))}" + + +#============================================================================= +# Test function +#============================================================================= def test_layernorm(M, N, tile_n, eps=1e-5, dtype=np.float32, name=None): """Test layer normalization with given parameters.""" name = name or f"layernorm ({M}x{N}), tile={tile_n}, dtype={dtype.__name__}" print(f"--- {name} ---") - run_layernorm(M=M, N=N, tile_n=tile_n, eps=eps, dtype=dtype, validate=True) + data = layernorm_prepare(M=M, N=N, eps=eps, dtype=dtype) + result = layernorm_run(data, tile_n=tile_n) + layernorm_verify(data, result) print(" passed") diff --git a/examples/matmul.jl b/examples/matmul.jl index 247c929..dc25da5 100644 --- a/examples/matmul.jl +++ b/examples/matmul.jl @@ -62,35 +62,57 @@ function matmul_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArr return nothing end -# Run matmul - for benchmarking or programmatic use -function run_matmul(; M::Int, N::Int, K::Int, tm::Int, tn::Int, tk::Int, T::DataType=Float32, - A::Union{CuArray,Nothing}=nothing, - B::Union{CuArray,Nothing}=nothing, - C::Union{CuArray,Nothing}=nothing, - validate::Bool=false) - A = something(A, CUDA.rand(T, M, K)) - B = something(B, CUDA.rand(T, K, N)) - C = something(C, CUDA.zeros(T, M, N)) - grid_m = cld(M, tm) - grid_n = cld(N, tn) - grid = grid_m * grid_n - ct.launch(matmul_kernel, grid, A, B, C, - ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) - if validate - expected = Array(A) * Array(B) - result = Array(C) - @assert isapprox(result, expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(result - expected)))" +#============================================================================= + Matmul - prepare/run/verify pattern +=============================================================================# + +function matmul_prepare(; M::Int, N::Int, K::Int, T::DataType=Float32) + return (; + A = CUDA.rand(T, M, K), + B = CUDA.rand(T, K, N), + C = CUDA.zeros(T, M, N), + M, N, K + ) +end + +function matmul_run(data; tm::Int, tn::Int, tk::Int, nruns::Int=1, warmup::Int=0) + (; A, B, C, M, N, K) = data + grid = cld(M, tm) * cld(N, tn) + + for _ in 1:warmup + ct.launch(matmul_kernel, grid, A, B, C, + ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) + end + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(matmul_kernel, grid, A, B, C, + ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) + push!(times, t * 1000) # ms end - return (; A, B, C) + + return (; C, times) +end + +function matmul_verify(data, result) + expected = Array(data.A) * Array(data.B) + @assert isapprox(Array(result.C), expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))" end function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T name = something(name, "matmul ($M x $K) @ ($K x $N), $T, tiles=$tm x $tn x $tk") println("--- $name ---") - run_matmul(; M, N, K, tm, tn, tk, T, validate=true) + data = matmul_prepare(; M, N, K, T) + result = matmul_run(data; tm, tn, tk) + matmul_verify(data, result) println(" passed") end +#============================================================================= + Main +=============================================================================# + function main() println("--- cuTile Matrix Multiplication Examples ---\n") diff --git a/examples/matmul.py b/examples/matmul.py index 9eab93c..07fdb7a 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -44,42 +44,70 @@ def matmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk ct.store(C, index=(bidx, bidy), tile=accumulator) -def run_matmul(*, M: int = 1024, N: int = 1024, K: int = 1024, - tm: int = 64, tn: int = 64, tk: int = 64, - dtype=np.float32, A=None, B=None, C=None, validate: bool = False): - """Run matrix multiplication. Returns (A, B, C) arrays for benchmarking.""" - if A is None: - A = cp.random.randn(M, K).astype(dtype) - else: - M, K = A.shape - if B is None: - B = cp.random.randn(K, N).astype(dtype) - else: - K, N = B.shape - if C is None: - C = cp.zeros((M, N), dtype=dtype) +#============================================================================= +# prepare/run/verify pattern +#============================================================================= + +def matmul_prepare(*, M: int, N: int, K: int, dtype=np.float32): + """Allocate and initialize data for matmul.""" + return { + "A": cp.random.randn(M, K).astype(dtype), + "B": cp.random.randn(K, N).astype(dtype), + "C": cp.zeros((M, N), dtype=dtype), + "M": M, + "N": N, + "K": K + } + + +def matmul_run(data, *, tm: int, tn: int, tk: int, nruns: int = 1, warmup: int = 0): + """Run matmul kernel with timing.""" + A, B, C = data["A"], data["B"], data["C"] + M, N = data["M"], data["N"] grid_m = ceil(M / tm) grid_n = ceil(N / tn) grid = (grid_m * grid_n, 1, 1) stream = cp.cuda.get_current_stream() - ct.launch(stream, grid, matmul_cutile_kernel, (A, B, C, tm, tn, tk)) - if validate: - cp.cuda.runtime.deviceSynchronize() - expected = cp.asnumpy(A) @ cp.asnumpy(B) - # TF32 has reduced precision - assert np.allclose(cp.asnumpy(C), expected, rtol=1e-1, atol=1e-1), \ - f"matmul incorrect! max diff: {np.max(np.abs(cp.asnumpy(C) - expected))}" + # Warmup + for _ in range(warmup): + ct.launch(stream, grid, matmul_cutile_kernel, (A, B, C, tm, tn, tk)) + cp.cuda.runtime.deviceSynchronize() - return A, B, C + # Timed runs + times = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + ct.launch(stream, grid, matmul_cutile_kernel, (A, B, C, tm, tn, tk)) + end.record(stream) + end.synchronize() + times.append(cp.cuda.get_elapsed_time(start, end)) # ms + return {"C": C, "times": times} + + +def matmul_verify(data, result): + """Verify matmul results.""" + expected = cp.asnumpy(data["A"]) @ cp.asnumpy(data["B"]) + # TF32 has reduced precision + assert np.allclose(cp.asnumpy(result["C"]), expected, rtol=1e-1, atol=1e-1), \ + f"matmul incorrect! max diff: {np.max(np.abs(cp.asnumpy(result['C']) - expected))}" + + +#============================================================================= +# Test function +#============================================================================= def test_matmul(M, N, K, tm, tn, tk, dtype=np.float32, name=None): """Test matmul with given parameters.""" name = name or f"matmul ({M}x{K}) @ ({K}x{N}), tiles={tm}x{tn}x{tk}, dtype={dtype.__name__}" print(f"--- {name} ---") - run_matmul(M=M, N=N, K=K, tm=tm, tn=tn, tk=tk, dtype=dtype, validate=True) + data = matmul_prepare(M=M, N=N, K=K, dtype=dtype) + result = matmul_run(data, tm=tm, tn=tn, tk=tk) + matmul_verify(data, result) print(" passed") diff --git a/examples/transpose.jl b/examples/transpose.jl index 02a03b9..d99b3a1 100644 --- a/examples/transpose.jl +++ b/examples/transpose.jl @@ -17,30 +17,55 @@ function transpose_kernel(x::ct.TileArray{T,2}, y::ct.TileArray{T,2}, return end -# Run transpose - for benchmarking or programmatic use -function run_transpose(; m::Int, n::Int, tm::Int, tn::Int, T::DataType=Float32, - x::Union{CuArray,Nothing}=nothing, - y::Union{CuArray,Nothing}=nothing, - validate::Bool=false) - x = something(x, CUDA.rand(T, m, n)) - y = something(y, CUDA.zeros(T, n, m)) - grid_x = cld(m, tm) - grid_y = cld(n, tn) - ct.launch(transpose_kernel, (grid_x, grid_y), x, y, - ct.Constant(tm), ct.Constant(tn)) - if validate - @assert Array(y) ≈ transpose(Array(x)) +#============================================================================= + Transpose - prepare/run/verify pattern +=============================================================================# + +function transpose_prepare(; m::Int, n::Int, T::DataType=Float32) + return (; + x = CUDA.rand(T, m, n), + y = CUDA.zeros(T, n, m), + m, n + ) +end + +function transpose_run(data; tm::Int, tn::Int, nruns::Int=1, warmup::Int=0) + (; x, y, m, n) = data + grid = (cld(m, tm), cld(n, tn)) + + for _ in 1:warmup + ct.launch(transpose_kernel, grid, x, y, + ct.Constant(tm), ct.Constant(tn)) + end + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(transpose_kernel, grid, x, y, + ct.Constant(tm), ct.Constant(tn)) + push!(times, t * 1000) # ms end - return (; x, y) + + return (; y, times) +end + +function transpose_verify(data, result) + @assert Array(result.y) ≈ transpose(Array(data.x)) end function test_transpose(::Type{T}, m, n, tm, tn; name=nothing) where T name = something(name, "transpose ($m x $n, $T, tiles=$tm x $tn)") println("--- $name ---") - run_transpose(; m, n, tm, tn, T, validate=true) + data = transpose_prepare(; m, n, T) + result = transpose_run(data; tm, tn) + transpose_verify(data, result) println("✓ passed") end +#============================================================================= + Main +=============================================================================# + function main() println("--- cuTile Matrix Transposition Examples ---\n") diff --git a/examples/transpose.py b/examples/transpose.py index db0f7ab..eea33d7 100644 --- a/examples/transpose.py +++ b/examples/transpose.py @@ -16,33 +16,65 @@ def transpose_cutile_kernel(input, output, tile_m: ct.Constant[int], tile_n: ct. ct.store(output, index=(pid_n, pid_m), tile=tile_t) -def run_transpose(*, M: int = 1024, N: int = 1024, tile_m: int = 64, tile_n: int = 64, - dtype=np.float32, input=None, output=None, validate: bool = False): - """Run matrix transpose. Returns (input, output) arrays for benchmarking.""" - if input is None: - input = cp.random.rand(M, N).astype(dtype) - else: - M, N = input.shape - if output is None: - output = cp.zeros((N, M), dtype=dtype) +#============================================================================= +# prepare/run/verify pattern +#============================================================================= + +def transpose_prepare(*, M: int, N: int, dtype=np.float32): + """Allocate and initialize data for transpose.""" + return { + "input": cp.random.rand(M, N).astype(dtype), + "output": cp.zeros((N, M), dtype=dtype), + "M": M, + "N": N + } + + +def transpose_run(data, *, tile_m: int, tile_n: int, nruns: int = 1, warmup: int = 0): + """Run transpose kernel with timing.""" + input_arr = data["input"] + output_arr = data["output"] + M, N = data["M"], data["N"] grid = (ct.cdiv(M, tile_m), ct.cdiv(N, tile_n), 1) stream = cp.cuda.get_current_stream() - ct.launch(stream, grid, transpose_cutile_kernel, (input, output, tile_m, tile_n)) - if validate: - cp.cuda.runtime.deviceSynchronize() - expected = cp.asnumpy(input).T - assert np.allclose(cp.asnumpy(output), expected), "transpose incorrect!" + # Warmup + for _ in range(warmup): + ct.launch(stream, grid, transpose_cutile_kernel, (input_arr, output_arr, tile_m, tile_n)) + cp.cuda.runtime.deviceSynchronize() + + # Timed runs + times = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + ct.launch(stream, grid, transpose_cutile_kernel, (input_arr, output_arr, tile_m, tile_n)) + end.record(stream) + end.synchronize() + times.append(cp.cuda.get_elapsed_time(start, end)) # ms + + return {"output": output_arr, "times": times} + + +def transpose_verify(data, result): + """Verify transpose results.""" + expected = cp.asnumpy(data["input"]).T + assert np.allclose(cp.asnumpy(result["output"]), expected), "transpose incorrect!" - return input, output +#============================================================================= +# Test function +#============================================================================= def test_transpose(M, N, tile_m, tile_n, dtype=np.float32, name=None): """Test transpose with given parameters.""" name = name or f"transpose ({M}x{N}), tiles={tile_m}x{tile_n}, dtype={dtype.__name__}" print(f"--- {name} ---") - run_transpose(M=M, N=N, tile_m=tile_m, tile_n=tile_n, dtype=dtype, validate=True) + data = transpose_prepare(M=M, N=N, dtype=dtype) + result = transpose_run(data, tile_m=tile_m, tile_n=tile_n) + transpose_verify(data, result) print(" passed") diff --git a/examples/vadd.jl b/examples/vadd.jl index fcc04b3..2f14002 100644 --- a/examples/vadd.jl +++ b/examples/vadd.jl @@ -27,95 +27,163 @@ function vec_add_kernel_2d(a::ct.TileArray{T,2}, b::ct.TileArray{T,2}, c::ct.Til return end -# Run 1D vector addition - for benchmarking or programmatic use -function run_vadd_1d(; n::Int, tile::Int, T::DataType=Float32, - a::Union{CuArray,Nothing}=nothing, - b::Union{CuArray,Nothing}=nothing, - c::Union{CuArray,Nothing}=nothing, - validate::Bool=false) - a = something(a, CUDA.rand(T, n)) - b = something(b, CUDA.rand(T, n)) - c = something(c, CUDA.zeros(T, n)) - ct.launch(vec_add_kernel_1d, cld(n, tile), a, b, c, ct.Constant(tile)) - if validate - @assert Array(c) ≈ Array(a) + Array(b) +# 1D kernel using gather/scatter (explicit index-based memory access) +# This demonstrates the gather/scatter API for cases where you need +# explicit control over indices (e.g., for non-contiguous access patterns) +function vec_add_kernel_1d_gather(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c::ct.TileArray{T,1}, + tile::ct.Constant{Int}) where {T} + bid = ct.bid(1) + # Create index tile for this block's elements + offsets = ct.arange((tile[],), Int32) + base = ct.Tile((bid - Int32(1)) * Int32(tile[])) + indices = ct.broadcast_to(base, (tile[],)) .+ offsets + + # Gather, add, scatter + a_tile = ct.gather(a, indices) + b_tile = ct.gather(b, indices) + sum_tile = a_tile + b_tile + ct.scatter(c, indices, sum_tile) + return +end + +#============================================================================= + 1D Vector Addition - prepare/run/verify pattern +=============================================================================# + +function vadd_1d_prepare(; n::Int, T::DataType=Float32) + return (; + a = CUDA.rand(T, n), + b = CUDA.rand(T, n), + c = CUDA.zeros(T, n), + n + ) +end + +function vadd_1d_run(data; tile::Int, nruns::Int=1, warmup::Int=0) + (; a, b, c, n) = data + grid = cld(n, tile) + + for _ in 1:warmup + ct.launch(vec_add_kernel_1d, grid, a, b, c, ct.Constant(tile)) end - return (; a, b, c) + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(vec_add_kernel_1d, grid, a, b, c, ct.Constant(tile)) + push!(times, t * 1000) # ms + end + + return (; c, times) +end + +function vadd_1d_verify(data, result) + @assert Array(result.c) ≈ Array(data.a) + Array(data.b) end function test_add_1d(::Type{T}, n, tile; name=nothing) where T name = something(name, "1D vec_add ($n elements, $T, tile=$tile)") println("--- $name ---") - run_vadd_1d(; n, tile, T, validate=true) + data = vadd_1d_prepare(; n, T) + result = vadd_1d_run(data; tile) + vadd_1d_verify(data, result) println("✓ passed") end -# Run 2D matrix addition - for benchmarking or programmatic use -function run_vadd_2d(; m::Int, n::Int, tile_x::Int, tile_y::Int, T::DataType=Float32, - a::Union{CuArray,Nothing}=nothing, - b::Union{CuArray,Nothing}=nothing, - c::Union{CuArray,Nothing}=nothing, - validate::Bool=false) - a = something(a, CUDA.rand(T, m, n)) - b = something(b, CUDA.rand(T, m, n)) - c = something(c, CUDA.zeros(T, m, n)) - ct.launch(vec_add_kernel_2d, (cld(m, tile_x), cld(n, tile_y)), a, b, c, - ct.Constant(tile_x), ct.Constant(tile_y)) - if validate - @assert Array(c) ≈ Array(a) + Array(b) +#============================================================================= + 2D Matrix Addition - prepare/run/verify pattern +=============================================================================# + +function vadd_2d_prepare(; m::Int, n::Int, T::DataType=Float32) + return (; + a = CUDA.rand(T, m, n), + b = CUDA.rand(T, m, n), + c = CUDA.zeros(T, m, n), + m, n + ) +end + +function vadd_2d_run(data; tile_x::Int, tile_y::Int, nruns::Int=1, warmup::Int=0) + (; a, b, c, m, n) = data + grid = (cld(m, tile_x), cld(n, tile_y)) + + for _ in 1:warmup + ct.launch(vec_add_kernel_2d, grid, a, b, c, + ct.Constant(tile_x), ct.Constant(tile_y)) end - return (; a, b, c) + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(vec_add_kernel_2d, grid, a, b, c, + ct.Constant(tile_x), ct.Constant(tile_y)) + push!(times, t * 1000) # ms + end + + return (; c, times) +end + +function vadd_2d_verify(data, result) + @assert Array(result.c) ≈ Array(data.a) + Array(data.b) end function test_add_2d(::Type{T}, m, n, tile_x, tile_y; name=nothing) where T name = something(name, "2D vec_add ($m x $n, $T, tiles=$tile_x x $tile_y)") println("--- $name ---") - run_vadd_2d(; m, n, tile_x, tile_y, T, validate=true) + data = vadd_2d_prepare(; m, n, T) + result = vadd_2d_run(data; tile_x, tile_y) + vadd_2d_verify(data, result) println("✓ passed") end -# 1D kernel using gather/scatter (explicit index-based memory access) -# This demonstrates the gather/scatter API for cases where you need -# explicit control over indices (e.g., for non-contiguous access patterns) -function vec_add_kernel_1d_gather(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c::ct.TileArray{T,1}, - tile::ct.Constant{Int}) where {T} - bid = ct.bid(1) - # Create index tile for this block's elements - offsets = ct.arange((tile[],), Int32) - base = ct.Tile((bid - Int32(1)) * Int32(tile[])) - indices = ct.broadcast_to(base, (tile[],)) .+ offsets - - # Gather, add, scatter - a_tile = ct.gather(a, indices) - b_tile = ct.gather(b, indices) - sum_tile = a_tile + b_tile - ct.scatter(c, indices, sum_tile) - return +#============================================================================= + 1D Gather/Scatter Vector Addition - prepare/run/verify pattern +=============================================================================# + +function vadd_1d_gather_prepare(; n::Int, T::DataType=Float32) + return (; + a = CUDA.rand(T, n), + b = CUDA.rand(T, n), + c = CUDA.zeros(T, n), + n + ) end -# Run 1D gather/scatter vector addition - for benchmarking or programmatic use -function run_vadd_1d_gather(; n::Int, tile::Int, T::DataType=Float32, - a::Union{CuArray,Nothing}=nothing, - b::Union{CuArray,Nothing}=nothing, - c::Union{CuArray,Nothing}=nothing, - validate::Bool=false) - a = something(a, CUDA.rand(T, n)) - b = something(b, CUDA.rand(T, n)) - c = something(c, CUDA.zeros(T, n)) - ct.launch(vec_add_kernel_1d_gather, cld(n, tile), a, b, c, ct.Constant(tile)) - if validate - @assert Array(c) ≈ Array(a) + Array(b) +function vadd_1d_gather_run(data; tile::Int, nruns::Int=1, warmup::Int=0) + (; a, b, c, n) = data + grid = cld(n, tile) + + for _ in 1:warmup + ct.launch(vec_add_kernel_1d_gather, grid, a, b, c, ct.Constant(tile)) + end + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(vec_add_kernel_1d_gather, grid, a, b, c, ct.Constant(tile)) + push!(times, t * 1000) # ms end - return (; a, b, c) + + return (; c, times) +end + +function vadd_1d_gather_verify(data, result) + @assert Array(result.c) ≈ Array(data.a) + Array(data.b) end function test_add_1d_gather(::Type{T}, n, tile; name=nothing) where T name = something(name, "1D vec_add gather ($n elements, $T, tile=$tile)") println("--- $name ---") - run_vadd_1d_gather(; n, tile, T, validate=true) + data = vadd_1d_gather_prepare(; n, T) + result = vadd_1d_gather_run(data; tile) + vadd_1d_gather_verify(data, result) println("✓ passed") end +#============================================================================= + Main +=============================================================================# + function main() println("--- cuTile Vector/Matrix Addition Examples ---\n") diff --git a/examples/vadd.py b/examples/vadd.py index 8d997bb..77d77e8 100644 --- a/examples/vadd.py +++ b/examples/vadd.py @@ -16,29 +16,64 @@ def vadd_cutile_kernel(a, b, c, tile_size: ct.Constant[int]): ct.store(c, index=(pid,), tile=result) -def run_vadd(*, size: int = 2**20, tile: int = 1024, dtype=np.float32, validate: bool = False): - """Run vector addition. Returns (a, b, c) arrays for benchmarking.""" - a = cp.random.rand(size).astype(dtype) - b = cp.random.rand(size).astype(dtype) - c = cp.zeros(size, dtype=dtype) +#============================================================================= +# prepare/run/verify pattern +#============================================================================= - grid = (ct.cdiv(size, tile), 1, 1) +def vadd_prepare(*, n: int, dtype=np.float32): + """Allocate and initialize data for vector addition.""" + return { + "a": cp.random.rand(n).astype(dtype), + "b": cp.random.rand(n).astype(dtype), + "c": cp.zeros(n, dtype=dtype), + "n": n + } + + +def vadd_run(data, *, tile: int, nruns: int = 1, warmup: int = 0): + """Run vector addition kernel with timing.""" + a, b, c = data["a"], data["b"], data["c"] + n = data["n"] + + grid = (ct.cdiv(n, tile), 1, 1) stream = cp.cuda.get_current_stream() - ct.launch(stream, grid, vadd_cutile_kernel, (a, b, c, tile)) - if validate: - cp.cuda.runtime.deviceSynchronize() - expected = cp.asnumpy(a) + cp.asnumpy(b) - assert np.allclose(cp.asnumpy(c), expected), "vadd incorrect!" + # Warmup + for _ in range(warmup): + ct.launch(stream, grid, vadd_cutile_kernel, (a, b, c, tile)) + cp.cuda.runtime.deviceSynchronize() + + # Timed runs + times = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + ct.launch(stream, grid, vadd_cutile_kernel, (a, b, c, tile)) + end.record(stream) + end.synchronize() + times.append(cp.cuda.get_elapsed_time(start, end)) # ms + + return {"c": c, "times": times} + + +def vadd_verify(data, result): + """Verify vector addition results.""" + expected = cp.asnumpy(data["a"]) + cp.asnumpy(data["b"]) + assert np.allclose(cp.asnumpy(result["c"]), expected), "vadd incorrect!" - return a, b, c +#============================================================================= +# Test function +#============================================================================= -def test_vadd(size, tile, dtype=np.float32, name=None): +def test_vadd(n, tile, dtype=np.float32, name=None): """Test vector addition with given parameters.""" - name = name or f"vadd size={size}, tile={tile}, dtype={dtype.__name__}" + name = name or f"vadd size={n}, tile={tile}, dtype={dtype.__name__}" print(f"--- {name} ---") - run_vadd(size=size, tile=tile, dtype=dtype, validate=True) + data = vadd_prepare(n=n, dtype=dtype) + result = vadd_run(data, tile=tile) + vadd_verify(data, result) print(" passed") From 6b01aa9345ac035664b4cebe9aa764e5792cb03c Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 14 Jan 2026 09:05:50 -0500 Subject: [PATCH 3/7] Unify --- examples/batchmatmul.jl | 10 +++++----- examples/batchmatmul.py | 4 ++-- examples/fft.jl | 33 +-------------------------------- examples/fft.py | 9 ++------- examples/matmul.jl | 10 +++++----- examples/matmul.py | 4 ++-- examples/transpose.jl | 10 +++++----- examples/transpose.py | 4 ++-- 8 files changed, 24 insertions(+), 60 deletions(-) diff --git a/examples/batchmatmul.jl b/examples/batchmatmul.jl index f7ac0f0..fa81d18 100644 --- a/examples/batchmatmul.jl +++ b/examples/batchmatmul.jl @@ -58,7 +58,7 @@ function batch_matmul_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.T end #============================================================================= - Batch Matmul - prepare/run/verify pattern + Example harness =============================================================================# function batchmatmul_prepare(; M::Int, K::Int, N::Int, Batch::Int, T::DataType=Float32) @@ -101,6 +101,10 @@ function batchmatmul_verify(data, result) @assert isapprox(Array(result.C), expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))" end +#============================================================================= + Main +=============================================================================# + function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) where T name = something(name, "batch_matmul ($M x $K x $Batch) @ ($K x $N x $Batch), $T, tiles=$tm x $tn x $tk") println("--- $name ---") @@ -110,10 +114,6 @@ function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) println(" passed") end -#============================================================================= - Main -=============================================================================# - function main() println("--- cuTile Batch Matrix Multiplication Examples ---\n") diff --git a/examples/batchmatmul.py b/examples/batchmatmul.py index 01a8d8a..075cbb2 100644 --- a/examples/batchmatmul.py +++ b/examples/batchmatmul.py @@ -37,7 +37,7 @@ def batchmatmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int #============================================================================= -# prepare/run/verify pattern +# Example harness #============================================================================= def batchmatmul_prepare(*, Batch: int, M: int, K: int, N: int, dtype=np.float16): @@ -95,7 +95,7 @@ def batchmatmul_verify(data, result): #============================================================================= -# Test function +# Main #============================================================================= def test_batchmatmul(Batch, M, K, N, tm, tn, tk, dtype=np.float16, name=None): diff --git a/examples/fft.jl b/examples/fft.jl index 899f62b..c07dff1 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -212,7 +212,7 @@ function make_twiddles(factors::NTuple{3, Int}) end #============================================================================= - FFT - prepare/run/verify pattern + Example harness =============================================================================# function fft_prepare(; batch::Int, n::Int, factors::NTuple{3,Int}, atom_packing_dim::Int=2) @@ -286,37 +286,6 @@ function fft_verify(data, result) @assert isapprox(Array(result.output), reference, rtol=1e-4) end -#============================================================================= - Legacy wrapper for backward compatibility -=============================================================================# - -function cutile_fft(x::CuMatrix{ComplexF32}, factors::NTuple{3, Int}; atom_packing_dim::Int=2) - BS = size(x, 1) - N = size(x, 2) - - # Create data structure similar to prepare - D = atom_packing_dim - N2D = N * 2 ÷ D - W0, W1, W2, T0, T1 = make_twiddles(factors) - W0_gpu = CuArray(W0) - W1_gpu = CuArray(W1) - W2_gpu = CuArray(W2) - T0_gpu = CuArray(T0) - T1_gpu = CuArray(T1) - - x_packed = reinterpret(reshape, Float32, x) - y_packed = CUDA.zeros(Float32, D, BS, N2D) - - data = (; - input=x, x_packed, y_packed, - W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, - factors, batch=BS, n=N, D, N2D - ) - - result = fft_run(data) - return result.output -end - #============================================================================= Main =============================================================================# diff --git a/examples/fft.py b/examples/fft.py index a1719ab..0e35e6d 100644 --- a/examples/fft.py +++ b/examples/fft.py @@ -86,11 +86,6 @@ def fft_kernel(x_packed_in, y_packed_out, Y_ri = ct.reshape(ct.cat((X_r, X_i), axis=-1), (BS, N * 2 // D, D)) ct.store(y_packed_out, index=(bid, 0, 0), tile=Y_ri) - -#============================================================================= -# Helper functions -#============================================================================= - def fft_twiddles(rows: int, cols: int, factor: int, device, precision): """Generate DFT twiddle factors.""" I, J = torch.meshgrid(torch.arange(rows, device=device), @@ -113,7 +108,7 @@ def fft_make_twiddles(factors, precision, device): #============================================================================= -# prepare/run/verify pattern +# Example harness #============================================================================= def fft_prepare(*, batch: int, size: int, factors: tuple, atom_packing_dim: int = 2): @@ -186,7 +181,7 @@ def fft_verify(data, result): #============================================================================= -# Test function +# Main #============================================================================= def test_fft(batch, size, factors, name=None): diff --git a/examples/matmul.jl b/examples/matmul.jl index dc25da5..ef14c66 100644 --- a/examples/matmul.jl +++ b/examples/matmul.jl @@ -63,7 +63,7 @@ function matmul_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArr end #============================================================================= - Matmul - prepare/run/verify pattern + Example harness =============================================================================# function matmul_prepare(; M::Int, N::Int, K::Int, T::DataType=Float32) @@ -100,6 +100,10 @@ function matmul_verify(data, result) @assert isapprox(Array(result.C), expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))" end +#============================================================================= + Main +=============================================================================# + function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T name = something(name, "matmul ($M x $K) @ ($K x $N), $T, tiles=$tm x $tn x $tk") println("--- $name ---") @@ -109,10 +113,6 @@ function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T println(" passed") end -#============================================================================= - Main -=============================================================================# - function main() println("--- cuTile Matrix Multiplication Examples ---\n") diff --git a/examples/matmul.py b/examples/matmul.py index 07fdb7a..a0ecbf4 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -45,7 +45,7 @@ def matmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk #============================================================================= -# prepare/run/verify pattern +# Example harness #============================================================================= def matmul_prepare(*, M: int, N: int, K: int, dtype=np.float32): @@ -98,7 +98,7 @@ def matmul_verify(data, result): #============================================================================= -# Test function +# Main #============================================================================= def test_matmul(M, N, K, tm, tn, tk, dtype=np.float32, name=None): diff --git a/examples/transpose.jl b/examples/transpose.jl index d99b3a1..f199675 100644 --- a/examples/transpose.jl +++ b/examples/transpose.jl @@ -18,7 +18,7 @@ function transpose_kernel(x::ct.TileArray{T,2}, y::ct.TileArray{T,2}, end #============================================================================= - Transpose - prepare/run/verify pattern + Example harness =============================================================================# function transpose_prepare(; m::Int, n::Int, T::DataType=Float32) @@ -53,6 +53,10 @@ function transpose_verify(data, result) @assert Array(result.y) ≈ transpose(Array(data.x)) end +#============================================================================= + Main +=============================================================================# + function test_transpose(::Type{T}, m, n, tm, tn; name=nothing) where T name = something(name, "transpose ($m x $n, $T, tiles=$tm x $tn)") println("--- $name ---") @@ -62,10 +66,6 @@ function test_transpose(::Type{T}, m, n, tm, tn; name=nothing) where T println("✓ passed") end -#============================================================================= - Main -=============================================================================# - function main() println("--- cuTile Matrix Transposition Examples ---\n") diff --git a/examples/transpose.py b/examples/transpose.py index eea33d7..52e4113 100644 --- a/examples/transpose.py +++ b/examples/transpose.py @@ -17,7 +17,7 @@ def transpose_cutile_kernel(input, output, tile_m: ct.Constant[int], tile_n: ct. #============================================================================= -# prepare/run/verify pattern +# Example harness #============================================================================= def transpose_prepare(*, M: int, N: int, dtype=np.float32): @@ -65,7 +65,7 @@ def transpose_verify(data, result): #============================================================================= -# Test function +# Main #============================================================================= def test_transpose(M, N, tile_m, tile_n, dtype=np.float32, name=None): From d6783884e995e87de0156b94430c3cbb1fac534a Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 14 Jan 2026 09:38:28 -0500 Subject: [PATCH 4/7] Merge layernorm. --- examples/benchmarks.jl | 28 ++++-- examples/benchmarks.py | 20 ++++- examples/layernorm.jl | 177 +++++++++++++----------------------- examples/layernorm.py | 198 +++++++++++++++++++++++++++++++++++------ 4 files changed, 270 insertions(+), 153 deletions(-) diff --git a/examples/benchmarks.jl b/examples/benchmarks.jl index ca9840a..72272f3 100644 --- a/examples/benchmarks.jl +++ b/examples/benchmarks.jl @@ -348,8 +348,8 @@ function benchmark_layernorm() M, N = LAYERNORM_M, LAYERNORM_N println(" Size: $(M)x$(N) ($(M * N * 4 / 1e6) MB)") - # Prepare data once (using layernorm.jl's prepare function) - data = layernorm_fwd_prepare(; M, N, T=Float32, eps=LAYERNORM_EPS) + # Prepare data once (using layernorm.jl's unified prepare function) + data = layernorm_prepare(; M, N, eps=LAYERNORM_EPS) (; X, W, B, Y, Mean, Rstd) = data # Reference result @@ -373,17 +373,29 @@ function benchmark_layernorm() min_t, mean_t = benchmark_kernel(simt_f) push!(results, BenchmarkResult("SIMT naive", min_t, mean_t)) - # cuTile (using layernorm.jl's run/verify functions) - result = layernorm_fwd_run(data; tile_n=LAYERNORM_TILE_N, nruns=NRUNS, warmup=WARMUP) - layernorm_fwd_verify(data, result) - min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) + # cuTile (using layernorm.jl's unified run/verify functions) + result = layernorm_run(data; TILE_N=LAYERNORM_TILE_N, nruns=NRUNS, warmup=WARMUP) + layernorm_verify(data, result) + + # Forward pass timing + min_t_fwd = minimum(result.times_fwd) + mean_t_fwd = sum(result.times_fwd) / length(result.times_fwd) + push!(results, BenchmarkResult("cuTile Fwd", min_t_fwd, mean_t_fwd)) + + # Backward pass timing + min_t_bwd = minimum(result.times_bwd) + mean_t_bwd = sum(result.times_bwd) / length(result.times_bwd) # Calculate bandwidth (rough estimate: 3 reads of X + W + B, 1 write of Y) bytes = (3 * M * N + N + N + M * N) * sizeof(Float32) bandwidths = [string(round(bytes / (r.min_ms / 1000) / 1e9, digits=1), " GB/s") for r in results] - print_table("Layer Normalization (Float32)", results; extra_col=("Bandwidth", bandwidths)) + print_table("Layer Normalization Forward (Float32)", results; extra_col=("Bandwidth", bandwidths)) + + # Print backward results separately + bwd_results = [BenchmarkResult("cuTile Bwd", min_t_bwd, mean_t_bwd)] + print_table("Layer Normalization Backward (Float32)", bwd_results) + return results end diff --git a/examples/benchmarks.py b/examples/benchmarks.py index b27addd..9b865d6 100644 --- a/examples/benchmarks.py +++ b/examples/benchmarks.py @@ -377,17 +377,29 @@ def torch_layernorm(): min_t, mean_t = benchmark_torch(torch_layernorm) results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - # cuTile - use prepare/run/verify pattern + # cuTile - use prepare/run/verify pattern (unified fwd+bwd) result = layernorm_run(data, tile_n=LAYERNORM_TILE_N, nruns=NRUNS, warmup=WARMUP) layernorm_verify(data, result) - min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) + + # Forward pass timing + min_t_fwd = min(result["times_fwd"]) + mean_t_fwd = sum(result["times_fwd"]) / len(result["times_fwd"]) + results.append(BenchmarkResult("cuTile Fwd", min_t_fwd, mean_t_fwd)) + + # Backward pass timing + min_t_bwd = min(result["times_bwd"]) + mean_t_bwd = sum(result["times_bwd"]) / len(result["times_bwd"]) # Calculate bandwidth (rough estimate: 3 reads of X + W + B, 1 write of Y) bytes_transferred = (3 * M * N + N + N + M * N) * 4 bandwidths = [f"{bytes_transferred / (r.min_ms / 1000) / 1e9:.1f} GB/s" for r in results] - print_table("Layer Normalization (Float32)", results, extra_col=("Bandwidth", bandwidths)) + print_table("Layer Normalization Forward (Float32)", results, extra_col=("Bandwidth", bandwidths)) + + # Print backward results separately + bwd_results = [BenchmarkResult("cuTile Bwd", min_t_bwd, mean_t_bwd)] + print_table("Layer Normalization Backward (Float32)", bwd_results) + return results diff --git a/examples/layernorm.jl b/examples/layernorm.jl index 29aa2d6..f368509 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -273,82 +273,41 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa end #============================================================================= - Forward Pass - prepare/run/verify pattern + Unified prepare/run/verify pattern (fwd + bwd) =============================================================================# -function layernorm_fwd_prepare(; M::Int, N::Int, eps::Float32=1f-5) +function layernorm_prepare(; M::Int, N::Int, eps::Float32=1f-5, GROUP_SIZE_M::Int=64) return (; + # Forward inputs/outputs X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N), W = CUDA.randn(Float32, N), B = CUDA.randn(Float32, N), Y = CUDA.zeros(Float32, M, N), Mean = CUDA.zeros(Float32, M), Rstd = CUDA.zeros(Float32, M), - M, N, eps - ) -end - -function layernorm_fwd_run(data; TILE_N::Int, nruns::Int=1, warmup::Int=0) - (; X, W, B, Y, Mean, Rstd, M, eps) = data - - for _ in 1:warmup - ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, - ct.Constant(eps), ct.Constant(TILE_N)) - end - CUDA.synchronize() - - times = Float64[] - for _ in 1:nruns - t = CUDA.@elapsed ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, - ct.Constant(eps), ct.Constant(TILE_N)) - push!(times, t * 1000) # ms - end - - return (; Y, Mean, Rstd, times) -end - -function layernorm_fwd_verify(data, result) - (; X, W, B, N, eps) = data - X_cpu = Array(X) - W_cpu = Array(W) - B_cpu = Array(B) - expected_mean = vec(sum(X_cpu, dims=2) ./ N) - expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) - expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ eps) - normalized = (X_cpu .- expected_mean) .* expected_rstd - expected_Y = normalized .* W_cpu' .+ B_cpu' - - atol, rtol = 1f-2, 1f-2 - @assert isapprox(expected_mean, Array(result.Mean); rtol, atol) "Mean mismatch" - @assert isapprox(expected_rstd, Array(result.Rstd); rtol, atol) "Rstd mismatch" - @assert isapprox(expected_Y, Array(result.Y); rtol, atol) "Y mismatch" -end - -#============================================================================= - Backward Pass - prepare/run/verify pattern -=============================================================================# - -function layernorm_bwd_prepare(fwd_data, fwd_result; GROUP_SIZE_M::Int=64) - (; X, W, M, N) = fwd_data - (; Mean, Rstd) = fwd_result - return (; - X, W, Mean, Rstd, - DY = CUDA.randn(Float32, M, N), + # Backward inputs/outputs + DY = 0.1f0 .* CUDA.randn(Float32, M, N), DX = CUDA.zeros(Float32, M, N), DW_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N), DB_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N), Locks = CUDA.zeros(Int, GROUP_SIZE_M), FINAL_DW = CUDA.zeros(Float32, N), FINAL_DB = CUDA.zeros(Float32, N), - M, N, GROUP_SIZE_M + # Metadata + M, N, eps, GROUP_SIZE_M ) end -function layernorm_bwd_run(data; TILE_N::Int, TILE_M::Int=32, nruns::Int=1, warmup::Int=0) - (; X, W, Mean, Rstd, DY, DX, DW_partial, DB_partial, Locks, FINAL_DW, FINAL_DB, M, N, GROUP_SIZE_M) = data +function layernorm_run(data; TILE_N::Int, TILE_M::Int=32, nruns::Int=1, warmup::Int=0) + (; X, W, B, Y, Mean, Rstd, DY, DX, DW_partial, DB_partial, Locks, FINAL_DW, FINAL_DB, + M, N, eps, GROUP_SIZE_M) = data - for _ in 1:warmup - # Reset partial buffers + function run_fwd() + ct.launch(layer_norm_fwd, M, X, W, B, Y, Mean, Rstd, + ct.Constant(eps), ct.Constant(TILE_N)) + end + + function run_bwd() fill!(DW_partial, 0) fill!(DB_partial, 0) fill!(Locks, 0) @@ -358,41 +317,50 @@ function layernorm_bwd_run(data; TILE_N::Int, TILE_M::Int=32, nruns::Int=1, warm ct.launch(layer_norm_bwd_dwdb, num_tiles_n, DW_partial, DB_partial, FINAL_DW, FINAL_DB, ct.Constant(TILE_M), ct.Constant(TILE_N)) end + + # Warmup + for _ in 1:warmup + run_fwd() + run_bwd() + end CUDA.synchronize() - times = Float64[] + # Timed forward runs + times_fwd = Float64[] for _ in 1:nruns - # Reset partial buffers - fill!(DW_partial, 0) - fill!(DB_partial, 0) - fill!(Locks, 0) - t = CUDA.@elapsed begin - ct.launch(layer_norm_bwd_dx_partial_dwdb, M, DX, DY, DW_partial, DB_partial, X, W, - Mean, Rstd, Locks, ct.Constant(GROUP_SIZE_M), ct.Constant(TILE_N)) - num_tiles_n = cld(N, TILE_N) - ct.launch(layer_norm_bwd_dwdb, num_tiles_n, DW_partial, DB_partial, FINAL_DW, FINAL_DB, - ct.Constant(TILE_M), ct.Constant(TILE_N)) - end - push!(times, t * 1000) # ms + t = CUDA.@elapsed run_fwd() + push!(times_fwd, t * 1000) # ms + end + + # Timed backward runs + times_bwd = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed run_bwd() + push!(times_bwd, t * 1000) # ms end - return (; DX, FINAL_DW, FINAL_DB, times) + return (; Y, Mean, Rstd, DX, FINAL_DW, FINAL_DB, times_fwd, times_bwd) end -function layernorm_bwd_verify(fwd_data, bwd_data, bwd_result) - (; X, W, N, eps) = fwd_data - (; DY, Mean, Rstd) = bwd_data +function layernorm_verify(data, result) + (; X, W, B, DY, N, eps) = data X_cpu = Array(X) W_cpu = Array(W) + B_cpu = Array(B) DY_cpu = Array(DY) - # Compute expected values + # Forward verification expected_mean = vec(sum(X_cpu, dims=2) ./ N) expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ eps) xhat = (X_cpu .- expected_mean) .* expected_rstd + expected_Y = xhat .* W_cpu' .+ B_cpu' + atol, rtol = 1f-2, 1f-2 + @assert isapprox(expected_Y, Array(result.Y); rtol, atol) "Y mismatch" + + # Backward verification wdy = W_cpu' .* DY_cpu c1 = sum(xhat .* wdy, dims=2) ./ N c2 = sum(wdy, dims=2) ./ N @@ -400,10 +368,18 @@ function layernorm_bwd_verify(fwd_data, bwd_data, bwd_result) expected_DW = vec(sum(DY_cpu .* xhat, dims=1)) expected_DB = vec(sum(DY_cpu, dims=1)) - atol, rtol = 1f-2, 1f-2 - @assert isapprox(expected_DX, Array(bwd_result.DX); rtol, atol) "dX mismatch" - @assert isapprox(expected_DW, Array(bwd_result.FINAL_DW); rtol, atol) "dW mismatch" - @assert isapprox(expected_DB, Array(bwd_result.FINAL_DB); rtol, atol) "dB mismatch" + @assert isapprox(expected_DX, Array(result.DX); rtol, atol) "dX mismatch" + @assert isapprox(expected_DW, Array(result.FINAL_DW); rtol, atol) "dW mismatch" + @assert isapprox(expected_DB, Array(result.FINAL_DB); rtol, atol) "dB mismatch" +end + +function test_layernorm(M, N, TILE_N; TILE_M::Int=32, eps::Float32=1f-5, name=nothing) + name = something(name, "layernorm ($M x $N), tile_n=$TILE_N, tile_m=$TILE_M") + println("--- $name ---") + data = layernorm_prepare(; M, N, eps) + result = layernorm_run(data; TILE_N, TILE_M) + layernorm_verify(data, result) + println(" fwd passed, bwd passed") end #============================================================================= @@ -411,42 +387,13 @@ end =============================================================================# function main() - println("=== cuTile LayerNorm Sample ===\n") - - M, N = 1024, 2048 - TILE_N = 1024 - eps = 1f-5 - - println("Input shape: ($M, $N), dtype: Float32, eps: $eps") - - # ========================================================================= - # Forward Pass - # ========================================================================= - println("\n--- Forward Pass ---") - fwd_data = layernorm_fwd_prepare(; M, N, eps) - fwd_result = layernorm_fwd_run(fwd_data; TILE_N) - layernorm_fwd_verify(fwd_data, fwd_result) - println("Forward pass: PASSED") - - # ========================================================================= - # Backward Pass (Full: dX, dW, dB) - # ========================================================================= - println("\n--- Backward Pass (Full: dX, dW, dB) ---") - GROUP_SIZE_M = 64 - TILE_M = 32 - bwd_data = layernorm_bwd_prepare(fwd_data, fwd_result; GROUP_SIZE_M) - bwd_result = layernorm_bwd_run(bwd_data; TILE_N, TILE_M) - layernorm_bwd_verify(fwd_data, bwd_data, bwd_result) - println(" dX: PASSED") - println(" dW: PASSED") - println(" dB: PASSED") - - # ========================================================================= - # Summary - # ========================================================================= - println("\n=== Summary ===") - println("Forward pass: PASSED") - println("Backward (dX/dW/dB): PASSED") + println("=== cuTile LayerNorm Examples (fwd+bwd) ===\n") + + test_layernorm(256, 256, 256) + test_layernorm(512, 512, 512) + test_layernorm(1024, 2048, 1024) + + println("\n=== All layernorm examples completed ===") end isinteractive() || main() diff --git a/examples/layernorm.py b/examples/layernorm.py index f833f4c..9c2f6ca 100644 --- a/examples/layernorm.py +++ b/examples/layernorm.py @@ -1,14 +1,21 @@ #!/usr/bin/env python3 """ Layer Normalization example - cuTile Python +Forward and backward passes with unified prepare/run/verify pattern. """ import cupy as cp import numpy as np import cuda.tile as ct +from math import ceil + +#============================================================================= +# Forward Kernel +#============================================================================= @ct.kernel -def layernorm_cutile_kernel(X, W, B, Y, Mean, Rstd, eps: ct.Constant[float], TILE_N: ct.Constant[int]): +def layernorm_fwd_kernel(X, W, B, Y, Mean, Rstd, eps: ct.Constant[float], TILE_N: ct.Constant[int]): + """Forward pass: computes mean/var, normalizes input, applies affine transform.""" bid_m = ct.bid(0) num_tiles = ct.num_tiles(X, axis=1, shape=(1, TILE_N)) N = X.shape[1] @@ -43,84 +50,223 @@ def layernorm_cutile_kernel(X, W, B, Y, Mean, Rstd, eps: ct.Constant[float], TIL #============================================================================= -# prepare/run/verify pattern +# Backward Kernels #============================================================================= -def layernorm_prepare(*, M: int, N: int, eps: float = 1e-5, dtype=np.float32): - """Allocate and initialize data for layer normalization.""" +def bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N): + """Helper to load data and compute common backward terms.""" + tx = ct.load(X, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + tw = ct.load(W, index=(j,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO) + tdy = ct.load(DY, index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + xhat = (tx - mean) * rstd + wdy = tw * tdy + mask = j * TILE_N + ct.arange(TILE_N, dtype=ct.int32) < N + xhat = ct.where(mask, xhat, 0) + wdy = ct.where(mask, wdy, 0) + return tdy, xhat, wdy + + +@ct.kernel +def layernorm_bwd_dx_partial_dwdb_kernel(DX, DY, DW, DB, X, W, Mean, Rstd, Locks, TILE_N: ct.Constant[int]): + """Backward pass part 1: computes dX and partial dW/dB with atomic accumulation.""" + bid_m = ct.bid(0) + num_tiles = ct.num_tiles(X, axis=1, shape=(1, TILE_N)) + N = X.shape[1] + GROUP_SIZE_M = DW.shape[0] + group_bid_m = bid_m % GROUP_SIZE_M + + mean = ct.load(Mean, index=(bid_m,), shape=(1,)) + rstd = ct.load(Rstd, index=(bid_m,), shape=(1,)) + + c1 = ct.full((1, TILE_N), 0, dtype=ct.float32) + c2 = ct.full((1, TILE_N), 0, dtype=ct.float32) + for j in range(num_tiles): + _, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N) + c1 += xhat * wdy + c2 += wdy + c1 = ct.sum(c1, axis=1) / N + c2 = ct.sum(c2, axis=1) / N + + for j in range(num_tiles): + tdy, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N) + tdx = (wdy - (xhat * c1 + c2)) * rstd + ct.store(DX, index=(bid_m, j), tile=tdx.astype(DX.dtype)) + + partial_dw = (tdy * xhat).astype(DW.dtype) + partial_db = tdy.astype(DB.dtype) + + while ct.atomic_cas(Locks, group_bid_m, 0, 1, memory_order=ct.MemoryOrder.ACQUIRE) == 1: + pass + + partial_dw += ct.load(DW, index=(group_bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + partial_db += ct.load(DB, index=(group_bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.ZERO) + ct.store(DW, index=(group_bid_m, j), tile=partial_dw) + ct.store(DB, index=(group_bid_m, j), tile=partial_db) + + ct.atomic_xchg(Locks, group_bid_m, 0, memory_order=ct.MemoryOrder.RELEASE) + + +@ct.kernel +def layernorm_bwd_dwdb_kernel(DW, DB, FINAL_DW, FINAL_DB, TILE_M: ct.Constant[int], TILE_N: ct.Constant[int]): + """Backward pass part 2: Final reduction for dW and dB.""" + bid_n = ct.bid(0) + num_tiles = ct.num_tiles(DW, axis=0, shape=(TILE_M, TILE_N)) + + dw = ct.zeros((TILE_M, TILE_N), dtype=ct.float32) + db = ct.zeros((TILE_M, TILE_N), dtype=ct.float32) + for i in range(num_tiles): + dw += ct.load(DW, index=(i, bid_n), shape=(TILE_M, TILE_N), padding_mode=ct.PaddingMode.ZERO) + db += ct.load(DB, index=(i, bid_n), shape=(TILE_M, TILE_N), padding_mode=ct.PaddingMode.ZERO) + sum_dw = ct.sum(dw, axis=0) + sum_db = ct.sum(db, axis=0) + + ct.store(FINAL_DW, index=(bid_n,), tile=sum_dw.astype(FINAL_DW.dtype)) + ct.store(FINAL_DB, index=(bid_n,), tile=sum_db.astype(FINAL_DB.dtype)) + + +#============================================================================= +# Example harness +#============================================================================= + +def layernorm_prepare(*, M: int, N: int, eps: float = 1e-5, GROUP_SIZE_M: int = 64, dtype=np.float32): + """Allocate all data for forward and backward passes.""" return { + # Forward inputs/outputs "X": (-2.3 + 0.5 * cp.random.randn(M, N)).astype(dtype), "W": cp.random.randn(N).astype(dtype), "B": cp.random.randn(N).astype(dtype), "Y": cp.zeros((M, N), dtype=dtype), "Mean": cp.zeros(M, dtype=np.float32), "Rstd": cp.zeros(M, dtype=np.float32), + # Backward inputs/outputs + "DY": (0.1 * cp.random.randn(M, N)).astype(dtype), + "DX": cp.zeros((M, N), dtype=dtype), + "DW_partial": cp.zeros((GROUP_SIZE_M, N), dtype=np.float32), + "DB_partial": cp.zeros((GROUP_SIZE_M, N), dtype=np.float32), + "Locks": cp.zeros(GROUP_SIZE_M, dtype=np.int32), + "FINAL_DW": cp.zeros(N, dtype=dtype), + "FINAL_DB": cp.zeros(N, dtype=dtype), + # Metadata "eps": eps, "M": M, - "N": N + "N": N, + "GROUP_SIZE_M": GROUP_SIZE_M } -def layernorm_run(data, *, tile_n: int, nruns: int = 1, warmup: int = 0): - """Run layer normalization kernel with timing.""" +def layernorm_run(data, *, tile_n: int, tile_m: int = 32, nruns: int = 1, warmup: int = 0): + """Run both forward and backward passes with timing.""" X, W, B, Y = data["X"], data["W"], data["B"], data["Y"] Mean, Rstd = data["Mean"], data["Rstd"] - eps, M = data["eps"], data["M"] + DY, DX = data["DY"], data["DX"] + DW_partial, DB_partial = data["DW_partial"], data["DB_partial"] + Locks = data["Locks"] + FINAL_DW, FINAL_DB = data["FINAL_DW"], data["FINAL_DB"] + eps, M, N = data["eps"], data["M"], data["N"] + GROUP_SIZE_M = data["GROUP_SIZE_M"] stream = cp.cuda.get_current_stream() + def run_fwd(): + ct.launch(stream, (M,), layernorm_fwd_kernel, (X, W, B, Y, Mean, Rstd, eps, tile_n)) + + def run_bwd(): + DW_partial.fill(0) + DB_partial.fill(0) + Locks.fill(0) + ct.launch(stream, (M,), layernorm_bwd_dx_partial_dwdb_kernel, + (DX, DY, DW_partial, DB_partial, X, W, Mean, Rstd, Locks, tile_n)) + num_tiles_n = ceil(N / tile_n) + ct.launch(stream, (num_tiles_n,), layernorm_bwd_dwdb_kernel, + (DW_partial, DB_partial, FINAL_DW, FINAL_DB, tile_m, tile_n)) + # Warmup for _ in range(warmup): - ct.launch(stream, (M,), layernorm_cutile_kernel, (X, W, B, Y, Mean, Rstd, eps, tile_n)) + run_fwd() + run_bwd() cp.cuda.runtime.deviceSynchronize() - # Timed runs - times = [] + # Timed forward runs + times_fwd = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + run_fwd() + end.record(stream) + end.synchronize() + times_fwd.append(cp.cuda.get_elapsed_time(start, end)) # ms + + # Timed backward runs + times_bwd = [] for _ in range(nruns): start = cp.cuda.Event() end = cp.cuda.Event() start.record(stream) - ct.launch(stream, (M,), layernorm_cutile_kernel, (X, W, B, Y, Mean, Rstd, eps, tile_n)) + run_bwd() end.record(stream) end.synchronize() - times.append(cp.cuda.get_elapsed_time(start, end)) # ms + times_bwd.append(cp.cuda.get_elapsed_time(start, end)) # ms - return {"Y": Y, "Mean": Mean, "Rstd": Rstd, "times": times} + return { + "Y": Y, "Mean": Mean, "Rstd": Rstd, + "DX": DX, "DW": FINAL_DW, "DB": FINAL_DB, + "times_fwd": times_fwd, "times_bwd": times_bwd + } def layernorm_verify(data, result): - """Verify layer normalization results.""" + """Verify both forward and backward results.""" X_np = cp.asnumpy(data["X"]) W_np = cp.asnumpy(data["W"]) B_np = cp.asnumpy(data["B"]) + DY_np = cp.asnumpy(data["DY"]) eps = data["eps"] + N = data["N"] + # Forward verification expected_mean = np.mean(X_np, axis=1, keepdims=True) expected_var = np.mean((X_np - expected_mean) ** 2, axis=1, keepdims=True) expected_rstd = 1.0 / np.sqrt(expected_var + eps) - normalized = (X_np - expected_mean) * expected_rstd - expected_Y = normalized * W_np + B_np + xhat = (X_np - expected_mean) * expected_rstd + expected_Y = xhat * W_np + B_np + + atol, rtol = 1e-2, 1e-2 + assert np.allclose(cp.asnumpy(result["Y"]), expected_Y, rtol=rtol, atol=atol), \ + f"Y mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['Y']) - expected_Y))}" + + # Backward verification + wdy = W_np * DY_np + c1 = np.sum(xhat * wdy, axis=1, keepdims=True) / N + c2 = np.sum(wdy, axis=1, keepdims=True) / N + expected_DX = (wdy - (xhat * c1 + c2)) * expected_rstd + expected_DW = np.sum(DY_np * xhat, axis=0) + expected_DB = np.sum(DY_np, axis=0) - assert np.allclose(cp.asnumpy(result["Y"]), expected_Y, rtol=1e-2, atol=1e-2), \ - f"layernorm incorrect! max diff: {np.max(np.abs(cp.asnumpy(result['Y']) - expected_Y))}" + assert np.allclose(cp.asnumpy(result["DX"]), expected_DX, rtol=rtol, atol=atol), \ + f"DX mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['DX']) - expected_DX))}" + assert np.allclose(cp.asnumpy(result["DW"]), expected_DW, rtol=rtol, atol=atol), \ + f"DW mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['DW']) - expected_DW))}" + assert np.allclose(cp.asnumpy(result["DB"]), expected_DB, rtol=rtol, atol=atol), \ + f"DB mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['DB']) - expected_DB))}" #============================================================================= -# Test function +# Main #============================================================================= -def test_layernorm(M, N, tile_n, eps=1e-5, dtype=np.float32, name=None): - """Test layer normalization with given parameters.""" - name = name or f"layernorm ({M}x{N}), tile={tile_n}, dtype={dtype.__name__}" +def test_layernorm(M, N, tile_n, tile_m=32, eps=1e-5, dtype=np.float32, name=None): + """Test layer normalization (fwd+bwd) with given parameters.""" + name = name or f"layernorm ({M}x{N}), tile_n={tile_n}, tile_m={tile_m}, dtype={dtype.__name__}" print(f"--- {name} ---") data = layernorm_prepare(M=M, N=N, eps=eps, dtype=dtype) - result = layernorm_run(data, tile_n=tile_n) + result = layernorm_run(data, tile_n=tile_n, tile_m=tile_m) layernorm_verify(data, result) - print(" passed") + print(" fwd passed, bwd passed") def main(): - print("--- cuTile Layer Normalization Examples ---\n") + print("--- cuTile Layer Normalization Examples (fwd+bwd) ---\n") test_layernorm(256, 256, 256) test_layernorm(512, 512, 512) From fc65921dcfd641bbae9ed653e9829595aa7e9a55 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 14 Jan 2026 10:00:35 -0500 Subject: [PATCH 5/7] Unify vadd. --- examples/benchmarks.jl | 6 +- examples/benchmarks.py | 2 +- examples/vadd.jl | 203 +++++++++++++++-------------------------- examples/vadd.py | 114 +++++++++++++++++++---- 4 files changed, 172 insertions(+), 153 deletions(-) diff --git a/examples/benchmarks.jl b/examples/benchmarks.jl index 72272f3..703f1b5 100644 --- a/examples/benchmarks.jl +++ b/examples/benchmarks.jl @@ -105,7 +105,7 @@ function benchmark_vadd() println(" Size: $VADD_SIZE elements ($(VADD_SIZE * 4 / 1e6) MB)") # Prepare data once (using vadd.jl's prepare function) - data = vadd_1d_prepare(; n=VADD_SIZE, T=Float32) + data = vadd_prepare(; shape=(VADD_SIZE,), T=Float32) (; a, b, c) = data expected = Array(a) .+ Array(b) @@ -132,8 +132,8 @@ function benchmark_vadd() push!(results, BenchmarkResult("SIMT (CUDA.jl)", min_t, mean_t)) # cuTile (using vadd.jl's run/verify functions) - result = vadd_1d_run(data; tile=VADD_TILE, nruns=NRUNS, warmup=WARMUP) - vadd_1d_verify(data, result) + result = vadd_run(data; tile=VADD_TILE, nruns=NRUNS, warmup=WARMUP) + vadd_verify(data, result) min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) diff --git a/examples/benchmarks.py b/examples/benchmarks.py index 9b865d6..fd18ca2 100644 --- a/examples/benchmarks.py +++ b/examples/benchmarks.py @@ -182,7 +182,7 @@ def torch_vadd(): results.append(BenchmarkResult("PyTorch", min_t, mean_t)) # cuTile - use prepare/run/verify pattern - data = vadd_prepare(n=VADD_SIZE, dtype=np.float32) + data = vadd_prepare(shape=(VADD_SIZE,), dtype=np.float32) # Copy expected data for apples-to-apples comparison data["a"] = a_cp data["b"] = b_cp diff --git a/examples/vadd.jl b/examples/vadd.jl index 2f14002..f4c4e28 100644 --- a/examples/vadd.jl +++ b/examples/vadd.jl @@ -5,8 +5,7 @@ using CUDA import cuTile as ct -# 1D kernel with TileArray and constant tile size -# TileArray carries size/stride metadata, Constant is a ghost type +# 1D kernel function vec_add_kernel_1d(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c::ct.TileArray{T,1}, tile::ct.Constant{Int}) where {T} bid = ct.bid(1) @@ -16,7 +15,7 @@ function vec_add_kernel_1d(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c::ct.Til return end -# 2D kernel with TileArray and constant tile sizes +# 2D kernel function vec_add_kernel_2d(a::ct.TileArray{T,2}, b::ct.TileArray{T,2}, c::ct.TileArray{T,2}, tile_x::ct.Constant{Int}, tile_y::ct.Constant{Int}) where {T} bid_x = ct.bid(1) @@ -27,9 +26,7 @@ function vec_add_kernel_2d(a::ct.TileArray{T,2}, b::ct.TileArray{T,2}, c::ct.Til return end -# 1D kernel using gather/scatter (explicit index-based memory access) -# This demonstrates the gather/scatter API for cases where you need -# explicit control over indices (e.g., for non-contiguous access patterns) +# 1D kernel using gather/scatter function vec_add_kernel_1d_gather(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c::ct.TileArray{T,1}, tile::ct.Constant{Int}) where {T} bid = ct.bid(1) @@ -46,171 +43,117 @@ function vec_add_kernel_1d_gather(a::ct.TileArray{T,1}, b::ct.TileArray{T,1}, c: return end -#============================================================================= - 1D Vector Addition - prepare/run/verify pattern -=============================================================================# - -function vadd_1d_prepare(; n::Int, T::DataType=Float32) - return (; - a = CUDA.rand(T, n), - b = CUDA.rand(T, n), - c = CUDA.zeros(T, n), - n - ) -end - -function vadd_1d_run(data; tile::Int, nruns::Int=1, warmup::Int=0) - (; a, b, c, n) = data - grid = cld(n, tile) - - for _ in 1:warmup - ct.launch(vec_add_kernel_1d, grid, a, b, c, ct.Constant(tile)) - end - CUDA.synchronize() - - times = Float64[] - for _ in 1:nruns - t = CUDA.@elapsed ct.launch(vec_add_kernel_1d, grid, a, b, c, ct.Constant(tile)) - push!(times, t * 1000) # ms - end - - return (; c, times) -end - -function vadd_1d_verify(data, result) - @assert Array(result.c) ≈ Array(data.a) + Array(data.b) -end - -function test_add_1d(::Type{T}, n, tile; name=nothing) where T - name = something(name, "1D vec_add ($n elements, $T, tile=$tile)") - println("--- $name ---") - data = vadd_1d_prepare(; n, T) - result = vadd_1d_run(data; tile) - vadd_1d_verify(data, result) - println("✓ passed") -end #============================================================================= - 2D Matrix Addition - prepare/run/verify pattern +# Example harness =============================================================================# -function vadd_2d_prepare(; m::Int, n::Int, T::DataType=Float32) +function vadd_prepare(; shape::Tuple, use_gather::Bool=false, T::DataType=Float32) return (; - a = CUDA.rand(T, m, n), - b = CUDA.rand(T, m, n), - c = CUDA.zeros(T, m, n), - m, n + a = CUDA.rand(T, shape...), + b = CUDA.rand(T, shape...), + c = CUDA.zeros(T, shape...), + shape, + use_gather ) end -function vadd_2d_run(data; tile_x::Int, tile_y::Int, nruns::Int=1, warmup::Int=0) - (; a, b, c, m, n) = data - grid = (cld(m, tile_x), cld(n, tile_y)) - - for _ in 1:warmup - ct.launch(vec_add_kernel_2d, grid, a, b, c, - ct.Constant(tile_x), ct.Constant(tile_y)) - end - CUDA.synchronize() - - times = Float64[] - for _ in 1:nruns - t = CUDA.@elapsed ct.launch(vec_add_kernel_2d, grid, a, b, c, - ct.Constant(tile_x), ct.Constant(tile_y)) - push!(times, t * 1000) # ms +function vadd_run(data; tile::Union{Int, Tuple{Int,Int}}, nruns::Int=1, warmup::Int=0) + (; a, b, c, shape, use_gather) = data + + if length(shape) == 2 + # 2D case + m, n = shape + tile_x, tile_y = tile isa Tuple ? tile : (tile, tile) + grid = (cld(m, tile_x), cld(n, tile_y)) + + for _ in 1:warmup + ct.launch(vec_add_kernel_2d, grid, a, b, c, + ct.Constant(tile_x), ct.Constant(tile_y)) + end + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(vec_add_kernel_2d, grid, a, b, c, + ct.Constant(tile_x), ct.Constant(tile_y)) + push!(times, t * 1000) # ms + end + else + # 1D case + n = shape[1] + tile_val = tile isa Tuple ? tile[1] : tile + grid = cld(n, tile_val) + kernel = use_gather ? vec_add_kernel_1d_gather : vec_add_kernel_1d + + for _ in 1:warmup + ct.launch(kernel, grid, a, b, c, ct.Constant(tile_val)) + end + CUDA.synchronize() + + times = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed ct.launch(kernel, grid, a, b, c, ct.Constant(tile_val)) + push!(times, t * 1000) # ms + end end return (; c, times) end -function vadd_2d_verify(data, result) +function vadd_verify(data, result) @assert Array(result.c) ≈ Array(data.a) + Array(data.b) end -function test_add_2d(::Type{T}, m, n, tile_x, tile_y; name=nothing) where T - name = something(name, "2D vec_add ($m x $n, $T, tiles=$tile_x x $tile_y)") - println("--- $name ---") - data = vadd_2d_prepare(; m, n, T) - result = vadd_2d_run(data; tile_x, tile_y) - vadd_2d_verify(data, result) - println("✓ passed") -end #============================================================================= - 1D Gather/Scatter Vector Addition - prepare/run/verify pattern +# Main =============================================================================# -function vadd_1d_gather_prepare(; n::Int, T::DataType=Float32) - return (; - a = CUDA.rand(T, n), - b = CUDA.rand(T, n), - c = CUDA.zeros(T, n), - n - ) -end - -function vadd_1d_gather_run(data; tile::Int, nruns::Int=1, warmup::Int=0) - (; a, b, c, n) = data - grid = cld(n, tile) - - for _ in 1:warmup - ct.launch(vec_add_kernel_1d_gather, grid, a, b, c, ct.Constant(tile)) - end - CUDA.synchronize() - - times = Float64[] - for _ in 1:nruns - t = CUDA.@elapsed ct.launch(vec_add_kernel_1d_gather, grid, a, b, c, ct.Constant(tile)) - push!(times, t * 1000) # ms +function test_vadd(shape, tile; use_gather::Bool=false, T::DataType=Float32, name=nothing) + if name === nothing + if length(shape) == 2 + name = "2D vec_add ($(shape[1]) x $(shape[2]), $T, tile=$tile)" + elseif use_gather + name = "1D vec_add gather ($(shape[1]) elements, $T, tile=$tile)" + else + name = "1D vec_add ($(shape[1]) elements, $T, tile=$tile)" + end end - - return (; c, times) -end - -function vadd_1d_gather_verify(data, result) - @assert Array(result.c) ≈ Array(data.a) + Array(data.b) -end - -function test_add_1d_gather(::Type{T}, n, tile; name=nothing) where T - name = something(name, "1D vec_add gather ($n elements, $T, tile=$tile)") println("--- $name ---") - data = vadd_1d_gather_prepare(; n, T) - result = vadd_1d_gather_run(data; tile) - vadd_1d_gather_verify(data, result) - println("✓ passed") + data = vadd_prepare(; shape, use_gather, T) + result = vadd_run(data; tile) + vadd_verify(data, result) + println(" passed") end -#============================================================================= - Main -=============================================================================# - function main() println("--- cuTile Vector/Matrix Addition Examples ---\n") # 1D tests with Float32 - test_add_1d(Float32, 1_024_000, 1024) - test_add_1d(Float32, 2^20, 512) + test_vadd((1_024_000,), 1024) + test_vadd((2^20,), 512) # 1D tests with Float64 - test_add_1d(Float64, 2^18, 512) + test_vadd((2^18,), 512; T=Float64) # 1D tests with Float16 - test_add_1d(Float16, 1_024_000, 1024) + test_vadd((1_024_000,), 1024; T=Float16) # 2D tests with Float32 - test_add_2d(Float32, 2048, 1024, 32, 32) - test_add_2d(Float32, 1024, 2048, 64, 64) + test_vadd((2048, 1024), (32, 32)) + test_vadd((1024, 2048), (64, 64)) # 2D tests with Float64 - test_add_2d(Float64, 1024, 512, 32, 32) + test_vadd((1024, 512), (32, 32); T=Float64) # 2D tests with Float16 - test_add_2d(Float16, 1024, 1024, 64, 64) + test_vadd((1024, 1024), (64, 64); T=Float16) # 1D gather/scatter tests with Float32 # Uses explicit index-based memory access instead of tiled loads/stores - test_add_1d_gather(Float32, 1_024_000, 1024) - test_add_1d_gather(Float32, 2^20, 512) + test_vadd((1_024_000,), 1024; use_gather=true) + test_vadd((2^20,), 512; use_gather=true) println("\n--- All addition examples completed ---") end diff --git a/examples/vadd.py b/examples/vadd.py index 77d77e8..38d0ad0 100644 --- a/examples/vadd.py +++ b/examples/vadd.py @@ -7,8 +7,9 @@ import numpy as np import cuda.tile as ct +# 1D kernel @ct.kernel -def vadd_cutile_kernel(a, b, c, tile_size: ct.Constant[int]): +def vadd_kernel_1d(a, b, c, tile_size: ct.Constant[int]): pid = ct.bid(0) tile_a = ct.load(a, index=(pid,), shape=(tile_size,)) tile_b = ct.load(b, index=(pid,), shape=(tile_size,)) @@ -16,31 +17,80 @@ def vadd_cutile_kernel(a, b, c, tile_size: ct.Constant[int]): ct.store(c, index=(pid,), tile=result) +# 2D kernel +@ct.kernel +def vadd_kernel_2d(a, b, c, tile_x: ct.Constant[int], tile_y: ct.Constant[int]): + pid_x = ct.bid(0) + pid_y = ct.bid(1) + tile_a = ct.load(a, index=(pid_x, pid_y), shape=(tile_x, tile_y)) + tile_b = ct.load(b, index=(pid_x, pid_y), shape=(tile_x, tile_y)) + result = tile_a + tile_b + ct.store(c, index=(pid_x, pid_y), tile=result) + + +# 1D kernel with gather/scatter +@ct.kernel +def vadd_kernel_1d_gather(a, b, c, tile_size: ct.Constant[int]): + pid = ct.bid(0) + # Create index tile for this block's elements + offsets = ct.arange(tile_size, dtype=ct.int32) + base = pid * tile_size + indices = base + offsets + + # Gather, add, scatter + tile_a = ct.gather(a, indices) + tile_b = ct.gather(b, indices) + result = tile_a + tile_b + ct.scatter(c, indices, result) + + #============================================================================= -# prepare/run/verify pattern +# Example harness #============================================================================= -def vadd_prepare(*, n: int, dtype=np.float32): +def vadd_prepare(*, shape: tuple, use_gather: bool = False, dtype=np.float32): """Allocate and initialize data for vector addition.""" return { - "a": cp.random.rand(n).astype(dtype), - "b": cp.random.rand(n).astype(dtype), - "c": cp.zeros(n, dtype=dtype), - "n": n + "a": cp.random.rand(*shape).astype(dtype), + "b": cp.random.rand(*shape).astype(dtype), + "c": cp.zeros(shape, dtype=dtype), + "shape": shape, + "use_gather": use_gather } -def vadd_run(data, *, tile: int, nruns: int = 1, warmup: int = 0): +def vadd_run(data, *, tile, nruns: int = 1, warmup: int = 0): """Run vector addition kernel with timing.""" a, b, c = data["a"], data["b"], data["c"] - n = data["n"] + shape = data["shape"] + use_gather = data["use_gather"] - grid = (ct.cdiv(n, tile), 1, 1) stream = cp.cuda.get_current_stream() + if len(shape) == 2: + # 2D case + m, n = shape + tile_x, tile_y = tile if isinstance(tile, tuple) else (tile, tile) + grid = (ct.cdiv(m, tile_x), ct.cdiv(n, tile_y), 1) + + def run_kernel(): + ct.launch(stream, grid, vadd_kernel_2d, (a, b, c, tile_x, tile_y)) + else: + # 1D case + n = shape[0] + tile_val = tile[0] if isinstance(tile, tuple) else tile + grid = (ct.cdiv(n, tile_val), 1, 1) + + if use_gather: + def run_kernel(): + ct.launch(stream, grid, vadd_kernel_1d_gather, (a, b, c, tile_val)) + else: + def run_kernel(): + ct.launch(stream, grid, vadd_kernel_1d, (a, b, c, tile_val)) + # Warmup for _ in range(warmup): - ct.launch(stream, grid, vadd_cutile_kernel, (a, b, c, tile)) + run_kernel() cp.cuda.runtime.deviceSynchronize() # Timed runs @@ -49,7 +99,7 @@ def vadd_run(data, *, tile: int, nruns: int = 1, warmup: int = 0): start = cp.cuda.Event() end = cp.cuda.Event() start.record(stream) - ct.launch(stream, grid, vadd_cutile_kernel, (a, b, c, tile)) + run_kernel() end.record(stream) end.synchronize() times.append(cp.cuda.get_elapsed_time(start, end)) # ms @@ -64,14 +114,20 @@ def vadd_verify(data, result): #============================================================================= -# Test function +# Main #============================================================================= -def test_vadd(n, tile, dtype=np.float32, name=None): +def test_vadd(shape, tile, use_gather=False, dtype=np.float32, name=None): """Test vector addition with given parameters.""" - name = name or f"vadd size={n}, tile={tile}, dtype={dtype.__name__}" + if name is None: + if len(shape) == 2: + name = f"2D vadd ({shape[0]}x{shape[1]}), tile={tile}, dtype={dtype.__name__}" + elif use_gather: + name = f"1D vadd gather size={shape[0]}, tile={tile}, dtype={dtype.__name__}" + else: + name = f"1D vadd size={shape[0]}, tile={tile}, dtype={dtype.__name__}" print(f"--- {name} ---") - data = vadd_prepare(n=n, dtype=dtype) + data = vadd_prepare(shape=shape, use_gather=use_gather, dtype=dtype) result = vadd_run(data, tile=tile) vadd_verify(data, result) print(" passed") @@ -80,9 +136,29 @@ def test_vadd(n, tile, dtype=np.float32, name=None): def main(): print("--- cuTile Vector Addition Examples ---\n") - test_vadd(1_024_000, 1024) - test_vadd(2**20, 512) - test_vadd(2**20, 1024) + # 1D tests with float32 + test_vadd((1_024_000,), 1024) + test_vadd((2**20,), 512) + + # 1D tests with float64 + test_vadd((2**18,), 512, dtype=np.float64) + + # 1D tests with float16 + test_vadd((1_024_000,), 1024, dtype=np.float16) + + # 2D tests with float32 + test_vadd((2048, 1024), (32, 32)) + test_vadd((1024, 2048), (64, 64)) + + # 2D tests with float64 + test_vadd((1024, 512), (32, 32), dtype=np.float64) + + # 2D tests with float16 + test_vadd((1024, 1024), (64, 64), dtype=np.float16) + + # 1D gather/scatter tests + test_vadd((1_024_000,), 1024, use_gather=True) + test_vadd((2**20,), 512, use_gather=True) print("\n--- All vadd examples completed ---") From 45c8c5f467c97b41c4d5c9dabf2e82a178ab2c07 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 14 Jan 2026 10:59:07 -0500 Subject: [PATCH 6/7] More refactoring --- examples/batchmatmul.jl | 44 +++- examples/batchmatmul.py | 55 ++++- examples/benchmarks.jl | 509 +++++--------------------------------- examples/benchmarks.py | 533 ++++++---------------------------------- examples/fft.jl | 44 +++- examples/fft.py | 46 +++- examples/layernorm.jl | 22 +- examples/layernorm.py | 18 +- examples/matmul.jl | 44 +++- examples/matmul.py | 50 +++- examples/transpose.jl | 66 ++++- examples/transpose.py | 48 +++- examples/vadd.jl | 71 +++++- examples/vadd.py | 48 +++- 14 files changed, 614 insertions(+), 984 deletions(-) diff --git a/examples/batchmatmul.jl b/examples/batchmatmul.jl index fa81d18..dd2b20e 100644 --- a/examples/batchmatmul.jl +++ b/examples/batchmatmul.jl @@ -61,7 +61,12 @@ end Example harness =============================================================================# -function batchmatmul_prepare(; M::Int, K::Int, N::Int, Batch::Int, T::DataType=Float32) +function prepare(; benchmark::Bool=false, + M::Int=benchmark ? 1024 : 256, + K::Int=benchmark ? 512 : 128, + N::Int=benchmark ? 2048 : 256, + Batch::Int=benchmark ? 8 : 4, + T::DataType=Float32) return (; A = CUDA.rand(T, M, K, Batch), B = CUDA.rand(T, K, N, Batch), @@ -70,15 +75,14 @@ function batchmatmul_prepare(; M::Int, K::Int, N::Int, Batch::Int, T::DataType=F ) end -function batchmatmul_run(data; tm::Int, tn::Int, tk::Int, nruns::Int=1, warmup::Int=0) +function run(data; tm::Int=64, tn::Int=64, tk::Int=64, nruns::Int=1, warmup::Int=0) (; A, B, C, M, N, Batch) = data grid = (cld(M, tm), cld(N, tn), Batch) - for _ in 1:warmup + CUDA.@sync for _ in 1:warmup ct.launch(batch_matmul_kernel, grid, A, B, C, ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) end - CUDA.synchronize() times = Float64[] for _ in 1:nruns @@ -90,7 +94,7 @@ function batchmatmul_run(data; tm::Int, tn::Int, tk::Int, nruns::Int=1, warmup:: return (; C, times) end -function batchmatmul_verify(data, result) +function verify(data, result) (; A, B, M, N, Batch) = data A_cpu = Array(A) B_cpu = Array(B) @@ -101,6 +105,30 @@ function batchmatmul_verify(data, result) @assert isapprox(Array(result.C), expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))" end +#============================================================================= + Reference implementations for benchmarking +=============================================================================# + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; A, B, M, N, Batch) = data + results = Dict{String, Vector{Float64}}() + + C_cublas = similar(A, M, N, Batch) + + # cuBLAS batched gemm via CUBLAS.gemm_strided_batched! + CUDA.@sync for _ in 1:warmup + CUDA.CUBLAS.gemm_strided_batched!('N', 'N', one(eltype(A)), A, B, zero(eltype(A)), C_cublas) + end + times_cublas = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed CUDA.CUBLAS.gemm_strided_batched!('N', 'N', one(eltype(A)), A, B, zero(eltype(A)), C_cublas) + push!(times_cublas, t * 1000) + end + results["cuBLAS batched"] = times_cublas + + return results +end + #============================================================================= Main =============================================================================# @@ -108,9 +136,9 @@ end function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) where T name = something(name, "batch_matmul ($M x $K x $Batch) @ ($K x $N x $Batch), $T, tiles=$tm x $tn x $tk") println("--- $name ---") - data = batchmatmul_prepare(; M, K, N, Batch, T) - result = batchmatmul_run(data; tm, tn, tk) - batchmatmul_verify(data, result) + data = prepare(; M, K, N, Batch, T) + result = run(data; tm, tn, tk) + verify(data, result) println(" passed") end diff --git a/examples/batchmatmul.py b/examples/batchmatmul.py index 075cbb2..95b93db 100644 --- a/examples/batchmatmul.py +++ b/examples/batchmatmul.py @@ -40,8 +40,16 @@ def batchmatmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int # Example harness #============================================================================= -def batchmatmul_prepare(*, Batch: int, M: int, K: int, N: int, dtype=np.float16): +def prepare(*, benchmark: bool = False, Batch: int = None, M: int = None, K: int = None, N: int = None, dtype=np.float16): """Allocate and initialize data for batch matmul.""" + if Batch is None: + Batch = 8 if benchmark else 4 + if M is None: + M = 1024 if benchmark else 256 + if K is None: + K = 512 if benchmark else 128 + if N is None: + N = 2048 if benchmark else 256 return { "A": cp.random.randn(Batch, M, K).astype(dtype), "B": cp.random.randn(Batch, K, N).astype(dtype), @@ -53,7 +61,7 @@ def batchmatmul_prepare(*, Batch: int, M: int, K: int, N: int, dtype=np.float16) } -def batchmatmul_run(data, *, tm: int, tn: int, tk: int, nruns: int = 1, warmup: int = 0): +def run(data, *, tm: int = 64, tn: int = 64, tk: int = 64, nruns: int = 1, warmup: int = 0): """Run batch matmul kernel with timing.""" A, B, C = data["A"], data["B"], data["C"] Batch, M, N = data["Batch"], data["M"], data["N"] @@ -80,7 +88,7 @@ def batchmatmul_run(data, *, tm: int, tn: int, tk: int, nruns: int = 1, warmup: return {"C": C, "times": times} -def batchmatmul_verify(data, result): +def verify(data, result): """Verify batch matmul results.""" A_np = cp.asnumpy(data["A"]).astype(np.float32) B_np = cp.asnumpy(data["B"]).astype(np.float32) @@ -94,6 +102,41 @@ def batchmatmul_verify(data, result): f"batchmatmul incorrect! max diff: {np.max(np.abs(C_np - expected))}" +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + import torch + + results = {} + A_cp, B_cp = data["A"], data["B"] + Batch, M, N = data["Batch"], data["M"], data["N"] + + # PyTorch bmm + A_torch = torch.as_tensor(A_cp, device='cuda') + B_torch = torch.as_tensor(B_cp, device='cuda') + C_torch = torch.zeros(Batch, M, N, dtype=A_torch.dtype, device='cuda') + + for _ in range(warmup): + torch.bmm(A_torch, B_torch, out=C_torch) + torch.cuda.synchronize() + + times_torch = [] + for _ in range(nruns): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.bmm(A_torch, B_torch, out=C_torch) + end.record() + torch.cuda.synchronize() + times_torch.append(start.elapsed_time(end)) + results["PyTorch bmm"] = times_torch + + return results + + #============================================================================= # Main #============================================================================= @@ -102,9 +145,9 @@ def test_batchmatmul(Batch, M, K, N, tm, tn, tk, dtype=np.float16, name=None): """Test batch matmul with given parameters.""" name = name or f"batchmatmul ({Batch}x{M}x{K}) @ ({Batch}x{K}x{N}), tiles={tm}x{tn}x{tk}, dtype={dtype.__name__}" print(f"--- {name} ---") - data = batchmatmul_prepare(Batch=Batch, M=M, K=K, N=N, dtype=dtype) - result = batchmatmul_run(data, tm=tm, tn=tn, tk=tk) - batchmatmul_verify(data, result) + data = prepare(Batch=Batch, M=M, K=K, N=N, dtype=dtype) + result = run(data, tm=tm, tn=tn, tk=tk) + verify(data, result) print(" passed") diff --git a/examples/benchmarks.jl b/examples/benchmarks.jl index 703f1b5..e626cce 100644 --- a/examples/benchmarks.jl +++ b/examples/benchmarks.jl @@ -1,19 +1,9 @@ # EXCLUDE FROM TESTING # -# Comprehensive benchmarks for cuTile.jl -# Compares: GPUArrays (generic), SIMT (CUDA.jl), cuTile -# Kernels: vadd, transpose, matmul +# Generic benchmark runner for cuTile.jl examples +# Discovers and benchmarks all examples in the examples/ directory -# Include example files to reuse their kernels -include("vadd.jl") -include("transpose.jl") -include("matmul.jl") -include("batchmatmul.jl") -include("layernorm.jl") -include("fft.jl") - -using LinearAlgebra -using CUDA: GPUArrays +using CUDA #============================================================================= Configuration @@ -22,19 +12,6 @@ using CUDA: GPUArrays const NRUNS = 10 const WARMUP = 3 -# Data sizes - large enough to saturate GPU and minimize launch overhead -const VADD_SIZE = 2^27 # 512 MB (128M elements) -const TRANSPOSE_DIM = 8192 # 8192x8192 = 268 MB -const MATMUL_DIM = 4096 # 4096x4096x4096 - -# Tile sizes -const VADD_TILE = 1024 -const TRANSPOSE_TILE_M = 64 -const TRANSPOSE_TILE_N = 64 -const MATMUL_TM = 64 -const MATMUL_TN = 64 -const MATMUL_TK = 64 - #============================================================================= Benchmark Utilities =============================================================================# @@ -45,437 +22,70 @@ struct BenchmarkResult mean_ms::Float64 end -function benchmark_kernel(f, nruns::Int=NRUNS, warmup::Int=WARMUP) - # Warmup - for _ in 1:warmup - f() - end - CUDA.synchronize() - - # Benchmark - times = Float64[] - for _ in 1:nruns - t = CUDA.@elapsed f() - push!(times, t * 1000) # Convert to ms - end - - return minimum(times), sum(times) / length(times) -end - -function print_table(title::String, results::Vector{BenchmarkResult}; extra_col=nothing) +function print_table(title::String, results::Vector{BenchmarkResult}) println() println("=" ^ 60) println(" ", title) println("=" ^ 60) - - if extra_col !== nothing - println(rpad("Implementation", 20), rpad("Min (ms)", 12), rpad("Mean (ms)", 12), extra_col[1]) - println("-" ^ 60) - for (i, r) in enumerate(results) - extra = extra_col[2][i] - println(rpad(r.name, 20), rpad(round(r.min_ms, digits=3), 12), - rpad(round(r.mean_ms, digits=3), 12), extra) - end - else - println(rpad("Implementation", 20), rpad("Min (ms)", 12), "Mean (ms)") - println("-" ^ 60) - for r in results - println(rpad(r.name, 20), rpad(round(r.min_ms, digits=3), 12), - round(r.mean_ms, digits=3)) - end - end + println(rpad("Implementation", 20), rpad("Min (ms)", 12), "Mean (ms)") println("-" ^ 60) -end - -#============================================================================= - Vector Addition -=============================================================================# - -# SIMT kernel (benchmark-specific) -function vadd_simt_kernel!(a, b, c) - i = (blockIdx().x - 1) * blockDim().x + threadIdx().x - if i <= length(c) - @inbounds c[i] = a[i] + b[i] + for r in results + println(rpad(r.name, 20), rpad(round(r.min_ms, digits=3), 12), + round(r.mean_ms, digits=3)) end - return -end - -function benchmark_vadd() - println("\nBenchmarking Vector Addition...") - println(" Size: $VADD_SIZE elements ($(VADD_SIZE * 4 / 1e6) MB)") - - # Prepare data once (using vadd.jl's prepare function) - data = vadd_prepare(; shape=(VADD_SIZE,), T=Float32) - (; a, b, c) = data - expected = Array(a) .+ Array(b) - - results = BenchmarkResult[] - - # GPUArrays (broadcast) - gpuarrays_f = () -> begin - c .= a .+ b - end - gpuarrays_f() - CUDA.synchronize() - @assert Array(c) ≈ expected "GPUArrays incorrect!" - min_t, mean_t = benchmark_kernel(gpuarrays_f) - push!(results, BenchmarkResult("GPUArrays", min_t, mean_t)) - - # SIMT - threads = 1024 - blocks = cld(VADD_SIZE, threads) - simt_f = () -> @cuda threads=threads blocks=blocks vadd_simt_kernel!(a, b, c) - simt_f() - CUDA.synchronize() - @assert Array(c) ≈ expected "SIMT incorrect!" - min_t, mean_t = benchmark_kernel(simt_f) - push!(results, BenchmarkResult("SIMT (CUDA.jl)", min_t, mean_t)) - - # cuTile (using vadd.jl's run/verify functions) - result = vadd_run(data; tile=VADD_TILE, nruns=NRUNS, warmup=WARMUP) - vadd_verify(data, result) - min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Calculate bandwidth - bytes = 3 * VADD_SIZE * sizeof(Float32) # 2 reads + 1 write - bandwidths = [string(round(bytes / (r.min_ms / 1000) / 1e9, digits=1), " GB/s") for r in results] - - print_table("Vector Addition (Float32)", results; extra_col=("Bandwidth", bandwidths)) - return results -end - -#============================================================================= - Matrix Transpose -=============================================================================# - -# SIMT naive kernel -function transpose_simt_naive_kernel!(input, output, M, N) - i = (blockIdx().x - 1) * blockDim().x + threadIdx().x - j = (blockIdx().y - 1) * blockDim().y + threadIdx().y - if i <= M && j <= N - @inbounds output[j, i] = input[i, j] - end - return -end - -# SIMT shared memory kernel (benchmark-specific) -function transpose_simt_shared_kernel!(input, output, M, N) - TILE = 32 - tile = CuStaticSharedArray(Float32, (TILE+1, TILE)) - - x = (blockIdx().x - 1) * TILE + threadIdx().x - y = (blockIdx().y - 1) * TILE + threadIdx().y - - if x <= M && y <= N - @inbounds tile[threadIdx().x, threadIdx().y] = input[x, y] - end - sync_threads() - - x = (blockIdx().y - 1) * TILE + threadIdx().x - y = (blockIdx().x - 1) * TILE + threadIdx().y - - if x <= N && y <= M - @inbounds output[x, y] = tile[threadIdx().y, threadIdx().x] - end - return -end - -function benchmark_transpose() - println("\nBenchmarking Matrix Transpose...") - M, N = TRANSPOSE_DIM, TRANSPOSE_DIM - println(" Size: $(M)x$(N) ($(M * N * 4 / 1e6) MB)") - - # Prepare data once (using transpose.jl's prepare function) - data = transpose_prepare(; m=M, n=N, T=Float32) - (; x, y) = data - expected = Array(permutedims(x, (2, 1))) - - results = BenchmarkResult[] - - # GPUArrays (permutedims) - gpuarrays_f = () -> permutedims!(y, x, (2, 1)) - gpuarrays_f() - CUDA.synchronize() - @assert Array(y) ≈ expected "GPUArrays incorrect!" - min_t, mean_t = benchmark_kernel(gpuarrays_f) - push!(results, BenchmarkResult("GPUArrays", min_t, mean_t)) - - # SIMT naive - fill!(y, 0) - threads_naive = (16, 16) - blocks_naive = (cld(M, 16), cld(N, 16)) - simt_naive_f = () -> @cuda threads=threads_naive blocks=blocks_naive transpose_simt_naive_kernel!(x, y, M, N) - simt_naive_f() - CUDA.synchronize() - @assert Array(y) ≈ expected "SIMT naive incorrect!" - min_t, mean_t = benchmark_kernel(simt_naive_f) - push!(results, BenchmarkResult("SIMT naive", min_t, mean_t)) - - # SIMT shared - fill!(y, 0) - threads_shared = (32, 32) - blocks_shared = (cld(M, 32), cld(N, 32)) - simt_shared_f = () -> @cuda threads=threads_shared blocks=blocks_shared transpose_simt_shared_kernel!(x, y, M, N) - simt_shared_f() - CUDA.synchronize() - @assert Array(y) ≈ expected "SIMT shared incorrect!" - min_t, mean_t = benchmark_kernel(simt_shared_f) - push!(results, BenchmarkResult("SIMT shared", min_t, mean_t)) - - # cuTile (using transpose.jl's run/verify functions) - result = transpose_run(data; tm=TRANSPOSE_TILE_M, tn=TRANSPOSE_TILE_N, nruns=NRUNS, warmup=WARMUP) - transpose_verify(data, result) - min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Calculate bandwidth - bytes = 2 * M * N * sizeof(Float32) # read + write - bandwidths = [string(round(bytes / (r.min_ms / 1000) / 1e9, digits=1), " GB/s") for r in results] - - print_table("Matrix Transpose (Float32)", results; extra_col=("Bandwidth", bandwidths)) - return results -end - -#============================================================================= - Matrix Multiplication -=============================================================================# - -function benchmark_matmul() - println("\nBenchmarking Matrix Multiplication...") - M, N, K = MATMUL_DIM, MATMUL_DIM, MATMUL_DIM - println(" Size: $(M)x$(K) * $(K)x$(N)") - - # Prepare data once (using matmul.jl's prepare function) - data = matmul_prepare(; M, K, N, T=Float32) - (; A, B, C) = data - - # Reference result (cuBLAS) - C_ref = similar(C) - mul!(C_ref, A, B) - CUDA.synchronize() - - results = BenchmarkResult[] - flops = 2.0 * M * N * K - - # GPUArrays (generic matmul) - gpuarrays_f = () -> GPUArrays.generic_matmatmul!(C, A, B, one(Float32), zero(Float32)) - gpuarrays_f() - CUDA.synchronize() - @assert isapprox(Array(C), Array(C_ref), rtol=1e-2, atol=1e-2) "GPUArrays incorrect!" - min_t, mean_t = benchmark_kernel(gpuarrays_f) - push!(results, BenchmarkResult("GPUArrays", min_t, mean_t)) - - # cuBLAS - fill!(C, 0) - cublas_f = () -> mul!(C, A, B) - cublas_f() - CUDA.synchronize() - min_t, mean_t = benchmark_kernel(cublas_f) - push!(results, BenchmarkResult("cuBLAS", min_t, mean_t)) - - # cuTile (using matmul.jl's run/verify functions) - result = matmul_run(data; tm=MATMUL_TM, tn=MATMUL_TN, tk=MATMUL_TK, nruns=NRUNS, warmup=WARMUP) - matmul_verify(data, result) - min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Calculate TFLOPS - tflops_vals = [string(round(flops / (r.min_ms * 1e-3) / 1e12, digits=2), " TFLOPS") for r in results] - - print_table("Matrix Multiplication (Float32, TF32 cores)", results; extra_col=("Performance", tflops_vals)) - return results + println("-" ^ 60) end #============================================================================= - Layer Normalization + Benchmark Discovery & Execution =============================================================================# -const LAYERNORM_M = 4096 -const LAYERNORM_N = 4096 -const LAYERNORM_TILE_N = 1024 -const LAYERNORM_EPS = 1f-5 - -# Batch matmul sizes -const BATCHMATMUL_BATCH = 8 -const BATCHMATMUL_M = 1024 -const BATCHMATMUL_K = 512 -const BATCHMATMUL_N = 2048 -const BATCHMATMUL_TM = 128 -const BATCHMATMUL_TN = 256 -const BATCHMATMUL_TK = 64 - -# FFT sizes -# Tile size is (D, BS, N2D), limited by tileiras compiler. -# Current kernel loads all batches per block, limiting scalability. -const FFT_BATCH = 64 -const FFT_SIZE = 512 -const FFT_FACTORS = (8, 8, 8) -const FFT_ATOM_PACKING_DIM = 2 - -# SIMT naive kernel (benchmark-specific, 2-pass: compute mean/var, then normalize) -function layernorm_simt_kernel!(X, W, B, Y, Mean, Rstd, N, eps) - m = blockIdx().x - - # First pass: compute mean - mean_acc = 0.0f0 - for i in 1:N - @inbounds mean_acc += X[m, i] +function discover_benchmarks() + examples = String[] + for file in readdir(@__DIR__) + endswith(file, ".jl") || continue + file == "benchmarks.jl" && continue + name = replace(file, ".jl" => "") + push!(examples, name) end - mean = mean_acc / N - @inbounds Mean[m] = mean - - # Second pass: compute variance - var_acc = 0.0f0 - for i in 1:N - @inbounds diff = X[m, i] - mean - var_acc += diff * diff - end - var = var_acc / N - rstd = 1.0f0 / sqrt(var + eps) - @inbounds Rstd[m] = rstd - - # Third pass: normalize and apply affine - for i in 1:N - @inbounds Y[m, i] = (X[m, i] - mean) * rstd * W[i] + B[i] - end - - return + return sort(examples) end -function benchmark_layernorm() - println("\nBenchmarking Layer Normalization...") - M, N = LAYERNORM_M, LAYERNORM_N - println(" Size: $(M)x$(N) ($(M * N * 4 / 1e6) MB)") - - # Prepare data once (using layernorm.jl's unified prepare function) - data = layernorm_prepare(; M, N, eps=LAYERNORM_EPS) - (; X, W, B, Y, Mean, Rstd) = data - - # Reference result - X_cpu = Array(X) - W_cpu = Array(W) - B_cpu = Array(B) - expected_mean = vec(sum(X_cpu, dims=2) ./ N) - expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) - expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ LAYERNORM_EPS) - normalized = (X_cpu .- expected_mean) .* expected_rstd - expected_Y = normalized .* W_cpu' .+ B_cpu' - - results = BenchmarkResult[] - - # SIMT naive (single thread per row) - fill!(Y, 0); fill!(Mean, 0); fill!(Rstd, 0) - simt_f = () -> @cuda threads=1 blocks=M layernorm_simt_kernel!(X, W, B, Y, Mean, Rstd, N, LAYERNORM_EPS) - simt_f() - CUDA.synchronize() - @assert isapprox(Array(Y), expected_Y, rtol=1e-2, atol=1e-2) "SIMT incorrect!" - min_t, mean_t = benchmark_kernel(simt_f) - push!(results, BenchmarkResult("SIMT naive", min_t, mean_t)) - - # cuTile (using layernorm.jl's unified run/verify functions) - result = layernorm_run(data; TILE_N=LAYERNORM_TILE_N, nruns=NRUNS, warmup=WARMUP) - layernorm_verify(data, result) - - # Forward pass timing - min_t_fwd = minimum(result.times_fwd) - mean_t_fwd = sum(result.times_fwd) / length(result.times_fwd) - push!(results, BenchmarkResult("cuTile Fwd", min_t_fwd, mean_t_fwd)) - - # Backward pass timing - min_t_bwd = minimum(result.times_bwd) - mean_t_bwd = sum(result.times_bwd) / length(result.times_bwd) - - # Calculate bandwidth (rough estimate: 3 reads of X + W + B, 1 write of Y) - bytes = (3 * M * N + N + N + M * N) * sizeof(Float32) - bandwidths = [string(round(bytes / (r.min_ms / 1000) / 1e9, digits=1), " GB/s") for r in results] - - print_table("Layer Normalization Forward (Float32)", results; extra_col=("Bandwidth", bandwidths)) +function run_benchmark(name::String) + file = joinpath(@__DIR__, name * ".jl") - # Print backward results separately - bwd_results = [BenchmarkResult("cuTile Bwd", min_t_bwd, mean_t_bwd)] - print_table("Layer Normalization Backward (Float32)", bwd_results) + # Include file in anonymous module to avoid polluting namespace + mod = Module() + Base.include(mod, file) - return results -end - -#============================================================================= - Batch Matrix Multiplication -=============================================================================# + # Check required functions exist (unprefixed) + isdefined(mod, :prepare) || return nothing + isdefined(mod, :run) || return nothing -function benchmark_batchmatmul() - println("\nBenchmarking Batch Matrix Multiplication...") - Batch, M, K, N = BATCHMATMUL_BATCH, BATCHMATMUL_M, BATCHMATMUL_K, BATCHMATMUL_N - println(" Size: ($M x $K x $Batch) @ ($K x $N x $Batch), Float16") + # Prepare data with benchmark=true for larger sizes + data = mod.prepare(; benchmark=true) - # Prepare data once (using batchmatmul.jl's prepare function) - data = batchmatmul_prepare(; M, K, N, Batch, T=Float16) - (; A, B, C) = data + # Run cuTile + result = mod.run(data; nruns=NRUNS, warmup=WARMUP) - # Reference result (batched matmul on CPU) - A_cpu = Float32.(Array(A)) - B_cpu = Float32.(Array(B)) - C_ref = zeros(Float32, M, N, Batch) - for b in 1:Batch - C_ref[:, :, b] = A_cpu[:, :, b] * B_cpu[:, :, b] + # Extract times (handle times_fwd/times_bwd for layernorm) + if hasproperty(result, :times) + results = Dict{String, Vector{Float64}}("cuTile" => result.times) + elseif hasproperty(result, :times_fwd) + results = Dict{String, Vector{Float64}}( + "cuTile Fwd" => result.times_fwd, + "cuTile Bwd" => result.times_bwd + ) + else + return nothing end - results = BenchmarkResult[] - flops = 2.0 * Batch * M * N * K - - # cuBLAS batched gemm (via loop) - fill!(C, 0) - cublas_f = () -> begin - for b in 1:Batch - mul!(view(C, :, :, b), view(A, :, :, b), view(B, :, :, b)) - end + # Run others if available + if isdefined(mod, :run_others) + others = mod.run_others(data; nruns=NRUNS, warmup=WARMUP) + merge!(results, others) end - cublas_f() - CUDA.synchronize() - @assert isapprox(Float32.(Array(C)), C_ref, rtol=1e-1, atol=1e-1) "cuBLAS incorrect!" - min_t, mean_t = benchmark_kernel(cublas_f) - push!(results, BenchmarkResult("cuBLAS (loop)", min_t, mean_t)) - # cuTile (using batchmatmul.jl's run/verify functions) - result = batchmatmul_run(data; tm=BATCHMATMUL_TM, tn=BATCHMATMUL_TN, tk=BATCHMATMUL_TK, - nruns=NRUNS, warmup=WARMUP) - batchmatmul_verify(data, result) - min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Calculate TFLOPS - tflops_vals = [string(round(flops / (r.min_ms * 1e-3) / 1e12, digits=2), " TFLOPS") for r in results] - - print_table("Batch Matrix Multiplication (Float16)", results; extra_col=("Performance", tflops_vals)) - return results -end - -#============================================================================= - FFT (3-stage Cooley-Tukey) - Column-Major Version -=============================================================================# - -function benchmark_fft() - println("\nBenchmarking FFT...") - BS, N = FFT_BATCH, FFT_SIZE - println(" Size: $BS batches × $N FFT ($(BS * N * 8 / 1e6) MB)") - - # Prepare data once (using fft.jl's prepare function) - data = fft_prepare(; batch=BS, n=N, factors=FFT_FACTORS, atom_packing_dim=FFT_ATOM_PACKING_DIM) - - results = BenchmarkResult[] - - # cuTile (using fft.jl's run/verify functions) - result = fft_run(data; nruns=NRUNS, warmup=WARMUP) - fft_verify(data, result) - min_t, mean_t = minimum(result.times), sum(result.times) / length(result.times) - push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t)) - - # Performance metric: GFLOPS (5 * N * log2(N) per complex FFT) - flops_per_fft = 5.0 * N * log2(N) - total_flops = BS * flops_per_fft - gflops = [string(round(total_flops / (r.min_ms * 1e-3) / 1e9, digits=1), " GFLOPS") for r in results] - - print_table("FFT (ComplexF32)", results; extra_col=("Performance", gflops)) return results end @@ -485,20 +95,35 @@ end function main() println("=" ^ 60) - println(" cuTile.jl Comprehensive Benchmarks") + println(" cuTile.jl Benchmarks") println("=" ^ 60) println() println("Configuration:") println(" Runs: $NRUNS (+ $WARMUP warmup)") println(" GPU: ", CUDA.name(CUDA.device())) - println() - vadd_results = benchmark_vadd() - transpose_results = benchmark_transpose() - matmul_results = benchmark_matmul() - layernorm_results = benchmark_layernorm() - batchmatmul_results = benchmark_batchmatmul() - fft_results = benchmark_fft() + for name in discover_benchmarks() + println("\nBenchmarking $name...") + + results = run_benchmark(name) + if results === nothing + println(" (skipped - no prepare/run functions)") + continue + end + + # Convert to BenchmarkResult for printing + benchmark_results = BenchmarkResult[] + for (impl_name, times) in results + min_t = minimum(times) + mean_t = sum(times) / length(times) + push!(benchmark_results, BenchmarkResult(impl_name, min_t, mean_t)) + end + + # Sort by min time + sort!(benchmark_results, by=r -> r.min_ms) + + print_table(name, benchmark_results) + end println() println("=" ^ 60) diff --git a/examples/benchmarks.py b/examples/benchmarks.py index fd18ca2..5ff588a 100644 --- a/examples/benchmarks.py +++ b/examples/benchmarks.py @@ -1,23 +1,12 @@ #!/usr/bin/env python3 -""" -Comprehensive benchmarks for cuTile Python -Compares: CuPy, PyTorch, cuTile -Kernels: vadd, transpose, matmul -""" +# EXCLUDE FROM TESTING +# +# Generic benchmark runner for cuTile Python examples +# Discovers and benchmarks all examples in the examples/ directory +import os +import importlib.util import cupy as cp -import numpy as np -import torch -import cuda.tile as ct -from math import ceil, log2 - -# Import prepare/run/verify functions from example files -from vadd import vadd_prepare, vadd_run, vadd_verify -from transpose import transpose_prepare, transpose_run, transpose_verify -from matmul import matmul_prepare, matmul_run, matmul_verify -from layernorm import layernorm_prepare, layernorm_run, layernorm_verify -from batchmatmul import batchmatmul_prepare, batchmatmul_run, batchmatmul_verify -from fft import fft_prepare, fft_run, fft_verify #============================================================================= # Configuration @@ -26,39 +15,6 @@ NRUNS = 10 WARMUP = 3 -# Data sizes - large enough to saturate GPU and minimize launch overhead -VADD_SIZE = 2**27 # 512 MB (128M elements) -TRANSPOSE_DIM = 8192 # 8192x8192 = 268 MB -MATMUL_DIM = 4096 # 4096x4096x4096 - -# FFT sizes - must match Julia configuration -FFT_BATCH = 64 -FFT_SIZE = 512 -FFT_FACTORS = (8, 8, 8) - -# Tile sizes -VADD_TILE = 1024 -TRANSPOSE_TILE_M = 64 -TRANSPOSE_TILE_N = 64 -MATMUL_TM = 64 -MATMUL_TN = 64 -MATMUL_TK = 64 - -# Layer norm sizes -LAYERNORM_M = 4096 -LAYERNORM_N = 4096 -LAYERNORM_TILE_N = 1024 -LAYERNORM_EPS = 1e-5 - -# Batch matmul sizes -BATCHMATMUL_BATCH = 8 -BATCHMATMUL_M = 1024 -BATCHMATMUL_K = 512 -BATCHMATMUL_N = 2048 -BATCHMATMUL_TM = 128 -BATCHMATMUL_TN = 256 -BATCHMATMUL_TK = 64 - #============================================================================= # Benchmark Utilities #============================================================================= @@ -70,423 +26,76 @@ def __init__(self, name: str, min_ms: float, mean_ms: float): self.mean_ms = mean_ms -def benchmark_cupy(f, nruns: int = NRUNS, warmup: int = WARMUP): - """Benchmark a CuPy function using CUDA events.""" - stream = cp.cuda.get_current_stream() - - # Warmup - for _ in range(warmup): - f() - cp.cuda.runtime.deviceSynchronize() - - # Benchmark - times = [] - for _ in range(nruns): - start = cp.cuda.Event() - end = cp.cuda.Event() - - start.record(stream) - f() - end.record(stream) - end.synchronize() - - elapsed_ms = cp.cuda.get_elapsed_time(start, end) - times.append(elapsed_ms) - - return min(times), sum(times) / len(times) - - -def benchmark_torch(f, nruns: int = NRUNS, warmup: int = WARMUP): - """Benchmark a PyTorch function using CUDA events.""" - # Warmup - for _ in range(warmup): - f() - torch.cuda.synchronize() - - # Benchmark - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - times = [] - for _ in range(nruns): - start_event.record() - f() - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - times.append(elapsed_ms) - - return min(times), sum(times) / len(times) - - -def print_table(title: str, results: list, extra_col=None): +def print_table(title: str, results: list): """Print formatted benchmark results table.""" print() print("=" * 60) print(f" {title}") print("=" * 60) - - if extra_col: - print(f"{'Implementation':<20}{'Min (ms)':<12}{'Mean (ms)':<12}{extra_col[0]}") - print("-" * 60) - for i, r in enumerate(results): - print(f"{r.name:<20}{r.min_ms:<12.3f}{r.mean_ms:<12.3f}{extra_col[1][i]}") - else: - print(f"{'Implementation':<20}{'Min (ms)':<12}Mean (ms)") - print("-" * 60) - for r in results: - print(f"{r.name:<20}{r.min_ms:<12.3f}{r.mean_ms:.3f}") + print(f"{'Implementation':<20}{'Min (ms)':<12}Mean (ms)") + print("-" * 60) + for r in results: + print(f"{r.name:<20}{r.min_ms:<12.3f}{r.mean_ms:.3f}") print("-" * 60) #============================================================================= -# Vector Addition -#============================================================================= - -def benchmark_vadd(): - print("\nBenchmarking Vector Addition...") - print(f" Size: {VADD_SIZE} elements ({VADD_SIZE * 4 / 1e6} MB)") - - # CuPy arrays for CuPy/PyTorch benchmarks - a_cp = cp.random.rand(VADD_SIZE).astype(np.float32) - b_cp = cp.random.rand(VADD_SIZE).astype(np.float32) - c_cp = cp.zeros(VADD_SIZE, dtype=np.float32) - - # PyTorch tensors (from same data) - a_torch = torch.as_tensor(a_cp, device='cuda') - b_torch = torch.as_tensor(b_cp, device='cuda') - c_torch = torch.zeros(VADD_SIZE, dtype=torch.float32, device='cuda') - - expected = cp.asnumpy(a_cp) + cp.asnumpy(b_cp) - results = [] - - # CuPy - def cupy_vadd(): - cp.add(a_cp, b_cp, out=c_cp) - - cupy_vadd() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(c_cp), expected), "CuPy incorrect!" - min_t, mean_t = benchmark_cupy(cupy_vadd) - results.append(BenchmarkResult("CuPy", min_t, mean_t)) - - # PyTorch - def torch_vadd(): - torch.add(a_torch, b_torch, out=c_torch) - - torch_vadd() - torch.cuda.synchronize() - assert np.allclose(c_torch.cpu().numpy(), expected), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_vadd) - results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - - # cuTile - use prepare/run/verify pattern - data = vadd_prepare(shape=(VADD_SIZE,), dtype=np.float32) - # Copy expected data for apples-to-apples comparison - data["a"] = a_cp - data["b"] = b_cp - data["c"] = cp.zeros(VADD_SIZE, dtype=np.float32) - - result = vadd_run(data, tile=VADD_TILE, nruns=NRUNS, warmup=WARMUP) - vadd_verify(data, result) - min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) - - # Calculate bandwidth - bytes_transferred = 3 * VADD_SIZE * 4 # 2 reads + 1 write, float32 - bandwidths = [f"{bytes_transferred / (r.min_ms / 1000) / 1e9:.1f} GB/s" for r in results] - - print_table("Vector Addition (Float32)", results, extra_col=("Bandwidth", bandwidths)) - return results - - -#============================================================================= -# Matrix Transpose -#============================================================================= - -def benchmark_transpose(): - print("\nBenchmarking Matrix Transpose...") - M, N = TRANSPOSE_DIM, TRANSPOSE_DIM - print(f" Size: {M}x{N} ({M * N * 4 / 1e6} MB)") - - # CuPy arrays - input_cp = cp.random.rand(M, N).astype(np.float32) - output_cp = cp.zeros((N, M), dtype=np.float32) - - # PyTorch tensors - input_torch = torch.as_tensor(input_cp, device='cuda') - output_torch = torch.zeros(N, M, dtype=torch.float32, device='cuda') - - expected = cp.asnumpy(input_cp).T - results = [] - - # CuPy - def cupy_transpose(): - cp.copyto(output_cp, input_cp.T) - - cupy_transpose() - cp.cuda.runtime.deviceSynchronize() - assert np.allclose(cp.asnumpy(output_cp), expected), "CuPy incorrect!" - min_t, mean_t = benchmark_cupy(cupy_transpose) - results.append(BenchmarkResult("CuPy", min_t, mean_t)) - - # PyTorch - output_torch.fill_(0) - def torch_transpose(): - output_torch.copy_(input_torch.T) - - torch_transpose() - torch.cuda.synchronize() - assert np.allclose(output_torch.cpu().numpy(), expected), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_transpose) - results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - - # cuTile - use prepare/run/verify pattern - data = transpose_prepare(M=M, N=N, dtype=np.float32) - # Copy input for apples-to-apples comparison - data["input"] = input_cp - data["output"] = cp.zeros((N, M), dtype=np.float32) - - result = transpose_run(data, tile_m=TRANSPOSE_TILE_M, tile_n=TRANSPOSE_TILE_N, - nruns=NRUNS, warmup=WARMUP) - transpose_verify(data, result) - min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) - - # Calculate bandwidth - bytes_transferred = 2 * M * N * 4 # read + write, float32 - bandwidths = [f"{bytes_transferred / (r.min_ms / 1000) / 1e9:.1f} GB/s" for r in results] - - print_table("Matrix Transpose (Float32)", results, extra_col=("Bandwidth", bandwidths)) - return results - - -#============================================================================= -# Matrix Multiplication -#============================================================================= - -def benchmark_matmul(): - print("\nBenchmarking Matrix Multiplication...") - M, N, K = MATMUL_DIM, MATMUL_DIM, MATMUL_DIM - print(f" Size: {M}x{K} * {K}x{N}") - - # CuPy arrays (used for cuTile and cuBLAS) - A_cp = cp.random.randn(M, K, dtype=np.float32) - B_cp = cp.random.randn(K, N, dtype=np.float32) - C_cp = cp.zeros((M, N), dtype=np.float32) - - # PyTorch tensors (from same data for fair comparison) - torch.set_float32_matmul_precision("high") # Enable TF32 - A_torch = torch.as_tensor(A_cp, device='cuda') - B_torch = torch.as_tensor(B_cp, device='cuda') - C_torch = torch.zeros(M, N, dtype=torch.float32, device='cuda') - - # Compute reference using CuPy (cuBLAS) for correctness checks - # This avoids TF32 precision differences between PyTorch and CuPy - C_ref_cp = cp.matmul(A_cp, B_cp) - cp.cuda.runtime.deviceSynchronize() - C_ref = cp.asnumpy(C_ref_cp) - - results = [] - flops = 2.0 * M * N * K - - # PyTorch - def torch_matmul(): - torch.matmul(A_torch, B_torch, out=C_torch) - - torch_matmul() - torch.cuda.synchronize() - # PyTorch TF32 vs CuPy cuBLAS may differ, use relaxed tolerance - assert np.allclose(C_torch.cpu().numpy(), C_ref, rtol=1e-1, atol=1e-1), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_matmul) - results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - - # CuPy (uses cuBLAS) - this is the reference - def cupy_matmul(): - cp.matmul(A_cp, B_cp, out=C_cp) - - cupy_matmul() - cp.cuda.runtime.deviceSynchronize() - min_t, mean_t = benchmark_cupy(cupy_matmul) - results.append(BenchmarkResult("CuPy (cuBLAS)", min_t, mean_t)) - - # cuTile - use prepare/run/verify pattern - data = matmul_prepare(M=M, N=N, K=K, dtype=np.float32) - # Copy input for apples-to-apples comparison - data["A"] = A_cp - data["B"] = B_cp - data["C"] = cp.zeros((M, N), dtype=np.float32) - - result = matmul_run(data, tm=MATMUL_TM, tn=MATMUL_TN, tk=MATMUL_TK, - nruns=NRUNS, warmup=WARMUP) - matmul_verify(data, result) - min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) - - # Calculate TFLOPS - tflops_vals = [f"{flops / (r.min_ms * 1e-3) / 1e12:.2f} TFLOPS" for r in results] - - print_table("Matrix Multiplication (Float32, TF32 cores)", results, extra_col=("Performance", tflops_vals)) - return results - - -#============================================================================= -# Layer Normalization -#============================================================================= - -def benchmark_layernorm(): - print("\nBenchmarking Layer Normalization...") - M, N = LAYERNORM_M, LAYERNORM_N - print(f" Size: {M}x{N} ({M * N * 4 / 1e6} MB)") - - # cuTile - prepare data - data = layernorm_prepare(M=M, N=N, eps=LAYERNORM_EPS, dtype=np.float32) - - # Extract CuPy/NumPy arrays for other benchmarks - X_cp = data["X"] - W_cp = data["W"] - B_cp = data["B"] - - # PyTorch tensors - X_torch = torch.as_tensor(X_cp, device='cuda') - W_torch = torch.as_tensor(W_cp, device='cuda') - B_torch = torch.as_tensor(B_cp, device='cuda') - Y_torch = torch.zeros(M, N, dtype=torch.float32, device='cuda') - - # Reference result - X_np = cp.asnumpy(X_cp) - W_np = cp.asnumpy(W_cp) - B_np = cp.asnumpy(B_cp) - expected_mean = np.mean(X_np, axis=1, keepdims=True) - expected_var = np.mean((X_np - expected_mean) ** 2, axis=1, keepdims=True) - expected_rstd = 1.0 / np.sqrt(expected_var + LAYERNORM_EPS) - normalized = (X_np - expected_mean) * expected_rstd - expected_Y = normalized * W_np + B_np - - results = [] - - # PyTorch F.layer_norm - def torch_layernorm(): - nonlocal Y_torch - Y_torch = torch.nn.functional.layer_norm(X_torch, (N,), W_torch, B_torch, LAYERNORM_EPS) - - torch_layernorm() - torch.cuda.synchronize() - assert np.allclose(Y_torch.cpu().numpy(), expected_Y, rtol=1e-2, atol=1e-2), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_layernorm) - results.append(BenchmarkResult("PyTorch", min_t, mean_t)) - - # cuTile - use prepare/run/verify pattern (unified fwd+bwd) - result = layernorm_run(data, tile_n=LAYERNORM_TILE_N, nruns=NRUNS, warmup=WARMUP) - layernorm_verify(data, result) - - # Forward pass timing - min_t_fwd = min(result["times_fwd"]) - mean_t_fwd = sum(result["times_fwd"]) / len(result["times_fwd"]) - results.append(BenchmarkResult("cuTile Fwd", min_t_fwd, mean_t_fwd)) - - # Backward pass timing - min_t_bwd = min(result["times_bwd"]) - mean_t_bwd = sum(result["times_bwd"]) / len(result["times_bwd"]) - - # Calculate bandwidth (rough estimate: 3 reads of X + W + B, 1 write of Y) - bytes_transferred = (3 * M * N + N + N + M * N) * 4 - bandwidths = [f"{bytes_transferred / (r.min_ms / 1000) / 1e9:.1f} GB/s" for r in results] - - print_table("Layer Normalization Forward (Float32)", results, extra_col=("Bandwidth", bandwidths)) - - # Print backward results separately - bwd_results = [BenchmarkResult("cuTile Bwd", min_t_bwd, mean_t_bwd)] - print_table("Layer Normalization Backward (Float32)", bwd_results) - - return results - - -#============================================================================= -# Batch Matrix Multiplication +# Benchmark Discovery & Execution #============================================================================= -def benchmark_batchmatmul(): - print("\nBenchmarking Batch Matrix Multiplication...") - Batch, M, K, N = BATCHMATMUL_BATCH, BATCHMATMUL_M, BATCHMATMUL_K, BATCHMATMUL_N - print(f" Size: ({Batch} x {M} x {K}) @ ({Batch} x {K} x {N}), Float16") - - # PyTorch tensors - A_torch = torch.randn(Batch, M, K, dtype=torch.float16, device='cuda') - B_torch = torch.randn(Batch, K, N, dtype=torch.float16, device='cuda') - C_torch = torch.zeros(Batch, M, N, dtype=torch.float16, device='cuda') - - # CuPy arrays (from same data) - A_cp = cp.asarray(A_torch) - B_cp = cp.asarray(B_torch) - - # Reference result (PyTorch bmm in fp32 for accuracy) - C_ref = torch.bmm(A_torch.float(), B_torch.float()).cpu().numpy() - - results = [] - flops = 2.0 * Batch * M * N * K - - # PyTorch bmm - def torch_bmm(): - torch.bmm(A_torch, B_torch, out=C_torch) - - torch_bmm() - torch.cuda.synchronize() - assert np.allclose(C_torch.float().cpu().numpy(), C_ref, rtol=1e-1, atol=1e-1), "PyTorch incorrect!" - min_t, mean_t = benchmark_torch(torch_bmm) - results.append(BenchmarkResult("PyTorch bmm", min_t, mean_t)) +def discover_benchmarks(): + """Discover all benchmark-enabled examples in the examples directory.""" + examples = [] + examples_dir = os.path.dirname(__file__) + for file in sorted(os.listdir(examples_dir)): + if not file.endswith(".py"): + continue + if file == "benchmarks.py": + continue + name = file.replace(".py", "") + examples.append(name) + return examples - # cuTile - use prepare/run/verify pattern - data = batchmatmul_prepare(Batch=Batch, M=M, K=K, N=N, dtype=np.float16) - # Copy input for apples-to-apples comparison - data["A"] = A_cp - data["B"] = B_cp - data["C"] = cp.zeros((Batch, M, N), dtype=np.float16) - result = batchmatmul_run(data, tm=BATCHMATMUL_TM, tn=BATCHMATMUL_TN, tk=BATCHMATMUL_TK, - nruns=NRUNS, warmup=WARMUP) - batchmatmul_verify(data, result) - min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) +def run_benchmark(name: str): + """Load and run benchmark for a given example.""" + examples_dir = os.path.dirname(__file__) + file_path = os.path.join(examples_dir, f"{name}.py") - # Calculate TFLOPS - tflops_vals = [f"{flops / (r.min_ms * 1e-3) / 1e12:.2f} TFLOPS" for r in results] + # Import module dynamically + spec = importlib.util.spec_from_file_location(name, file_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) - print_table("Batch Matrix Multiplication (Float16)", results, extra_col=("Performance", tflops_vals)) - return results - - -#============================================================================= -# FFT (3-stage Cooley-Tukey) -#============================================================================= - -def benchmark_fft(): - print("\nBenchmarking FFT...") - BS, N = FFT_BATCH, FFT_SIZE - print(f" Size: {BS} batches × {N} FFT ({BS * N * 8 / 1e6} MB)") - - # cuTile - use prepare/run/verify pattern - data = fft_prepare(batch=BS, size=N, factors=FFT_FACTORS) + # Check required functions exist (unprefixed) + prepare_fn = getattr(mod, "prepare", None) + run_fn = getattr(mod, "run", None) + if not prepare_fn or not run_fn: + return None - # Reference result using torch - reference = torch.fft.fft(data["input"], dim=-1) - torch.cuda.synchronize() + # Prepare data with benchmark=True for larger sizes + data = prepare_fn(benchmark=True) - results = [] + # Run cuTile + result = run_fn(data, nruns=NRUNS, warmup=WARMUP) - # cuTile FFT - result = fft_run(data, nruns=NRUNS, warmup=WARMUP) - fft_verify(data, result) - min_t, mean_t = min(result["times"]), sum(result["times"]) / len(result["times"]) - results.append(BenchmarkResult("cuTile Python", min_t, mean_t)) + # Extract times (handle times_fwd/times_bwd for layernorm) + if "times" in result: + results = {"cuTile": result["times"]} + elif "times_fwd" in result: + results = { + "cuTile Fwd": result["times_fwd"], + "cuTile Bwd": result["times_bwd"] + } + else: + return None - # Calculate GFLOPS (5 * N * log2(N) ops per complex FFT) - flops_per_fft = 5.0 * N * log2(N) - total_flops = BS * flops_per_fft - gflops = [f"{total_flops / (r.min_ms * 1e-3) / 1e9:.1f} GFLOPS" for r in results] + # Run others if available + run_others_fn = getattr(mod, "run_others", None) + if run_others_fn: + others = run_others_fn(data, nruns=NRUNS, warmup=WARMUP) + results.update(others) - print_table("FFT (ComplexF32)", results, extra_col=("Performance", gflops)) return results @@ -495,21 +104,35 @@ def benchmark_fft(): #============================================================================= def main(): + import torch # For GPU name + print("=" * 60) - print(" cuTile Python Comprehensive Benchmarks") + print(" cuTile Python Benchmarks") print("=" * 60) print() print("Configuration:") print(f" Runs: {NRUNS} (+ {WARMUP} warmup)") print(f" GPU: {torch.cuda.get_device_name()}") - print() - vadd_results = benchmark_vadd() - transpose_results = benchmark_transpose() - matmul_results = benchmark_matmul() - layernorm_results = benchmark_layernorm() - batchmatmul_results = benchmark_batchmatmul() - fft_results = benchmark_fft() + for name in discover_benchmarks(): + print(f"\nBenchmarking {name}...") + + results = run_benchmark(name) + if results is None: + print(" (skipped - no prepare/run functions)") + continue + + # Convert to BenchmarkResult for printing + benchmark_results = [] + for impl_name, times in results.items(): + min_t = min(times) + mean_t = sum(times) / len(times) + benchmark_results.append(BenchmarkResult(impl_name, min_t, mean_t)) + + # Sort by min time + benchmark_results.sort(key=lambda r: r.min_ms) + + print_table(name, benchmark_results) print() print("=" * 60) diff --git a/examples/fft.jl b/examples/fft.jl index c07dff1..24c8648 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -215,7 +215,11 @@ end Example harness =============================================================================# -function fft_prepare(; batch::Int, n::Int, factors::NTuple{3,Int}, atom_packing_dim::Int=2) +function prepare(; benchmark::Bool=false, + batch::Int=benchmark ? 64 : 2, + n::Int=benchmark ? 512 : 8, + factors::NTuple{3,Int}=benchmark ? (8, 8, 8) : (2, 2, 2), + atom_packing_dim::Int=2) @assert factors[1] * factors[2] * factors[3] == n "Factors must multiply to N" @assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim" @@ -243,7 +247,7 @@ function fft_prepare(; batch::Int, n::Int, factors::NTuple{3,Int}, atom_packing_ ) end -function fft_run(data; nruns::Int=1, warmup::Int=0) +function run(data; nruns::Int=1, warmup::Int=0) (; x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, factors, batch, n, D, N2D) = data @@ -253,7 +257,7 @@ function fft_run(data; nruns::Int=1, warmup::Int=0) F0F2 = F0 * F2 grid = (batch, 1, 1) - for _ in 1:warmup + CUDA.@sync for _ in 1:warmup ct.launch(fft_kernel, grid, x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, @@ -261,7 +265,6 @@ function fft_run(data; nruns::Int=1, warmup::Int=0) ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), ct.Constant(batch), ct.Constant(D), ct.Constant(N2D)) end - CUDA.synchronize() times = Float64[] for _ in 1:nruns @@ -281,11 +284,36 @@ function fft_run(data; nruns::Int=1, warmup::Int=0) return (; output, times) end -function fft_verify(data, result) +function verify(data, result) reference = FFTW.fft(Array(data.input), 2) @assert isapprox(Array(result.output), reference, rtol=1e-4) end +#============================================================================= + Reference implementations for benchmarking +=============================================================================# + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; input, batch, n) = data + results = Dict{String, Vector{Float64}}() + + output_cufft = similar(input) + + # CUFFT via CUDA.CUFFT + CUDA.@sync for _ in 1:warmup + CUDA.CUFFT.fft!(copy(input), 2) + end + times_cufft = Float64[] + for _ in 1:nruns + input_copy = copy(input) + t = CUDA.@elapsed CUDA.CUFFT.fft!(input_copy, 2) + push!(times_cufft, t * 1000) + end + results["cuFFT"] = times_cufft + + return results +end + #============================================================================= Main =============================================================================# @@ -306,13 +334,13 @@ function main() println(" Atom Packing Dim: $ATOM_PACKING_DIM") # Use prepare/run/verify pattern - data = fft_prepare(; batch=BATCH_SIZE, n=FFT_SIZE, factors=FFT_FACTORS, atom_packing_dim=ATOM_PACKING_DIM) + data = prepare(; batch=BATCH_SIZE, n=FFT_SIZE, factors=FFT_FACTORS, atom_packing_dim=ATOM_PACKING_DIM) println("\nInput data shape: $(size(data.input)), dtype: $(eltype(data.input))") - result = fft_run(data) + result = run(data) println("cuTile FFT Output shape: $(size(result.output)), dtype: $(eltype(result.output))") - fft_verify(data, result) + verify(data, result) println("\n✓ Correctness check PASSED") println("\n--- cuTile FFT example execution complete ---") diff --git a/examples/fft.py b/examples/fft.py index 0e35e6d..8c60ab6 100644 --- a/examples/fft.py +++ b/examples/fft.py @@ -111,10 +111,16 @@ def fft_make_twiddles(factors, precision, device): # Example harness #============================================================================= -def fft_prepare(*, batch: int, size: int, factors: tuple, atom_packing_dim: int = 2): +def prepare(*, benchmark: bool = False, batch: int = None, size: int = None, factors: tuple = None, atom_packing_dim: int = 2): """Allocate and initialize data for FFT.""" + if batch is None: + batch = 64 if benchmark else 2 + if factors is None: + factors = (8, 8, 8) if benchmark else (2, 2, 2) F0, F1, F2 = factors N = F0 * F1 * F2 + if size is None: + size = N assert size == N, f"size ({size}) must equal product of factors ({N})" D = atom_packing_dim @@ -140,7 +146,7 @@ def fft_prepare(*, batch: int, size: int, factors: tuple, atom_packing_dim: int } -def fft_run(data, *, nruns: int = 1, warmup: int = 0): +def run(data, *, nruns: int = 1, warmup: int = 0): """Run FFT kernel with timing.""" x_packed = data["x_packed"] y_packed = data["y_packed"] @@ -173,13 +179,41 @@ def fft_run(data, *, nruns: int = 1, warmup: int = 0): return {"output": output, "times": times} -def fft_verify(data, result): +def verify(data, result): """Verify FFT results.""" reference = torch.fft.fft(data["input"], dim=-1) assert torch.allclose(result["output"], reference, rtol=1e-3, atol=1e-3), \ f"FFT incorrect! max diff: {torch.max(torch.abs(result['output'] - reference))}" +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + results = {} + input_data = data["input"] + + # PyTorch FFT (uses cuFFT) + for _ in range(warmup): + torch.fft.fft(input_data, dim=-1) + torch.cuda.synchronize() + + times_torch = [] + for _ in range(nruns): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.fft.fft(input_data, dim=-1) + end.record() + torch.cuda.synchronize() + times_torch.append(start.elapsed_time(end)) + results["cuFFT"] = times_torch + + return results + + #============================================================================= # Main #============================================================================= @@ -188,9 +222,9 @@ def test_fft(batch, size, factors, name=None): """Test FFT with given parameters.""" name = name or f"fft batch={batch}, size={size}, factors={factors}" print(f"--- {name} ---") - data = fft_prepare(batch=batch, size=size, factors=factors) - result = fft_run(data) - fft_verify(data, result) + data = prepare(batch=batch, size=size, factors=factors) + result = run(data) + verify(data, result) print(" passed") diff --git a/examples/layernorm.jl b/examples/layernorm.jl index f368509..dba56ea 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -273,10 +273,13 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa end #============================================================================= - Unified prepare/run/verify pattern (fwd + bwd) + Example harness =============================================================================# -function layernorm_prepare(; M::Int, N::Int, eps::Float32=1f-5, GROUP_SIZE_M::Int=64) +function prepare(; benchmark::Bool=false, + M::Int=benchmark ? 4096 : 256, + N::Int=benchmark ? 4096 : 256, + eps::Float32=1f-5, GROUP_SIZE_M::Int=64) return (; # Forward inputs/outputs X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N), @@ -298,7 +301,7 @@ function layernorm_prepare(; M::Int, N::Int, eps::Float32=1f-5, GROUP_SIZE_M::In ) end -function layernorm_run(data; TILE_N::Int, TILE_M::Int=32, nruns::Int=1, warmup::Int=0) +function run(data; TILE_N::Int=1024, TILE_M::Int=32, nruns::Int=1, warmup::Int=0) (; X, W, B, Y, Mean, Rstd, DY, DX, DW_partial, DB_partial, Locks, FINAL_DW, FINAL_DB, M, N, eps, GROUP_SIZE_M) = data @@ -319,11 +322,10 @@ function layernorm_run(data; TILE_N::Int, TILE_M::Int=32, nruns::Int=1, warmup:: end # Warmup - for _ in 1:warmup + CUDA.@sync for _ in 1:warmup run_fwd() run_bwd() end - CUDA.synchronize() # Timed forward runs times_fwd = Float64[] @@ -342,7 +344,7 @@ function layernorm_run(data; TILE_N::Int, TILE_M::Int=32, nruns::Int=1, warmup:: return (; Y, Mean, Rstd, DX, FINAL_DW, FINAL_DB, times_fwd, times_bwd) end -function layernorm_verify(data, result) +function verify(data, result) (; X, W, B, DY, N, eps) = data X_cpu = Array(X) @@ -376,12 +378,14 @@ end function test_layernorm(M, N, TILE_N; TILE_M::Int=32, eps::Float32=1f-5, name=nothing) name = something(name, "layernorm ($M x $N), tile_n=$TILE_N, tile_m=$TILE_M") println("--- $name ---") - data = layernorm_prepare(; M, N, eps) - result = layernorm_run(data; TILE_N, TILE_M) - layernorm_verify(data, result) + data = prepare(; M, N, eps) + result = run(data; TILE_N, TILE_M) + verify(data, result) println(" fwd passed, bwd passed") end +# No run_others for layernorm - no simple reference implementation to compare against + #============================================================================= Main =============================================================================# diff --git a/examples/layernorm.py b/examples/layernorm.py index 9c2f6ca..bf9e65a 100644 --- a/examples/layernorm.py +++ b/examples/layernorm.py @@ -128,8 +128,12 @@ def layernorm_bwd_dwdb_kernel(DW, DB, FINAL_DW, FINAL_DB, TILE_M: ct.Constant[in # Example harness #============================================================================= -def layernorm_prepare(*, M: int, N: int, eps: float = 1e-5, GROUP_SIZE_M: int = 64, dtype=np.float32): +def prepare(*, benchmark: bool = False, M: int = None, N: int = None, eps: float = 1e-5, GROUP_SIZE_M: int = 64, dtype=np.float32): """Allocate all data for forward and backward passes.""" + if M is None: + M = 4096 if benchmark else 256 + if N is None: + N = 4096 if benchmark else 256 return { # Forward inputs/outputs "X": (-2.3 + 0.5 * cp.random.randn(M, N)).astype(dtype), @@ -154,7 +158,7 @@ def layernorm_prepare(*, M: int, N: int, eps: float = 1e-5, GROUP_SIZE_M: int = } -def layernorm_run(data, *, tile_n: int, tile_m: int = 32, nruns: int = 1, warmup: int = 0): +def run(data, *, tile_n: int = 1024, tile_m: int = 32, nruns: int = 1, warmup: int = 0): """Run both forward and backward passes with timing.""" X, W, B, Y = data["X"], data["W"], data["B"], data["Y"] Mean, Rstd = data["Mean"], data["Rstd"] @@ -215,7 +219,7 @@ def run_bwd(): } -def layernorm_verify(data, result): +def verify(data, result): """Verify both forward and backward results.""" X_np = cp.asnumpy(data["X"]) W_np = cp.asnumpy(data["W"]) @@ -250,6 +254,8 @@ def layernorm_verify(data, result): assert np.allclose(cp.asnumpy(result["DB"]), expected_DB, rtol=rtol, atol=atol), \ f"DB mismatch! max diff: {np.max(np.abs(cp.asnumpy(result['DB']) - expected_DB))}" +# No run_others for layernorm - no simple reference implementation to compare against + #============================================================================= # Main @@ -259,9 +265,9 @@ def test_layernorm(M, N, tile_n, tile_m=32, eps=1e-5, dtype=np.float32, name=Non """Test layer normalization (fwd+bwd) with given parameters.""" name = name or f"layernorm ({M}x{N}), tile_n={tile_n}, tile_m={tile_m}, dtype={dtype.__name__}" print(f"--- {name} ---") - data = layernorm_prepare(M=M, N=N, eps=eps, dtype=dtype) - result = layernorm_run(data, tile_n=tile_n, tile_m=tile_m) - layernorm_verify(data, result) + data = prepare(M=M, N=N, eps=eps, dtype=dtype) + result = run(data, tile_n=tile_n, tile_m=tile_m) + verify(data, result) print(" fwd passed, bwd passed") diff --git a/examples/matmul.jl b/examples/matmul.jl index ef14c66..c1f04d8 100644 --- a/examples/matmul.jl +++ b/examples/matmul.jl @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 using CUDA +using LinearAlgebra import cuTile as ct # 2D swizzle for better L2 cache locality @@ -66,7 +67,11 @@ end Example harness =============================================================================# -function matmul_prepare(; M::Int, N::Int, K::Int, T::DataType=Float32) +function prepare(; benchmark::Bool=false, + M::Int=benchmark ? 4096 : 256, + N::Int=benchmark ? 4096 : 256, + K::Int=benchmark ? 4096 : 256, + T::DataType=Float32) return (; A = CUDA.rand(T, M, K), B = CUDA.rand(T, K, N), @@ -75,15 +80,14 @@ function matmul_prepare(; M::Int, N::Int, K::Int, T::DataType=Float32) ) end -function matmul_run(data; tm::Int, tn::Int, tk::Int, nruns::Int=1, warmup::Int=0) +function run(data; tm::Int=64, tn::Int=64, tk::Int=64, nruns::Int=1, warmup::Int=0) (; A, B, C, M, N, K) = data grid = cld(M, tm) * cld(N, tn) - for _ in 1:warmup + CUDA.@sync for _ in 1:warmup ct.launch(matmul_kernel, grid, A, B, C, ct.Constant(tm), ct.Constant(tn), ct.Constant(tk)) end - CUDA.synchronize() times = Float64[] for _ in 1:nruns @@ -95,11 +99,35 @@ function matmul_run(data; tm::Int, tn::Int, tk::Int, nruns::Int=1, warmup::Int=0 return (; C, times) end -function matmul_verify(data, result) +function verify(data, result) expected = Array(data.A) * Array(data.B) @assert isapprox(Array(result.C), expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))" end +#============================================================================= + Reference implementations for benchmarking +=============================================================================# + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; A, B) = data + results = Dict{String, Vector{Float64}}() + + C_gpuarrays = similar(A, size(A, 1), size(B, 2)) + + # GPUArrays (uses cuBLAS under the hood via LinearAlgebra.mul!) + CUDA.@sync for _ in 1:warmup + mul!(C_gpuarrays, A, B) + end + times_gpuarrays = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed mul!(C_gpuarrays, A, B) + push!(times_gpuarrays, t * 1000) + end + results["cuBLAS"] = times_gpuarrays + + return results +end + #============================================================================= Main =============================================================================# @@ -107,9 +135,9 @@ end function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T name = something(name, "matmul ($M x $K) @ ($K x $N), $T, tiles=$tm x $tn x $tk") println("--- $name ---") - data = matmul_prepare(; M, N, K, T) - result = matmul_run(data; tm, tn, tk) - matmul_verify(data, result) + data = prepare(; M, N, K, T) + result = run(data; tm, tn, tk) + verify(data, result) println(" passed") end diff --git a/examples/matmul.py b/examples/matmul.py index a0ecbf4..26766a4 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -48,8 +48,14 @@ def matmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk # Example harness #============================================================================= -def matmul_prepare(*, M: int, N: int, K: int, dtype=np.float32): +def prepare(*, benchmark: bool = False, M: int = None, N: int = None, K: int = None, dtype=np.float32): """Allocate and initialize data for matmul.""" + if M is None: + M = 4096 if benchmark else 256 + if N is None: + N = 4096 if benchmark else 256 + if K is None: + K = 4096 if benchmark else 256 return { "A": cp.random.randn(M, K).astype(dtype), "B": cp.random.randn(K, N).astype(dtype), @@ -60,7 +66,7 @@ def matmul_prepare(*, M: int, N: int, K: int, dtype=np.float32): } -def matmul_run(data, *, tm: int, tn: int, tk: int, nruns: int = 1, warmup: int = 0): +def run(data, *, tm: int = 64, tn: int = 64, tk: int = 64, nruns: int = 1, warmup: int = 0): """Run matmul kernel with timing.""" A, B, C = data["A"], data["B"], data["C"] M, N = data["M"], data["N"] @@ -89,7 +95,7 @@ def matmul_run(data, *, tm: int, tn: int, tk: int, nruns: int = 1, warmup: int = return {"C": C, "times": times} -def matmul_verify(data, result): +def verify(data, result): """Verify matmul results.""" expected = cp.asnumpy(data["A"]) @ cp.asnumpy(data["B"]) # TF32 has reduced precision @@ -97,6 +103,38 @@ def matmul_verify(data, result): f"matmul incorrect! max diff: {np.max(np.abs(cp.asnumpy(result['C']) - expected))}" +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + results = {} + A, B = data["A"], data["B"] + M, N = data["M"], data["N"] + C_cupy = cp.zeros((M, N), dtype=A.dtype) + + stream = cp.cuda.get_current_stream() + + # CuPy matmul (uses cuBLAS) + for _ in range(warmup): + cp.matmul(A, B, out=C_cupy) + cp.cuda.runtime.deviceSynchronize() + + times_cupy = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + cp.matmul(A, B, out=C_cupy) + end.record(stream) + end.synchronize() + times_cupy.append(cp.cuda.get_elapsed_time(start, end)) + results["cuBLAS"] = times_cupy + + return results + + #============================================================================= # Main #============================================================================= @@ -105,9 +143,9 @@ def test_matmul(M, N, K, tm, tn, tk, dtype=np.float32, name=None): """Test matmul with given parameters.""" name = name or f"matmul ({M}x{K}) @ ({K}x{N}), tiles={tm}x{tn}x{tk}, dtype={dtype.__name__}" print(f"--- {name} ---") - data = matmul_prepare(M=M, N=N, K=K, dtype=dtype) - result = matmul_run(data, tm=tm, tn=tn, tk=tk) - matmul_verify(data, result) + data = prepare(M=M, N=N, K=K, dtype=dtype) + result = run(data, tm=tm, tn=tn, tk=tk) + verify(data, result) print(" passed") diff --git a/examples/transpose.jl b/examples/transpose.jl index f199675..eb1beb7 100644 --- a/examples/transpose.jl +++ b/examples/transpose.jl @@ -21,7 +21,10 @@ end Example harness =============================================================================# -function transpose_prepare(; m::Int, n::Int, T::DataType=Float32) +function prepare(; benchmark::Bool=false, + m::Int=benchmark ? 8192 : 1024, + n::Int=benchmark ? 8192 : 512, + T::DataType=Float32) return (; x = CUDA.rand(T, m, n), y = CUDA.zeros(T, n, m), @@ -29,15 +32,14 @@ function transpose_prepare(; m::Int, n::Int, T::DataType=Float32) ) end -function transpose_run(data; tm::Int, tn::Int, nruns::Int=1, warmup::Int=0) +function run(data; tm::Int=64, tn::Int=64, nruns::Int=1, warmup::Int=0) (; x, y, m, n) = data grid = (cld(m, tm), cld(n, tn)) - for _ in 1:warmup + CUDA.@sync for _ in 1:warmup ct.launch(transpose_kernel, grid, x, y, ct.Constant(tm), ct.Constant(tn)) end - CUDA.synchronize() times = Float64[] for _ in 1:nruns @@ -49,10 +51,58 @@ function transpose_run(data; tm::Int, tn::Int, nruns::Int=1, warmup::Int=0) return (; y, times) end -function transpose_verify(data, result) +function verify(data, result) @assert Array(result.y) ≈ transpose(Array(data.x)) end +#============================================================================= + Reference implementations for benchmarking +=============================================================================# + +# Simple SIMT transpose kernel (naive, no shared memory) +function simt_naive_kernel(x, y, m, n) + i = (blockIdx().x - 1) * blockDim().x + threadIdx().x + j = (blockIdx().y - 1) * blockDim().y + threadIdx().y + if i <= m && j <= n + @inbounds y[j, i] = x[i, j] + end + return +end + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; x, m, n) = data + results = Dict{String, Vector{Float64}}() + + y_gpuarrays = similar(x, n, m) + y_simt = similar(x, n, m) + + # GPUArrays (permutedims) + CUDA.@sync for _ in 1:warmup + permutedims!(y_gpuarrays, x, (2, 1)) + end + times_gpuarrays = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed permutedims!(y_gpuarrays, x, (2, 1)) + push!(times_gpuarrays, t * 1000) + end + results["GPUArrays"] = times_gpuarrays + + # SIMT naive kernel + threads = (16, 16) + blocks = (cld(m, threads[1]), cld(n, threads[2])) + CUDA.@sync for _ in 1:warmup + @cuda threads=threads blocks=blocks simt_naive_kernel(x, y_simt, m, n) + end + times_simt = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed @cuda threads=threads blocks=blocks simt_naive_kernel(x, y_simt, m, n) + push!(times_simt, t * 1000) + end + results["SIMT naive"] = times_simt + + return results +end + #============================================================================= Main =============================================================================# @@ -60,9 +110,9 @@ end function test_transpose(::Type{T}, m, n, tm, tn; name=nothing) where T name = something(name, "transpose ($m x $n, $T, tiles=$tm x $tn)") println("--- $name ---") - data = transpose_prepare(; m, n, T) - result = transpose_run(data; tm, tn) - transpose_verify(data, result) + data = prepare(; m, n, T) + result = run(data; tm, tn) + verify(data, result) println("✓ passed") end diff --git a/examples/transpose.py b/examples/transpose.py index 52e4113..61ba891 100644 --- a/examples/transpose.py +++ b/examples/transpose.py @@ -20,8 +20,12 @@ def transpose_cutile_kernel(input, output, tile_m: ct.Constant[int], tile_n: ct. # Example harness #============================================================================= -def transpose_prepare(*, M: int, N: int, dtype=np.float32): +def prepare(*, benchmark: bool = False, M: int = None, N: int = None, dtype=np.float32): """Allocate and initialize data for transpose.""" + if M is None: + M = 8192 if benchmark else 1024 + if N is None: + N = 8192 if benchmark else 512 return { "input": cp.random.rand(M, N).astype(dtype), "output": cp.zeros((N, M), dtype=dtype), @@ -30,7 +34,7 @@ def transpose_prepare(*, M: int, N: int, dtype=np.float32): } -def transpose_run(data, *, tile_m: int, tile_n: int, nruns: int = 1, warmup: int = 0): +def run(data, *, tile_m: int = 64, tile_n: int = 64, nruns: int = 1, warmup: int = 0): """Run transpose kernel with timing.""" input_arr = data["input"] output_arr = data["output"] @@ -58,12 +62,44 @@ def transpose_run(data, *, tile_m: int, tile_n: int, nruns: int = 1, warmup: int return {"output": output_arr, "times": times} -def transpose_verify(data, result): +def verify(data, result): """Verify transpose results.""" expected = cp.asnumpy(data["input"]).T assert np.allclose(cp.asnumpy(result["output"]), expected), "transpose incorrect!" +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + results = {} + input_arr = data["input"] + M, N = data["M"], data["N"] + output_cupy = cp.zeros((N, M), dtype=input_arr.dtype) + + stream = cp.cuda.get_current_stream() + + # CuPy transpose + for _ in range(warmup): + cp.copyto(output_cupy, input_arr.T) + cp.cuda.runtime.deviceSynchronize() + + times_cupy = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + cp.copyto(output_cupy, input_arr.T) + end.record(stream) + end.synchronize() + times_cupy.append(cp.cuda.get_elapsed_time(start, end)) + results["CuPy"] = times_cupy + + return results + + #============================================================================= # Main #============================================================================= @@ -72,9 +108,9 @@ def test_transpose(M, N, tile_m, tile_n, dtype=np.float32, name=None): """Test transpose with given parameters.""" name = name or f"transpose ({M}x{N}), tiles={tile_m}x{tile_n}, dtype={dtype.__name__}" print(f"--- {name} ---") - data = transpose_prepare(M=M, N=N, dtype=dtype) - result = transpose_run(data, tile_m=tile_m, tile_n=tile_n) - transpose_verify(data, result) + data = prepare(M=M, N=N, dtype=dtype) + result = run(data, tile_m=tile_m, tile_n=tile_n) + verify(data, result) print(" passed") diff --git a/examples/vadd.jl b/examples/vadd.jl index f4c4e28..f523677 100644 --- a/examples/vadd.jl +++ b/examples/vadd.jl @@ -48,7 +48,9 @@ end # Example harness =============================================================================# -function vadd_prepare(; shape::Tuple, use_gather::Bool=false, T::DataType=Float32) +function prepare(; benchmark::Bool=false, + shape::Tuple=benchmark ? (2^27,) : (1_024_000,), + use_gather::Bool=false, T::DataType=Float32) return (; a = CUDA.rand(T, shape...), b = CUDA.rand(T, shape...), @@ -58,7 +60,7 @@ function vadd_prepare(; shape::Tuple, use_gather::Bool=false, T::DataType=Float3 ) end -function vadd_run(data; tile::Union{Int, Tuple{Int,Int}}, nruns::Int=1, warmup::Int=0) +function run(data; tile::Union{Int, Tuple{Int,Int}}=1024, nruns::Int=1, warmup::Int=0) (; a, b, c, shape, use_gather) = data if length(shape) == 2 @@ -67,11 +69,10 @@ function vadd_run(data; tile::Union{Int, Tuple{Int,Int}}, nruns::Int=1, warmup:: tile_x, tile_y = tile isa Tuple ? tile : (tile, tile) grid = (cld(m, tile_x), cld(n, tile_y)) - for _ in 1:warmup + CUDA.@sync for _ in 1:warmup ct.launch(vec_add_kernel_2d, grid, a, b, c, ct.Constant(tile_x), ct.Constant(tile_y)) end - CUDA.synchronize() times = Float64[] for _ in 1:nruns @@ -86,10 +87,9 @@ function vadd_run(data; tile::Union{Int, Tuple{Int,Int}}, nruns::Int=1, warmup:: grid = cld(n, tile_val) kernel = use_gather ? vec_add_kernel_1d_gather : vec_add_kernel_1d - for _ in 1:warmup + CUDA.@sync for _ in 1:warmup ct.launch(kernel, grid, a, b, c, ct.Constant(tile_val)) end - CUDA.synchronize() times = Float64[] for _ in 1:nruns @@ -101,11 +101,62 @@ function vadd_run(data; tile::Union{Int, Tuple{Int,Int}}, nruns::Int=1, warmup:: return (; c, times) end -function vadd_verify(data, result) +function verify(data, result) @assert Array(result.c) ≈ Array(data.a) + Array(data.b) end +#============================================================================= +# Reference implementations for benchmarking +=============================================================================# + +# Simple SIMT kernel for comparison +function simt_kernel(a, b, c, n) + i = (blockIdx().x - 1) * blockDim().x + threadIdx().x + if i <= n + @inbounds c[i] = a[i] + b[i] + end + return +end + +function run_others(data; nruns::Int=1, warmup::Int=0) + (; a, b, c, shape) = data + results = Dict{String, Vector{Float64}}() + + if length(shape) == 1 + n = shape[1] + c_gpuarrays = similar(c) + c_simt = similar(c) + + # GPUArrays (broadcasting) + CUDA.@sync for _ in 1:warmup + c_gpuarrays .= a .+ b + end + times_gpuarrays = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed c_gpuarrays .= a .+ b + push!(times_gpuarrays, t * 1000) + end + results["GPUArrays"] = times_gpuarrays + + # SIMT kernel + threads = 256 + blocks = cld(n, threads) + CUDA.@sync for _ in 1:warmup + @cuda threads=threads blocks=blocks simt_kernel(a, b, c_simt, n) + end + times_simt = Float64[] + for _ in 1:nruns + t = CUDA.@elapsed @cuda threads=threads blocks=blocks simt_kernel(a, b, c_simt, n) + push!(times_simt, t * 1000) + end + results["SIMT"] = times_simt + end + + return results +end + + #============================================================================= # Main =============================================================================# @@ -121,9 +172,9 @@ function test_vadd(shape, tile; use_gather::Bool=false, T::DataType=Float32, nam end end println("--- $name ---") - data = vadd_prepare(; shape, use_gather, T) - result = vadd_run(data; tile) - vadd_verify(data, result) + data = prepare(; shape, use_gather, T) + result = run(data; tile) + verify(data, result) println(" passed") end diff --git a/examples/vadd.py b/examples/vadd.py index 38d0ad0..ccbc058 100644 --- a/examples/vadd.py +++ b/examples/vadd.py @@ -48,8 +48,10 @@ def vadd_kernel_1d_gather(a, b, c, tile_size: ct.Constant[int]): # Example harness #============================================================================= -def vadd_prepare(*, shape: tuple, use_gather: bool = False, dtype=np.float32): +def prepare(*, benchmark: bool = False, shape: tuple = None, use_gather: bool = False, dtype=np.float32): """Allocate and initialize data for vector addition.""" + if shape is None: + shape = (2**27,) if benchmark else (1_024_000,) return { "a": cp.random.rand(*shape).astype(dtype), "b": cp.random.rand(*shape).astype(dtype), @@ -59,7 +61,7 @@ def vadd_prepare(*, shape: tuple, use_gather: bool = False, dtype=np.float32): } -def vadd_run(data, *, tile, nruns: int = 1, warmup: int = 0): +def run(data, *, tile=1024, nruns: int = 1, warmup: int = 0): """Run vector addition kernel with timing.""" a, b, c = data["a"], data["b"], data["c"] shape = data["shape"] @@ -107,12 +109,46 @@ def run_kernel(): return {"c": c, "times": times} -def vadd_verify(data, result): +def verify(data, result): """Verify vector addition results.""" expected = cp.asnumpy(data["a"]) + cp.asnumpy(data["b"]) assert np.allclose(cp.asnumpy(result["c"]), expected), "vadd incorrect!" +#============================================================================= +# Reference implementations for benchmarking +#============================================================================= + +def run_others(data, *, nruns: int = 1, warmup: int = 0): + """Run reference implementations for comparison.""" + results = {} + shape = data["shape"] + + if len(shape) == 1: + a, b = data["a"], data["b"] + c_cupy = cp.zeros_like(a) + + stream = cp.cuda.get_current_stream() + + # CuPy (broadcasting) + for _ in range(warmup): + cp.add(a, b, out=c_cupy) + cp.cuda.runtime.deviceSynchronize() + + times_cupy = [] + for _ in range(nruns): + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + cp.add(a, b, out=c_cupy) + end.record(stream) + end.synchronize() + times_cupy.append(cp.cuda.get_elapsed_time(start, end)) + results["CuPy"] = times_cupy + + return results + + #============================================================================= # Main #============================================================================= @@ -127,9 +163,9 @@ def test_vadd(shape, tile, use_gather=False, dtype=np.float32, name=None): else: name = f"1D vadd size={shape[0]}, tile={tile}, dtype={dtype.__name__}" print(f"--- {name} ---") - data = vadd_prepare(shape=shape, use_gather=use_gather, dtype=dtype) - result = vadd_run(data, tile=tile) - vadd_verify(data, result) + data = prepare(shape=shape, use_gather=use_gather, dtype=dtype) + result = run(data, tile=tile) + verify(data, result) print(" passed") From 5f580a1f485d4720b4594cded0fa3e4f008f3efc Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 14 Jan 2026 11:21:11 -0500 Subject: [PATCH 7/7] Don't allocate zeros when we don't need to. --- examples/batchmatmul.jl | 2 +- examples/batchmatmul.py | 2 +- examples/fft.jl | 2 +- examples/layernorm.jl | 18 +++++++++--------- examples/layernorm.py | 18 +++++++++--------- examples/matmul.jl | 2 +- examples/matmul.py | 2 +- examples/transpose.jl | 5 +++-- examples/transpose.py | 2 +- examples/vadd.jl | 5 +++-- examples/vadd.py | 5 +++-- 11 files changed, 33 insertions(+), 30 deletions(-) diff --git a/examples/batchmatmul.jl b/examples/batchmatmul.jl index dd2b20e..a9c2186 100644 --- a/examples/batchmatmul.jl +++ b/examples/batchmatmul.jl @@ -70,7 +70,7 @@ function prepare(; benchmark::Bool=false, return (; A = CUDA.rand(T, M, K, Batch), B = CUDA.rand(T, K, N, Batch), - C = CUDA.zeros(T, M, N, Batch), + C = CuArray{T}(undef, M, N, Batch), M, K, N, Batch ) end diff --git a/examples/batchmatmul.py b/examples/batchmatmul.py index 95b93db..07bce7e 100644 --- a/examples/batchmatmul.py +++ b/examples/batchmatmul.py @@ -53,7 +53,7 @@ def prepare(*, benchmark: bool = False, Batch: int = None, M: int = None, K: int return { "A": cp.random.randn(Batch, M, K).astype(dtype), "B": cp.random.randn(Batch, K, N).astype(dtype), - "C": cp.zeros((Batch, M, N), dtype=dtype), + "C": cp.empty((Batch, M, N), dtype=dtype), "Batch": Batch, "M": M, "K": K, diff --git a/examples/fft.jl b/examples/fft.jl index 24c8648..be1c87b 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -238,7 +238,7 @@ function prepare(; benchmark::Bool=false, D = atom_packing_dim N2D = n * 2 ÷ D x_packed = reinterpret(reshape, Float32, input) - y_packed = CUDA.zeros(Float32, D, batch, N2D) + y_packed = CuArray{Float32}(undef, D, batch, N2D) return (; input, x_packed, y_packed, diff --git a/examples/layernorm.jl b/examples/layernorm.jl index dba56ea..659ec6c 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -285,17 +285,17 @@ function prepare(; benchmark::Bool=false, X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N), W = CUDA.randn(Float32, N), B = CUDA.randn(Float32, N), - Y = CUDA.zeros(Float32, M, N), - Mean = CUDA.zeros(Float32, M), - Rstd = CUDA.zeros(Float32, M), + Y = CuArray{Float32}(undef, M, N), + Mean = CuArray{Float32}(undef, M), + Rstd = CuArray{Float32}(undef, M), # Backward inputs/outputs DY = 0.1f0 .* CUDA.randn(Float32, M, N), - DX = CUDA.zeros(Float32, M, N), - DW_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N), - DB_partial = CUDA.zeros(Float32, GROUP_SIZE_M, N), - Locks = CUDA.zeros(Int, GROUP_SIZE_M), - FINAL_DW = CUDA.zeros(Float32, N), - FINAL_DB = CUDA.zeros(Float32, N), + DX = CuArray{Float32}(undef, M, N), + DW_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N), + DB_partial = CuArray{Float32}(undef, GROUP_SIZE_M, N), + Locks = CuArray{Int}(undef, GROUP_SIZE_M), + FINAL_DW = CuArray{Float32}(undef, N), + FINAL_DB = CuArray{Float32}(undef, N), # Metadata M, N, eps, GROUP_SIZE_M ) diff --git a/examples/layernorm.py b/examples/layernorm.py index bf9e65a..b65e136 100644 --- a/examples/layernorm.py +++ b/examples/layernorm.py @@ -139,17 +139,17 @@ def prepare(*, benchmark: bool = False, M: int = None, N: int = None, eps: float "X": (-2.3 + 0.5 * cp.random.randn(M, N)).astype(dtype), "W": cp.random.randn(N).astype(dtype), "B": cp.random.randn(N).astype(dtype), - "Y": cp.zeros((M, N), dtype=dtype), - "Mean": cp.zeros(M, dtype=np.float32), - "Rstd": cp.zeros(M, dtype=np.float32), + "Y": cp.empty((M, N), dtype=dtype), + "Mean": cp.empty(M, dtype=np.float32), + "Rstd": cp.empty(M, dtype=np.float32), # Backward inputs/outputs "DY": (0.1 * cp.random.randn(M, N)).astype(dtype), - "DX": cp.zeros((M, N), dtype=dtype), - "DW_partial": cp.zeros((GROUP_SIZE_M, N), dtype=np.float32), - "DB_partial": cp.zeros((GROUP_SIZE_M, N), dtype=np.float32), - "Locks": cp.zeros(GROUP_SIZE_M, dtype=np.int32), - "FINAL_DW": cp.zeros(N, dtype=dtype), - "FINAL_DB": cp.zeros(N, dtype=dtype), + "DX": cp.empty((M, N), dtype=dtype), + "DW_partial": cp.empty((GROUP_SIZE_M, N), dtype=np.float32), + "DB_partial": cp.empty((GROUP_SIZE_M, N), dtype=np.float32), + "Locks": cp.empty(GROUP_SIZE_M, dtype=np.int32), + "FINAL_DW": cp.empty(N, dtype=dtype), + "FINAL_DB": cp.empty(N, dtype=dtype), # Metadata "eps": eps, "M": M, diff --git a/examples/matmul.jl b/examples/matmul.jl index c1f04d8..47a24e6 100644 --- a/examples/matmul.jl +++ b/examples/matmul.jl @@ -75,7 +75,7 @@ function prepare(; benchmark::Bool=false, return (; A = CUDA.rand(T, M, K), B = CUDA.rand(T, K, N), - C = CUDA.zeros(T, M, N), + C = CuArray{T}(undef, M, N), M, N, K ) end diff --git a/examples/matmul.py b/examples/matmul.py index 26766a4..ab367cf 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -59,7 +59,7 @@ def prepare(*, benchmark: bool = False, M: int = None, N: int = None, K: int = N return { "A": cp.random.randn(M, K).astype(dtype), "B": cp.random.randn(K, N).astype(dtype), - "C": cp.zeros((M, N), dtype=dtype), + "C": cp.empty((M, N), dtype=dtype), "M": M, "N": N, "K": K diff --git a/examples/transpose.jl b/examples/transpose.jl index eb1beb7..773c988 100644 --- a/examples/transpose.jl +++ b/examples/transpose.jl @@ -25,9 +25,10 @@ function prepare(; benchmark::Bool=false, m::Int=benchmark ? 8192 : 1024, n::Int=benchmark ? 8192 : 512, T::DataType=Float32) + x = CUDA.rand(T, m, n) return (; - x = CUDA.rand(T, m, n), - y = CUDA.zeros(T, n, m), + x, + y = similar(x, n, m), m, n ) end diff --git a/examples/transpose.py b/examples/transpose.py index 61ba891..1996a3b 100644 --- a/examples/transpose.py +++ b/examples/transpose.py @@ -28,7 +28,7 @@ def prepare(*, benchmark: bool = False, M: int = None, N: int = None, dtype=np.f N = 8192 if benchmark else 512 return { "input": cp.random.rand(M, N).astype(dtype), - "output": cp.zeros((N, M), dtype=dtype), + "output": cp.empty((N, M), dtype=dtype), "M": M, "N": N } diff --git a/examples/vadd.jl b/examples/vadd.jl index f523677..14444e0 100644 --- a/examples/vadd.jl +++ b/examples/vadd.jl @@ -51,10 +51,11 @@ end function prepare(; benchmark::Bool=false, shape::Tuple=benchmark ? (2^27,) : (1_024_000,), use_gather::Bool=false, T::DataType=Float32) + a = CUDA.rand(T, shape...) return (; - a = CUDA.rand(T, shape...), + a, b = CUDA.rand(T, shape...), - c = CUDA.zeros(T, shape...), + c = similar(a), shape, use_gather ) diff --git a/examples/vadd.py b/examples/vadd.py index ccbc058..566ba94 100644 --- a/examples/vadd.py +++ b/examples/vadd.py @@ -52,10 +52,11 @@ def prepare(*, benchmark: bool = False, shape: tuple = None, use_gather: bool = """Allocate and initialize data for vector addition.""" if shape is None: shape = (2**27,) if benchmark else (1_024_000,) + a = cp.random.rand(*shape).astype(dtype) return { - "a": cp.random.rand(*shape).astype(dtype), + "a": a, "b": cp.random.rand(*shape).astype(dtype), - "c": cp.zeros(shape, dtype=dtype), + "c": cp.empty_like(a), "shape": shape, "use_gather": use_gather }