forked from tile-ai/tilelang
-
Notifications
You must be signed in to change notification settings - Fork 6
[Feature Request] [PASS] Lower GEMM to intrinsic #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
JiaqiGuoSunlune
merged 14 commits into
SUNMMIO:tilelang_mesh_main
from
wanghz18:gemmlower
Feb 13, 2026
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
caf5db8
change if statement position
wanghz18 e483cd0
Merge branch 'tilelang_mesh_main' of github.com:SUNMMIO/Tilelang-Mesh…
wanghz18 1463fa3
Merge branch 'SUNMMIO:tilelang_mesh_main' into gemmlower
wanghz18 5394525
Merge branch 'tilelang_mesh_main' of github.com:SUNMMIO/Tilelang-Mesh…
wanghz18 5ce5b7b
Merge branch 'SUNMMIO:tilelang_mesh_main' into gemmlower
wanghz18 6632bbf
Merge branch 'gemmlower' of github.com:wanghz18/Tilelang-Mesh into ge…
wanghz18 7e346d4
gemm lower implementation
wanghz18 ae0b81e
add rank for mma
wanghz18 3072e17
Merge branch 'SUNMMIO:tilelang_mesh_main' into gemmlower
wanghz18 56d8b1d
add scope check for mma sunmmio
wanghz18 36f1b31
Merge branch 'tilelang_mesh_main' of github.com:SUNMMIO/Tilelang-Mesh…
wanghz18 06c3128
fix parameters to region
wanghz18 0d8d1bc
fix data type in test cases
wanghz18 7d7cc55
fix according to comments
wanghz18 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| import tilelang | ||
| import pytest | ||
| from tilelang import tvm as tvm | ||
| from tilelang.utils.target import determine_target | ||
| import tilelang as tl | ||
| import tilelang.language as T | ||
|
|
||
| tilelang.env.disable_cache() | ||
|
|
||
|
|
||
| def layout_func(i, j, continuous): | ||
| return (i // 32 * (continuous // 32) + j // 32) * 32 * 32 + i % 32 * 32 + j % 32 | ||
|
|
||
|
|
||
| def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_dtype=T.float32): | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((M, K), dtype), | ||
| B: T.Tensor((K, N), dtype), | ||
| C: T.Tensor((M, N), accum_dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_K), dtype) | ||
| B_shared = T.alloc_shared((block_K, block_N), dtype) | ||
| C_shared = T.alloc_shared((block_M, block_N), accum_dtype) | ||
|
|
||
| # T.clear(C_shared) | ||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): | ||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||
| T.copy(B[k * block_K, bx * block_N], B_shared) | ||
| if version == 1: | ||
| T.gemm_v1( | ||
| A_shared[0:8, 16:32], | ||
| B_shared[0:16, 8:16], | ||
| C_shared[8:24, 16:32], | ||
| transpose_A=True, | ||
| transpose_B=True) | ||
| elif version == 2: | ||
| T.gemm_v2( | ||
| A_shared[0:8, 16:32], | ||
| B_shared[0:16, 8:16], | ||
| C_shared[8:24, 16:32], | ||
| transpose_A=True, | ||
| transpose_B=True) | ||
| else: | ||
| raise ValueError(f'unsupported gemm version: {version}') | ||
|
|
||
| T.copy(C_shared, C[by * block_M, bx * block_N]) | ||
|
|
||
| return tvm.IRModule({'main': main}) | ||
|
|
||
|
|
||
| stmts = [ | ||
| "T.mma_sunmmio(T.region(A_shared[0, 16], 1, 8, 16), T.region(B_shared[0, 8], 1, 16, 8), T.region(C_shared[8, 16], 3, 16, 16), T.bool(True), T.bool(True), T.bool(False))", | ||
| "T.mma_sunmmio(T.region(A_shared[0, 16], 1, 8, 16), T.region(B_shared[0, 8], 1, 16, 8), T.region(C_shared[8, 16], 3, 16, 16), T.bool(True), T.bool(True), T.bool(False))" | ||
| ] | ||
| TEST_CASES = [ | ||
| # gemm v1 | ||
| (128, 128, 128, 32, 32, 32, 1, stmts[0]), | ||
| (128, 128, 128, 64, 32, 64, 1, stmts[1]), | ||
| # # gemm v2 | ||
| (128, 128, 128, 32, 32, 32, 2, stmts[0]), | ||
| (128, 128, 128, 64, 32, 64, 2, stmts[1]), | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "M, N, K, block_M, block_N, block_K, version, lower_stmt", | ||
| TEST_CASES, | ||
| ) | ||
| def test_tilelang_gemm_sunmmio_layout(M, N, K, block_M, block_N, block_K, version, lower_stmt): | ||
| target_name = "Sunmmio" | ||
| target = determine_target(target_name, return_object=True) | ||
| with tvm.target.Target(target): | ||
| mod = matmul(M, N, K, block_M, block_N, block_K, version) | ||
| mod = tvm.tir.transform.BindTarget(target)(mod) | ||
| mod = tilelang.transform.InferSramScope()(mod) | ||
| mod = tl.transform.LayoutInference()(mod) | ||
| mod = tl.transform.LowerTileOp()(mod) | ||
| texts = mod.script().split('\n') | ||
| text = texts[-2].lstrip() | ||
| assert text == lower_stmt |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,128 +1,59 @@ | ||
| from .gemm_base import GemmBase | ||
| from tilelang.layout import make_blockwise_zz_layout | ||
| from tilelang.intrinsics.mma_macro_generator import ( | ||
| TensorCoreIntrinEmitter,) | ||
| from tilelang.utils.language import is_shared, is_full_region | ||
| from tilelang import tvm as tvm | ||
| from tvm.target import Target | ||
| from tvm import tir | ||
| from tilelang import language as T | ||
| from tilelang.transform.simplify import _Simplify | ||
| from tilelang import language as T | ||
| from tilelang.utils.language import ( | ||
| retrieve_shape,) | ||
| from tilelang.language.utils import ( | ||
| buffer_region_to_tile_region,) | ||
|
|
||
|
|
||
| class GemmSunmmio(GemmBase): | ||
|
|
||
| def infer_layout(self, target: Target, thread_nums: int): | ||
| if self.is_gemm_sss(): | ||
| if self.is_gemm_sunmmio_scope(): | ||
| return { | ||
| self.A: make_blockwise_zz_layout(self.A), | ||
| self.B: make_blockwise_zz_layout(self.B), | ||
| self.C: make_blockwise_zz_layout(self.C), | ||
| } | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" | ||
| f"Unsupported gemm combination of Sunmmio, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" | ||
| ) | ||
|
|
||
| def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): | ||
| # TODO: lower not implemented | ||
| m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, | ||
| False) | ||
| warp_row_tiles = int(self.M // m_warp) | ||
| warp_col_tiles = int(self.N // n_warp) | ||
| mma_emitter = TensorCoreIntrinEmitter( | ||
| a_dtype=self.in_dtype, | ||
| b_dtype=self.in_dtype, | ||
| accum_dtype=self.accum_dtype, | ||
| a_transposed=self.trans_A, | ||
| b_transposed=self.trans_B, | ||
| block_row_warps=m_warp, | ||
| block_col_warps=n_warp, | ||
| warp_row_tiles=warp_row_tiles, | ||
| warp_col_tiles=warp_col_tiles, | ||
| chunk=self.chunk, | ||
| thread_var=thread_var, | ||
| ) | ||
|
|
||
| in_dtype = self.in_dtype | ||
| warp_rows = mma_emitter.warp_rows | ||
| warp_cols = mma_emitter.warp_cols | ||
| local_size_a = mma_emitter.local_size_a | ||
| local_size_b = mma_emitter.local_size_b | ||
| block_K = mma_emitter.chunk | ||
| micro_size_k = mma_emitter.micro_size_k | ||
| # We use region for memory input to support strided gemm | ||
| # T.gemm(A_shared[0:128, :], B_shared, C_local) | ||
| A_region = self.ARegion | ||
| B_region = self.BRegion | ||
| C_region = self.CRegion | ||
|
|
||
| A_buf = A_region.buffer | ||
| B_buf = B_region.buffer | ||
| C_buf = C_region.buffer | ||
|
|
||
| clear_accum = self.clear_accum | ||
|
|
||
| assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" | ||
|
|
||
| assert is_full_region(C_region), "Fragment output C must be a full region" | ||
|
|
||
| if self.is_gemm_sss(): | ||
|
|
||
| @T.prim_func | ||
| def _gemm_ssr() -> None: | ||
| """ | ||
| The inner macro that loads data from shared buffers A_shared and | ||
| B_shared into local fragments, then issues Tensor Core mma ops, | ||
| accumulating into C_local. | ||
| """ | ||
| A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) | ||
| B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) | ||
| if clear_accum: | ||
| T.clear(C_buf) | ||
| for ki in T.serial(0, (block_K // micro_size_k)): | ||
| # Load A into fragment | ||
| mma_emitter.ldmatrix_a( | ||
| A_local, | ||
| A_region, | ||
| ki, | ||
| ) | ||
|
|
||
| # Load B into fragment | ||
| mma_emitter.ldmatrix_b( | ||
| B_local, | ||
| B_region, | ||
| ki, | ||
| ) | ||
|
|
||
| # Perform Matrix Multiplication | ||
| mma_emitter.mma(A_local, B_local, C_buf, ki) | ||
| if self.is_gemm_sunmmio_scope(): | ||
| A_shape = retrieve_shape(self.ARegion) | ||
| B_shape = retrieve_shape(self.BRegion) | ||
| C_shape = retrieve_shape(self.CRegion) | ||
| A_arg = buffer_region_to_tile_region(self.ARegion, "r", [r for r in A_shape]) | ||
| B_arg = buffer_region_to_tile_region(self.BRegion, "r", [r for r in B_shape]) | ||
| C_arg = buffer_region_to_tile_region(self.CRegion, "rw", [r for r in C_shape]) | ||
|
|
||
| # Simplify to optimize the index computing | ||
| # Must inline let statements to simplify the analysis | ||
| return _Simplify(_gemm_ssr, inline_let=True) | ||
| assert is_full_region(A_region), "Fragment input A must be a full region" | ||
| assert is_full_region(B_region), "Fragment input B must be a full region" | ||
| args = [A_arg, B_arg, C_arg, self.trans_A, self.trans_B, self.clear_accum] | ||
|
|
||
| @T.prim_func | ||
| def _gemm_rrr() -> None: | ||
| """ | ||
| The inner macro that loads data from shared buffers A_shared and | ||
| B_shared into local fragments, then issues Tensor Core mma ops, | ||
| accumulating into C_local. | ||
| """ | ||
| def _gemm_sss() -> None: | ||
| tir.call_intrin( | ||
| "handle", | ||
| tir.op.Op.get("tl.mma_sunmmio"), | ||
| *args, | ||
| ) | ||
|
|
||
| for ki in T.serial(0, (block_K // micro_size_k)): | ||
| # Perform Matrix Multiplication | ||
| mma_emitter.mma(A_buf, B_buf, C_buf, ki) | ||
| return _Simplify(_gemm_sss, inline_let=True) | ||
|
|
||
| # Simplify to optimize the index computing | ||
| # Must inline let statements to simplify the analysis | ||
| return _Simplify(_gemm_rrr, inline_let=True) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" | ||
| f"Unsupported gemm combination of Sunmmio, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" | ||
| ) | ||
|
|
||
| def is_gemm_sss(self) -> bool: | ||
| return is_shared(self.A) and is_shared(self.B) and is_shared(self.C) | ||
| def is_gemm_sunmmio_scope(self) -> bool: | ||
| a_check = self.A.scope() == 'shared.asram' | ||
| b_check = self.B.scope() == 'shared.wsram' | ||
| c_check = self.C.scope() == 'shared.rsram' | ||
| return a_check and b_check and c_check |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't support trans_A in Sunmmio. OK to keep it right now