diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 058c395a8..8131c97bd 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -109,6 +109,11 @@ TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(mma_sunmmio) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(create_tma_descriptor) .set_num_inputs(-1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 3779a8562..78dd808a1 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -202,6 +202,29 @@ TVM_DLL const Op &tma_load(); */ TVM_DLL const Op &dma_copy(); +/*! + * \brief tvm intrinsic for mma operation of Sunmmio target. + * + * \param A_region + * A tl.tileop.region PrimExpr describing A. + * + * \param B_region + * A tl.tileop.region PrimExpr describing B. + * + * \param C_region + * A tl.tileop.region PrimExpr describing C. + * + * \param trans_A + * Whether to transpose A operand. + * + * \param trans_B + * Whether to transpose B operand. + * + * \param clear_accum + * Whether to clear accmulation buffer. + */ +TVM_DLL const Op &mma_sunmmio(); + /*! * \brief tvm intrinsics for loading image from global tensor to columns in * shared memory diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 8f4450d5f..5c963cd12 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -150,6 +150,10 @@ std::pair GemmWarpPolicyNode::computeWarpPartition( return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning } + if (gemm_inst == GemmInst::kSunmmioMMA) { + return {1, num_warps}; // kSunmmioMMA doesn't care about warp partitioning + } + int m_warp = 1, n_warp = 1; constexpr int kMPerWarp = 16; // Rows processed by a single warp int kNPerWarp = 8; // Columns processed by a single warp @@ -439,6 +443,31 @@ static int GetArchInt(Target target) { Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto block_size = *as_const_int(T.thread_bounds->extent); GemmInst gemm_inst = getGemmInst(block_size, T.target); + + if (gemm_inst == GemmInst::kSunmmioMMA) { + ICHECK(a_.scope() == "shared.asram") + << "Invalid scope of buffer " << a_ << " in SunmmioMMA."; + ICHECK(b_.scope() == "shared.wsram") + << "Invalid scope of buffer " << b_ << " in SunmmioMMA."; + ICHECK(c_.scope() == "shared.rsram") + << "Invalid scope of buffer " << c_ << " in SunmmioMMA."; + + PrimExpr A_region = + MakeRegionExpr(aRegion_->buffer, aRegion_->region, /*access_mask=*/1); + PrimExpr B_region = + MakeRegionExpr(bRegion_->buffer, bRegion_->region, /*access_mask=*/1); + PrimExpr C_region = + MakeRegionExpr(cRegion_->buffer, cRegion_->region, /*access_mask=*/3); + Array args = {A_region, B_region, C_region, + Bool(transA_), Bool(transB_), clearAccum_}; + + auto op = mma_sunmmio(); + Stmt mma_sunmmio; + mma_sunmmio = Evaluate(Call(DataType::Handle(), op, args)); + + return mma_sunmmio; + } + auto [warp_m, warp_n] = policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); @@ -810,12 +839,12 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ICHECK(0); } } else if (gemm_inst == GemmInst::kSunmmioMMA) { - ICHECK((((std::string)a_.scope()).compare(0, 6, "shared") == 0)) - << "Sunmmio Gemm only supports A in shared scope, got " << a_.scope(); - ICHECK((((std::string)b_.scope()).compare(0, 6, "shared") == 0)) - << "Sunmmio Gemm only supports B in shared scope, got " << b_.scope(); - ICHECK((((std::string)c_.scope()).compare(0, 6, "shared") == 0)) - << "Sunmmio Gemm only supports C in shared scope, got " << c_.scope(); + ICHECK(a_.scope() == "shared.asram") + << "Invalid scope of buffer " << a_ << " in SunmmioMMA."; + ICHECK(b_.scope() == "shared.wsram") + << "Invalid scope of buffer " << b_ << " in SunmmioMMA."; + ICHECK(c_.scope() == "shared.rsram") + << "Invalid scope of buffer " << c_ << " in SunmmioMMA."; const auto f = ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); diff --git a/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py new file mode 100644 index 000000000..53a4b0a71 --- /dev/null +++ b/testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py @@ -0,0 +1,83 @@ +import tilelang +import pytest +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target +import tilelang as tl +import tilelang.language as T + +tilelang.env.disable_cache() + + +def layout_func(i, j, continuous): + return (i // 32 * (continuous // 32) + j // 32) * 32 * 32 + i % 32 * 32 + j % 32 + + +def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_dtype=T.float32): + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + + # T.clear(C_shared) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + if version == 1: + T.gemm_v1( + A_shared[0:8, 16:32], + B_shared[0:16, 8:16], + C_shared[8:24, 16:32], + transpose_A=True, + transpose_B=True) + elif version == 2: + T.gemm_v2( + A_shared[0:8, 16:32], + B_shared[0:16, 8:16], + C_shared[8:24, 16:32], + transpose_A=True, + transpose_B=True) + else: + raise ValueError(f'unsupported gemm version: {version}') + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return tvm.IRModule({'main': main}) + + +stmts = [ + "T.mma_sunmmio(T.region(A_shared[0, 16], 1, 8, 16), T.region(B_shared[0, 8], 1, 16, 8), T.region(C_shared[8, 16], 3, 16, 16), T.bool(True), T.bool(True), T.bool(False))", + "T.mma_sunmmio(T.region(A_shared[0, 16], 1, 8, 16), T.region(B_shared[0, 8], 1, 16, 8), T.region(C_shared[8, 16], 3, 16, 16), T.bool(True), T.bool(True), T.bool(False))" +] +TEST_CASES = [ + # gemm v1 + (128, 128, 128, 32, 32, 32, 1, stmts[0]), + (128, 128, 128, 64, 32, 64, 1, stmts[1]), + # # gemm v2 + (128, 128, 128, 32, 32, 32, 2, stmts[0]), + (128, 128, 128, 64, 32, 64, 2, stmts[1]), +] + + +@pytest.mark.parametrize( + "M, N, K, block_M, block_N, block_K, version, lower_stmt", + TEST_CASES, +) +def test_tilelang_gemm_sunmmio_layout(M, N, K, block_M, block_N, block_K, version, lower_stmt): + target_name = "Sunmmio" + target = determine_target(target_name, return_object=True) + with tvm.target.Target(target): + mod = matmul(M, N, K, block_M, block_N, block_K, version) + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.InferSramScope()(mod) + mod = tl.transform.LayoutInference()(mod) + mod = tl.transform.LowerTileOp()(mod) + texts = mod.script().split('\n') + text = texts[-2].lstrip() + assert text == lower_stmt diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 86caee01f..a292f56d3 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -183,9 +183,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): NotImplementedError: If the instruction type is not supported ValueError: If the instruction type is unknown """ - if gemm_inst.is_sunmmio(): - return GemmSunmmio - elif gemm_inst.is_mma(): + if gemm_inst.is_mma(): if target_is_volta(target): return GemmMMASm70 return GemmMMA @@ -197,5 +195,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): return GemmMFMA elif gemm_inst.is_tcgen5mma(): raise NotImplementedError("TCGEN5MMA is not implemented") + elif gemm_inst.is_sunmmio(): + return GemmSunmmio else: raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}") diff --git a/tilelang/tileop/gemm/gemm_sunmmio.py b/tilelang/tileop/gemm/gemm_sunmmio.py index 0254bed7f..59cf598f4 100644 --- a/tilelang/tileop/gemm/gemm_sunmmio.py +++ b/tilelang/tileop/gemm/gemm_sunmmio.py @@ -1,19 +1,20 @@ from .gemm_base import GemmBase from tilelang.layout import make_blockwise_zz_layout -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter,) -from tilelang.utils.language import is_shared, is_full_region from tilelang import tvm as tvm from tvm.target import Target from tvm import tir -from tilelang import language as T from tilelang.transform.simplify import _Simplify +from tilelang import language as T +from tilelang.utils.language import ( + retrieve_shape,) +from tilelang.language.utils import ( + buffer_region_to_tile_region,) class GemmSunmmio(GemmBase): def infer_layout(self, target: Target, thread_nums: int): - if self.is_gemm_sss(): + if self.is_gemm_sunmmio_scope(): return { self.A: make_blockwise_zz_layout(self.A), self.B: make_blockwise_zz_layout(self.B), @@ -21,108 +22,38 @@ def infer_layout(self, target: Target, thread_nums: int): } else: raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" + f"Unsupported gemm combination of Sunmmio, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" ) def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): - # TODO: lower not implemented - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, - False) - warp_row_tiles = int(self.M // m_warp) - warp_col_tiles = int(self.N // n_warp) - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=self.in_dtype, - b_dtype=self.in_dtype, - accum_dtype=self.accum_dtype, - a_transposed=self.trans_A, - b_transposed=self.trans_B, - block_row_warps=m_warp, - block_col_warps=n_warp, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=self.chunk, - thread_var=thread_var, - ) - - in_dtype = self.in_dtype - warp_rows = mma_emitter.warp_rows - warp_cols = mma_emitter.warp_cols - local_size_a = mma_emitter.local_size_a - local_size_b = mma_emitter.local_size_b - block_K = mma_emitter.chunk - micro_size_k = mma_emitter.micro_size_k - # We use region for memory input to support strided gemm - # T.gemm(A_shared[0:128, :], B_shared, C_local) - A_region = self.ARegion - B_region = self.BRegion - C_region = self.CRegion - - A_buf = A_region.buffer - B_buf = B_region.buffer - C_buf = C_region.buffer - - clear_accum = self.clear_accum - - assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" - - assert is_full_region(C_region), "Fragment output C must be a full region" - - if self.is_gemm_sss(): - - @T.prim_func - def _gemm_ssr() -> None: - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - if clear_accum: - T.clear(C_buf) - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_region, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_region, - ki, - ) - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_buf, ki) + if self.is_gemm_sunmmio_scope(): + A_shape = retrieve_shape(self.ARegion) + B_shape = retrieve_shape(self.BRegion) + C_shape = retrieve_shape(self.CRegion) + A_arg = buffer_region_to_tile_region(self.ARegion, "r", [r for r in A_shape]) + B_arg = buffer_region_to_tile_region(self.BRegion, "r", [r for r in B_shape]) + C_arg = buffer_region_to_tile_region(self.CRegion, "rw", [r for r in C_shape]) - # Simplify to optimize the index computing - # Must inline let statements to simplify the analysis - return _Simplify(_gemm_ssr, inline_let=True) - assert is_full_region(A_region), "Fragment input A must be a full region" - assert is_full_region(B_region), "Fragment input B must be a full region" + args = [A_arg, B_arg, C_arg, self.trans_A, self.trans_B, self.clear_accum] @T.prim_func - def _gemm_rrr() -> None: - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ + def _gemm_sss() -> None: + tir.call_intrin( + "handle", + tir.op.Op.get("tl.mma_sunmmio"), + *args, + ) - for ki in T.serial(0, (block_K // micro_size_k)): - # Perform Matrix Multiplication - mma_emitter.mma(A_buf, B_buf, C_buf, ki) + return _Simplify(_gemm_sss, inline_let=True) - # Simplify to optimize the index computing - # Must inline let statements to simplify the analysis - return _Simplify(_gemm_rrr, inline_let=True) else: raise ValueError( - f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" + f"Unsupported gemm combination of Sunmmio, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}" ) - def is_gemm_sss(self) -> bool: - return is_shared(self.A) and is_shared(self.B) and is_shared(self.C) + def is_gemm_sunmmio_scope(self) -> bool: + a_check = self.A.scope() == 'shared.asram' + b_check = self.B.scope() == 'shared.wsram' + c_check = self.C.scope() == 'shared.rsram' + return a_check and b_check and c_check