diff --git a/src/op/builtin.cc b/src/op/builtin.cc index ed5f3067a..058c395a8 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -106,10 +106,7 @@ TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); - -TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr( +TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(create_tma_descriptor) diff --git a/src/op/builtin.h b/src/op/builtin.h index 6709f7511..de9d9ab0e 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -189,7 +189,7 @@ TVM_DLL const Op &get_mbarrier(); TVM_DLL const Op &tma_load(); /*! - * \brief Perform a DMA load operation from source memory to destination memory. + * \brief Perform a DMA copy operation from source memory to destination memory. * * This function describes a DMA-based tensor copy with explicit shape, layout, * memory scope. It is typically used to lower a high-level @@ -211,38 +211,42 @@ TVM_DLL const Op &tma_load(); * A[32:64, 128:192, 0:256] * then: * src_rank = 3 - * src_shape = [128, 256, 512] + * src_region_shape = [32, 64, 256] * coord = [32, 128, 0] * + * \param src_scope + * Memory scope of the source tensor. + * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". + * + * \param dst_scope + * Memory scope of the destination tensor. + * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". + * * \param data_type * Element data type of the tensor (e.g. float32, float16). * * \param src_rank * Rank (number of dimensions) of the source tensor. * - * \param src_shape - * Logical shape of the source tensor. - * For example, Tensor(128, 256, 512) -> [128, 256, 512]. + * \param src_region_shape + * Logical shape of the source buffer region. + * For example, A[32:64, 128:192, 0:256] -> [32, 64, 256]. * * \param src_input_size - * Input shape of the source layout, retrievable via Layout::getInputShape(). + * Input shape of the source layout, retrievable via LayoutNode::InputShape(). * For a row-major 3D tensor, this is identical to src_shape. * * \param src_forward * Forward index mapping of the source layout, retrievable via - * Layout::GetForwardIndex(). + * LayoutNode::GetForwardIndex(). * For a row-major layout of Tensor(128, 256, 512), * this is [256 * 512, 512, 1]. * - * \param src_scope - * Memory scope of the source tensor. - * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". - * * \param dst_rank * Rank (number of dimensions) of the destination tensor. * - * \param dst_shape - * Logical shape of the destination tensor. + * \param dst_region_shape + * Logical shape of the destination buffer region. * * \param dst_input_size * Input shape of the destination layout, retrievable via @@ -252,32 +256,24 @@ TVM_DLL const Op &tma_load(); * Forward index mapping of the destination layout, retrievable via * Layout::GetForwardIndex(). * - * \param dst_scope - * Memory scope of the destination tensor. - * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". - * * \param src_addr * Base address of the source tensor in memory . * - * \param coord + * \param src_coord * Coordinate offset specifying the starting point of the copy in the source * tensor. Its length must equal src_rank. * * \param dst_addr * Base address of the destination tensor in memory . * + * \param dst_coord + * Coordinate offset specifying the starting point of the copy in the + * destination tensor. Its length must equal dst_rank. + * * \note * Out-of-bound fill policies are currently not supported. */ -TVM_DLL const Op &dma_load(); - -/*! - * \brief Perform a DMA store operation from source memory to destination - * memory. see dma_load for details. - * - * - */ -TVM_DLL const Op &dma_store(); +TVM_DLL const Op &dma_copy(); /*! * \brief tvm intrinsics for loading image from global tensor to columns in diff --git a/src/op/copy.cc b/src/op/copy.cc index 22cb4a074..66104bd87 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -564,6 +564,28 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } return {}; } + + if (copy_inst == CopyInst::kDMACopy) { + // for dma copy, we can directly apply the blockwise_zz_layout + const auto f = + ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); + auto result = Map(); + + if (level == InferLevel::kFree && !T.layout_map.count(src)) { + if (src.scope() != "global") { + auto layout = Downcast((*f)(src)); + result.Set(src, layout); + } + } + + if (level == InferLevel::kFree && !T.layout_map.count(dst)) { + if (dst.scope() != "global") { + auto layout = Downcast((*f)(dst)); + result.Set(dst, layout); + } + } + return result; + } // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy // Use parallel op to infer the layout @@ -573,6 +595,67 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } return par_op_->InferLayout(T, level); } + +/** + * @brief Determine whether this CopyNode can be lowered to a DMA Load + * instruction. + * + * The function returns true when all of the following hold: + * - the target architecture advertises DMA support; + * - the source buffer and the destination buffer are legal; + * - the source and destination have the same element data type. + * + * If the source and destination dtypes differ, a warning is logged and the + * function returns false (the caller is expected to fall back to a normal + * copy). + * + * + * @param target The compilation target to query for dma load support. + * @return true if the copy can be implemented as a DMA Load; false + * otherwise. + */ +bool CopyNode::CheckDMACopy(Target target, arith::Analyzer *analyzer, + bool check_last_dim) const { + // 1. arch must support Sunmmio + if (!TargetIsSunmmio(target)) + return false; + + // 2. src and dst must be legal + bool scope_check = false; + // 2.1 DRAM -> RSRAM + if (src.scope() == "global" && dst.scope() == "shared.rsram") + scope_check = true; + // 2.2 DRAM -> WSRAM + if (src.scope() == "global" && dst.scope() == "shared.wsram") + scope_check = true; + // 2.3 DRAM -> ASRAM + if (src.scope() == "global" && dst.scope() == "shared.asram") + scope_check = true; + // 2.4 RSRAM -> WSRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.wsram") + scope_check = true; + // 2.5 RSRAM -> ASRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.asram") + scope_check = true; + // 2.6 RSRAM <-> RSRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.rsram") + scope_check = true; + // 2.7 RSRAM -> DRAM + if (src.scope() == "shared.rsram" && dst.scope() == "global") + scope_check = true; + if (!scope_check) + return false; + + // 3. src and dst must have the same dtype + if (src->dtype != dst->dtype) { + LOG(WARNING) << "src and dst must have the same dtype for dma copy " + << src->name << " vs. " << dst->name << " dtype " << src->dtype + << " vs. " << dst->dtype << " will be fallback to normal copy"; + return false; + } + return true; +} + /** * @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA) * instruction. @@ -848,6 +931,12 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, return CopyInst::kTMemLoad; } else if (CheckTMemStore(target)) { return CopyInst::kTMemStore; + } else if (TargetIsSunmmio(target)) { + auto is_copy = CheckDMACopy(target, analyzer); + if (is_copy) + return CopyInst::kDMACopy; + ICHECK(0) << "Unsupported copy from " << src.scope() << " to " + << dst.scope() << " of Sunmmio target."; } else { return CopyInst::kNormal; } @@ -860,6 +949,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, * determined copy instruction type: * - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions * - LDSM/STSM: Uses matrix load/store instructions for tensor cores + * - DMA Load/Store: Sunmmio specified instructions for copy * - Normal: Uses standard load/store operations with loop transformations * \param T LowerArgs containing target and layout map. * \param analyzer Arithmetic analyzer for simplification. @@ -874,6 +964,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, T.layout_map, analyzer); + if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { auto tmem_copy = LowerTmemCopy(T, analyzer); ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy"; @@ -894,11 +985,141 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return ldsm_copy; } else if (copy_inst == CopyInst::kNormal) { return LowerNormalCopy(T, analyzer); + } else if (copy_inst == CopyInst::kDMACopy) { + auto dma_copy = LowerDMACopy(T, analyzer, copy_inst); + ICHECK(dma_copy.defined()) << "Failed to lower dma load/store"; + return dma_copy; } else { LOG(FATAL) << "Unsupported copy inst " << static_cast(copy_inst); } } +/** + * @brief Lower a Copy operator to a DMA transfer. + * + * Lowers the copy to an optimized DMA load or store when the target and buffer + * layouts permit. + * + * @param T LowerArgs containing target information, thread/bounds variables, + * and layout/ buffer remap information + * construction. + * @param analyzer Analyzer used to prove shapes/contiguity/equality + * constraints. + * @param copy_inst Indicates whether to emit a DMA load or DMA store. Must be + * CopyInst::kDMALoad or kDMAStore. + * @return Stmt A TIR statement performing the DMA copy. + */ +Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const { + ICHECK(copy_inst == CopyInst::kDMACopy) + << "Invalid copy inst " << static_cast(copy_inst); + + Array args; + // \param src_scope + auto src_scope = StringImm(src.scope()); + args.push_back(src_scope); + // \param dst_scope + auto dst_scope = StringImm(dst.scope()); + args.push_back(dst_scope); + // \param data_type + args.push_back(to_CUtensorMapDataType(src->dtype)); + + // \param src_rank + args.push_back(static_cast(src->shape.size())); + // \param src_region_shape + for (auto r : src_range) { + args.push_back(r->extent); + } + // \param src_input_size & \param src_forward + if (src.scope() == "global") { + ICHECK(T.global_layout_map.count(src)) + << "Layout of buffer " << src << " not found."; + auto layout = T.global_layout_map.at(src); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } + } else { + ICHECK(T.layout_map.count(src)) + << "Layout of buffer " << src << " not found."; + auto layout = T.layout_map.at(src); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } + } + + // \param dst_rank + args.push_back(static_cast(dst->shape.size())); + // \param dst_region_shape + for (auto r : dst_range) { + args.push_back(r->extent); + } + // \param dst_input_size & \param dst_forward + if (dst.scope() == "global") { + ICHECK(T.global_layout_map.count(dst)) + << "Layout of buffer " << dst << " not found."; + auto layout = T.global_layout_map.at(dst); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } + } else { + ICHECK(T.layout_map.count(dst)) + << "Layout of buffer " << dst << " not found."; + auto layout = T.layout_map.at(dst); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } + } + + // \param src_addr + if (src.scope() == "global") { + args.push_back(src->data); + } else { + PrimExpr total_elements = 1; + for (auto e : src->shape) { + total_elements *= e; + } + auto addr = src.access_ptr(1, DataType::Handle(), 1, 0, total_elements); + args.push_back(addr); + } + // \param src_coord + for (auto r : src_range) { + args.push_back(r->min); + } + // \param dst_addr + if (dst.scope() == "global") { + args.push_back(dst->data); + } else { + PrimExpr total_elements = 1; + for (auto e : dst->shape) { + total_elements *= e; + } + auto addr = dst.access_ptr(2, DataType::Handle(), 1, 0, total_elements); + args.push_back(addr); + } + // \param dst_coord + for (auto r : dst_range) { + args.push_back(r->min); + } + + auto op = dma_copy(); + Stmt dma_copy; + dma_copy = Evaluate(Call(DataType::Handle(), op, args)); + + return dma_copy; +} + /** * @brief Lower the copy operator using the generic (non-specialized) path. * @@ -1763,6 +1984,7 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); return tma_copy; } + /*! * \brief Encode the TMA descriptor into an array of PrimExpr. * This function serializes the TMA descriptor fields into a format suitable for diff --git a/src/op/copy.h b/src/op/copy.h index fd3e01a40..e58eb4fc8 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -26,6 +26,9 @@ enum class CopyInst : uint8_t { kBulkStore1D = 6, // utilize tma store 1d kTMemLoad = 7, // tcgen05.ld (tensor memory -> register) kTMemStore = 8, // tcgen05.st (register -> tensor memory) + + // dma + kDMACopy = 9, }; /// Descriptor for Tensor Memory Access (TMA) copy operations @@ -127,6 +130,12 @@ class CopyNode : public TileOperatorNode { LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + /*! + * \brief Check if dma copy is supported. + */ + bool CheckDMACopy(Target target, arith::Analyzer *analyzer, + bool check_last_dim = true) const; + /*! * \brief Check if bulk copy is supported. */ @@ -188,6 +197,12 @@ class CopyNode : public TileOperatorNode { bool buffer_oob) const; protected: + /*! + * \brief Generate lowering for dma copy. + */ + Stmt LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const; + /*! * \brief Generate lowering for bulk/global-to-shared copy. */ diff --git a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py new file mode 100644 index 000000000..a4e8eb290 --- /dev/null +++ b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py @@ -0,0 +1,205 @@ +import tilelang +import pytest +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target +import tilelang.language as T +from tilelang.language.v2.annot import MeshShardingPolicy + + +def copy(K, block_M, block_N, block_K, dtype="float32", accum_dtype="float32"): + MyTensor = T.MeshTensor((128, 128), + sharding_policy=MeshShardingPolicy(cross_mesh_dim=0), + device_mesh_config=(2, 2), + hierarchical_dims=(4, 32, 128), + hierarchical_groups=((0, 2), (2, 3)), + hierarchical_strides=(32, 1, 4096)) + + @T.prim_func + def main(C: MyTensor): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(128, block_N), T.ceildiv(128, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + D_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # DRAM -> RSRAM + T.copy(C[by * block_M, ko * block_K], C_shared) + # DRAM -> WSRAM + T.copy(C[by * block_M, ko * block_K], B_shared) + # DRAM <- RSRAM + T.copy(C_shared, C[by * block_M, ko * block_K]) + # DRAM -> ASRAM + T.copy(C[by * block_M, ko * block_K], A_shared) + # RSRAM -> ASRAM + T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) + # RSRAM -> WSRAM + T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) + # RSRAM <-> RSRAM + T.copy(C_shared, D_shared) + + return tvm.IRModule({'main': main}) + + +TEST_CASES = [ + ( + 128, + 64, + 64, + 32, + [ + # DRAM -> RSRAM + # T.copy(C[by * block_M, ko * block_K], C_shared) + "T.dma_copy(\"global\", \"shared.rsram\", 7, 2, 64, 64, 32, 128, _j * 32 + _i, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, C.data, by * 64, ko * 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 2), 0, 0)", + # DRAM -> WSRAM + # T.copy(C[by * block_M, ko * block_K], B_shared) + "T.dma_copy(\"global\", \"shared.wsram\", 7, 2, 64, 64, 32, 128, _j * 32 + _i, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, C.data, by * 64, ko * 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 4096, 2), 0, 0)", + # DRAM <- RSRAM + # T.copy(C_shared, C[by * block_M, ko * block_K]) + "T.dma_copy(\"shared.rsram\", \"global\", 7, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, 2, 64, 64, 32, 128, _j * 32 + _i, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 0, 0, C.data, by * 64, ko * 32)", + # DRAM -> ASRAM + # T.copy(C[by * block_M, ko * block_K], A_shared) + "T.dma_copy(\"global\", \"shared.asram\", 7, 2, 64, 64, 32, 128, _j * 32 + _i, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, C.data, by * 64, ko * 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 4096, 2), 0, 0)", + # RSRAM -> ASRAM + # T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) + "T.dma_copy(\"shared.rsram\", \"shared.asram\", 7, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 4096, 2), 24, 8)", + # RSRAM -> WSRAM + # T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) + "T.dma_copy(\"shared.rsram\", \"shared.wsram\", 7, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 8, 48, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 4096, 2), 40, 0)", + # RSRAM <-> RSRAM + # T.copy(C_shared, D_shared) + "T.dma_copy(\"shared.rsram\", \"shared.rsram\", 7, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 0, 0, T.tvm_access_ptr(T.type_annotation(\"float32\"), D_shared.data, 0, 4096, 2), 0, 0)", + ]), +] + + +@pytest.mark.parametrize( + "K, block_M, block_N, block_K, lower_stmt", + TEST_CASES, +) +def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K, lower_stmt): + target_name = "Sunmmio" + target = determine_target(target_name, return_object=True) + with tvm.target.Target(target): + mod = copy(K, block_M, block_N, block_K) + mod = tvm.tir.transform.BindTarget(target)(mod) + # Add wrapper for single buf store + mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + # Normalize negative indices to canonical non-negative form + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + # Inject assumes to speedup tvm prover + mod = tilelang.transform.InjectAssumes()(mod) + # Simplify the IR expressions + mod = tilelang.transform.Simplify()(mod) + # Infer shared memory SRAM scope + mod = tilelang.transform.InferSramScope()(mod) + # Set layouts for reducers + mod = tilelang.transform.LayoutReducer()(mod) + # Infer memory layouts for fragments and shared memory + mod = tilelang.transform.LayoutInference()(mod) + # Lower high-level tile operations to low-level operations + mod = tilelang.transform.LowerTileOp()(mod) + texts = mod.script().split('\n') + texts = texts[29:-2] + texts = [it.lstrip() for it in texts] + for i in range(len(texts)): + assert texts[i] == lower_stmt[i] + + +def wrong_copy(M, + N, + K, + block_M, + block_N, + block_K, + error_type, + dtype="float16", + accum_dtype="float16"): + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + A_shared_2 = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + B_shared_2 = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + if error_type == 'A->D': + T.copy(A_shared, C[by * block_M, ko * block_K]) + elif error_type == 'W->D': + T.copy(B_shared, C[by * block_M, ko * block_K]) + elif error_type == 'A->R': + T.copy(A_shared, C_shared) + elif error_type == 'W->R': + T.copy(B_shared, C_shared) + elif error_type == 'D<->D': + T.copy(C[by * block_M, ko * block_K], B[by * block_M, ko * block_K]) + elif error_type == 'A<->A': + T.copy(A_shared, A_shared_2) + elif error_type == 'W<->W': + T.copy(B_shared, B_shared_2) + elif error_type == 'A->W': + T.copy(A_shared, B_shared) + elif error_type == 'W->A': + T.copy(B_shared, A_shared) + + return tvm.IRModule({'main': main}) + + +WRONG_TEST_CASES = [ + (128, 128, 128, 32, 32, 32, "A->D", + "Unsupported copy from shared.asram to global of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->D", + "Unsupported copy from shared.wsram to global of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "A->R", + "Unsupported copy from shared.asram to shared.rsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->R", + "Unsupported copy from shared.wsram to shared.rsram of Sunmmio target."), + # (128, 128, 128, 32, 32, 32, "D<->D", + # "Unsupported copy from global to global of Sunmmio target."), + # D<->D not work now + (128, 128, 128, 32, 32, 32, "A<->A", + "Unsupported copy from shared.asram to shared.asram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W<->W", + "Unsupported copy from shared.wsram to shared.wsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "A->W", + "Unsupported copy from shared.asram to shared.wsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->A", + "Unsupported copy from shared.wsram to shared.asram of Sunmmio target."), +] + + +@pytest.mark.parametrize( + "M, N, K, block_M, block_N, block_K, error_type, error_msg", + WRONG_TEST_CASES, +) +def test_tilelang_mesh_wrong_copy_to_dma(M, N, K, block_M, block_N, block_K, error_type, error_msg): + target_name = "Sunmmio" + target = determine_target(target_name, return_object=True) + with pytest.raises(tvm.error.InternalError, match=error_msg), tvm.target.Target(target): + mod = wrong_copy(M, N, K, block_M, block_N, block_K, error_type) + mod = tvm.tir.transform.BindTarget(target)(mod) + # Add wrapper for single buf store + mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + # Normalize negative indices to canonical non-negative form + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + # Inject assumes to speedup tvm prover + mod = tilelang.transform.InjectAssumes()(mod) + # Simplify the IR expressions + mod = tilelang.transform.Simplify()(mod) + # Infer shared memory SRAM scope + mod = tilelang.transform.InferSramScope()(mod) + # Set layouts for reducers + mod = tilelang.transform.LayoutReducer()(mod) + # Infer memory layouts for fragments and shared memory + mod = tilelang.transform.LayoutInference()(mod) + # Lower high-level tile operations to low-level operations + mod = tilelang.transform.LowerTileOp()(mod) diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index cabc4a3e4..aca3a5f7f 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -67,6 +67,13 @@ def get_extent(data): # Combine the nested if statements into a single if statement as suggested by SIM102 if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and isinstance(dst, tir.BufferLoad)): + # FIXME + # Now an invalid D<->D copy operation will enter here, for example: + # T.copy(C[by * block_M, ko * block_K], B[by * block_M, ko * block_K]) -> + # for ko in T.serial(4, annotations={"num_stages": 3}): + # B[by * 32, ko * 32] = C[by * 32, ko * 32] + # which causes an exception can't be caught. + # # check if the case is like this: # copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes # In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i] diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 86caee01f..1715d3a96 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -183,9 +183,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): NotImplementedError: If the instruction type is not supported ValueError: If the instruction type is unknown """ - if gemm_inst.is_sunmmio(): - return GemmSunmmio - elif gemm_inst.is_mma(): + if gemm_inst.is_mma(): if target_is_volta(target): return GemmMMASm70 return GemmMMA @@ -195,6 +193,8 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): return GemmTCGEN5 elif gemm_inst.is_mfma(): return GemmMFMA + elif gemm_inst.is_sunmmio(): + return GemmSunmmio elif gemm_inst.is_tcgen5mma(): raise NotImplementedError("TCGEN5MMA is not implemented") else: