diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bf1404c..fede161 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,30 +12,39 @@ jobs: runs-on: [self-hosted, amd-gpu] container: - image: rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.7.1 + image: rocm/dev-ubuntu-24.04:7.2-complete options: --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - submodules: recursive - + - name: Install system dependencies + run: | + apt-get update + apt-get install -y git python3.12 python3.12-venv python3-pip python3.12-dev + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 + # Remove PEP 668 externally-managed marker so pip works in this disposable container since we're not using a virtual environment + rm -f /usr/lib/python3.12/EXTERNALLY-MANAGED + - name: Set up environment run: | echo "Setting up ROCm environment..." export ROCM_PATH=/opt/rocm export PATH=$ROCM_PATH/bin:$PATH - - - name: Install system dependencies + + - name: Install PyTorch with ROCm support run: | - apt-get update - apt-get install -y git - - - name: Install Python dependencies + pip3 install torch --index-url https://download.pytorch.org/whl/rocm7.1 + + - name: Install Triton run: | - python3 -m pip install --upgrade pip pip3 install -U triton + + - name: Checkout tritonBLAS code + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Install tritonBLAS + run: | pip3 install -e . - name: Verify installation diff --git a/include/tritonblas/__init__.py b/include/tritonblas/__init__.py index 6f9365f..2a2c8b3 100644 --- a/include/tritonblas/__init__.py +++ b/include/tritonblas/__init__.py @@ -1,4 +1,5 @@ from .matmul import matmul, matmul_a8w8 from .matmul import matmul_lt, matmul_a8w8_lt from .matmul import matmul_fp4 +from .matmul import addmm from .origami import OrigamiMatmulSelector diff --git a/include/tritonblas/kernels/persistent_gemm.py b/include/tritonblas/kernels/persistent_gemm.py index d0fdb09..106e92e 100644 --- a/include/tritonblas/kernels/persistent_gemm.py +++ b/include/tritonblas/kernels/persistent_gemm.py @@ -13,7 +13,8 @@ import torch from tritonblas.kernels.stages import ( - ScheduleContext, + ScheduleContext, + make_schedule_context, GemmContext, make_input_view, make_output_view, @@ -52,7 +53,7 @@ def persistent_matmul( BIAS: tl.constexpr, EVEN_K: tl.constexpr, QUANTIZED: tl.constexpr = False, - ALLOW_TF32: tl.constexpr = torch.backends.cuda.matmul.allow_tf32, + ALLOW_TF32: tl.constexpr = True, ): """ Persistent GEMM kernel using GemmContext aggregate. @@ -85,7 +86,7 @@ def persistent_matmul( # CREATE EPILOGUE VIEWS (optional scale and bias) # ════════════════════════════════════════════════════════════════════════ scale_view = make_scale_view(A_scale_ptr, B_scale_ptr, M, N) if A_scale_ptr is not None else None - bias_view = make_bias_view(bias_ptr, M, stride_bias) if BIAS else None + bias_view = make_bias_view(bias_ptr, N, stride_bias) if BIAS else None # ════════════════════════════════════════════════════════════════════════ # CONSTRUCT GEMM CONTEXT TO MANAGE MATH RELEVANT CONTEXT @@ -101,7 +102,7 @@ def persistent_matmul( # ════════════════════════════════════════════════════════════════════════ # CREATE SCHEDULE CONTEXT FROM GEMM CONTEXT TO MANAGE OUTER LOOP ITERATION # ════════════════════════════════════════════════════════════════════════ - sched = ScheduleContext(M, N, K, ctx) + sched = make_schedule_context(M, N, K, ctx) # ════════════════════════════════════════════════════════════════════════ # PERSISTENT LOOP: Process multiple tiles per workgroup diff --git a/include/tritonblas/kernels/stages/__init__.py b/include/tritonblas/kernels/stages/__init__.py index eb362b3..f19c268 100644 --- a/include/tritonblas/kernels/stages/__init__.py +++ b/include/tritonblas/kernels/stages/__init__.py @@ -68,7 +68,9 @@ def kernel(A, B, C, A_scale_ptr, B_scale_ptr, bias_ptr, M, N, K, # Core aggregates from .tile import Tile from .gemm_context import GemmContext -from .schedule import ScheduleContext +from .schedule import ( + ScheduleContext, make_schedule_context, +) from .matrix_view import ( InputView, OutputView, ScaleView, BiasView, make_input_view, make_tensor_view, make_output_view, diff --git a/include/tritonblas/kernels/stages/matrix_view.py b/include/tritonblas/kernels/stages/matrix_view.py index b221898..5b777ae 100644 --- a/include/tritonblas/kernels/stages/matrix_view.py +++ b/include/tritonblas/kernels/stages/matrix_view.py @@ -181,13 +181,13 @@ class BiasView: stride: Stride for bias vector (default: 1) """ ptr: tl.tensor - M: tl.tensor + N: tl.tensor stride: tl.tensor @triton.constexpr_function - def __init__(self, ptr, M, stride): + def __init__(self, ptr, N, stride): self.ptr = ptr - self.M = M + self.N = N self.stride = stride @triton.jit @@ -202,9 +202,9 @@ def apply(self, acc, tile: Tile): Returns: Accumulator with bias added """ - rm, _ = tile.indices() - bias_vector = tl.load(self.ptr + rm * self.stride, mask=rm < self.M, other=0.0) - acc = acc + bias_vector[:, None] + _, rn = tile.indices() + bias_vector = tl.load(self.ptr + rn * self.stride, mask=rn < self.N, other=0.0) + acc = acc + bias_vector[None, :] return acc @@ -317,7 +317,7 @@ def load(self, tile: Tile, boundary: tl.constexpr = False, cache_modifier: tl.co # ============================================================================= @triton.jit -def make_input_view(ptr, rows, cols, stride_row, stride_col) -> InputView: +def make_input_view(ptr, rows, cols, stride_row, stride_col): """ Create an InputView with automatic stride type coercion. @@ -327,7 +327,7 @@ def make_input_view(ptr, rows, cols, stride_row, stride_col) -> InputView: Args: ptr: Base pointer to matrix data - rows: Number of rows (first dimension) - must be a tensor + rows: Number of rows (first dimension) cols: Number of columns (second dimension) stride_row: Stride when moving along rows stride_col: Stride when moving along columns @@ -347,23 +347,25 @@ def make_input_view(ptr, rows, cols, stride_row, stride_col) -> InputView: # TYPE PROMOTION TRICK # ═══════════════════════════════════════════════════════════════════════ # Triton aggregates require strongly-typed fields (tl.tensor). However, - # strides can be either Python ints (stride=1 for contiguous dimensions) - # or Triton tensors (stride>1 from kernel params). + # dimensions and strides can be either Python ints or Triton tensors, + # especially under torch.compile which may pass ints during tracing. # - # The pattern `stride + 0 * rows` promotes any int to a tensor: - # - 0 * rows produces a tensor with value 0 (since rows is a tensor) - # - stride + (tensor 0) = tensor with stride's value + # The pattern `value + 0 * stride_row` promotes any int to a tensor: + # - 0 * stride_row produces a tensor with value 0 (since stride_row is a tensor) + # - value + (tensor 0) = tensor with value # # This has ZERO runtime cost - the compiler constant-folds 0*x and x+0. # ═══════════════════════════════════════════════════════════════════════ + rows_t = rows + 0 * rows + cols_t = cols + 0 * rows stride_row_t = stride_row + 0 * rows stride_col_t = stride_col + 0 * rows - return InputView(ptr, rows, cols, stride_row_t, stride_col_t) + return InputView(ptr, rows_t, cols_t, stride_row_t, stride_col_t) @triton.jit -def make_output_view(ptr, rows, cols, stride_row, stride_col) -> OutputView: +def make_output_view(ptr, rows, cols, stride_row, stride_col): """ Create an OutputView with automatic stride type coercion. @@ -372,7 +374,7 @@ def make_output_view(ptr, rows, cols, stride_row, stride_col) -> OutputView: Args: ptr: Base pointer to matrix data - rows: Number of rows (first dimension) - must be a tensor + rows: Number of rows (first dimension) cols: Number of columns (second dimension) stride_row: Stride when moving along rows stride_col: Stride when moving along columns @@ -388,10 +390,12 @@ def make_output_view(ptr, rows, cols, stride_row, stride_col) -> OutputView: # ═══════════════════════════════════════════════════════════════════════ # TYPE PROMOTION TRICK - See make_input_view() for detailed explanation # ═══════════════════════════════════════════════════════════════════════ + rows_t = rows + 0 * rows + cols_t = cols + 0 * rows stride_row_t = stride_row + 0 * rows stride_col_t = stride_col + 0 * rows - return OutputView(ptr, rows, cols, stride_row_t, stride_col_t) + return OutputView(ptr, rows_t, cols_t, stride_row_t, stride_col_t) # Alias for backward compatibility @@ -399,7 +403,7 @@ def make_output_view(ptr, rows, cols, stride_row, stride_col) -> OutputView: @triton.jit -def make_scale_view(a_scale_ptr, b_scale_ptr, M, N, stride_a=1, stride_b=1) -> ScaleView: +def make_scale_view(a_scale_ptr, b_scale_ptr, M, N, stride_a=1, stride_b=1): """ Create a ScaleView for quantized GEMM epilogue. @@ -430,15 +434,15 @@ def make_scale_view(a_scale_ptr, b_scale_ptr, M, N, stride_a=1, stride_b=1) -> S @triton.jit -def make_bias_view(bias_ptr, M, stride=1) -> BiasView: +def make_bias_view(bias_ptr, N, stride=1): """ Create a BiasView for GEMM epilogue. Stores bias vector pointer with automatic stride type coercion. Args: - bias_ptr: Pointer to bias vector (length M) - M: Number of rows (for bounds checking) - must be a tensor + bias_ptr: Pointer to bias vector (length N) + N: Number of columns (for bounds checking) stride: Stride for bias vector (default: 1) Returns: @@ -446,13 +450,14 @@ def make_bias_view(bias_ptr, M, stride=1) -> BiasView: Example:: - bias_view = make_bias_view(bias_ptr, M, stride_bias) + bias_view = make_bias_view(bias_ptr, N, stride_bias) tensorC.store(acc, out_tile, bias=bias_view) """ # Type promotion for stride - stride_t = stride + 0 * M + stride_t = stride + 0 * N + N_t = N + 0 * N - return BiasView(bias_ptr, M, stride_t) + return BiasView(bias_ptr, N_t, stride_t) # ============================================================================= diff --git a/include/tritonblas/kernels/stages/schedule.py b/include/tritonblas/kernels/stages/schedule.py index a6586b3..4aa8dd0 100644 --- a/include/tritonblas/kernels/stages/schedule.py +++ b/include/tritonblas/kernels/stages/schedule.py @@ -255,3 +255,19 @@ def total_tiles(self): num_pid_m = tl.cdiv(self.M, self.ctx.block_m) num_pid_n = tl.cdiv(self.N, self.ctx.block_n) return num_pid_m * num_pid_n + + +@triton.jit +def make_schedule_context(M, N, K, ctx: GemmContext, streamk_tiles=0): + """ + Create a ScheduleContext from a GemmContext. + + Args: + M, N, K: Problem dimensions + ctx: GemmContext with block sizes and scheduling parameters + streamk_tiles: Number of tiles for Stream-K (0 = persistent only) + """ + M_t = M + 0 * M + N_t = N + 0 * M + K_t = K + 0 * M + return ScheduleContext(M_t, N_t, K_t, ctx, streamk_tiles) diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 9b23986..b924c07 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -1,12 +1,15 @@ -import torch -import triton -import random import functools +import random import time +from typing import Any, Dict, Optional, Tuple + +import torch +from torch.library import triton_op, wrap_triton +import triton + from .kernels import persistent_matmul, streamk_matmul from .kernels.fp4_matmul import fp4_matmul from .origami import OrigamiMatmulSelector -from typing import Dict, Tuple, Optional _tensor_cache = {} current_device_index = torch.cuda.current_device() @@ -22,7 +25,7 @@ # Function will behave like an LRU-Cache of heuristic results # Saves several microseconds for previously seen problems by not rerunning the heuristic unnecessarily -@functools.lru_cache(maxsize=1024) +#@functools.lru_cache(maxsize=1024) def _make_matmul_selector( M: int, N: int, @@ -52,6 +55,7 @@ def persistent_matmul_lt( b: torch.Tensor, c: torch.Tensor, selector, + bias: Optional[torch.Tensor] = None, a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, quantized: bool = False, @@ -92,13 +96,14 @@ def persistent_matmul_lt( chunk_size = min(chunk_size, total_programs // num_xcds) # TODO: Support other matmul algs. - kk = persistent_matmul[(grids,)]( + #kk = persistent_matmul[(grids,)]( + kk = wrap_triton(persistent_matmul)[(grids,)]( a, b, c, a_scale if quantized else None, # A_scale_ptr b_scale if quantized else None, # B_scale_ptr - None, # TODO: Enable bias. + bias if bias is not None else None, M, N, K, @@ -106,7 +111,7 @@ def persistent_matmul_lt( b.stride(1), c.stride(0), c.stride(1), - 0, # TODO: Enable bias stride. + bias.stride(0) if bias is not None else 0, stride_ak=a.stride(1), stride_bk=b.stride(0), BLOCK_SIZE_M=BLK_M, @@ -116,7 +121,7 @@ def persistent_matmul_lt( NUM_SMS=total_programs, NUM_XCDS=num_xcds, CHUNK_SIZE=chunk_size, - BIAS=False, + BIAS=bias is not None, EVEN_K=even_k, CACHE_MODIFIER_A=CACHE_MODIFIER_A, CACHE_MODIFIER_B=CACHE_MODIFIER_B, @@ -126,6 +131,7 @@ def persistent_matmul_lt( waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, ) return c @@ -135,6 +141,7 @@ def streamk_matmul_lt( b: torch.Tensor, c: torch.Tensor, selector, + bias: Optional[torch.Tensor] = None, sk_grid: Optional[int] = None, a_scale: Optional[torch.Tensor] = None, b_scale: Optional[torch.Tensor] = None, @@ -192,13 +199,14 @@ def streamk_matmul_lt( chunk_size = gsize_m * gsize_m chunk_size = min(chunk_size, grids // num_xcds) - kk = streamk_matmul[(grids,)]( + #kk = streamk_matmul[(grids,)]( + kk = wrap_triton(streamk_matmul)[(grids,)]( a, b, c, a_scale if quantized else None, # A_scale_ptr b_scale if quantized else None, # B_scale_ptr - None, # TODO: Enable bias. + bias if bias is not None else None, P, locks, M, @@ -208,7 +216,7 @@ def streamk_matmul_lt( b.stride(1), c.stride(0), c.stride(1), - 0, # TODO: Enable bias stride. + bias.stride(0) if bias is not None else None, stride_ak=a.stride(1), stride_bk=b.stride(0), BLOCK_SIZE_M=BLK_M, @@ -219,7 +227,7 @@ def streamk_matmul_lt( NUM_XCDS=num_xcds, CHUNK_SIZE=chunk_size, STREAMK_TILES=total_tiles_streamk, - BIAS=False, + BIAS=bias is not None, EVEN_K=even_k, CACHE_MODIFIER_A=CACHE_MODIFIER_A, CACHE_MODIFIER_B=CACHE_MODIFIER_B, @@ -253,22 +261,118 @@ def matmul_a8w8_lt( else: return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True) -def matmul( + +@triton_op("tritonblas::_matmul", mutates_args={}) +def _matmul( a: torch.Tensor, b: torch.Tensor, - c: torch.Tensor, - enable_streamk=False, - sk_grid=None, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None, +) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" + M, K = a.shape + _, N = b.shape + + # Allocate an output tensor + out = a.new_empty(M, N) + + # Query Origami for solution + selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk) + if enable_streamk: + return streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid) + else: + return persistent_matmul_lt(a, b, out, selector) + + +def _setup_context_matmul_backwards( + ctx: Any, + inputs: tuple[Any, ...], + output: Any ): - assert a.shape[1] == b.shape[0], "Incompatible Dimensions" + a, b, enable_streamk, sk_grid = inputs + ctx.save_for_backward(a, b) + ctx.enable_streamk = enable_streamk + ctx.sk_grid = sk_grid + + +def _matmul_backwards( + ctx: Any, + grad_output: torch.Tensor +): + a, b = ctx.saved_tensors + enable_streamk = ctx.enable_streamk + sk_grid = ctx.sk_grid + + # Make grad_output contiguous + grad_output_cont = grad_output.contiguous() + + # grad_a = grad_output @ b^T + b_t = b.T.contiguous() + grad_a = matmul(grad_output_cont, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid) + + # grad_b = a^T @ grad_output + a_t = a.T.contiguous() + grad_b = matmul(a_t, grad_output_cont, enable_streamk=enable_streamk, sk_grid=sk_grid) + + # tuple[a, b, enable_streamk, sk_grid] + # First 2 must be in the order that matches matmul()'s forward args + # Last 2 are not part of the gradient and so are None + return grad_a, grad_b, None, None + + +_matmul.register_autograd(_matmul_backwards, + setup_context=_setup_context_matmul_backwards) + + +@triton_op("tritonblas::_matmul_out", mutates_args={'out'}) +def _matmul_out( + a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None, +) -> None: + assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape _, N = b.shape - selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, c.dtype, a.device, streamk=enable_streamk) + # Query Origami for solution + selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk) + if enable_streamk: - return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid) + streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid) else: - return persistent_matmul_lt(a, b, c, selector) + persistent_matmul_lt(a, b, out, selector) + + # Custom torch ops cannot return a value which is an alias of an input. So + # even though torch returns a pointer to the out arg when used, we can't. + return None + + +def matmul( + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None +) -> Optional[torch.Tensor]: + # If no out tensor provided - we do the allocation - we support autograd + if out is None: + return _matmul(a, b, enable_streamk, sk_grid) + + # If out tensor provided - in-place - we do NOT support autograd + # Check for autograd conditions (global and per-tensor) + if torch.is_grad_enabled() and ( + a.requires_grad + or b.requires_grad + or out.requires_grad + ): + raise RuntimeError( + "tritonblas.matmul(): functions with out=... arguments don't support " + "automatic differentiation, but one of the arguments requires grad." + ) + return _matmul_out(a, b, out, enable_streamk, sk_grid) + def matmul_a8w8( a: torch.Tensor, @@ -396,3 +500,124 @@ def matmul_fp4( ) return c + + +@triton_op("tritonblas::_addmm", mutates_args={}) +def _addmm( + bias: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None, +) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" + M, K = a.shape + _, N = b.shape + + # Query Origami for solution + selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk) + + # Allocate an output tensor + out = a.new_empty(M, N) + + if enable_streamk: + return streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid) + else: + return persistent_matmul_lt(a, b, out, selector, bias=bias) + + +def _setup_context_addmm_backwards( + ctx: Any, + inputs: tuple[Any, ...], + output: Any +): + bias, a, b, enable_streamk, sk_grid = inputs + ctx.save_for_backward(a, b) + ctx.enable_streamk = enable_streamk + ctx.sk_grid = sk_grid + + +def _addmm_backwards( + ctx: Any, + grad_output: torch.Tensor +): + a, b = ctx.saved_tensors + enable_streamk = ctx.enable_streamk + sk_grid = ctx.sk_grid + + # Make grad_output contiguous + grad_output_cont = grad_output.contiguous() + + # grad_a = grad_output @ b^T + b_t = b.T.contiguous() + grad_a = matmul(grad_output_cont, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid) + + # grad_b = a^T @ grad_output + a_t = a.T.contiguous() + grad_b = matmul(a_t, grad_output_cont, enable_streamk=enable_streamk, sk_grid=sk_grid) + + # grad_bias = sum(grad_output) + grad_bias = grad_output.sum(dim=0) + + # tuple[bias, a, b, enable_streamk, sk_grid] + # First 3 must be in the order that matches addmm()'s forward args + # Last 2 are not part of the gradient and so are None + return grad_bias, grad_a, grad_b, None, None + + +_addmm.register_autograd(_addmm_backwards, + setup_context=_setup_context_addmm_backwards) + + +@triton_op("tritonblas::_addmm_out", mutates_args={'out'}) +def _addmm_out( + bias: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None, +) -> None: + assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" + M, K = a.shape + _, N = b.shape + + # Query Origami for solution + selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk) + + if enable_streamk: + streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid) + else: + persistent_matmul_lt(a, b, out, selector, bias=bias) + + # Custom torch ops cannot return a value which is an alias of an input. So + # even though torch returns a pointer to the out arg when used, we can't. + return None + + +def addmm( + bias: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + enable_streamk: Optional[bool] = False, + sk_grid: Optional[int] = None +) -> Optional[torch.Tensor]: + # If no out tensor provided - we do the allocation - we support autograd + if out is None: + return _addmm(bias, a, b, enable_streamk, sk_grid) + + # If out tensor provided - in-place - we do NOT support autograd + # Check for autograd conditions (global and per-tensor) + if torch.is_grad_enabled() and ( + bias.requires_grad + or a.requires_grad + or b.requires_grad + or out.requires_grad + ): + raise RuntimeError( + "tritonblas.addmm(): functions with out=... arguments don't support " + "automatic differentiation, but one of the arguments requires grad." + ) + return _addmm_out(bias, a, b, out, enable_streamk, sk_grid) + diff --git a/include/tritonblas/origami.py b/include/tritonblas/origami.py index 291e06c..e470c34 100644 --- a/include/tritonblas/origami.py +++ b/include/tritonblas/origami.py @@ -1,6 +1,7 @@ import itertools import torch import origami +import math from math import ceil diff --git a/tests/test_addmm_correctness.py b/tests/test_addmm_correctness.py new file mode 100644 index 0000000..5e3e8be --- /dev/null +++ b/tests/test_addmm_correctness.py @@ -0,0 +1,316 @@ +""" +Tests for tritonblas.addmm with torch.autograd and torch.compile support + +Tests cover: +1. Forward pass correctness against torch.addmm +2. Backward pass gradient correctness against torch.addmm +3. In-place (out=...) functionality and autograd restrictions +4. Edge cases including small dimensions +5. torch.compile compatibility +6. Persistent vs. StreamK compatibility +""" + +import pytest +import torch +import tritonblas + + +# If we don't increase this, torch will complain about too many recompilations. +torch._dynamo.config.cache_size_limit = 10000 +# Also disable caches so every compile is fresh and new issues are caught. +# Note this causes a single UserWarning that notes caches are disabled. +torch._inductor.config.force_disable_caches = True +# FIXME: Inductor seems to be initializing multiple CUDA runtimes somehow in +# relation to some of triton's new features which is causing errors unrelated to +# tritonBLAS. The error tells you to change the multiplrocessing strategy to +# 'spawn' but that actually doesn't fix the issue - you have to force +# single-threaded compilation. This needs to be fixed upstream in torch/triton. +torch._inductor.config.compile_threads = 1 + +# Standard test dimensions +STANDARD_DIMS = [ + (128, 256, 512), # Medium sizes + (256, 256, 256), # Square + (512, 1024, 768), # Larger + (2048, 1024, 512), # Wide output + (1024, 2048, 512), # Tall output +] + +# Edge case dimensions (small dimensions, N < 16 cases) +EDGE_CASE_DIMS = [ + (32, 32, 32), # Small square + (64, 16, 128), # Small N + (16, 64, 128), # Small M + (128, 8, 256), # N < 16 + (8, 128, 256), # M < 16 + (12, 12, 512), # Small M and N + (15, 17, 512), # Weird and small M and N + (19, 13, 512), # Weird and small M and N + (128, 64, 12), # Small K +] + +# Skinny matrix dimensions (stress tests) +SKINNY_DIMS = [ + (16, 16, 4096), # Very large K + (32, 32, 8192), # Large K +] + +# Data types to test +DTYPES = [torch.bfloat16, torch.float16] + +# Whether to test with torch.compile +USE_COMPILE = [False, True] + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_addmm_forward_correctness(m, n, k, dtype, use_compile): + """Test that tritonblas.addmm forward pass matches torch.addmm.""" + torch.manual_seed(42) + + a = torch.randn(m, k, device='cuda', dtype=dtype) + b = torch.randn(k, n, device='cuda', dtype=dtype) + bias = torch.randn(n, device='cuda', dtype=dtype) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # tritonblas result + result = addmm_fn(bias, a, b) + + # torch reference + expected = torch.addmm(bias, a, b) + + # Check forward correctness with relaxed tolerance for low precision + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_addmm_backward_correctness(m, n, k, dtype, use_compile): + """Test that tritonblas.addmm backward pass produces correct gradients.""" + torch.manual_seed(42) + + # Create inputs with requires_grad for tritonblas + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + + # Clone for torch reference + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + bias_ref = bias.detach().clone().requires_grad_(True) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # Forward pass + result = addmm_fn(bias, a, b) + result_ref = torch.addmm(bias_ref, a_ref, b_ref) + + # Backward pass with same upstream gradient + grad_output = torch.randn_like(result) + result.backward(grad_output) + result_ref.backward(grad_output) + + # Check gradients match + torch.testing.assert_close(bias.grad, bias_ref.grad, atol=1e-1, rtol=1e-1, + msg="bias gradient mismatch") + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1, + msg="a gradient mismatch") + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1, + msg="b gradient mismatch") + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", SKINNY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_addmm_skinny_matrices(m, n, k, dtype, use_compile): + """Test addmm with skinny matrices (large K dimension).""" + torch.manual_seed(42) + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + bias_ref = bias.detach().clone().requires_grad_(True) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # Forward + result = addmm_fn(bias, a, b) + result_ref = torch.addmm(bias_ref, a_ref, b_ref) + + torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) + + # Backward + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(bias.grad, bias_ref.grad, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_inplace_with_grad_raises(use_compile): + """Test that addmm with out=... raises RuntimeError when autograd is enabled.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + with pytest.raises(RuntimeError, match="don't support automatic differentiation"): + addmm_fn(bias, a, b, out=out) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_inplace_without_grad_works(use_compile): + """Test that addmm with out=... works when autograd is disabled.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # Should work with torch.no_grad() + with torch.no_grad(): + result = addmm_fn(bias, a, b, out=out) + + # In-place path returns None (custom ops don't support aliasing) + assert result is None, "in-place addmm should return None" + + # Verify correctness against torch + expected = torch.addmm(bias, a, b) + torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_inplace_output_correctness(use_compile): + """Test that addmm in-place mode produces correct results.""" + torch.manual_seed(42) + m, n, k = 128, 256, 512 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype) + b = torch.randn(k, n, device='cuda', dtype=dtype) + bias = torch.randn(n, device='cuda', dtype=dtype) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + with torch.no_grad(): + addmm_fn(bias, a, b, out=out) + + expected = torch.addmm(bias, a, b) + torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_no_grad_tensors(use_compile): + """Test addmm works when input tensors don't require grad.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=False) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=False) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=False) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + result = addmm_fn(bias, a, b) + expected = torch.addmm(bias, a, b) + + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_addmm_partial_grad(use_compile): + """Test addmm when only some inputs require grad.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + # Only a requires grad + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=False) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=False) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone() + bias_ref = bias.detach().clone() + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + result = addmm_fn(bias, a, b) + result_ref = torch.addmm(bias_ref, a_ref, b_ref) + + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", [False, True]) +def test_addmm_streamk_modes(enable_streamk, use_compile): + """Test addmm with different streamk settings.""" + torch.manual_seed(42) + m, n, k = 256, 256, 256 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + bias_ref = bias.detach().clone().requires_grad_(True) + + addmm_fn = tritonblas.addmm + if use_compile: + addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) + + # Forward + result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) + result_ref = torch.addmm(bias_ref, a_ref, b_ref) + + torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) + + # Backward + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(bias.grad, bias_ref.grad, atol=1e-1, rtol=1e-1) diff --git a/tests/test_matmul_correctness.py b/tests/test_matmul_correctness.py new file mode 100644 index 0000000..d555ba5 --- /dev/null +++ b/tests/test_matmul_correctness.py @@ -0,0 +1,299 @@ +""" +Tests for tritonblas.matmul with torch.autograd and torch.compile support + +Tests cover: +1. Forward pass correctness against torch.mm +2. Backward pass gradient correctness against torch.mm +3. In-place (out=...) functionality and autograd restrictions +4. Edge cases including small dimensions +5. torch.compile compatibility +6. Persistent vs. StreamK compatibility +""" + +import pytest +import torch +import tritonblas + + +# If we don't increase this, torch will complain about too many recompilations. +torch._dynamo.config.cache_size_limit = 10000 +# Also disable caches so every compile is fresh and new issues are caught. +# Note this causes a single UserWarning that notes caches are disabled. +torch._inductor.config.force_disable_caches = True +# FIXME: Inductor seems to be initializing multiple CUDA runtimes somehow in +# relation to some of triton's new features which is causing errors unrelated to +# tritonBLAS. The error tells you to change the multiplrocessing strategy to +# 'spawn' but that actually doesn't fix the issue - you have to force +# single-threaded compilation. This needs to be fixed upstream in torch/triton. +torch._inductor.config.compile_threads = 1 + +# Standard test dimensions +STANDARD_DIMS = [ + (128, 256, 512), # Medium sizes + (256, 256, 256), # Square + (512, 1024, 768), # Larger + (2048, 1024, 512), # Wide output + (1024, 2048, 512), # Tall output +] + +# Edge case dimensions (small dimensions, N < 16 cases) +EDGE_CASE_DIMS = [ + (32, 32, 32), # Small square + (64, 16, 128), # Small N + (16, 64, 128), # Small M + (128, 8, 256), # N < 16 + (8, 128, 256), # M < 16 + (12, 12, 512), # Small M and N + (15, 17, 512), # Weird and small M and N + (19, 13, 512), # Weird and small M and N + (128, 64, 12), # Small K +] + +# Skinny matrix dimensions (stress tests) +SKINNY_DIMS = [ + (16, 16, 4096), # Very large K + (32, 32, 8192), # Large K +] + +# Data types to test +DTYPES = [torch.bfloat16, torch.float16] + +# Whether to test with torch.compile +USE_COMPILE = [False, True] + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_matmul_forward_correctness(m, n, k, dtype, use_compile): + """Test that tritonblas.matmul forward pass matches torch.mm.""" + torch.manual_seed(42) + + a = torch.randn(m, k, device='cuda', dtype=dtype) + b = torch.randn(k, n, device='cuda', dtype=dtype) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # tritonblas result + result = matmul_fn(a, b) + + # torch reference + expected = torch.mm(a, b) + + # Check forward correctness with relaxed tolerance for low precision + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_matmul_backward_correctness(m, n, k, dtype, use_compile): + """Test that tritonblas.matmul backward pass produces correct gradients.""" + torch.manual_seed(42) + + # Create inputs with requires_grad for tritonblas + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + + # Clone for torch reference + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # Forward pass + result = matmul_fn(a, b) + result_ref = torch.mm(a_ref, b_ref) + + # Backward pass with same upstream gradient + grad_output = torch.randn_like(result) + result.backward(grad_output) + result_ref.backward(grad_output) + + # Check gradients match + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1, + msg="a gradient mismatch") + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1, + msg="b gradient mismatch") + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("m, n, k", SKINNY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_matmul_skinny_matrices(m, n, k, dtype, use_compile): + """Test matmul with skinny matrices (large K dimension).""" + torch.manual_seed(42) + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # Forward + result = matmul_fn(a, b) + result_ref = torch.mm(a_ref, b_ref) + + torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) + + # Backward + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_inplace_with_grad_raises(use_compile): + """Test that matmul with out=... raises RuntimeError when autograd is enabled.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + with pytest.raises(RuntimeError, match="don't support automatic differentiation"): + matmul_fn(a, b, out=out) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_inplace_without_grad_works(use_compile): + """Test that matmul with out=... works when autograd is disabled.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # Should work with torch.no_grad() + with torch.no_grad(): + result = matmul_fn(a, b, out=out) + + # In-place path returns None (custom ops don't support aliasing) + assert result is None, "in-place matmul should return None" + + # Verify correctness against torch + expected = torch.mm(a, b) + torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_inplace_output_correctness(use_compile): + """Test that matmul in-place mode produces correct results.""" + torch.manual_seed(42) + m, n, k = 128, 256, 512 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype) + b = torch.randn(k, n, device='cuda', dtype=dtype) + out = torch.empty(m, n, device='cuda', dtype=dtype) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + with torch.no_grad(): + matmul_fn(a, b, out=out) + + expected = torch.mm(a, b) + torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_no_grad_tensors(use_compile): + """Test matmul works when input tensors don't require grad.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=False) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=False) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + result = matmul_fn(a, b) + expected = torch.mm(a, b) + + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +def test_matmul_partial_grad(use_compile): + """Test matmul when only some inputs require grad.""" + torch.manual_seed(42) + m, n, k = 64, 64, 64 + dtype = torch.bfloat16 + + # Only a requires grad + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=False) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone() + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + result = matmul_fn(a, b) + result_ref = torch.mm(a_ref, b_ref) + + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("use_compile", USE_COMPILE) +@pytest.mark.parametrize("enable_streamk", [False, True]) +def test_matmul_streamk_modes(enable_streamk, use_compile): + """Test matmul with different streamk settings.""" + torch.manual_seed(42) + m, n, k = 256, 256, 256 + dtype = torch.bfloat16 + + a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) + b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) + + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + matmul_fn = tritonblas.matmul + if use_compile: + matmul_fn = torch.compile(tritonblas.matmul, fullgraph=True) + + # Forward + result = matmul_fn(a, b, enable_streamk=enable_streamk) + result_ref = torch.mm(a_ref, b_ref) + + torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) + + # Backward + result.sum().backward() + result_ref.sum().backward() + + torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1)