From 438c9d6b4ee8a3ef8a376840cac68bf1e153076a Mon Sep 17 00:00:00 2001 From: Devanshu Ladsaria Date: Fri, 12 Sep 2025 23:22:21 +0000 Subject: [PATCH 1/3] WCC Grouped Gemm Implementation --- include/tritonblas/__init__.py | 1 + include/tritonblas/grouped_gemm.py | 91 +++++ .../tritonblas/internal/wcc_grouped_gemm.py | 385 ++++++++++++++++++ tests/test_grouped_gemm.py | 49 +++ 4 files changed, 526 insertions(+) create mode 100644 include/tritonblas/grouped_gemm.py create mode 100644 include/tritonblas/internal/wcc_grouped_gemm.py create mode 100644 tests/test_grouped_gemm.py diff --git a/include/tritonblas/__init__.py b/include/tritonblas/__init__.py index 3ef3fe2..46a2d70 100644 --- a/include/tritonblas/__init__.py +++ b/include/tritonblas/__init__.py @@ -1,3 +1,4 @@ from .matmul import matmul from .matmul import matmul_lt +from .grouped_gemm import grouped_gemm from .origami import MatmulHeuristicResult diff --git a/include/tritonblas/grouped_gemm.py b/include/tritonblas/grouped_gemm.py new file mode 100644 index 0000000..d6c7208 --- /dev/null +++ b/include/tritonblas/grouped_gemm.py @@ -0,0 +1,91 @@ +import torch +import triton +import random +import functools +import time +import math +from .internal.wcc_grouped_gemm import wcc_groupgemm +from .origami import MatmulHeuristicResult + +_tensor_cache = {} +current_device_index = torch.cuda.current_device() +current_device = torch.cuda.get_device_properties(current_device_index) +MAX_SMS = current_device.multi_processor_count +#TODO: 256x256 for fp16/bf16, need adjust for fp8/fp4 +MAX_BLOCK_SIZE = 65536 + + + +def grouped_gemm( + group_a: list[torch.Tensor], + group_b: list[torch.Tensor], + group_c: list[torch.Tensor], + BLK_M: int, + BLK_N: int, + BLK_K: int, + ): + + group_size = len(group_a) + a_addrs, b_addrs, c_addrs = [], [], [] + g_sizes, g_lds = [], [] + + for i in range(group_size): + A, B, C = group_a[i], group_b[i], group_c[i] + assert A.shape[1] == B.shape[0], "Incompatible Dimensions" + m, k = A.shape + _, n = B.shape + a_addrs.append(A.data_ptr()) + b_addrs.append(B.data_ptr()) + c_addrs.append(C.data_ptr()) + g_sizes.extend([m, n, k]) + g_lds.extend([A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1)]) + + d_a_ptrs = torch.tensor(a_addrs, device="cuda", dtype=torch.int64) + d_b_ptrs = torch.tensor(b_addrs, device="cuda", dtype=torch.int64) + d_c_ptrs = torch.tensor(c_addrs, device="cuda", dtype=torch.int64) + d_g_sizes = torch.tensor(g_sizes, device="cuda", dtype=torch.int32) + d_g_lds = torch.tensor(g_lds, device="cuda", dtype=torch.int32) + + grids = MAX_SMS + locks = torch.zeros((MAX_SMS,), device="cuda", dtype=torch.int32) + P = torch.zeros((MAX_SMS, BLK_M * BLK_N), device="cuda", dtype=torch.float32) + + group_tiles_count = [] + total = 0 + for g in range(group_size): + mm = math.ceil(g_sizes[g * 3] / BLK_M) + nn = math.ceil(g_sizes[g * 3 + 1] / BLK_N) + kk = math.ceil(g_sizes[g * 3 + 2] / BLK_K) + gemm_tiles = nn * mm * kk + total += gemm_tiles + group_tiles_count.append(int(gemm_tiles)) + + gemm_offsets = [0] + for count in group_tiles_count: + gemm_offsets.append(gemm_offsets[-1] + count) + + group_total_tiles = total + streamk_tiles_pcu = group_total_tiles // MAX_SMS + streamk_remainder_tiles = group_total_tiles % MAX_SMS + d_gemm_offsets = torch.tensor(gemm_offsets, dtype=torch.int32, device="cuda") + + wcc_groupgemm[(grids,)]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_gemm_offsets, + d_g_lds, + group_size, + P, + locks, + streamk_tiles_pcu=streamk_tiles_pcu, + streamk_remainder_tiles=streamk_remainder_tiles, + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=1, + NUM_PRGMS=MAX_SMS, + NUM_XCDS=8, + ) + return group_c \ No newline at end of file diff --git a/include/tritonblas/internal/wcc_grouped_gemm.py b/include/tritonblas/internal/wcc_grouped_gemm.py new file mode 100644 index 0000000..487cf09 --- /dev/null +++ b/include/tritonblas/internal/wcc_grouped_gemm.py @@ -0,0 +1,385 @@ +import triton +import triton.language as tl +import triton.profiler.language as pl + + +""" +This is a user defined function, where partial results are being stored and accumulated + +Input: +start_index <- Uint : The starting index of "work_tile" that this program will process. Positive integer +end_index <- Uint : The last index(exclusive) of "work_tile" that this program will process. Positive integer +tile_id <- Uint : ID of the tile this program is processing. Positive integer +tile_offset <- Uint : The offset within the tile. [0, work_tile) +work_tile <- Uint : User provides the smallest chunk/tile of work. Positive interger +NUM_PRGMS <- Uint : Total number of programs that the kernel was launched with. Positive integer + + +The following inputs can be variadic arguments or arguments passed in an object by the user +However, they are pasased in as arguments here because triton does not support variadic arguments or objects +partials <- float : This holds the partials computed by the program so far +P <- tensor : A user defined tensor used to hold partial values +locks <- tensor : A user defined tensor used for book keeping of accumulated partials +n_cols <- uint : Used to compute the RMS value +BLOCK_SIZE_M <- uint : Used to compute the size of tensor to store +BLOCK_SIZE_N <- uint : Used to compute the size of tensor to store + +Output: +partials <- tensor : Now the accumulated tensor can be stored +""" + + +@triton.jit +def accumulate_partials( + pid, + start_index, + end_index, + tile_id, + tile_offset, + work_tile, + partials, + P, + locks, + NUM_PRGMS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + streamk_tiles_pcu: tl.constexpr, + streamk_remainder_tiles: tl.constexpr, +): + if tile_offset != 0: + rm1 = tl.arange(0, BLOCK_SIZE_M) + rn1 = tl.arange(0, BLOCK_SIZE_N) + rm1 = tl.max_contiguous(tl.multiple_of(rm1, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) + P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] + tl.store(P_, partials, cache_modifier=".wt") + tl.store(locks + pid, 1, cache_modifier=".wt") + # Only the pid processing the first tile does the reduction + else: + tile_iter = tile_id * work_tile + next_pid = pid + 1 + tile_iter_end = tile_iter + work_tile + end = end_index + while end < tile_iter_end and next_pid < NUM_PRGMS: + while tl.load(locks + next_pid, cache_modifier=".cv", volatile=True) != 1: + pass + rm1 = tl.arange(0, BLOCK_SIZE_M) + rn1 = tl.arange(0, BLOCK_SIZE_N) + rm1 = tl.max_contiguous(tl.multiple_of(rm1, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) + P_ = P + next_pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] + partials += tl.load(tl.multiple_of(P_, (1, 16)), cache_modifier=".cv") + + next_pid_start_index = next_pid * streamk_tiles_pcu + tl.minimum(next_pid, streamk_remainder_tiles) + next_pid_end_index = (next_pid + 1) * streamk_tiles_pcu + tl.minimum(next_pid + 1, streamk_remainder_tiles) + next_pid_curr_gemm_start = tl.maximum(next_pid_start_index, tile_iter) + next_pid_curr_gemm_end = tl.minimum(next_pid_end_index, tile_iter_end) + num_tiles = next_pid_curr_gemm_end - next_pid_curr_gemm_start + + if num_tiles > 0: + end += num_tiles + next_pid += 1 + + return partials + + +""" +The function calculates calculates indices needed for each iteration. Specifically it calculates the end index of this +iteration, tile id of the tile being processed, and the offset within that tile + +Input: +start_index <- Uint : The starting index of "work_tile" that this program will process. Positive integer +last_index <- Uint : The last index(exclusive) of "work_tile" that this program will process. Positive integer +work_tile <- Uint : User provides the smallest chunk/tile of work. Positive interger + +Output: +end_index <- Uint : The index of the last "atomic tile" of this iteration +tile_id <- Uint : Returns which tile this program is processing +tile_offset <- Uint : Returns the offset within a tile. [0, work_tile) +""" + + +@triton.jit +def per_iter_indices(start_index, last_index, work_tile): + tile_offset = start_index % work_tile + end_index = tl.minimum(start_index + (work_tile - tile_offset), last_index) + tile_id = start_index // work_tile + return (end_index, tile_id, tile_offset) + + +""" +Given the total streamk_tiles_pcu and streamk_remainder_tiles, the function computes the first and the last index of the work_iles +that the given pid will process. +Inherently the function is splitting the work evenly among the pids + +Input: +pid <- Uint : User provides the PID of the program. Positive interger in the range of [0:NUM_PRGMS) #Assuming 1D grid launch +streamk_tiles_pcu <- Uint : Total number of tiles per CU +streamk_remainder_tiles <- Uint : Remainder number of tiles +NUM_PRGMS <- Uint : Total number of programs that the kernel was launched with. It is needed to split the work. Positive integer + +Output: +start_index <- Uint : The starting index of "work_tile" that this program will process. Positive integer +last_index <- Uint : The last index(exclusive) of "work_tile" that this program will process. Positive integer +""" + + +@triton.jit +def work_split( + pid, + streamk_tiles_pcu: tl.constexpr, + streamk_remainder_tiles: tl.constexpr, + NUM_PRGMS: tl.constexpr, +): + start_index = pid * streamk_tiles_pcu + tl.minimum(pid, streamk_remainder_tiles) + last_index = (pid + 1) * streamk_tiles_pcu + tl.minimum(pid + 1, streamk_remainder_tiles) + return (start_index, last_index) + + +""" +Work Centric Grouped GEMM implementation + +Inputs: +group_a_ptrs <- pointer : A pointer which points to all the 'A' matrices +group_b_ptrs <- pointer : A pointer which points to all the 'C' matrices +group_c_ptrs <- pointer : A pointer which points to all the 'C' matrices +group_gemm_sizes <- pointer : A pointer which points to all the matrix sizes: [m, n, k] +gemm_offsets <- tensor : This tensor is essentially a inclusive prefix sum array of all + the linearized tiles of all the gemms. It is needed for each pid + to know which gemm should it be processing +g_lds <- pointer : A pointer which points to all the stride values for each matrix + [A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1)] +group_size <- uint : Total number of gemm we have +P <- tensor : A user defined tensor used to hold partial values +locks <- tensor : A user defined tensor used for book keeping of accumulated partials +streamk_tiles_pcu <- uint : Total number of tiles per CU +streamk_remainder_tiles <- uint : Remainder number of tiles +BLOCK_SIZE_M <- uint : Block size in the 'm' dimension +BLOCK_SIZE_N <- uint : Block size in the 'n' dimension +BLOCK_SIZE_K <- uint : Block size in the 'k' dimension +GROUP_SIZE_M <- uint : To assign work in a more lds friendly manner +NUM_PRGMS <- uint : The number of program the kernel was launched with +NUM_XCDS <- uint : Total number of XCDS in the hardware + +Output +group_c_ptrs <- pointer : A pointer which points to all the 'C' matrices (now populated) +""" + + +@triton.jit() +def wcc_groupgemm( + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + group_gemm_sizes, + gemm_offsets, + g_lds, + group_size, + P, + locks, + streamk_tiles_pcu: tl.constexpr, + streamk_remainder_tiles: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_PRGMS: tl.constexpr, + NUM_XCDS: tl.constexpr, +): + pid = tl.program_id(0) + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_PRGMS // NUM_XCDS) + (pid // NUM_XCDS) + + # Calculate the start and last tile this pid would be processing. This is linearized tiles + # for all the gemms + start_index, last_iter = work_split(pid, streamk_tiles_pcu, streamk_remainder_tiles, NUM_PRGMS) + for g in range(group_size): + # Check to see if this pid needs to process the "g th" gemm + g_val = tl.load(gemm_offsets + g + 1) + if start_index < g_val and start_index != last_iter: + # If it does, find the end of that gemm + last_index = tl.minimum(last_iter, g_val) + # Core loop + while start_index < last_index: + # Load in all the corresponding data for that gemm + M = tl.load(group_gemm_sizes + g * 3) + N = tl.load(group_gemm_sizes + g * 3 + 1) + K = tl.load(group_gemm_sizes + g * 3 + 2) + + A = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + B = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + C = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + + stride_am = tl.load(g_lds + g * 6) + stride_ak = tl.load(g_lds + g * 6 + 1) + stride_bk = tl.load(g_lds + g * 6 + 2) + stride_bn = tl.load(g_lds + g * 6 + 3) + stride_cm = tl.load(g_lds + g * 6 + 4) + stride_cn = tl.load(g_lds + g * 6 + 5) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + work_tile = tl.cdiv(K, BLOCK_SIZE_K) + + acc_dtype = tl.float32 # if C.type.element_ty != tl.int8 else tl.int32 + end_index, tile_id, tile_offset = per_iter_indices(start_index, last_index, work_tile) + + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + tl.assume(pid_m > 0) + tl.assume(pid_n > 0) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + rk = tl.arange(0, BLOCK_SIZE_K) + + """ + The following two lines, support all transpose types, however the triton compiler is unable to optimize + it, leading to 'short' loads rather than 'dwordx4' loads + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder + """ + A_BASE = A + rm[:, None] * stride_am + rk[None, :] + (BLOCK_SIZE_K * tile_offset) + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] + (BLOCK_SIZE_K * stride_bk * tile_offset) + partials = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for current_iter in range(start_index, end_index): + """ + The following masking logic is omitted because it leads to 'short' loads rather than 'dwordx4' loads + However, it has been tested and can be added back anytime + """ + # global_k_offset = (current_iter % work_tile) * BLOCK_SIZE_K + # mask = global_k_offset + rk < K + A_BASE = tl.multiple_of(A_BASE, (16, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 16)) + a = tl.load(A_BASE) + b = tl.load(B_BASE) + # do the actual gemm computation + partials += tl.dot(a, b) + # The following line has been omitted to make sure loads are 'dwordx4' + # A_BASE += BLOCK_SIZE_K * stride_ak + A_BASE += BLOCK_SIZE_K + B_BASE += BLOCK_SIZE_K * stride_bk + + # work_fixup() + partials = accumulate_partials( + pid, + start_index, + end_index, + tile_id, + tile_offset, + work_tile, + partials, + P, + locks, + NUM_PRGMS, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + streamk_tiles_pcu, + streamk_remainder_tiles, + ) + # Only the pid which does the first chunk of the tiles, stores the whole tile to output + if tile_offset == 0: + c = partials.to(C.type.element_ty) + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] + """ + The following two lines are omitted because they cause the stores to be 'short' rather than + 'dwordx2', however, they are tested and can be added back anytime + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + """ + C_ = tl.multiple_of(C_, (16, 16)) + tl.store(C_, c) + + start_index = end_index + + +""" +The following function was adapted from https://triton-lang.org/main/getting-started/tutorials/08-grouped-gemm.html +This function is used as a reference to check correctness of WCC grouped gemm. It only supports square matrices +and only one transpose type +""" + + +@triton.jit +def grouped_matmul_kernel( + # device tensor of matrices pointers + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + # device tensor of gemm sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + group_gemm_sizes, + # device tensor of leading dimension sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + g_lds, + # number of gemms + group_size, + # number of virtual SM + NUM_SM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + # get the gemm size of the current problem + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + # iterate through the tiles in the current gemm problem + while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: + # pick up a tile from the current gemm problem + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + # figure out tile coordinates + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + # do regular gemm here + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + # hint to Triton compiler to do proper loop pipelining + tl.multiple_of(a_ptrs, [16, 16]) + tl.multiple_of(b_ptrs, [16, 16]) + # assume full tile for now + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + + # assumes full tile for now + tl.store(c_ptrs, c) + + # go to the next tile by advancing NUM_SM + tile_idx += NUM_SM + + # get ready to go to the next gemm problem + last_problem_end = last_problem_end + num_tiles diff --git a/tests/test_grouped_gemm.py b/tests/test_grouped_gemm.py new file mode 100644 index 0000000..90b4484 --- /dev/null +++ b/tests/test_grouped_gemm.py @@ -0,0 +1,49 @@ +import pytest +import torch +import triton +import tritonblas + + +@pytest.mark.parametrize( + "m, n, k, total_programs_streamk, in_dtype, out_dtype", + [ + (8192, 8192, 8192, 304, torch.float16, torch.float16), + (4864, 8192, 4160, 304, torch.float16, torch.float16), + ], +) +@pytest.mark.parametrize( + "BLK_M, BLK_N, BLK_K", + [ + (256, 256, 64), + (128, 128, 64), + (256, 128, 64), + ], +) +@pytest.mark.parametrize("gsize_m", [1]) +@pytest.mark.parametrize("group_size", [1, 2, 4, 6, 8]) +def test_grouped_gemm(m, n, k, total_programs_streamk, in_dtype, out_dtype, BLK_M, BLK_N, BLK_K, gsize_m, group_size): + + group_A = [] + group_B = [] + group_C = [] + torch_result = [] + + for i in range(group_size): + A = torch.randn(m, k, device="cuda", dtype=in_dtype) + B = torch.randn(k, n, device="cuda", dtype=in_dtype) + C = torch.empty((m, n), device="cuda", dtype=out_dtype) + group_A.append(A) + group_B.append(B) + group_C.append(C) + torch_result.append(torch.matmul(A, B)) + + tritonblas.grouped_gemm( + group_A, + group_B, + group_C, + BLK_M, + BLK_N, + BLK_K, + ) + for i in range(group_size): + torch.testing.assert_close(torch_result[i], group_C[i], atol=0.5, rtol=0.5) From 04ffcca286308fa902ed046963d22713a65b3baf Mon Sep 17 00:00:00 2001 From: Muhammad Osama Date: Fri, 7 Nov 2025 08:12:14 -0800 Subject: [PATCH 2/3] Update include/tritonblas/grouped_gemm.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- include/tritonblas/grouped_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tritonblas/grouped_gemm.py b/include/tritonblas/grouped_gemm.py index d6c7208..deae766 100644 --- a/include/tritonblas/grouped_gemm.py +++ b/include/tritonblas/grouped_gemm.py @@ -16,7 +16,7 @@ -def grouped_gemm( +def grouped_gemm( group_a: list[torch.Tensor], group_b: list[torch.Tensor], group_c: list[torch.Tensor], From a1dc950e361c4a0db2fce34efc3f1dbfafca5ed9 Mon Sep 17 00:00:00 2001 From: Muhammad Osama Date: Fri, 7 Nov 2025 08:14:11 -0800 Subject: [PATCH 3/3] Update include/tritonblas/internal/wcc_grouped_gemm.py --- include/tritonblas/internal/wcc_grouped_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tritonblas/internal/wcc_grouped_gemm.py b/include/tritonblas/internal/wcc_grouped_gemm.py index 487cf09..cce1b66 100644 --- a/include/tritonblas/internal/wcc_grouped_gemm.py +++ b/include/tritonblas/internal/wcc_grouped_gemm.py @@ -226,8 +226,8 @@ def wcc_groupgemm( pid_m = tile_id // num_pid_n pid_n = tile_id % num_pid_n - tl.assume(pid_m > 0) - tl.assume(pid_n > 0) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N