Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,39 @@ jobs:
runs-on: [self-hosted, amd-gpu]

container:
image: rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
image: rocm/dev-ubuntu-24.04:7.2-complete
options: --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined

steps:
- name: Checkout code
uses: actions/checkout@v4
with:
submodules: recursive

- name: Install system dependencies
run: |
apt-get update
apt-get install -y git python3.12 python3.12-venv python3-pip python3.12-dev
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
# Remove PEP 668 externally-managed marker so pip works in this disposable container since we're not using a virtual environment
rm -f /usr/lib/python3.12/EXTERNALLY-MANAGED

- name: Set up environment
run: |
echo "Setting up ROCm environment..."
export ROCM_PATH=/opt/rocm
export PATH=$ROCM_PATH/bin:$PATH
- name: Install system dependencies

- name: Install PyTorch with ROCm support
run: |
apt-get update
apt-get install -y git

- name: Install Python dependencies
pip3 install torch --index-url https://download.pytorch.org/whl/rocm7.1

- name: Install Triton
run: |
python3 -m pip install --upgrade pip
pip3 install -U triton

- name: Checkout tritonBLAS code
uses: actions/checkout@v4
with:
submodules: recursive

- name: Install tritonBLAS
run: |
pip3 install -e .

- name: Verify installation
Expand Down
1 change: 1 addition & 0 deletions include/tritonblas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .matmul import matmul, matmul_a8w8
from .matmul import matmul_lt, matmul_a8w8_lt
from .matmul import matmul_fp4
from .matmul import addmm
from .origami import OrigamiMatmulSelector
9 changes: 5 additions & 4 deletions include/tritonblas/kernels/persistent_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import torch

from tritonblas.kernels.stages import (
ScheduleContext,
ScheduleContext,
make_schedule_context,
GemmContext,
make_input_view,
make_output_view,
Expand Down Expand Up @@ -52,7 +53,7 @@ def persistent_matmul(
BIAS: tl.constexpr,
EVEN_K: tl.constexpr,
QUANTIZED: tl.constexpr = False,
ALLOW_TF32: tl.constexpr = torch.backends.cuda.matmul.allow_tf32,
ALLOW_TF32: tl.constexpr = True,
):
"""
Persistent GEMM kernel using GemmContext aggregate.
Expand Down Expand Up @@ -85,7 +86,7 @@ def persistent_matmul(
# CREATE EPILOGUE VIEWS (optional scale and bias)
# ════════════════════════════════════════════════════════════════════════
scale_view = make_scale_view(A_scale_ptr, B_scale_ptr, M, N) if A_scale_ptr is not None else None
bias_view = make_bias_view(bias_ptr, M, stride_bias) if BIAS else None
bias_view = make_bias_view(bias_ptr, N, stride_bias) if BIAS else None

# ════════════════════════════════════════════════════════════════════════
# CONSTRUCT GEMM CONTEXT TO MANAGE MATH RELEVANT CONTEXT
Expand All @@ -101,7 +102,7 @@ def persistent_matmul(
# ════════════════════════════════════════════════════════════════════════
# CREATE SCHEDULE CONTEXT FROM GEMM CONTEXT TO MANAGE OUTER LOOP ITERATION
# ════════════════════════════════════════════════════════════════════════
sched = ScheduleContext(M, N, K, ctx)
sched = make_schedule_context(M, N, K, ctx)

# ════════════════════════════════════════════════════════════════════════
# PERSISTENT LOOP: Process multiple tiles per workgroup
Expand Down
4 changes: 3 additions & 1 deletion include/tritonblas/kernels/stages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def kernel(A, B, C, A_scale_ptr, B_scale_ptr, bias_ptr, M, N, K,
# Core aggregates
from .tile import Tile
from .gemm_context import GemmContext
from .schedule import ScheduleContext
from .schedule import (
ScheduleContext, make_schedule_context,
)
from .matrix_view import (
InputView, OutputView, ScaleView, BiasView,
make_input_view, make_tensor_view, make_output_view,
Expand Down
53 changes: 29 additions & 24 deletions include/tritonblas/kernels/stages/matrix_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,13 @@ class BiasView:
stride: Stride for bias vector (default: 1)
"""
ptr: tl.tensor
M: tl.tensor
N: tl.tensor
stride: tl.tensor

@triton.constexpr_function
def __init__(self, ptr, M, stride):
def __init__(self, ptr, N, stride):
self.ptr = ptr
self.M = M
self.N = N
self.stride = stride

@triton.jit
Expand All @@ -202,9 +202,9 @@ def apply(self, acc, tile: Tile):
Returns:
Accumulator with bias added
"""
rm, _ = tile.indices()
bias_vector = tl.load(self.ptr + rm * self.stride, mask=rm < self.M, other=0.0)
acc = acc + bias_vector[:, None]
_, rn = tile.indices()
bias_vector = tl.load(self.ptr + rn * self.stride, mask=rn < self.N, other=0.0)
acc = acc + bias_vector[None, :]
return acc


Expand Down Expand Up @@ -317,7 +317,7 @@ def load(self, tile: Tile, boundary: tl.constexpr = False, cache_modifier: tl.co
# =============================================================================

@triton.jit
def make_input_view(ptr, rows, cols, stride_row, stride_col) -> InputView:
def make_input_view(ptr, rows, cols, stride_row, stride_col):
"""
Create an InputView with automatic stride type coercion.

Expand All @@ -327,7 +327,7 @@ def make_input_view(ptr, rows, cols, stride_row, stride_col) -> InputView:

Args:
ptr: Base pointer to matrix data
rows: Number of rows (first dimension) - must be a tensor
rows: Number of rows (first dimension)
cols: Number of columns (second dimension)
stride_row: Stride when moving along rows
stride_col: Stride when moving along columns
Expand All @@ -347,23 +347,25 @@ def make_input_view(ptr, rows, cols, stride_row, stride_col) -> InputView:
# TYPE PROMOTION TRICK
# ═══════════════════════════════════════════════════════════════════════
# Triton aggregates require strongly-typed fields (tl.tensor). However,
# strides can be either Python ints (stride=1 for contiguous dimensions)
# or Triton tensors (stride>1 from kernel params).
# dimensions and strides can be either Python ints or Triton tensors,
# especially under torch.compile which may pass ints during tracing.
#
# The pattern `stride + 0 * rows` promotes any int to a tensor:
# - 0 * rows produces a tensor with value 0 (since rows is a tensor)
# - stride + (tensor 0) = tensor with stride's value
# The pattern `value + 0 * stride_row` promotes any int to a tensor:
# - 0 * stride_row produces a tensor with value 0 (since stride_row is a tensor)
# - value + (tensor 0) = tensor with value
#
# This has ZERO runtime cost - the compiler constant-folds 0*x and x+0.
# ═══════════════════════════════════════════════════════════════════════
rows_t = rows + 0 * rows
cols_t = cols + 0 * rows
stride_row_t = stride_row + 0 * rows
stride_col_t = stride_col + 0 * rows

return InputView(ptr, rows, cols, stride_row_t, stride_col_t)
return InputView(ptr, rows_t, cols_t, stride_row_t, stride_col_t)


@triton.jit
def make_output_view(ptr, rows, cols, stride_row, stride_col) -> OutputView:
def make_output_view(ptr, rows, cols, stride_row, stride_col):
"""
Create an OutputView with automatic stride type coercion.

Expand All @@ -372,7 +374,7 @@ def make_output_view(ptr, rows, cols, stride_row, stride_col) -> OutputView:

Args:
ptr: Base pointer to matrix data
rows: Number of rows (first dimension) - must be a tensor
rows: Number of rows (first dimension)
cols: Number of columns (second dimension)
stride_row: Stride when moving along rows
stride_col: Stride when moving along columns
Expand All @@ -388,18 +390,20 @@ def make_output_view(ptr, rows, cols, stride_row, stride_col) -> OutputView:
# ═══════════════════════════════════════════════════════════════════════
# TYPE PROMOTION TRICK - See make_input_view() for detailed explanation
# ═══════════════════════════════════════════════════════════════════════
rows_t = rows + 0 * rows
cols_t = cols + 0 * rows
stride_row_t = stride_row + 0 * rows
stride_col_t = stride_col + 0 * rows

return OutputView(ptr, rows, cols, stride_row_t, stride_col_t)
return OutputView(ptr, rows_t, cols_t, stride_row_t, stride_col_t)


# Alias for backward compatibility
make_tensor_view = make_input_view


@triton.jit
def make_scale_view(a_scale_ptr, b_scale_ptr, M, N, stride_a=1, stride_b=1) -> ScaleView:
def make_scale_view(a_scale_ptr, b_scale_ptr, M, N, stride_a=1, stride_b=1):
"""
Create a ScaleView for quantized GEMM epilogue.

Expand Down Expand Up @@ -430,29 +434,30 @@ def make_scale_view(a_scale_ptr, b_scale_ptr, M, N, stride_a=1, stride_b=1) -> S


@triton.jit
def make_bias_view(bias_ptr, M, stride=1) -> BiasView:
def make_bias_view(bias_ptr, N, stride=1):
"""
Create a BiasView for GEMM epilogue.

Stores bias vector pointer with automatic stride type coercion.

Args:
bias_ptr: Pointer to bias vector (length M)
M: Number of rows (for bounds checking) - must be a tensor
bias_ptr: Pointer to bias vector (length N)
N: Number of columns (for bounds checking)
stride: Stride for bias vector (default: 1)

Returns:
BiasView with all fields as tensors

Example::

bias_view = make_bias_view(bias_ptr, M, stride_bias)
bias_view = make_bias_view(bias_ptr, N, stride_bias)
tensorC.store(acc, out_tile, bias=bias_view)
"""
# Type promotion for stride
stride_t = stride + 0 * M
stride_t = stride + 0 * N
N_t = N + 0 * N

return BiasView(bias_ptr, M, stride_t)
return BiasView(bias_ptr, N_t, stride_t)


# =============================================================================
Expand Down
16 changes: 16 additions & 0 deletions include/tritonblas/kernels/stages/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,19 @@ def total_tiles(self):
num_pid_m = tl.cdiv(self.M, self.ctx.block_m)
num_pid_n = tl.cdiv(self.N, self.ctx.block_n)
return num_pid_m * num_pid_n


@triton.jit
def make_schedule_context(M, N, K, ctx: GemmContext, streamk_tiles=0):
"""
Create a ScheduleContext from a GemmContext.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is required because the M,N,K can be const or non const depending on their value (const if it's one), so we type promote here to match strong typing in aggregates.


Args:
M, N, K: Problem dimensions
ctx: GemmContext with block sizes and scheduling parameters
streamk_tiles: Number of tiles for Stream-K (0 = persistent only)
"""
M_t = M + 0 * M
N_t = N + 0 * M
K_t = K + 0 * M
return ScheduleContext(M_t, N_t, K_t, ctx, streamk_tiles)
Loading
Loading