From caf5db81699c4609e10180a075418c74a431e064 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Fri, 23 Jan 2026 14:41:15 +0800 Subject: [PATCH 1/7] change if statement position --- tilelang/tileop/gemm/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 86caee01f..a292f56d3 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -183,9 +183,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): NotImplementedError: If the instruction type is not supported ValueError: If the instruction type is unknown """ - if gemm_inst.is_sunmmio(): - return GemmSunmmio - elif gemm_inst.is_mma(): + if gemm_inst.is_mma(): if target_is_volta(target): return GemmMMASm70 return GemmMMA @@ -197,5 +195,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): return GemmMFMA elif gemm_inst.is_tcgen5mma(): raise NotImplementedError("TCGEN5MMA is not implemented") + elif gemm_inst.is_sunmmio(): + return GemmSunmmio else: raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}") From 7e346d4a1cef75d543577ad8fcaacb41731f9761 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Fri, 6 Feb 2026 19:26:28 +0800 Subject: [PATCH 2/7] gemm lower implementation --- src/op/builtin.cc | 5 + src/op/builtin.h | 77 +++++++++++ src/op/gemm.cc | 99 ++++++++++++++ ...test_tilelang_mesh_language_mma_sunmmio.py | 83 +++++++++++ tilelang/tileop/gemm/gemm_sunmmio.py | 129 +++++------------- tilelang/utils/language.py | 3 + 6 files changed, 303 insertions(+), 93 deletions(-) create mode 100644 testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py diff --git a/src/op/builtin.cc b/src/op/builtin.cc index ed5f3067a..265fa1c42 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -112,6 +112,11 @@ TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr( TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(mma_sunmmio) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(create_tma_descriptor) .set_num_inputs(-1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 6709f7511..1d068a307 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -279,6 +279,83 @@ TVM_DLL const Op &dma_load(); */ TVM_DLL const Op &dma_store(); +/*! + * \brief tvm intrinsic for mma operation of Sunmmio target. + * + * \param A_region_shape + * Region shape of A operand + * + * \param A_dtype + * Dtype of A operand + * + * \param A_input_size + * Input_shape and of A operand. Input_size + forward = layout. + * + * \param A_forward + * Forward of A operand. Input_size + forward = layout. + * + * \param A_scope + * Scope of A operand. + * + * \param A_coord + * Coord of A operand. + * + * \param A_addr + * Addr of A operand. + * + * \param B_region_shape + * Region shape of B operand + * + * \param B_dtype + * Dtype of B operand + * + * \param B_input_size + * Input_shape and of B operand. Input_size + forward = layout. + * + * \param B_forward + * Forward of B operand. Input_size + forward = layout. + * + * \param B_scope + * Scope of B operand. + * + * \param B_coord + * Coord of B operand. + * + * \param B_addr + * Addr of B operand. + * + * \param C_region_shape + * Region shape of C operand + * + * \param C_dtype + * Dtype of C operand + * + * \param C_input_size + * Input_shape and of C operand. Input_size + forward = layout. + * + * \param C_forward + * Forward of C operand. Input_size + forward = layout. + * + * \param C_scope + * Scope of C operand. + * + * \param C_coord + * Coord of C operand. + * + * \param C_addr + * Addr of C operand. + * + * \param trans_A + * Whether to transpose A operand. + * + * \param trans_B + * Whether to transpose B operand. + * + * \param clear_accum + * Whether to clear accmulation buffer. + */ +TVM_DLL const Op &mma_sunmmio(); + /*! * \brief tvm intrinsics for loading image from global tensor to columns in * shared memory diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 8f4450d5f..c5ba63d1b 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -150,6 +150,10 @@ std::pair GemmWarpPolicyNode::computeWarpPartition( return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning } + if (gemm_inst == GemmInst::kSunmmioMMA) { + return {1, num_warps}; // kSunmmioMMA doesn't care about warp partitioning + } + int m_warp = 1, n_warp = 1; constexpr int kMPerWarp = 16; // Rows processed by a single warp int kNPerWarp = 8; // Columns processed by a single warp @@ -522,6 +526,101 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } + if (gemm_inst == GemmInst::kSunmmioMMA) { + Array args; + + { + for (auto it : aRegion_->region) { + args.push_back(it->extent); + } + args.push_back(StringImm(tvm::runtime::DLDataTypeToString(a_->dtype))); + + ICHECK(T.layout_map.count(a_)) + << "Layout of buffer " << a_ << " not found."; + auto layout = T.layout_map.at(a_); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } + args.push_back(StringImm(a_.scope())); + for (auto it : aRegion_->region) { + args.push_back(it->min); + } + PrimExpr total_elements = 1; + for (auto e : a_->shape) { + total_elements *= e; + } + auto addr = a_.access_ptr(1, DataType::Handle(), 1, 0, total_elements); + args.push_back(addr); + } + + { + for (auto it : bRegion_->region) { + args.push_back(it->extent); + } + args.push_back(StringImm(tvm::runtime::DLDataTypeToString(b_->dtype))); + + ICHECK(T.layout_map.count(b_)) + << "Layout of buffer " << b_ << " not found."; + auto layout = T.layout_map.at(b_); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } + args.push_back(StringImm(b_.scope())); + for (auto it : bRegion_->region) { + args.push_back(it->min); + } + PrimExpr total_elements = 1; + for (auto e : b_->shape) { + total_elements *= e; + } + auto addr = b_.access_ptr(1, DataType::Handle(), 1, 0, total_elements); + args.push_back(addr); + } + + { + for (auto it : cRegion_->region) { + args.push_back(it->extent); + } + args.push_back(StringImm(tvm::runtime::DLDataTypeToString(c_->dtype))); + + ICHECK(T.layout_map.count(c_)) + << "Layout of buffer " << c_ << " not found."; + auto layout = T.layout_map.at(c_); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } + args.push_back(StringImm(c_.scope())); + for (auto it : cRegion_->region) { + args.push_back(it->min); + } + PrimExpr total_elements = 1; + for (auto e : c_->shape) { + total_elements *= e; + } + auto addr = c_.access_ptr(2, DataType::Handle(), 1, 0, total_elements); + args.push_back(addr); + } + + args.push_back(Bool(transA_)); + args.push_back(Bool(transB_)); + args.push_back(clearAccum_); + + auto op = mma_sunmmio(); + Stmt mma_sunmmio; + mma_sunmmio = Evaluate(Call(DataType::Handle(), op, args)); + + return mma_sunmmio; + } + if (a_.scope() == "local.fragment") { ICHECK(b_.scope() != "local.fragment"); ICHECK(!transA_) diff --git a/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py new file mode 100644 index 000000000..85f98095d --- /dev/null +++ b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py @@ -0,0 +1,83 @@ +import tilelang +import pytest +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target +import tilelang as tl +import tilelang.language as T + +tilelang.env.disable_cache() + + +def layout_func(i, j, continuous): + return (i // 32 * (continuous // 32) + j // 32) * 32 * 32 + i % 32 * 32 + j % 32 + + +def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_dtype=T.float32): + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + + # T.clear(C_shared) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + if version == 1: + T.gemm_v1( + A_shared[0:8, 16:32], + B_shared[0:16, 8:16], + C_shared[8:24, 16:32], + transpose_A=True, + transpose_B=True) + elif version == 2: + T.gemm_v2( + A_shared[0:8, 16:32], + B_shared[0:16, 8:16], + C_shared[8:24, 16:32], + transpose_A=True, + transpose_B=True) + else: + raise ValueError(f'unsupported gemm version: {version}') + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return tvm.IRModule({'main': main}) + + +stmts = [ + "T.mma_sunmmio(8, 16, \"float16\", 32, 32, _i * 32 + _j, \"shared.asram\", 0, 16, T.tvm_access_ptr(T.type_annotation(\"float16\"), A_shared.data, 0, 1024, 1), 16, 8, \"float16\", 32, 32, _i * 32 + _j, \"shared.wsram\", 0, 8, T.tvm_access_ptr(T.type_annotation(\"float16\"), B_shared.data, 0, 1024, 1), 16, 16, \"float32\", 32, 32, _i * 32 + _j, \"shared.rsram\", 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 1024, 2), T.bool(True), T.bool(True), T.bool(False))", + "T.mma_sunmmio(8, 16, \"float16\", 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", 0, 16, T.tvm_access_ptr(T.type_annotation(\"float16\"), A_shared.data, 0, 4096, 1), 16, 8, \"float16\", 64, 32, _i * 32 + _j, \"shared.wsram\", 0, 8, T.tvm_access_ptr(T.type_annotation(\"float16\"), B_shared.data, 0, 2048, 1), 16, 16, \"float32\", 64, 32, _i * 32 + _j, \"shared.rsram\", 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 2048, 2), T.bool(True), T.bool(True), T.bool(False))" +] +TEST_CASES = [ + # gemm v1 + (128, 128, 128, 32, 32, 32, 1, stmts[0]), + (128, 128, 128, 64, 32, 64, 1, stmts[1]), + # # gemm v2 + (128, 128, 128, 32, 32, 32, 2, stmts[0]), + (128, 128, 128, 64, 32, 64, 2, stmts[1]), +] + + +@pytest.mark.parametrize( + "M, N, K, block_M, block_N, block_K, version, lower_stmt", + TEST_CASES, +) +def test_tilelang_gemm_sunmmio_layout(M, N, K, block_M, block_N, block_K, version, lower_stmt): + target_name = "Sunmmio" + target = determine_target(target_name, return_object=True) + with tvm.target.Target(target): + mod = matmul(M, N, K, block_M, block_N, block_K, version) + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.InferSramScope()(mod) + mod = tl.transform.LayoutInference()(mod) + mod = tl.transform.LowerTileOp()(mod) + texts = mod.script().split('\n') + text = texts[-6].lstrip() + assert text == lower_stmt diff --git a/tilelang/tileop/gemm/gemm_sunmmio.py b/tilelang/tileop/gemm/gemm_sunmmio.py index 0254bed7f..94ad3eae4 100644 --- a/tilelang/tileop/gemm/gemm_sunmmio.py +++ b/tilelang/tileop/gemm/gemm_sunmmio.py @@ -1,13 +1,12 @@ from .gemm_base import GemmBase from tilelang.layout import make_blockwise_zz_layout -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) -from tilelang.utils.language import is_shared, is_full_region +from tilelang.utils.language import is_shared from tilelang import tvm as tvm from tvm.target import Target from tvm import tir -from tilelang import language as T from tilelang.transform.simplify import _Simplify +from tilelang import language as T +from tilelang.utils.language import retrieve_ptr class GemmSunmmio(GemmBase): @@ -25,100 +24,44 @@ def infer_layout(self, target: Target, thread_nums: int): ) def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - # TODO: lower not implemented - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) - warp_row_tiles = int(self.M // m_warp) - warp_col_tiles = int(self.N // n_warp) - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=self.in_dtype, - b_dtype=self.in_dtype, - accum_dtype=self.accum_dtype, - a_transposed=self.trans_A, - b_transposed=self.trans_B, - block_row_warps=m_warp, - block_col_warps=n_warp, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=self.chunk, - thread_var=thread_var, - ) - - in_dtype = self.in_dtype - warp_rows = mma_emitter.warp_rows - warp_cols = mma_emitter.warp_cols - local_size_a = mma_emitter.local_size_a - local_size_b = mma_emitter.local_size_b - block_K = mma_emitter.chunk - micro_size_k = mma_emitter.micro_size_k - # We use region for memory input to support strided gemm - # T.gemm(A_shared[0:128, :], B_shared, C_local) - A_region = self.ARegion - B_region = self.BRegion - C_region = self.CRegion - - A_buf = A_region.buffer - B_buf = B_region.buffer - C_buf = C_region.buffer - - clear_accum = self.clear_accum - - assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" - - assert is_full_region(C_region), "Fragment output C must be a full region" - if self.is_gemm_sss(): + args = [] + + def add_info(args, region): + for it in region.region: + args.append(it.extent) + args.append(region.buffer.dtype) + layout = layout_map[region.buffer] + for it in layout.input_size: + args.append(it) + for it in layout.forward_index: + args.append(it) + args.append(region.buffer.scope()) + for it in region.region: + args.append(it.min) + if region != self.CRegion: + args.append(retrieve_ptr(region.buffer, access_type="r")) + else: + args.append(retrieve_ptr(region.buffer, access_type="w")) + + add_info(args, self.ARegion) + add_info(args, self.BRegion) + add_info(args, self.CRegion) + + args.append(self.trans_A) + args.append(self.trans_B) + args.append(self.clear_accum) @T.prim_func - def _gemm_ssr() -> None: - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - if clear_accum: - T.clear(C_buf) - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_region, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_region, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_buf, ki) - - # Simplify to optimize the index computing - # Must inline let statements to simplify the analysis - return _Simplify(_gemm_ssr, inline_let=True) - assert is_full_region(A_region), "Fragment input A must be a full region" - assert is_full_region(B_region), "Fragment input B must be a full region" - - @T.prim_func - def _gemm_rrr() -> None: - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ + def _gemm_sss() -> None: + tir.call_intrin( + "handle", + tir.op.Op.get("tl.mma_sunmmio"), + *args, + ) - for ki in T.serial(0, (block_K // micro_size_k)): - # Perform Matrix Multiplication - mma_emitter.mma(A_buf, B_buf, C_buf, ki) + return _Simplify(_gemm_sss, inline_let=True) - # Simplify to optimize the index computing - # Must inline let statements to simplify the analysis - return _Simplify(_gemm_rrr, inline_let=True) else: raise ValueError( f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 41da8ab0a..fb671f06c 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -55,6 +55,9 @@ def is_shared(buffer: Buffer | BufferLoad | BufferRegion, allow_dynamic: bool = buffer = _get_buffer(buffer) conditions = [False] conditions.append(buffer.scope() == "shared") + conditions.append(buffer.scope() == "shared.asram") + conditions.append(buffer.scope() == "shared.wsram") + conditions.append(buffer.scope() == "shared.rsram") if allow_dynamic: conditions.append(is_shared_dynamic(buffer)) return any(conditions) From ae0b81e9a59452bceda268469fee727dc9ba680e Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Fri, 6 Feb 2026 21:31:57 +0800 Subject: [PATCH 3/7] add rank for mma --- src/op/builtin.h | 9 +++++++++ src/op/gemm.cc | 3 +++ .../language/test_tilelang_mesh_language_mma_sunmmio.py | 4 ++-- tilelang/tileop/gemm/gemm_sunmmio.py | 1 + 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/op/builtin.h b/src/op/builtin.h index 1d068a307..84e2f6aa5 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -282,6 +282,9 @@ TVM_DLL const Op &dma_store(); /*! * \brief tvm intrinsic for mma operation of Sunmmio target. * + * \param A_rank + * Rank (number of dimensions) of A operand + * * \param A_region_shape * Region shape of A operand * @@ -303,6 +306,9 @@ TVM_DLL const Op &dma_store(); * \param A_addr * Addr of A operand. * + * \param B_rank + * Rank (number of dimensions) of B operand + * * \param B_region_shape * Region shape of B operand * @@ -324,6 +330,9 @@ TVM_DLL const Op &dma_store(); * \param B_addr * Addr of B operand. * + * \param C_rank + * Rank (number of dimensions) of C operand + * * \param C_region_shape * Region shape of C operand * diff --git a/src/op/gemm.cc b/src/op/gemm.cc index c5ba63d1b..5330e6f62 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -530,6 +530,7 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Array args; { + args.push_back(static_cast(a_->shape.size())); for (auto it : aRegion_->region) { args.push_back(it->extent); } @@ -557,6 +558,7 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } { + args.push_back(static_cast(b_->shape.size())); for (auto it : bRegion_->region) { args.push_back(it->extent); } @@ -584,6 +586,7 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } { + args.push_back(static_cast(c_->shape.size())); for (auto it : cRegion_->region) { args.push_back(it->extent); } diff --git a/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py index 85f98095d..72a526b91 100644 --- a/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py +++ b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py @@ -52,8 +52,8 @@ def main( stmts = [ - "T.mma_sunmmio(8, 16, \"float16\", 32, 32, _i * 32 + _j, \"shared.asram\", 0, 16, T.tvm_access_ptr(T.type_annotation(\"float16\"), A_shared.data, 0, 1024, 1), 16, 8, \"float16\", 32, 32, _i * 32 + _j, \"shared.wsram\", 0, 8, T.tvm_access_ptr(T.type_annotation(\"float16\"), B_shared.data, 0, 1024, 1), 16, 16, \"float32\", 32, 32, _i * 32 + _j, \"shared.rsram\", 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 1024, 2), T.bool(True), T.bool(True), T.bool(False))", - "T.mma_sunmmio(8, 16, \"float16\", 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", 0, 16, T.tvm_access_ptr(T.type_annotation(\"float16\"), A_shared.data, 0, 4096, 1), 16, 8, \"float16\", 64, 32, _i * 32 + _j, \"shared.wsram\", 0, 8, T.tvm_access_ptr(T.type_annotation(\"float16\"), B_shared.data, 0, 2048, 1), 16, 16, \"float32\", 64, 32, _i * 32 + _j, \"shared.rsram\", 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 2048, 2), T.bool(True), T.bool(True), T.bool(False))" + "T.mma_sunmmio(2, 8, 16, \"float16\", 32, 32, _i * 32 + _j, \"shared.asram\", 0, 16, T.tvm_access_ptr(T.type_annotation(\"float16\"), A_shared.data, 0, 1024, 1), 2, 16, 8, \"float16\", 32, 32, _i * 32 + _j, \"shared.wsram\", 0, 8, T.tvm_access_ptr(T.type_annotation(\"float16\"), B_shared.data, 0, 1024, 1), 2, 16, 16, \"float32\", 32, 32, _i * 32 + _j, \"shared.rsram\", 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 1024, 2), T.bool(True), T.bool(True), T.bool(False))", + "T.mma_sunmmio(2, 8, 16, \"float16\", 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", 0, 16, T.tvm_access_ptr(T.type_annotation(\"float16\"), A_shared.data, 0, 4096, 1), 2, 16, 8, \"float16\", 64, 32, _i * 32 + _j, \"shared.wsram\", 0, 8, T.tvm_access_ptr(T.type_annotation(\"float16\"), B_shared.data, 0, 2048, 1), 2, 16, 16, \"float32\", 64, 32, _i * 32 + _j, \"shared.rsram\", 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 2048, 2), T.bool(True), T.bool(True), T.bool(False))" ] TEST_CASES = [ # gemm v1 diff --git a/tilelang/tileop/gemm/gemm_sunmmio.py b/tilelang/tileop/gemm/gemm_sunmmio.py index 94ad3eae4..2378df58e 100644 --- a/tilelang/tileop/gemm/gemm_sunmmio.py +++ b/tilelang/tileop/gemm/gemm_sunmmio.py @@ -28,6 +28,7 @@ def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: args = [] def add_info(args, region): + args.append(len(region.buffer.shape)) for it in region.region: args.append(it.extent) args.append(region.buffer.dtype) From 56d8b1dacf39c844e26f191e37a20603a3395ea0 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Tue, 10 Feb 2026 15:22:25 +0800 Subject: [PATCH 4/7] add scope check for mma sunmmio --- src/op/gemm.cc | 18 ++++++++++++------ tilelang/tileop/gemm/gemm_sunmmio.py | 6 ++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 5330e6f62..44ad9982e 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -527,6 +527,12 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } if (gemm_inst == GemmInst::kSunmmioMMA) { + ICHECK(a_.scope() == "shared.asram") + << "Invalid scope of buffer " << a_ << " in SunmmioMMA."; + ICHECK(b_.scope() == "shared.wsram") + << "Invalid scope of buffer " << b_ << " in SunmmioMMA."; + ICHECK(c_.scope() == "shared.rsram") + << "Invalid scope of buffer " << c_ << " in SunmmioMMA."; Array args; { @@ -912,12 +918,12 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ICHECK(0); } } else if (gemm_inst == GemmInst::kSunmmioMMA) { - ICHECK((((std::string)a_.scope()).compare(0, 6, "shared") == 0)) - << "Sunmmio Gemm only supports A in shared scope, got " << a_.scope(); - ICHECK((((std::string)b_.scope()).compare(0, 6, "shared") == 0)) - << "Sunmmio Gemm only supports B in shared scope, got " << b_.scope(); - ICHECK((((std::string)c_.scope()).compare(0, 6, "shared") == 0)) - << "Sunmmio Gemm only supports C in shared scope, got " << c_.scope(); + ICHECK(a_.scope() == "shared.asram") + << "Invalid scope of buffer " << a_ << " in SunmmioMMA."; + ICHECK(b_.scope() == "shared.wsram") + << "Invalid scope of buffer " << b_ << " in SunmmioMMA."; + ICHECK(c_.scope() == "shared.rsram") + << "Invalid scope of buffer " << c_ << " in SunmmioMMA."; const auto f = ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); diff --git a/tilelang/tileop/gemm/gemm_sunmmio.py b/tilelang/tileop/gemm/gemm_sunmmio.py index 2378df58e..cae9c3034 100644 --- a/tilelang/tileop/gemm/gemm_sunmmio.py +++ b/tilelang/tileop/gemm/gemm_sunmmio.py @@ -12,6 +12,9 @@ class GemmSunmmio(GemmBase): def infer_layout(self, target: Target, thread_nums: int): + assert self.A.scope() == 'shared.asram' + assert self.B.scope() == 'shared.wsram' + assert self.C.scope() == 'shared.rsram' if self.is_gemm_sss(): return { self.A: make_blockwise_zz_layout(self.A), @@ -24,6 +27,9 @@ def infer_layout(self, target: Target, thread_nums: int): ) def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + assert self.A.scope() == 'shared.asram' + assert self.B.scope() == 'shared.wsram' + assert self.C.scope() == 'shared.rsram' if self.is_gemm_sss(): args = [] From 06c31285ce767153d2c1ea51f877ced9a1c0954e Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Wed, 11 Feb 2026 17:54:52 +0800 Subject: [PATCH 5/7] fix parameters to region --- src/op/builtin.h | 75 ++------------- src/op/gemm.cc | 96 ++----------------- ...test_tilelang_mesh_language_mma_sunmmio.py | 8 +- tilelang/tileop/gemm/gemm_sunmmio.py | 40 +++----- 4 files changed, 30 insertions(+), 189 deletions(-) diff --git a/src/op/builtin.h b/src/op/builtin.h index 6e415e455..78dd808a1 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -205,77 +205,14 @@ TVM_DLL const Op &dma_copy(); /*! * \brief tvm intrinsic for mma operation of Sunmmio target. * - * \param A_rank - * Rank (number of dimensions) of A operand + * \param A_region + * A tl.tileop.region PrimExpr describing A. * - * \param A_region_shape - * Region shape of A operand + * \param B_region + * A tl.tileop.region PrimExpr describing B. * - * \param A_dtype - * Dtype of A operand - * - * \param A_input_size - * Input_shape and of A operand. Input_size + forward = layout. - * - * \param A_forward - * Forward of A operand. Input_size + forward = layout. - * - * \param A_scope - * Scope of A operand. - * - * \param A_coord - * Coord of A operand. - * - * \param A_addr - * Addr of A operand. - * - * \param B_rank - * Rank (number of dimensions) of B operand - * - * \param B_region_shape - * Region shape of B operand - * - * \param B_dtype - * Dtype of B operand - * - * \param B_input_size - * Input_shape and of B operand. Input_size + forward = layout. - * - * \param B_forward - * Forward of B operand. Input_size + forward = layout. - * - * \param B_scope - * Scope of B operand. - * - * \param B_coord - * Coord of B operand. - * - * \param B_addr - * Addr of B operand. - * - * \param C_rank - * Rank (number of dimensions) of C operand - * - * \param C_region_shape - * Region shape of C operand - * - * \param C_dtype - * Dtype of C operand - * - * \param C_input_size - * Input_shape and of C operand. Input_size + forward = layout. - * - * \param C_forward - * Forward of C operand. Input_size + forward = layout. - * - * \param C_scope - * Scope of C operand. - * - * \param C_coord - * Coord of C operand. - * - * \param C_addr - * Addr of C operand. + * \param C_region + * A tl.tileop.region PrimExpr describing C. * * \param trans_A * Whether to transpose A operand. diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 44ad9982e..ca7e9d943 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -533,95 +533,15 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { << "Invalid scope of buffer " << b_ << " in SunmmioMMA."; ICHECK(c_.scope() == "shared.rsram") << "Invalid scope of buffer " << c_ << " in SunmmioMMA."; - Array args; - { - args.push_back(static_cast(a_->shape.size())); - for (auto it : aRegion_->region) { - args.push_back(it->extent); - } - args.push_back(StringImm(tvm::runtime::DLDataTypeToString(a_->dtype))); - - ICHECK(T.layout_map.count(a_)) - << "Layout of buffer " << a_ << " not found."; - auto layout = T.layout_map.at(a_); - for (auto s : layout->InputShape()) { - args.push_back(s); - } - for (auto s : layout->GetForwardIndex()) { - args.push_back(s); - } - args.push_back(StringImm(a_.scope())); - for (auto it : aRegion_->region) { - args.push_back(it->min); - } - PrimExpr total_elements = 1; - for (auto e : a_->shape) { - total_elements *= e; - } - auto addr = a_.access_ptr(1, DataType::Handle(), 1, 0, total_elements); - args.push_back(addr); - } - - { - args.push_back(static_cast(b_->shape.size())); - for (auto it : bRegion_->region) { - args.push_back(it->extent); - } - args.push_back(StringImm(tvm::runtime::DLDataTypeToString(b_->dtype))); - - ICHECK(T.layout_map.count(b_)) - << "Layout of buffer " << b_ << " not found."; - auto layout = T.layout_map.at(b_); - for (auto s : layout->InputShape()) { - args.push_back(s); - } - for (auto s : layout->GetForwardIndex()) { - args.push_back(s); - } - args.push_back(StringImm(b_.scope())); - for (auto it : bRegion_->region) { - args.push_back(it->min); - } - PrimExpr total_elements = 1; - for (auto e : b_->shape) { - total_elements *= e; - } - auto addr = b_.access_ptr(1, DataType::Handle(), 1, 0, total_elements); - args.push_back(addr); - } - - { - args.push_back(static_cast(c_->shape.size())); - for (auto it : cRegion_->region) { - args.push_back(it->extent); - } - args.push_back(StringImm(tvm::runtime::DLDataTypeToString(c_->dtype))); - - ICHECK(T.layout_map.count(c_)) - << "Layout of buffer " << c_ << " not found."; - auto layout = T.layout_map.at(c_); - for (auto s : layout->InputShape()) { - args.push_back(s); - } - for (auto s : layout->GetForwardIndex()) { - args.push_back(s); - } - args.push_back(StringImm(c_.scope())); - for (auto it : cRegion_->region) { - args.push_back(it->min); - } - PrimExpr total_elements = 1; - for (auto e : c_->shape) { - total_elements *= e; - } - auto addr = c_.access_ptr(2, DataType::Handle(), 1, 0, total_elements); - args.push_back(addr); - } - - args.push_back(Bool(transA_)); - args.push_back(Bool(transB_)); - args.push_back(clearAccum_); + PrimExpr A_region = + MakeRegionExpr(aRegion_->buffer, aRegion_->region, /*access_mask=*/1); + PrimExpr B_region = + MakeRegionExpr(bRegion_->buffer, bRegion_->region, /*access_mask=*/1); + PrimExpr C_region = + MakeRegionExpr(cRegion_->buffer, cRegion_->region, /*access_mask=*/3); + Array args = {A_region, B_region, C_region, + Bool(transA_), Bool(transB_), clearAccum_}; auto op = mma_sunmmio(); Stmt mma_sunmmio; diff --git a/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py index 72a526b91..ef86b4078 100644 --- a/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py +++ b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py @@ -12,7 +12,7 @@ def layout_func(i, j, continuous): return (i // 32 * (continuous // 32) + j // 32) * 32 * 32 + i % 32 * 32 + j % 32 -def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_dtype=T.float32): +def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_dtype=T.float16): @T.prim_func def main( @@ -52,8 +52,8 @@ def main( stmts = [ - "T.mma_sunmmio(2, 8, 16, \"float16\", 32, 32, _i * 32 + _j, \"shared.asram\", 0, 16, T.tvm_access_ptr(T.type_annotation(\"float16\"), A_shared.data, 0, 1024, 1), 2, 16, 8, \"float16\", 32, 32, _i * 32 + _j, \"shared.wsram\", 0, 8, T.tvm_access_ptr(T.type_annotation(\"float16\"), B_shared.data, 0, 1024, 1), 2, 16, 16, \"float32\", 32, 32, _i * 32 + _j, \"shared.rsram\", 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 1024, 2), T.bool(True), T.bool(True), T.bool(False))", - "T.mma_sunmmio(2, 8, 16, \"float16\", 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", 0, 16, T.tvm_access_ptr(T.type_annotation(\"float16\"), A_shared.data, 0, 4096, 1), 2, 16, 8, \"float16\", 64, 32, _i * 32 + _j, \"shared.wsram\", 0, 8, T.tvm_access_ptr(T.type_annotation(\"float16\"), B_shared.data, 0, 2048, 1), 2, 16, 16, \"float32\", 64, 32, _i * 32 + _j, \"shared.rsram\", 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 2048, 2), T.bool(True), T.bool(True), T.bool(False))" + "T.mma_sunmmio(T.region(A_shared[0, 16], 1, 8, 16), T.region(B_shared[0, 8], 1, 16, 8), T.region(C_shared[8, 16], 3, 16, 16), T.bool(True), T.bool(True), T.bool(False))", + "T.mma_sunmmio(T.region(A_shared[0, 16], 1, 8, 16), T.region(B_shared[0, 8], 1, 16, 8), T.region(C_shared[8, 16], 3, 16, 16), T.bool(True), T.bool(True), T.bool(False))" ] TEST_CASES = [ # gemm v1 @@ -79,5 +79,5 @@ def test_tilelang_gemm_sunmmio_layout(M, N, K, block_M, block_N, block_K, versio mod = tl.transform.LayoutInference()(mod) mod = tl.transform.LowerTileOp()(mod) texts = mod.script().split('\n') - text = texts[-6].lstrip() + text = texts[-2].lstrip() assert text == lower_stmt diff --git a/tilelang/tileop/gemm/gemm_sunmmio.py b/tilelang/tileop/gemm/gemm_sunmmio.py index cae9c3034..63ffd50c5 100644 --- a/tilelang/tileop/gemm/gemm_sunmmio.py +++ b/tilelang/tileop/gemm/gemm_sunmmio.py @@ -6,7 +6,10 @@ from tvm import tir from tilelang.transform.simplify import _Simplify from tilelang import language as T -from tilelang.utils.language import retrieve_ptr +from tilelang.utils.language import ( + retrieve_shape,) +from tilelang.language.utils import ( + buffer_region_to_tile_region,) class GemmSunmmio(GemmBase): @@ -31,33 +34,14 @@ def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: assert self.B.scope() == 'shared.wsram' assert self.C.scope() == 'shared.rsram' if self.is_gemm_sss(): - args = [] - - def add_info(args, region): - args.append(len(region.buffer.shape)) - for it in region.region: - args.append(it.extent) - args.append(region.buffer.dtype) - layout = layout_map[region.buffer] - for it in layout.input_size: - args.append(it) - for it in layout.forward_index: - args.append(it) - args.append(region.buffer.scope()) - for it in region.region: - args.append(it.min) - if region != self.CRegion: - args.append(retrieve_ptr(region.buffer, access_type="r")) - else: - args.append(retrieve_ptr(region.buffer, access_type="w")) - - add_info(args, self.ARegion) - add_info(args, self.BRegion) - add_info(args, self.CRegion) - - args.append(self.trans_A) - args.append(self.trans_B) - args.append(self.clear_accum) + A_shape = retrieve_shape(self.ARegion) + B_shape = retrieve_shape(self.BRegion) + C_shape = retrieve_shape(self.CRegion) + A_arg = buffer_region_to_tile_region(self.ARegion, "r", [r for r in A_shape]) + B_arg = buffer_region_to_tile_region(self.BRegion, "r", [r for r in B_shape]) + C_arg = buffer_region_to_tile_region(self.CRegion, "rw", [r for r in C_shape]) + + args = [A_arg, B_arg, C_arg, self.trans_A, self.trans_B, self.clear_accum] @T.prim_func def _gemm_sss() -> None: From 0d8d1bc7f7af9fdee069da1b68be9c2005fd52a6 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Wed, 11 Feb 2026 18:05:42 +0800 Subject: [PATCH 6/7] fix data type in test cases --- .../language/test_tilelang_mesh_language_mma_sunmmio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py index ef86b4078..53a4b0a71 100644 --- a/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py +++ b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py @@ -12,13 +12,13 @@ def layout_func(i, j, continuous): return (i // 32 * (continuous // 32) + j // 32) * 32 * 32 + i % 32 * 32 + j % 32 -def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_dtype=T.float16): +def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) From 7d7cc55159b0818976018a60d09380efc90d42d7 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Fri, 13 Feb 2026 12:48:44 +0800 Subject: [PATCH 7/7] fix according to comments --- src/op/gemm.cc | 49 ++++++++++++++-------------- tilelang/tileop/gemm/gemm_sunmmio.py | 23 ++++++------- 2 files changed, 35 insertions(+), 37 deletions(-) diff --git a/src/op/gemm.cc b/src/op/gemm.cc index ca7e9d943..5c963cd12 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -443,6 +443,31 @@ static int GetArchInt(Target target) { Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto block_size = *as_const_int(T.thread_bounds->extent); GemmInst gemm_inst = getGemmInst(block_size, T.target); + + if (gemm_inst == GemmInst::kSunmmioMMA) { + ICHECK(a_.scope() == "shared.asram") + << "Invalid scope of buffer " << a_ << " in SunmmioMMA."; + ICHECK(b_.scope() == "shared.wsram") + << "Invalid scope of buffer " << b_ << " in SunmmioMMA."; + ICHECK(c_.scope() == "shared.rsram") + << "Invalid scope of buffer " << c_ << " in SunmmioMMA."; + + PrimExpr A_region = + MakeRegionExpr(aRegion_->buffer, aRegion_->region, /*access_mask=*/1); + PrimExpr B_region = + MakeRegionExpr(bRegion_->buffer, bRegion_->region, /*access_mask=*/1); + PrimExpr C_region = + MakeRegionExpr(cRegion_->buffer, cRegion_->region, /*access_mask=*/3); + Array args = {A_region, B_region, C_region, + Bool(transA_), Bool(transB_), clearAccum_}; + + auto op = mma_sunmmio(); + Stmt mma_sunmmio; + mma_sunmmio = Evaluate(Call(DataType::Handle(), op, args)); + + return mma_sunmmio; + } + auto [warp_m, warp_n] = policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); @@ -526,30 +551,6 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } - if (gemm_inst == GemmInst::kSunmmioMMA) { - ICHECK(a_.scope() == "shared.asram") - << "Invalid scope of buffer " << a_ << " in SunmmioMMA."; - ICHECK(b_.scope() == "shared.wsram") - << "Invalid scope of buffer " << b_ << " in SunmmioMMA."; - ICHECK(c_.scope() == "shared.rsram") - << "Invalid scope of buffer " << c_ << " in SunmmioMMA."; - - PrimExpr A_region = - MakeRegionExpr(aRegion_->buffer, aRegion_->region, /*access_mask=*/1); - PrimExpr B_region = - MakeRegionExpr(bRegion_->buffer, bRegion_->region, /*access_mask=*/1); - PrimExpr C_region = - MakeRegionExpr(cRegion_->buffer, cRegion_->region, /*access_mask=*/3); - Array args = {A_region, B_region, C_region, - Bool(transA_), Bool(transB_), clearAccum_}; - - auto op = mma_sunmmio(); - Stmt mma_sunmmio; - mma_sunmmio = Evaluate(Call(DataType::Handle(), op, args)); - - return mma_sunmmio; - } - if (a_.scope() == "local.fragment") { ICHECK(b_.scope() != "local.fragment"); ICHECK(!transA_) diff --git a/tilelang/tileop/gemm/gemm_sunmmio.py b/tilelang/tileop/gemm/gemm_sunmmio.py index 63ffd50c5..59cf598f4 100644 --- a/tilelang/tileop/gemm/gemm_sunmmio.py +++ b/tilelang/tileop/gemm/gemm_sunmmio.py @@ -1,6 +1,5 @@ from .gemm_base import GemmBase from tilelang.layout import make_blockwise_zz_layout -from tilelang.utils.language import is_shared from tilelang import tvm as tvm from tvm.target import Target from tvm import tir @@ -15,10 +14,7 @@ class GemmSunmmio(GemmBase): def infer_layout(self, target: Target, thread_nums: int): - assert self.A.scope() == 'shared.asram' - assert self.B.scope() == 'shared.wsram' - assert self.C.scope() == 'shared.rsram' - if self.is_gemm_sss(): + if self.is_gemm_sunmmio_scope(): return { self.A: make_blockwise_zz_layout(self.A), self.B: make_blockwise_zz_layout(self.B), @@ -26,14 +22,12 @@ def infer_layout(self, target: Target, thread_nums: int): } else: raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" + f"Unsupported gemm combination of Sunmmio, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" ) def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - assert self.A.scope() == 'shared.asram' - assert self.B.scope() == 'shared.wsram' - assert self.C.scope() == 'shared.rsram' - if self.is_gemm_sss(): + + if self.is_gemm_sunmmio_scope(): A_shape = retrieve_shape(self.ARegion) B_shape = retrieve_shape(self.BRegion) C_shape = retrieve_shape(self.CRegion) @@ -55,8 +49,11 @@ def _gemm_sss() -> None: else: raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" + f"Unsupported gemm combination of Sunmmio, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" ) - def is_gemm_sss(self) -> bool: - return is_shared(self.A) and is_shared(self.B) and is_shared(self.C) + def is_gemm_sunmmio_scope(self) -> bool: + a_check = self.A.scope() == 'shared.asram' + b_check = self.B.scope() == 'shared.wsram' + c_check = self.C.scope() == 'shared.rsram' + return a_check and b_check and c_check