Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ LocalPreferences.toml
CLAUDE.md
AGENTS.md
TODO.md
__pycache__
90 changes: 71 additions & 19 deletions examples/batchmatmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,89 @@ function batch_matmul_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.T
return nothing
end

function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) where T
name = something(name, "batch_matmul ($M x $K x $Batch) @ ($K x $N x $Batch), $T, tiles=$tm x $tn x $tk")
println("--- $name ---")

# Batch-last ordering for optimal column-major access
A = CUDA.rand(T, M, K, Batch)
B = CUDA.rand(T, K, N, Batch)
C = CUDA.zeros(T, M, N, Batch)
#=============================================================================
Example harness
=============================================================================#

function prepare(; benchmark::Bool=false,
M::Int=benchmark ? 1024 : 256,
K::Int=benchmark ? 512 : 128,
N::Int=benchmark ? 2048 : 256,
Batch::Int=benchmark ? 8 : 4,
T::DataType=Float32)
return (;
A = CUDA.rand(T, M, K, Batch),
B = CUDA.rand(T, K, N, Batch),
C = CuArray{T}(undef, M, N, Batch),
M, K, N, Batch
)
end

# 3D grid: (M_tiles, N_tiles, Batch)
function run(data; tm::Int=64, tn::Int=64, tk::Int=64, nruns::Int=1, warmup::Int=0)
(; A, B, C, M, N, Batch) = data
grid = (cld(M, tm), cld(N, tn), Batch)

# Launch kernel
ct.launch(batch_matmul_kernel, grid, A, B, C,
ct.Constant(tm), ct.Constant(tn), ct.Constant(tk))
CUDA.@sync for _ in 1:warmup
ct.launch(batch_matmul_kernel, grid, A, B, C,
ct.Constant(tm), ct.Constant(tn), ct.Constant(tk))
end

# Verify result - compute batched matmul on CPU
times = Float64[]
for _ in 1:nruns
t = CUDA.@elapsed ct.launch(batch_matmul_kernel, grid, A, B, C,
ct.Constant(tm), ct.Constant(tn), ct.Constant(tk))
push!(times, t * 1000) # ms
end

return (; C, times)
end

function verify(data, result)
(; A, B, M, N, Batch) = data
A_cpu = Array(A)
B_cpu = Array(B)
expected = similar(A_cpu, M, N, Batch)
for b in 1:Batch
expected[:, :, b] = A_cpu[:, :, b] * B_cpu[:, :, b]
end
result = Array(C)
@assert isapprox(Array(result.C), expected, rtol=1e-2, atol=1e-2) "max diff: $(maximum(abs.(Array(result.C) - expected)))"
end

if isapprox(result, expected, rtol=1e-2, atol=1e-2)
println(" passed")
else
max_diff = maximum(abs.(result - expected))
println(" FAILED (max diff: $max_diff)")
#=============================================================================
Reference implementations for benchmarking
=============================================================================#

function run_others(data; nruns::Int=1, warmup::Int=0)
(; A, B, M, N, Batch) = data
results = Dict{String, Vector{Float64}}()

C_cublas = similar(A, M, N, Batch)

# cuBLAS batched gemm via CUBLAS.gemm_strided_batched!
CUDA.@sync for _ in 1:warmup
CUDA.CUBLAS.gemm_strided_batched!('N', 'N', one(eltype(A)), A, B, zero(eltype(A)), C_cublas)
end
times_cublas = Float64[]
for _ in 1:nruns
t = CUDA.@elapsed CUDA.CUBLAS.gemm_strided_batched!('N', 'N', one(eltype(A)), A, B, zero(eltype(A)), C_cublas)
push!(times_cublas, t * 1000)
end
results["cuBLAS batched"] = times_cublas

return results
end

#=============================================================================
Main
=============================================================================#

function test_batch_matmul(::Type{T}, M, K, N, Batch, tm, tn, tk; name=nothing) where T
name = something(name, "batch_matmul ($M x $K x $Batch) @ ($K x $N x $Batch), $T, tiles=$tm x $tn x $tk")
println("--- $name ---")
data = prepare(; M, K, N, Batch, T)
result = run(data; tm, tn, tk)
verify(data, result)
println(" passed")
end

function main()
Expand Down
165 changes: 165 additions & 0 deletions examples/batchmatmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#!/usr/bin/env python3
"""
Batch Matrix Multiplication example - cuTile Python
"""

import cupy as cp
import numpy as np
import cuda.tile as ct
from math import ceil

@ct.kernel
def batchmatmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int]):
"""CuTile kernel for batch matrix multiplication
A has shape (Batch, M, K), B has shape (Batch, K, N) and C has shape (Batch, M, N)
Grid: (Batch, M_tiles, N_tiles)
"""
pid_batch = ct.bid(0)
bidx = ct.bid(1)
bidy = ct.bid(2)

num_k_tiles = ct.cdiv(A.shape[2], tk)
accumulator = ct.full((tm, tn), 0.0, dtype=ct.float32)
zero_pad = ct.PaddingMode.ZERO

for k in range(num_k_tiles):
a = ct.load(A, index=(pid_batch, bidx, k), shape=(1, tm, tk), padding_mode=zero_pad)
a = ct.reshape(a, (tm, tk))

b = ct.load(B, index=(pid_batch, k, bidy), shape=(1, tk, tn), padding_mode=zero_pad)
b = ct.reshape(b, (tk, tn))

accumulator = ct.mma(a, b, acc=accumulator)

result = ct.astype(accumulator, C.dtype)
result_3d = ct.reshape(result, (1, tm, tn))
ct.store(C, index=(pid_batch, bidx, bidy), tile=result_3d)


#=============================================================================
# Example harness
#=============================================================================

def prepare(*, benchmark: bool = False, Batch: int = None, M: int = None, K: int = None, N: int = None, dtype=np.float16):
"""Allocate and initialize data for batch matmul."""
if Batch is None:
Batch = 8 if benchmark else 4
if M is None:
M = 1024 if benchmark else 256
if K is None:
K = 512 if benchmark else 128
if N is None:
N = 2048 if benchmark else 256
return {
"A": cp.random.randn(Batch, M, K).astype(dtype),
"B": cp.random.randn(Batch, K, N).astype(dtype),
"C": cp.empty((Batch, M, N), dtype=dtype),
"Batch": Batch,
"M": M,
"K": K,
"N": N
}


def run(data, *, tm: int = 64, tn: int = 64, tk: int = 64, nruns: int = 1, warmup: int = 0):
"""Run batch matmul kernel with timing."""
A, B, C = data["A"], data["B"], data["C"]
Batch, M, N = data["Batch"], data["M"], data["N"]

grid = (Batch, ceil(M / tm), ceil(N / tn))
stream = cp.cuda.get_current_stream()

# Warmup
for _ in range(warmup):
ct.launch(stream, grid, batchmatmul_cutile_kernel, (A, B, C, tm, tn, tk))
cp.cuda.runtime.deviceSynchronize()

# Timed runs
times = []
for _ in range(nruns):
start = cp.cuda.Event()
end = cp.cuda.Event()
start.record(stream)
ct.launch(stream, grid, batchmatmul_cutile_kernel, (A, B, C, tm, tn, tk))
end.record(stream)
end.synchronize()
times.append(cp.cuda.get_elapsed_time(start, end)) # ms

return {"C": C, "times": times}


def verify(data, result):
"""Verify batch matmul results."""
A_np = cp.asnumpy(data["A"]).astype(np.float32)
B_np = cp.asnumpy(data["B"]).astype(np.float32)
C_np = cp.asnumpy(result["C"]).astype(np.float32)
Batch, M, N = data["Batch"], data["M"], data["N"]

expected = np.zeros((Batch, M, N), dtype=np.float32)
for b in range(Batch):
expected[b] = A_np[b] @ B_np[b]
assert np.allclose(C_np, expected, rtol=1e-1, atol=1e-1), \
f"batchmatmul incorrect! max diff: {np.max(np.abs(C_np - expected))}"


#=============================================================================
# Reference implementations for benchmarking
#=============================================================================

def run_others(data, *, nruns: int = 1, warmup: int = 0):
"""Run reference implementations for comparison."""
import torch

results = {}
A_cp, B_cp = data["A"], data["B"]
Batch, M, N = data["Batch"], data["M"], data["N"]

# PyTorch bmm
A_torch = torch.as_tensor(A_cp, device='cuda')
B_torch = torch.as_tensor(B_cp, device='cuda')
C_torch = torch.zeros(Batch, M, N, dtype=A_torch.dtype, device='cuda')

for _ in range(warmup):
torch.bmm(A_torch, B_torch, out=C_torch)
torch.cuda.synchronize()

times_torch = []
for _ in range(nruns):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.bmm(A_torch, B_torch, out=C_torch)
end.record()
torch.cuda.synchronize()
times_torch.append(start.elapsed_time(end))
results["PyTorch bmm"] = times_torch

return results


#=============================================================================
# Main
#=============================================================================

def test_batchmatmul(Batch, M, K, N, tm, tn, tk, dtype=np.float16, name=None):
"""Test batch matmul with given parameters."""
name = name or f"batchmatmul ({Batch}x{M}x{K}) @ ({Batch}x{K}x{N}), tiles={tm}x{tn}x{tk}, dtype={dtype.__name__}"
print(f"--- {name} ---")
data = prepare(Batch=Batch, M=M, K=K, N=N, dtype=dtype)
result = run(data, tm=tm, tn=tn, tk=tk)
verify(data, result)
print(" passed")


def main():
print("--- cuTile Batch Matrix Multiplication Examples ---\n")

test_batchmatmul(4, 256, 128, 256, 32, 32, 32, np.float32)
test_batchmatmul(4, 512, 256, 512, 64, 64, 64, np.float32)
test_batchmatmul(4, 512, 256, 1024, 128, 256, 64, np.float16)

print("\n--- All batchmatmul examples completed ---")


if __name__ == "__main__":
main()
Loading