From b7896c5378f7d8118df947e41681e5d7c6aba1b9 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Sun, 4 Jan 2026 10:13:14 +0800 Subject: [PATCH 1/9] init a new branch for dma copy --- src/op/copy.cc | 477 ++++++++++++++++++++++++++++++++++++++++++++++++- src/op/copy.h | 41 +++++ 2 files changed, 516 insertions(+), 2 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 72e73e162..c8ba3ada0 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -564,6 +564,26 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } return {}; } + + if (copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) { + // if can apply swizzling, we skip layout inference + // for dma load/store, we can directly apply the layout of normal copy + // This must be a global/shared layout, so we can skip the parallel op + // layout inference (parallel layout inference only annotate the loop layout + // and the register layout). + // the same implementation as TMA + bool is_load = copy_inst == CopyInst::kDMALoad; + Buffer global_tensor = is_load ? src : dst; + Buffer shared_tensor = is_load ? dst : src; + // check shared layout is non-swizzle + // skip layout inference if shared layout is already annotated + if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) { + // create a new layout map for tma linear layout + Layout linear_layout = ComputeLinearLayout(shared_tensor); + return Map({{shared_tensor, linear_layout}}); + } + return {}; + } // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy // Use parallel op to infer the layout @@ -573,6 +593,107 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } return par_op_->InferLayout(T, level); } + +/** + * @brief Determine whether this CopyNode can be lowered to a DMA Load + * instruction. + * + * The function returns true when all of the following hold: + * - the target architecture advertises DMA support; + * - the source buffer resides in global memory; + * - the destination buffer resides in shared memory (either "shared" or + * "shared.dyn"); + * - the source and destination have the same element data type. + * + * If the source and destination dtypes differ, a warning is logged and the + * function returns false (the caller is expected to fall back to a normal + * copy). + * + * + * @param target The compilation target to query for dma load support. + * @return true if the copy can be implemented as a DMA Load; false + * otherwise. + */ +bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, + bool check_last_dim) const { + // 1. arch must support zpu + if (!TargetIsZpu(target)) + return false; + // 2. src and dst must be global and shared + if (src.scope() != "global" || + (dst.scope() != "shared.dyn" && dst.scope() != "shared")) + return false; + // 3. check shape. + // last dim of src * dtype.bits() must be a multiple of 16 + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // now we check src (gmem) as tma box dim is deduced from src + if (check_last_dim && + analyzer->CanProve( + FloorMod(src_range[src_range.size() - 1]->extent * src->dtype.bytes(), + 16) != 0, + arith::ProofStrength::kSymbolicBound)) { + LOG(WARNING) + << "src range must have last dim multiple of 16 for tma bulk load " + << src->name << " range " << src_range[src_range.size() - 1]->extent + << " * " << src->dtype.bytes() << " % 16 != 0"; + return false; + } + + // 4. src and dst must have the same dtype + if (src->dtype != dst->dtype) { + LOG(WARNING) << "src and dst must have the same dtype for tma load " + << src->name << " vs. " << dst->name << " dtype " << src->dtype + << " vs. " << dst->dtype << " will be fallback to normal copy"; + return false; + } + return true; +} + +/** + * @brief Determine if this CopyNode can be lowered to a CUDA DMA store. + * + * Checks whether the target supports DMA store, the source buffer is in shared + * memory (shared or shared.dyn), the destination buffer is in global memory, + * and both buffers have the same element data type. If the data types differ, + * a warning is logged and false is returned. + * + * @param target Target device/architecture to check for dma store support. + * @return true if all conditions are met; false otherwise. + */ +bool CopyNode::CheckDMAStore(Target target, arith::Analyzer *analyzer, + bool check_last_dim) const { + // 1. arch must support zpu + if (!TargetIsZpu(target)) + return false; + // 2. src and dst must be shared.dyn and local.fragment + if ((src.scope() != "shared.dyn" && src.scope() != "shared") || + dst.scope() != "global") + return false; + // 3. check shape. + // last dim of dst * dtype.bits() must be a multiple of 16 + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // now we check dst (gmem) as tma box dim is deduced from dst + if (check_last_dim && + analyzer->CanProve( + FloorMod(dst_range[dst_range.size() - 1]->extent * dst->dtype.bytes(), + 16) != 0, + arith::ProofStrength::kSymbolicBound)) { + LOG(WARNING) + << "dst range must have last dim multiple of 16 for tma bulk store " + << dst->name << " range " << dst_range[dst_range.size() - 1]->extent + << " * " << dst->dtype.bytes() << " % 16 != 0"; + return false; + } + // 4. src and dst must have the same dtype + if (src->dtype != dst->dtype) { + LOG(WARNING) << "src and dst must have the same dtype for tma store " + << src->name << " vs. " << dst->name << " dtype " << src->dtype + << " vs. " << dst->dtype << " will be fallback to normal copy"; + return false; + } + return true; +} + /** * @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA) * instruction. @@ -830,7 +951,11 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, // Check tensor memory operations first (highest priority for SM100/Blackwell) // 1d tma access can not support out of bound access - if (!disable_tma_lower && !buffer_oob && + if (CheckDMALoad(target, analyzer)) { + return CopyInst::kDMALoad; + } else if (CheckDMAStore(target, analyzer)) { + return CopyInst::kDMAStore; + } else if (!disable_tma_lower && !buffer_oob && CheckBulkLoad1D(target, layout_map, analyzer)) { return CopyInst::kBulkLoad1D; } else if (!disable_tma_lower && !buffer_oob && @@ -874,7 +999,12 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, T.layout_map, analyzer); - if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { + + if (copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) { + auto bulk_copy = LowerDMACopy(T, analyzer, copy_inst); + ICHECK(bulk_copy.defined()) << "Failed to lower dma load/store"; + return bulk_copy; + } else if(copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { auto tmem_copy = LowerTmemCopy(T, analyzer); ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy"; return tmem_copy; @@ -899,6 +1029,317 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } +/** + * @brief Lower a Copy operator to a DMA transfer. + * + * Haoze TODO: the same as tma now + * Lowers the copy to an optimized DMA load or store when the target and buffer + * layouts permit. + * + * If preconditions are not satisfied (unsupported swizzle, stride/size limits, + * mismatched element counts, OOB risks, or other hardware constraints), this + * function falls back to LowerNormalCopy. + * + * @param T LowerArgs containing target information, thread/bounds variables, + * and layout/ buffer remap information used for descriptor + * construction. + * @param analyzer Analyzer used to prove shapes/contiguity/equality + * constraints. + * @param copy_inst Indicates whether to emit a BulkLoad (TMA load) or BulkStore + * (TMA store). Must be CopyInst::kDMALoad or kDMAStore. + * @return Stmt A TIR statement performing the bulk TMA copy (or the result of + * LowerNormalCopy when falling back). + */ +Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const { + ICHECK(copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) + << "Invalid copy inst " << static_cast(copy_inst); + bool is_load = copy_inst == CopyInst::kDMALoad; + Buffer global_tensor = is_load ? src : dst; + Buffer shared_tensor = is_load ? dst : src; + Array global_range = is_load ? src_range : dst_range; + Array shared_range = is_load ? dst_range : src_range; + // Cannot support a non-swizzled global layout, will be fallback to normal copy + if (T.layout_map.count(global_tensor)) { + LOG(WARNING) << "DMA copy cannot support a non-swizzled global " + "layout, fallback to normal copy."; + return LowerNormalCopy(T, analyzer); + } + + // linear layout must be computed before remapping + auto linear_layout = ComputeLinearLayout(shared_tensor); + + Array shared_indices; + for (auto r : shared_range) + shared_indices.push_back(r->min); + std::vector shared_strides; + PrimExpr shared_stride = 1; + for (size_t i = 0; i < shared_tensor->shape.size(); i++) { + auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; + shared_strides.insert(shared_strides.begin(), shared_stride); + shared_stride *= s; + } + + Array global_indices; + for (auto r : global_range) { + global_indices.push_back(r->min); + } + std::vector global_strides; + PrimExpr global_stride = 1; + for (size_t i = 0; i < global_tensor->shape.size(); i++) { + auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; + global_strides.insert(global_strides.begin(), global_stride); + global_stride *= s; + } + + ICHECK(shared_strides.size() == shared_indices.size()) + << "shared_strides.size() != shared_indices.size()" + << shared_strides.size() << " " << shared_indices.size(); + PrimExpr shared_offset = 0; + for (size_t i = 0; i < shared_indices.size(); i++) { + shared_offset += shared_indices[i] * shared_strides[i]; + } + PrimExpr global_offset = 0; + for (size_t i = 0; i < global_indices.size(); i++) { + global_offset += global_indices[i] * global_strides[i]; + } + + TMADesc desc; + // Verify copy rank + desc.rank = global_tensor->shape.size(); + ICHECK(desc.rank >= 1 && desc.rank <= 5) << desc.rank; + + // Verify datatype + ICHECK(global_tensor->dtype == shared_tensor->dtype) + << "Copy between buffer " << global_tensor->name << " and " + << shared_tensor->name << " with different data type " + << global_tensor->dtype << " and " << shared_tensor->dtype; + + desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); + + // Global Tensor Shape and Stride + desc.global_addr = global_tensor->data; + desc.global_shape = ReverseArray(global_tensor->shape); + Array global_coords = + ReverseArray(global_range.Map([](Range r) { return r->min; })); + if (!global_tensor->strides.empty()) { + desc.global_stride = ReverseArray(global_tensor->strides); + } else { + // Create stride from shape + PrimExpr stride = 1; + desc.global_stride.reserve(desc.rank); + for (size_t i = 0; i < desc.rank; i++) { + desc.global_stride.push_back(stride); + stride *= desc.global_shape[i]; + } + } + // The first stride element should be 1 + ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; + // Make global stride in bytes + desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { + return cast(DataType::Int(64), e) * global_tensor->dtype.bytes(); + }); + for (size_t i{1}; i < desc.global_stride.size(); i++) { + auto stride = desc.global_stride[i].as(); + if (stride != nullptr) { + // otherwise, the stride is symbolic, we need to check in future with + // assumptions + if (stride->value % 16 != 0 || stride->value >= (1ULL << 40)) { + LOG(WARNING) << "TMA bulk copy cannot support a global stride of " + << desc.global_stride[i] << ", fallback to normal copy."; + return LowerNormalCopy(T, analyzer); + } + } + } + + // Smem Box + // check smem range and global range is legal + auto s_range_idx = 0; + for (size_t i = 0; i < global_range.size(); i++) { + auto g_range = global_range[i]; + if (is_one(g_range->extent)) { + continue; + } + // skip one range if it is 1 + // in case of global range is [128, 64], while shared range is [1, 128, 64] + // A_shared[0, :, :]. + while (is_one(shared_range[s_range_idx]->extent) && + s_range_idx < shared_range.size()) { + s_range_idx++; + } + if (s_range_idx >= shared_range.size()) { + LOG(FATAL) << "TMA bulk copy cannot support a global range of " + << global_range << ", shared_range " << shared_range; + } + auto s_range = shared_range[s_range_idx]; + s_range_idx++; + + ICHECK(StructuralEqual()(g_range->extent, s_range->extent)) + << global_tensor->name << "[" << i << "] is illegal, " + << global_tensor->name << "[" << i << "] = " << g_range->extent << ", " + << shared_tensor->name << "[" << s_range_idx + << "] = " << s_range->extent; + } + // TODO(lei): find a much smarter way to deduce smem box dim + // instead of using global_range + desc.smem_box = + ReverseArray(global_range.Map([](Range r) { return r->extent; })); + + desc.smem_stride = Array(desc.rank, PrimExpr(1)); + // L2 & OOB + desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); + desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + // Detect smem layout + // Shared memory swizzling is crucial for TMA performance + // It determines how data is arranged in shared memory banks to minimize bank + // conflicts Different swizzle patterns (32B, 64B, 128B) offer different + // trade-offs between access efficiency and memory usage + desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); + Layout shared_layout; + if (T.layout_map.count(shared_tensor)) { + shared_layout = T.layout_map.at(shared_tensor); + ICHECK(T.buffer_remap.count(shared_tensor)) + << "shared_tensor: " << shared_tensor->name + << " not found in buffer_remap"; + shared_tensor = T.buffer_remap.at(shared_tensor); + } + if (!shared_layout.defined()) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } else if (StructuralEqual()(shared_layout, linear_layout)) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } else { + ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout."; + auto stride = as_const_int(shared_layout->InputShape()[0]); + auto continuous = as_const_int(shared_layout->InputShape()[1]); + ICHECK(stride != nullptr && continuous != nullptr); + // We also need to check if the shape satisfies the following doc: + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + if (StructuralEqual()(shared_layout, makeQuarterBankSwizzleLayout( + *stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); + } else if (StructuralEqual()( + shared_layout, + makeHalfBankSwizzleLayout(*stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); + } else if (StructuralEqual()( + shared_layout, + makeFullBankSwizzleLayout(*stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); + } else if (StructuralEqual()( + shared_layout, + makeGemmABLayoutPadded(*stride, *continuous, + shared_tensor->dtype.bits()))) { + LOG(WARNING) << "Bulk copy cannot support a padded layout for src: " + << src->name << ", dst: " << dst->name + << ", fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } else { + LOG(WARNING) << "Came across unsupported swizzle layout for src: " + << src->name << ", dst: " << dst->name + << ", fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } + } + + auto inner_box_dim = as_const_int(desc.smem_box[0]); + if (inner_box_dim == nullptr) { + LOG(WARNING) << "inner_box_dim " << desc.smem_box[0] + << " can only be a constant integer for TMA bulk copy, " + "fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } + int instruction_dim = *inner_box_dim; + if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_64B)) { + instruction_dim = 64 / src->dtype.bytes(); + } else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_128B)) { + instruction_dim = 128 / src->dtype.bytes(); + } + if (instruction_dim > 256) { + // smem_box dim must be in [0, 256] + // if is 512, we need to split the copy into two parts + ICHECK((*inner_box_dim) % 256 == 0) + << "inner_box_dim: " << *inner_box_dim << " is not divisible by 256"; + instruction_dim = 256; + } + ICHECK((*inner_box_dim) % instruction_dim == 0) + << "inner_box_dim: " << *inner_box_dim + << " is not divisible by instruction_dim: " << instruction_dim; + desc.smem_box.Set(0, PrimExpr(instruction_dim)); + + int inner_box_dim_ = instruction_dim * shared_tensor->dtype.bytes(); + + // Check inner_box_dim_ for each swizzle type in a cleaner way + struct SwizzleCheck { + int swizzle; + int max_dim; + }; + static const std::vector swizzle_checks = { + {static_cast(CU_TENSOR_MAP_SWIZZLE_32B), 32}, + {static_cast(CU_TENSOR_MAP_SWIZZLE_64B), 64}, + {static_cast(CU_TENSOR_MAP_SWIZZLE_128B), 128}, + }; + for (const auto &check : swizzle_checks) { + if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) { + LOG(WARNING) << "TMA bulk copy cannot support a swizzled global layout " + "with inner_box_dim_ > " + << check.max_dim << ", will be fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } + } + + Call create_descriptor = + Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); + + Array args; + args.reserve(desc.rank + 4); + args.push_back(create_descriptor); + if (is_load) + args.push_back(0); // mbarrier id placeholder + auto op = is_load ? tma_load() : tma_store(); + + Stmt tma_copy; + PrimExpr total_elements = 1; + for (auto e : desc.smem_box) + total_elements *= e; + + if ((*inner_box_dim) != instruction_dim) { + Var loop_var("i"); + int loop_extent = (*inner_box_dim) / instruction_dim; + + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, + shared_offset + total_elements * loop_var, total_elements); + args.push_back(shared_addr); + global_coords.Set(0, global_coords[0] + instruction_dim * loop_var); + for (auto coord : global_coords) + args.push_back(coord); + int need_reduce = 0; + if (!is_load) + args.push_back(need_reduce); + args.push_back(this->eviction_policy); + tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, + Evaluate(Call(DataType::Handle(), op, args))); + } else { + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements); + args.push_back(shared_addr); + for (auto coord : global_coords) + args.push_back(coord); + int need_reduce = 0; + if (!is_load) + args.push_back(need_reduce); + args.push_back(this->eviction_policy); + tma_copy = Evaluate(Call(DataType::Handle(), op, args)); + } + tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); + + return tma_copy; +} + /** * @brief Lower the copy operator using the generic (non-specialized) path. * @@ -1762,6 +2203,38 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); return tma_copy; } + +/*! + * \brief Encode the DMA descriptor into an array of PrimExpr. + * This function serializes the DMA descriptor fields into a format suitable for + * passing to the create_dma_descriptor() builtin function. The encoding follows + * the expected argument order for the DMA descriptor creation. + * \return Array of PrimExpr representing the encoded DMA descriptor. + * the same implementation as TMA + */ +Array DMADesc::EncodeCallArgs() const { + Array args; + args.reserve(rank * 4 + 7); + + args.push_back(data_type); + args.push_back(static_cast(rank)); + args.push_back(global_addr); + for (auto e : global_shape) + args.push_back(e); + for (auto e : global_stride) + args.push_back(e); + for (auto e : smem_box) + args.push_back(e); + for (auto e : smem_stride) + args.push_back(e); + args.push_back(interleave); + args.push_back(swizzle); + args.push_back(l2_promotion); + args.push_back(oob_fill); + + return args; +} + /*! * \brief Encode the TMA descriptor into an array of PrimExpr. * This function serializes the TMA descriptor fields into a format suitable for diff --git a/src/op/copy.h b/src/op/copy.h index b08f57688..b2aac5864 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -26,6 +26,10 @@ enum class CopyInst : uint8_t { kBulkStore1D = 6, // utilize tma store 1d kTMemLoad = 7, // tcgen05.ld (tensor memory -> register) kTMemStore = 8, // tcgen05.st (register -> tensor memory) + + // dma + kDMALoad = 9, + kDMAStore = 10, }; /// Descriptor for Tensor Memory Access (TMA) copy operations @@ -46,6 +50,25 @@ struct TMADesc { Array EncodeCallArgs() const; }; +/// Descriptor for DMA copy operations +/// the same as TMADesc +struct DMADesc { + size_t rank; ///< Tensor rank (number of dimensions) + int data_type; ///< Data type identifier + Array global_shape; ///< Shape in global memory + Array global_stride; ///< Strides in global memory + Array smem_box; ///< Block shape in shared memory + Array smem_stride; ///< Strides in shared memory + PrimExpr global_addr; ///< Base address in global memory + int swizzle; ///< Memory layout swizzle parameter + int interleave; ///< Memory interleave parameter + int oob_fill; ///< Out-of-bound fill policy + int l2_promotion; ///< L2 cache promotion flag + + /// Encode descriptor fields into runtime call arguments + Array EncodeCallArgs() const; +}; + /*! * \brief Descriptor for TMA-based im2col transformation used in Conv2D. * @@ -128,6 +151,18 @@ class CopyNode : public TileOperatorNode { LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + /*! + * \brief Check if dma load is supported. + */ + bool CheckDMALoad(Target target, arith::Analyzer *analyzer, + bool check_last_dim = true) const; + + /*! + * \brief Check if dma store is supported. + */ + bool CheckDMAStore(Target target, arith::Analyzer *analyzer, + bool check_last_dim = true) const; + /*! * \brief Check if bulk copy is supported. */ @@ -189,6 +224,12 @@ class CopyNode : public TileOperatorNode { bool buffer_oob) const; protected: + /*! + * \brief Generate lowering for dma copy. + */ + Stmt LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const; + /*! * \brief Generate lowering for bulk/global-to-shared copy. */ From c288cafca27c4b09781eac5a86889394ffe824c8 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Fri, 23 Jan 2026 13:58:09 +0800 Subject: [PATCH 2/9] change if statement position --- src/op/copy.cc | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index c8ba3ada0..a37065d5d 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -616,8 +616,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, */ bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, bool check_last_dim) const { - // 1. arch must support zpu - if (!TargetIsZpu(target)) + // 1. arch must support Sunmmio + if (!TargetIsSunmmio(target)) return false; // 2. src and dst must be global and shared if (src.scope() != "global" || @@ -662,8 +662,8 @@ bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, */ bool CopyNode::CheckDMAStore(Target target, arith::Analyzer *analyzer, bool check_last_dim) const { - // 1. arch must support zpu - if (!TargetIsZpu(target)) + // 1. arch must support Sunmmio + if (!TargetIsSunmmio(target)) return false; // 2. src and dst must be shared.dyn and local.fragment if ((src.scope() != "shared.dyn" && src.scope() != "shared") || @@ -951,11 +951,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, // Check tensor memory operations first (highest priority for SM100/Blackwell) // 1d tma access can not support out of bound access - if (CheckDMALoad(target, analyzer)) { - return CopyInst::kDMALoad; - } else if (CheckDMAStore(target, analyzer)) { - return CopyInst::kDMAStore; - } else if (!disable_tma_lower && !buffer_oob && + if (!disable_tma_lower && !buffer_oob && CheckBulkLoad1D(target, layout_map, analyzer)) { return CopyInst::kBulkLoad1D; } else if (!disable_tma_lower && !buffer_oob && @@ -973,6 +969,10 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, return CopyInst::kTMemLoad; } else if (CheckTMemStore(target)) { return CopyInst::kTMemStore; + } else if (CheckDMALoad(target, analyzer)) { + return CopyInst::kDMALoad; + } else if (CheckDMAStore(target, analyzer)) { + return CopyInst::kDMAStore; } else { return CopyInst::kNormal; } @@ -1000,11 +1000,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, T.layout_map, analyzer); - if (copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) { - auto bulk_copy = LowerDMACopy(T, analyzer, copy_inst); - ICHECK(bulk_copy.defined()) << "Failed to lower dma load/store"; - return bulk_copy; - } else if(copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { + if(copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { auto tmem_copy = LowerTmemCopy(T, analyzer); ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy"; return tmem_copy; @@ -1024,6 +1020,10 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return ldsm_copy; } else if (copy_inst == CopyInst::kNormal) { return LowerNormalCopy(T, analyzer); + } else if (copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) { + auto bulk_copy = LowerDMACopy(T, analyzer, copy_inst); + ICHECK(bulk_copy.defined()) << "Failed to lower dma load/store"; + return bulk_copy; } else { LOG(FATAL) << "Unsupported copy inst " << static_cast(copy_inst); } From cc998ec4208a42555dfdfc6ec221ee109aa4b152 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Fri, 30 Jan 2026 15:32:31 +0800 Subject: [PATCH 3/9] change if statement position in gemm --- tilelang/tileop/gemm/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 86caee01f..1715d3a96 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -183,9 +183,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): NotImplementedError: If the instruction type is not supported ValueError: If the instruction type is unknown """ - if gemm_inst.is_sunmmio(): - return GemmSunmmio - elif gemm_inst.is_mma(): + if gemm_inst.is_mma(): if target_is_volta(target): return GemmMMASm70 return GemmMMA @@ -195,6 +193,8 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): return GemmTCGEN5 elif gemm_inst.is_mfma(): return GemmMFMA + elif gemm_inst.is_sunmmio(): + return GemmSunmmio elif gemm_inst.is_tcgen5mma(): raise NotImplementedError("TCGEN5MMA is not implemented") else: From 3b1a68dfe2ee69183431a8bcdc2172c42aaeee93 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Wed, 4 Feb 2026 16:25:41 +0800 Subject: [PATCH 4/9] implementation dma copy, remain global layout to do --- src/op/builtin.h | 22 +- src/op/copy.cc | 514 +++++------------- src/op/copy.h | 4 +- ...test_tilelang_mesh_language_copy_to_dma.py | 250 +++++++++ tilelang/language/copy.py | 7 + 5 files changed, 410 insertions(+), 387 deletions(-) create mode 100644 testing/python/language/test_tilelang_mesh_language_copy_to_dma.py diff --git a/src/op/builtin.h b/src/op/builtin.h index 6709f7511..ca56e196a 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -211,7 +211,7 @@ TVM_DLL const Op &tma_load(); * A[32:64, 128:192, 0:256] * then: * src_rank = 3 - * src_shape = [128, 256, 512] + * src_region_shape = [32, 64, 256] * coord = [32, 128, 0] * * \param data_type @@ -220,17 +220,17 @@ TVM_DLL const Op &tma_load(); * \param src_rank * Rank (number of dimensions) of the source tensor. * - * \param src_shape - * Logical shape of the source tensor. - * For example, Tensor(128, 256, 512) -> [128, 256, 512]. + * \param src_region_shape + * Logical shape of the source buffer region. + * For example, A[32:64, 128:192, 0:256] -> [32, 64, 256]. * * \param src_input_size - * Input shape of the source layout, retrievable via Layout::getInputShape(). + * Input shape of the source layout, retrievable via LayoutNode::InputShape(). * For a row-major 3D tensor, this is identical to src_shape. * * \param src_forward * Forward index mapping of the source layout, retrievable via - * Layout::GetForwardIndex(). + * LayoutNode::GetForwardIndex(). * For a row-major layout of Tensor(128, 256, 512), * this is [256 * 512, 512, 1]. * @@ -241,8 +241,8 @@ TVM_DLL const Op &tma_load(); * \param dst_rank * Rank (number of dimensions) of the destination tensor. * - * \param dst_shape - * Logical shape of the destination tensor. + * \param dst_region_shape + * Logical shape of the destination buffer region. * * \param dst_input_size * Input shape of the destination layout, retrievable via @@ -259,13 +259,17 @@ TVM_DLL const Op &tma_load(); * \param src_addr * Base address of the source tensor in memory . * - * \param coord + * \param src_coord * Coordinate offset specifying the starting point of the copy in the source * tensor. Its length must equal src_rank. * * \param dst_addr * Base address of the destination tensor in memory . * + * \param dst_coord + * Coordinate offset specifying the starting point of the copy in the + * destination tensor. Its length must equal dst_rank. + * * \note * Out-of-bound fill policies are currently not supported. */ diff --git a/src/op/copy.cc b/src/op/copy.cc index a37065d5d..7406b29d0 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -566,21 +566,30 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } if (copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) { - // if can apply swizzling, we skip layout inference - // for dma load/store, we can directly apply the layout of normal copy - // This must be a global/shared layout, so we can skip the parallel op - // layout inference (parallel layout inference only annotate the loop layout - // and the register layout). - // the same implementation as TMA + // for dma load/store, we can directly apply the blockwise_zz_layout bool is_load = copy_inst == CopyInst::kDMALoad; - Buffer global_tensor = is_load ? src : dst; - Buffer shared_tensor = is_load ? dst : src; - // check shared layout is non-swizzle - // skip layout inference if shared layout is already annotated - if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) { - // create a new layout map for tma linear layout - Layout linear_layout = ComputeLinearLayout(shared_tensor); - return Map({{shared_tensor, linear_layout}}); + const auto f = + ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); + if (!is_load) { + // DMA Store, only src in shared + if (level == InferLevel::kFree && !T.layout_map.count(src)) { + auto layout = Downcast((*f)(src)); + return Map({{src, layout}}); + } + return {}; + } else { + // DMA Load, src may in shared, dst in shared + auto result = Map(); + if (level == InferLevel::kFree && src.scope() != "global" && + !T.layout_map.count(src)) { + auto layout = Downcast((*f)(src)); + result.Set(src, layout); + } + if (level == InferLevel::kFree && !T.layout_map.count(dst)) { + auto layout = Downcast((*f)(dst)); + result.Set(dst, layout); + } + return result; } return {}; } @@ -595,51 +604,47 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } /** - * @brief Determine whether this CopyNode can be lowered to a DMA Load + * @brief Determine whether this CopyNode can be lowered to a DMA Load * instruction. * * The function returns true when all of the following hold: * - the target architecture advertises DMA support; - * - the source buffer resides in global memory; - * - the destination buffer resides in shared memory (either "shared" or - * "shared.dyn"); + * - the source buffer and the destination buffer are legal; * - the source and destination have the same element data type. * * If the source and destination dtypes differ, a warning is logged and the * function returns false (the caller is expected to fall back to a normal * copy). - * + * * * @param target The compilation target to query for dma load support. * @return true if the copy can be implemented as a DMA Load; false * otherwise. */ -bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, - bool check_last_dim) const { +bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, + bool check_last_dim) const { // 1. arch must support Sunmmio if (!TargetIsSunmmio(target)) return false; - // 2. src and dst must be global and shared - if (src.scope() != "global" || - (dst.scope() != "shared.dyn" && dst.scope() != "shared")) - return false; - // 3. check shape. - // last dim of src * dtype.bits() must be a multiple of 16 - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - // now we check src (gmem) as tma box dim is deduced from src - if (check_last_dim && - analyzer->CanProve( - FloorMod(src_range[src_range.size() - 1]->extent * src->dtype.bytes(), - 16) != 0, - arith::ProofStrength::kSymbolicBound)) { - LOG(WARNING) - << "src range must have last dim multiple of 16 for tma bulk load " - << src->name << " range " << src_range[src_range.size() - 1]->extent - << " * " << src->dtype.bytes() << " % 16 != 0"; + + // 2. src and dst must be legal + bool scope_check = false; + // 2.1 DRAM -> RSRAM + if (src.scope() == "global" && dst.scope() == "shared.rsram") + scope_check = true; + // 2.2 RSRAM -> WSRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.wsram") + scope_check = true; + // 2.3 RSRAM -> ASRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.asram") + scope_check = true; + // 2.4 RSRAM <-> RSRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.rsram") + scope_check = true; + if (!scope_check) return false; - } - // 4. src and dst must have the same dtype + // 3. src and dst must have the same dtype if (src->dtype != dst->dtype) { LOG(WARNING) << "src and dst must have the same dtype for tma load " << src->name << " vs. " << dst->name << " dtype " << src->dtype @@ -653,7 +658,7 @@ bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, * @brief Determine if this CopyNode can be lowered to a CUDA DMA store. * * Checks whether the target supports DMA store, the source buffer is in shared - * memory (shared or shared.dyn), the destination buffer is in global memory, + * memory (shared.rsram), the destination buffer is in global memory, * and both buffers have the same element data type. If the data types differ, * a warning is logged and false is returned. * @@ -661,32 +666,16 @@ bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, * @return true if all conditions are met; false otherwise. */ bool CopyNode::CheckDMAStore(Target target, arith::Analyzer *analyzer, - bool check_last_dim) const { + bool check_last_dim) const { // 1. arch must support Sunmmio if (!TargetIsSunmmio(target)) return false; // 2. src and dst must be shared.dyn and local.fragment - if ((src.scope() != "shared.dyn" && src.scope() != "shared") || - dst.scope() != "global") - return false; - // 3. check shape. - // last dim of dst * dtype.bits() must be a multiple of 16 - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - // now we check dst (gmem) as tma box dim is deduced from dst - if (check_last_dim && - analyzer->CanProve( - FloorMod(dst_range[dst_range.size() - 1]->extent * dst->dtype.bytes(), - 16) != 0, - arith::ProofStrength::kSymbolicBound)) { - LOG(WARNING) - << "dst range must have last dim multiple of 16 for tma bulk store " - << dst->name << " range " << dst_range[dst_range.size() - 1]->extent - << " * " << dst->dtype.bytes() << " % 16 != 0"; + if (src.scope() != "shared.rsram" || dst.scope() != "global") return false; - } - // 4. src and dst must have the same dtype + // 3. src and dst must have the same dtype if (src->dtype != dst->dtype) { - LOG(WARNING) << "src and dst must have the same dtype for tma store " + LOG(WARNING) << "src and dst must have the same dtype for dma store " << src->name << " vs. " << dst->name << " dtype " << src->dtype << " vs. " << dst->dtype << " will be fallback to normal copy"; return false; @@ -969,10 +958,15 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, return CopyInst::kTMemLoad; } else if (CheckTMemStore(target)) { return CopyInst::kTMemStore; - } else if (CheckDMALoad(target, analyzer)) { - return CopyInst::kDMALoad; - } else if (CheckDMAStore(target, analyzer)) { - return CopyInst::kDMAStore; + } else if (TargetIsSunmmio(target)) { + auto is_load = CheckDMALoad(target, analyzer); + auto is_store = CheckDMAStore(target, analyzer); + if (is_load) + return CopyInst::kDMALoad; + if (is_store) + return CopyInst::kDMAStore; + ICHECK(0) << "Unsupported copy from " << src.scope() << " to " + << dst.scope() << " of Sunmmio target."; } else { return CopyInst::kNormal; } @@ -985,6 +979,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, * determined copy instruction type: * - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions * - LDSM/STSM: Uses matrix load/store instructions for tensor cores + * - DMA Load/Store: Sunmmio specified instructions for copy * - Normal: Uses standard load/store operations with loop transformations * \param T LowerArgs containing target and layout map. * \param analyzer Arithmetic analyzer for simplification. @@ -1000,7 +995,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, T.layout_map, analyzer); - if(copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { + if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { auto tmem_copy = LowerTmemCopy(T, analyzer); ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy"; return tmem_copy; @@ -1020,7 +1015,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return ldsm_copy; } else if (copy_inst == CopyInst::kNormal) { return LowerNormalCopy(T, analyzer); - } else if (copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) { + } else if (copy_inst == CopyInst::kDMALoad || + copy_inst == CopyInst::kDMAStore) { auto bulk_copy = LowerDMACopy(T, analyzer, copy_inst); ICHECK(bulk_copy.defined()) << "Failed to lower dma load/store"; return bulk_copy; @@ -1032,312 +1028,109 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { /** * @brief Lower a Copy operator to a DMA transfer. * - * Haoze TODO: the same as tma now * Lowers the copy to an optimized DMA load or store when the target and buffer - * layouts permit. - * - * If preconditions are not satisfied (unsupported swizzle, stride/size limits, - * mismatched element counts, OOB risks, or other hardware constraints), this - * function falls back to LowerNormalCopy. + * layouts permit. * * @param T LowerArgs containing target information, thread/bounds variables, - * and layout/ buffer remap information used for descriptor + * and layout/ buffer remap information * construction. * @param analyzer Analyzer used to prove shapes/contiguity/equality * constraints. - * @param copy_inst Indicates whether to emit a BulkLoad (TMA load) or BulkStore - * (TMA store). Must be CopyInst::kDMALoad or kDMAStore. - * @return Stmt A TIR statement performing the bulk TMA copy (or the result of - * LowerNormalCopy when falling back). + * @param copy_inst Indicates whether to emit a DMA load or DMA store. Must be + * CopyInst::kDMALoad or kDMAStore. + * @return Stmt A TIR statement performing the DMA copy. */ Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, - CopyInst copy_inst) const { + CopyInst copy_inst) const { ICHECK(copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) << "Invalid copy inst " << static_cast(copy_inst); - bool is_load = copy_inst == CopyInst::kDMALoad; - Buffer global_tensor = is_load ? src : dst; - Buffer shared_tensor = is_load ? dst : src; - Array global_range = is_load ? src_range : dst_range; - Array shared_range = is_load ? dst_range : src_range; - // Cannot support a non-swizzled global layout, will be fallback to normal copy - if (T.layout_map.count(global_tensor)) { - LOG(WARNING) << "DMA copy cannot support a non-swizzled global " - "layout, fallback to normal copy."; - return LowerNormalCopy(T, analyzer); - } - // linear layout must be computed before remapping - auto linear_layout = ComputeLinearLayout(shared_tensor); - - Array shared_indices; - for (auto r : shared_range) - shared_indices.push_back(r->min); - std::vector shared_strides; - PrimExpr shared_stride = 1; - for (size_t i = 0; i < shared_tensor->shape.size(); i++) { - auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; - shared_strides.insert(shared_strides.begin(), shared_stride); - shared_stride *= s; - } - - Array global_indices; - for (auto r : global_range) { - global_indices.push_back(r->min); - } - std::vector global_strides; - PrimExpr global_stride = 1; - for (size_t i = 0; i < global_tensor->shape.size(); i++) { - auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; - global_strides.insert(global_strides.begin(), global_stride); - global_stride *= s; - } - - ICHECK(shared_strides.size() == shared_indices.size()) - << "shared_strides.size() != shared_indices.size()" - << shared_strides.size() << " " << shared_indices.size(); - PrimExpr shared_offset = 0; - for (size_t i = 0; i < shared_indices.size(); i++) { - shared_offset += shared_indices[i] * shared_strides[i]; - } - PrimExpr global_offset = 0; - for (size_t i = 0; i < global_indices.size(); i++) { - global_offset += global_indices[i] * global_strides[i]; - } - - TMADesc desc; - // Verify copy rank - desc.rank = global_tensor->shape.size(); - ICHECK(desc.rank >= 1 && desc.rank <= 5) << desc.rank; - - // Verify datatype - ICHECK(global_tensor->dtype == shared_tensor->dtype) - << "Copy between buffer " << global_tensor->name << " and " - << shared_tensor->name << " with different data type " - << global_tensor->dtype << " and " << shared_tensor->dtype; - - desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); - - // Global Tensor Shape and Stride - desc.global_addr = global_tensor->data; - desc.global_shape = ReverseArray(global_tensor->shape); - Array global_coords = - ReverseArray(global_range.Map([](Range r) { return r->min; })); - if (!global_tensor->strides.empty()) { - desc.global_stride = ReverseArray(global_tensor->strides); + bool is_load = copy_inst == CopyInst::kDMALoad; + Array args; + // \param data_type + args.push_back(to_CUtensorMapDataType(src->dtype)); + // \param src_rank + args.push_back(static_cast(src->shape.size())); + // \param src_region_shape + for (auto r : src_range) { + args.push_back(r->extent); + } + // \param src_input_size & \param src_forward + if (src.scope() == "global") { } else { - // Create stride from shape - PrimExpr stride = 1; - desc.global_stride.reserve(desc.rank); - for (size_t i = 0; i < desc.rank; i++) { - desc.global_stride.push_back(stride); - stride *= desc.global_shape[i]; + ICHECK(T.layout_map.count(src)) + << "Layout of buffer " << src << " not found."; + auto layout = T.layout_map.at(src); + for (auto s : layout->InputShape()) { + args.push_back(s); } - } - // The first stride element should be 1 - ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; - // Make global stride in bytes - desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { - return cast(DataType::Int(64), e) * global_tensor->dtype.bytes(); - }); - for (size_t i{1}; i < desc.global_stride.size(); i++) { - auto stride = desc.global_stride[i].as(); - if (stride != nullptr) { - // otherwise, the stride is symbolic, we need to check in future with - // assumptions - if (stride->value % 16 != 0 || stride->value >= (1ULL << 40)) { - LOG(WARNING) << "TMA bulk copy cannot support a global stride of " - << desc.global_stride[i] << ", fallback to normal copy."; - return LowerNormalCopy(T, analyzer); - } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); } } + // \param src_scope + auto src_scope = StringImm(src.scope()); + args.push_back(src_scope); - // Smem Box - // check smem range and global range is legal - auto s_range_idx = 0; - for (size_t i = 0; i < global_range.size(); i++) { - auto g_range = global_range[i]; - if (is_one(g_range->extent)) { - continue; - } - // skip one range if it is 1 - // in case of global range is [128, 64], while shared range is [1, 128, 64] - // A_shared[0, :, :]. - while (is_one(shared_range[s_range_idx]->extent) && - s_range_idx < shared_range.size()) { - s_range_idx++; + // \param dst_rank + args.push_back(static_cast(dst->shape.size())); + // \param dst_shape + for (auto r : dst_range) { + args.push_back(r->extent); + } + // \param dst_input_size & \param dst_forward + if (dst.scope() == "global") { + } else { + ICHECK(T.layout_map.count(dst)) + << "Layout of buffer " << dst << " not found."; + auto layout = T.layout_map.at(dst); + for (auto s : layout->InputShape()) { + args.push_back(s); } - if (s_range_idx >= shared_range.size()) { - LOG(FATAL) << "TMA bulk copy cannot support a global range of " - << global_range << ", shared_range " << shared_range; + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); } - auto s_range = shared_range[s_range_idx]; - s_range_idx++; - - ICHECK(StructuralEqual()(g_range->extent, s_range->extent)) - << global_tensor->name << "[" << i << "] is illegal, " - << global_tensor->name << "[" << i << "] = " << g_range->extent << ", " - << shared_tensor->name << "[" << s_range_idx - << "] = " << s_range->extent; } - // TODO(lei): find a much smarter way to deduce smem box dim - // instead of using global_range - desc.smem_box = - ReverseArray(global_range.Map([](Range r) { return r->extent; })); + // \param dst_scope + auto dst_scope = StringImm(dst.scope()); + args.push_back(dst_scope); - desc.smem_stride = Array(desc.rank, PrimExpr(1)); - // L2 & OOB - desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); - desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - - // Detect smem layout - // Shared memory swizzling is crucial for TMA performance - // It determines how data is arranged in shared memory banks to minimize bank - // conflicts Different swizzle patterns (32B, 64B, 128B) offer different - // trade-offs between access efficiency and memory usage - desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); - Layout shared_layout; - if (T.layout_map.count(shared_tensor)) { - shared_layout = T.layout_map.at(shared_tensor); - ICHECK(T.buffer_remap.count(shared_tensor)) - << "shared_tensor: " << shared_tensor->name - << " not found in buffer_remap"; - shared_tensor = T.buffer_remap.at(shared_tensor); - } - if (!shared_layout.defined()) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); - } else if (StructuralEqual()(shared_layout, linear_layout)) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + // \param src_addr + if (src.scope() == "global") { + args.push_back(src->data); } else { - ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout."; - auto stride = as_const_int(shared_layout->InputShape()[0]); - auto continuous = as_const_int(shared_layout->InputShape()[1]); - ICHECK(stride != nullptr && continuous != nullptr); - // We also need to check if the shape satisfies the following doc: - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - if (StructuralEqual()(shared_layout, makeQuarterBankSwizzleLayout( - *stride, *continuous, - shared_tensor->dtype.bits()))) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); - } else if (StructuralEqual()( - shared_layout, - makeHalfBankSwizzleLayout(*stride, *continuous, - shared_tensor->dtype.bits()))) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); - } else if (StructuralEqual()( - shared_layout, - makeFullBankSwizzleLayout(*stride, *continuous, - shared_tensor->dtype.bits()))) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); - } else if (StructuralEqual()( - shared_layout, - makeGemmABLayoutPadded(*stride, *continuous, - shared_tensor->dtype.bits()))) { - LOG(WARNING) << "Bulk copy cannot support a padded layout for src: " - << src->name << ", dst: " << dst->name - << ", fallback to normal copy"; - return LowerNormalCopy(T, analyzer); - } else { - LOG(WARNING) << "Came across unsupported swizzle layout for src: " - << src->name << ", dst: " << dst->name - << ", fallback to normal copy"; - return LowerNormalCopy(T, analyzer); + PrimExpr total_elements = 1; + for (auto e : src_range) { + total_elements *= e->extent; } + auto addr = src.access_ptr(1, DataType::Handle(), 1, 0, total_elements); + args.push_back(addr); } - - auto inner_box_dim = as_const_int(desc.smem_box[0]); - if (inner_box_dim == nullptr) { - LOG(WARNING) << "inner_box_dim " << desc.smem_box[0] - << " can only be a constant integer for TMA bulk copy, " - "fallback to normal copy"; - return LowerNormalCopy(T, analyzer); - } - int instruction_dim = *inner_box_dim; - if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_64B)) { - instruction_dim = 64 / src->dtype.bytes(); - } else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_128B)) { - instruction_dim = 128 / src->dtype.bytes(); + // \param src_coord + for (auto r : src_range) { + args.push_back(r->min); } - if (instruction_dim > 256) { - // smem_box dim must be in [0, 256] - // if is 512, we need to split the copy into two parts - ICHECK((*inner_box_dim) % 256 == 0) - << "inner_box_dim: " << *inner_box_dim << " is not divisible by 256"; - instruction_dim = 256; - } - ICHECK((*inner_box_dim) % instruction_dim == 0) - << "inner_box_dim: " << *inner_box_dim - << " is not divisible by instruction_dim: " << instruction_dim; - desc.smem_box.Set(0, PrimExpr(instruction_dim)); - - int inner_box_dim_ = instruction_dim * shared_tensor->dtype.bytes(); - - // Check inner_box_dim_ for each swizzle type in a cleaner way - struct SwizzleCheck { - int swizzle; - int max_dim; - }; - static const std::vector swizzle_checks = { - {static_cast(CU_TENSOR_MAP_SWIZZLE_32B), 32}, - {static_cast(CU_TENSOR_MAP_SWIZZLE_64B), 64}, - {static_cast(CU_TENSOR_MAP_SWIZZLE_128B), 128}, - }; - for (const auto &check : swizzle_checks) { - if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) { - LOG(WARNING) << "TMA bulk copy cannot support a swizzled global layout " - "with inner_box_dim_ > " - << check.max_dim << ", will be fallback to normal copy"; - return LowerNormalCopy(T, analyzer); + // \param dst_addr + if (dst.scope() == "global") { + args.push_back(dst->data); + } else { + PrimExpr total_elements = 1; + for (auto e : dst_range) { + total_elements *= e->extent; } + auto addr = dst.access_ptr(2, DataType::Handle(), 1, 0, total_elements); + args.push_back(addr); } - - Call create_descriptor = - Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); - - Array args; - args.reserve(desc.rank + 4); - args.push_back(create_descriptor); - if (is_load) - args.push_back(0); // mbarrier id placeholder - auto op = is_load ? tma_load() : tma_store(); - - Stmt tma_copy; - PrimExpr total_elements = 1; - for (auto e : desc.smem_box) - total_elements *= e; - - if ((*inner_box_dim) != instruction_dim) { - Var loop_var("i"); - int loop_extent = (*inner_box_dim) / instruction_dim; - - PrimExpr shared_addr = shared_tensor.access_ptr( - is_load ? 2 : 1, DataType::Handle(), 1, - shared_offset + total_elements * loop_var, total_elements); - args.push_back(shared_addr); - global_coords.Set(0, global_coords[0] + instruction_dim * loop_var); - for (auto coord : global_coords) - args.push_back(coord); - int need_reduce = 0; - if (!is_load) - args.push_back(need_reduce); - args.push_back(this->eviction_policy); - tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, - Evaluate(Call(DataType::Handle(), op, args))); - } else { - PrimExpr shared_addr = shared_tensor.access_ptr( - is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements); - args.push_back(shared_addr); - for (auto coord : global_coords) - args.push_back(coord); - int need_reduce = 0; - if (!is_load) - args.push_back(need_reduce); - args.push_back(this->eviction_policy); - tma_copy = Evaluate(Call(DataType::Handle(), op, args)); + // \param dst_coord + for (auto r : dst_range) { + args.push_back(r->min); } - tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); - return tma_copy; + auto op = is_load ? dma_load() : dma_store(); + Stmt dma_copy; + dma_copy = Evaluate(Call(DataType::Handle(), op, args)); + + return dma_copy; } /** @@ -2204,37 +1997,6 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, return tma_copy; } -/*! - * \brief Encode the DMA descriptor into an array of PrimExpr. - * This function serializes the DMA descriptor fields into a format suitable for - * passing to the create_dma_descriptor() builtin function. The encoding follows - * the expected argument order for the DMA descriptor creation. - * \return Array of PrimExpr representing the encoded DMA descriptor. - * the same implementation as TMA - */ -Array DMADesc::EncodeCallArgs() const { - Array args; - args.reserve(rank * 4 + 7); - - args.push_back(data_type); - args.push_back(static_cast(rank)); - args.push_back(global_addr); - for (auto e : global_shape) - args.push_back(e); - for (auto e : global_stride) - args.push_back(e); - for (auto e : smem_box) - args.push_back(e); - for (auto e : smem_stride) - args.push_back(e); - args.push_back(interleave); - args.push_back(swizzle); - args.push_back(l2_promotion); - args.push_back(oob_fill); - - return args; -} - /*! * \brief Encode the TMA descriptor into an array of PrimExpr. * This function serializes the TMA descriptor fields into a format suitable for diff --git a/src/op/copy.h b/src/op/copy.h index b2d5406dd..04fb44200 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -135,7 +135,7 @@ class CopyNode : public TileOperatorNode { * \brief Check if dma load is supported. */ bool CheckDMALoad(Target target, arith::Analyzer *analyzer, - bool check_last_dim = true) const; + bool check_last_dim = true) const; /*! * \brief Check if dma store is supported. @@ -208,7 +208,7 @@ class CopyNode : public TileOperatorNode { * \brief Generate lowering for dma copy. */ Stmt LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, - CopyInst copy_inst) const; + CopyInst copy_inst) const; /*! * \brief Generate lowering for bulk/global-to-shared copy. diff --git a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py new file mode 100644 index 000000000..58f69aa75 --- /dev/null +++ b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py @@ -0,0 +1,250 @@ +import tilelang +import pytest +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target +import tilelang.language as T +from tilelang.language.v2.annot import MeshShardingPolicy + + +def copy(K, block_M, block_N, block_K, dtype="float32", accum_dtype="float32"): + MyTensor = T.MeshTensor((128, 128), + sharding_policy=MeshShardingPolicy(cross_mesh_dim=0), + device_mesh_config=(2, 2), + hierarchical_dims=(4, 32, 128), + hierarchical_groups=((0, 2), (2, 3)), + hierarchical_strides=(32, 1, 4096)) + + @T.prim_func + def main(C: MyTensor): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(128, block_N), T.ceildiv(128, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + D_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # DRAM -> RSRAM + T.copy(C[by * block_M, ko * block_K], C_shared) + # DRAM <- RSRAM + T.copy(C_shared, C[by * block_M, ko * block_K]) + # RSRAM -> ASRAM + T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) + # RSRAM -> WSRAM + T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) + # RSRAM <-> RSRAM + T.copy(C_shared, D_shared) + + return tvm.IRModule({'main': main}) + + +TEST_CASES = [ + (128, 64, 64, 32), +] + + +@pytest.mark.parametrize( + "K, block_M, block_N, block_K", + TEST_CASES, +) +def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K): + target_name = "Sunmmio" + target = determine_target(target_name, return_object=True) + with tvm.target.Target(target): + mod = copy(K, block_M, block_N, block_K) + mod = tvm.tir.transform.BindTarget(target)(mod) + # Add wrapper for single buf store + mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + # Normalize negative indices to canonical non-negative form + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + # Inject assumes to speedup tvm prover + mod = tilelang.transform.InjectAssumes()(mod) + # Simplify the IR expressions + mod = tilelang.transform.Simplify()(mod) + # Infer shared memory SRAM scope + mod = tilelang.transform.InferSramScope()(mod) + # Set layouts for reducers + mod = tilelang.transform.LayoutReducer()(mod) + # Infer memory layouts for fragments and shared memory + mod = tilelang.transform.LayoutInference()(mod) + # Lower high-level tile operations to low-level operations + mod = tilelang.transform.LowerTileOp()(mod) + print(mod) + + +def wrong_copy_1(M, + N, + K, + block_M, + block_N, + block_K, + error_type, + dtype="float16", + accum_dtype="float16"): + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + A_shared_2 = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + B_shared_2 = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + if error_type == 'D->A': + T.copy(C[by * block_M, ko * block_K], A_shared) + elif error_type == 'A->D': + T.copy(A_shared, C[by * block_M, ko * block_K]) + elif error_type == 'D->W': + T.copy(C[by * block_M, ko * block_K], B_shared) + elif error_type == 'W->D': + T.copy(B_shared, C[by * block_M, ko * block_K]) + elif error_type == 'A->R': + T.copy(A_shared, C_shared) + elif error_type == 'W->R': + T.copy(B_shared, C_shared) + elif error_type == 'D<->D': + T.copy(C[by * block_M, ko * block_K], B[by * block_M, ko * block_K]) + elif error_type == 'A<->A': + T.copy(A_shared, A_shared_2) + elif error_type == 'W<->W': + T.copy(B_shared, B_shared_2) + + return tvm.IRModule({'main': main}) + + +WRONG_TEST_CASES = [ + (128, 128, 128, 32, 32, 32, "D->A", + "Unsupported copy from global to shared.asram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "A->D", + "Unsupported copy from shared.asram to global of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "D->W", + "Unsupported copy from global to shared.wsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->D", + "Unsupported copy from shared.wsram to global of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "A->R", + "Unsupported copy from shared.asram to shared.rsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->R", + "Unsupported copy from shared.wsram to shared.rsram of Sunmmio target."), + # (128, 128, 128, 32, 32, 32, "D<->D", + # "Unsupported copy from global to global of Sunmmio target."), + # D<->D not work now + (128, 128, 128, 32, 32, 32, "A<->A", + "Unsupported copy from shared.asram to shared.asram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W<->W", + "Unsupported copy from shared.wsram to shared.wsram of Sunmmio target."), +] + + +@pytest.mark.parametrize( + "M, N, K, block_M, block_N, block_K, error_type, error_msg", + WRONG_TEST_CASES, +) +def test_tilelang_mesh_wrong_copy_to_dma_1(M, N, K, block_M, block_N, block_K, error_type, + error_msg): + target_name = "Sunmmio" + target = determine_target(target_name, return_object=True) + with pytest.raises(tvm.error.InternalError, match=error_msg), tvm.target.Target(target): + mod = wrong_copy_1(M, N, K, block_M, block_N, block_K, error_type) + mod = tvm.tir.transform.BindTarget(target)(mod) + # Add wrapper for single buf store + mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + # Normalize negative indices to canonical non-negative form + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + # Inject assumes to speedup tvm prover + mod = tilelang.transform.InjectAssumes()(mod) + # Simplify the IR expressions + mod = tilelang.transform.Simplify()(mod) + # Infer shared memory SRAM scope + mod = tilelang.transform.InferSramScope()(mod) + # Set layouts for reducers + mod = tilelang.transform.LayoutReducer()(mod) + # Infer memory layouts for fragments and shared memory + mod = tilelang.transform.LayoutInference()(mod) + # Lower high-level tile operations to low-level operations + mod = tilelang.transform.LowerTileOp()(mod) + + +def wrong_copy_2(M, + N, + K, + block_M, + block_N, + block_K, + error_type, + dtype="float16", + accum_dtype="float16"): + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + D_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # DRAM -> RSRAM + T.copy(C[by * block_M, ko * block_K], C_shared) + # DRAM <- RSRAM + T.copy(C_shared, C[by * block_M, ko * block_K]) + # RSRAM -> ASRAM + T.copy(C_shared, A_shared) + # RSRAM -> WSRAM + T.copy(C_shared, B_shared) + # RSRAM <-> RSRAM + T.copy(C_shared, D_shared) + if error_type == 'A->W': + T.copy(A_shared, B_shared) + elif error_type == 'W->A': + T.copy(B_shared, A_shared) + + return tvm.IRModule({'main': main}) + + +WRONG_TEST_CASES = [ + (128, 128, 128, 32, 32, 32, "A->W", + "Unsupported copy from shared.asram to shared.wsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->A", + "Unsupported copy from shared.wsram to shared.asram of Sunmmio target."), +] + + +@pytest.mark.parametrize( + "M, N, K, block_M, block_N, block_K, error_type, error_msg", + WRONG_TEST_CASES, +) +def test_tilelang_mesh_wrong_copy_to_dma_2(M, N, K, block_M, block_N, block_K, error_type, + error_msg): + target_name = "Sunmmio" + target = determine_target(target_name, return_object=True) + with pytest.raises(tvm.error.InternalError, match=error_msg), tvm.target.Target(target): + mod = wrong_copy_2(M, N, K, block_M, block_N, block_K, error_type) + mod = tvm.tir.transform.BindTarget(target)(mod) + # Add wrapper for single buf store + mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + # Normalize negative indices to canonical non-negative form + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + # Inject assumes to speedup tvm prover + mod = tilelang.transform.InjectAssumes()(mod) + # Simplify the IR expressions + mod = tilelang.transform.Simplify()(mod) + # Infer shared memory SRAM scope + mod = tilelang.transform.InferSramScope()(mod) + # Set layouts for reducers + mod = tilelang.transform.LayoutReducer()(mod) + # Infer memory layouts for fragments and shared memory + mod = tilelang.transform.LayoutInference()(mod) + # Lower high-level tile operations to low-level operations + mod = tilelang.transform.LowerTileOp()(mod) diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index cabc4a3e4..aca3a5f7f 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -67,6 +67,13 @@ def get_extent(data): # Combine the nested if statements into a single if statement as suggested by SIM102 if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and isinstance(dst, tir.BufferLoad)): + # FIXME + # Now an invalid D<->D copy operation will enter here, for example: + # T.copy(C[by * block_M, ko * block_K], B[by * block_M, ko * block_K]) -> + # for ko in T.serial(4, annotations={"num_stages": 3}): + # B[by * 32, ko * 32] = C[by * 32, ko * 32] + # which causes an exception can't be caught. + # # check if the case is like this: # copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes # In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i] From c0e766e0fd39a82bc09ea56de8e11a6a7cb1466c Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Thu, 5 Feb 2026 13:42:20 +0800 Subject: [PATCH 5/9] let test script work, global layout todo --- src/op/copy.cc | 3 ++ ...test_tilelang_mesh_language_copy_to_dma.py | 48 ++++++++++++++----- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 7406b29d0..a675b24c8 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -641,6 +641,9 @@ bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, // 2.4 RSRAM <-> RSRAM if (src.scope() == "shared.rsram" && dst.scope() == "shared.rsram") scope_check = true; + // 2.5 DRAM -> WSRAM + if (src.scope() == "global" && dst.scope() == "shared.wsram") + scope_check = true; if (!scope_check) return false; diff --git a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py index 58f69aa75..67db45db1 100644 --- a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py +++ b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py @@ -25,9 +25,11 @@ def main(C: MyTensor): for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): # DRAM -> RSRAM - T.copy(C[by * block_M, ko * block_K], C_shared) - # DRAM <- RSRAM - T.copy(C_shared, C[by * block_M, ko * block_K]) + # T.copy(C[by * block_M, ko * block_K], C_shared) + # # DRAM -> WSRAM + # T.copy(C[by * block_M, ko * block_K], B_shared) + # # DRAM <- RSRAM + # T.copy(C_shared, C[by * block_M, ko * block_K]) # RSRAM -> ASRAM T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) # RSRAM -> WSRAM @@ -39,15 +41,39 @@ def main(C: MyTensor): TEST_CASES = [ - (128, 64, 64, 32), + ( + 128, + 64, + 64, + 32, + [ + # DRAM -> RSRAM + # T.copy(C[by * block_M, ko * block_K], C_shared) + # "1", + # DRAM -> WSRAM + # T.copy(C[by * block_M, ko * block_K], B_shared) + # "1", + # DRAM <- RSRAM + # T.copy(C_shared, C[by * block_M, ko * block_K]) + # "1", + # RSRAM -> ASRAM + # T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) + "T.dma_load(7, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 512, 1), 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 512, 2), 24, 8)", + # RSRAM -> WSRAM + # T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) + "T.dma_load(7, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.wsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 192, 1), 8, 48, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 192, 2), 40, 0)", + # RSRAM <-> RSRAM + # T.copy(C_shared, D_shared) + "T.dma_load(7, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 0, 0, T.tvm_access_ptr(T.type_annotation(\"float32\"), D_shared.data, 0, 4096, 2), 0, 0)", + ]), ] @pytest.mark.parametrize( - "K, block_M, block_N, block_K", + "K, block_M, block_N, block_K, lower_stmt", TEST_CASES, ) -def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K): +def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K, lower_stmt): target_name = "Sunmmio" target = determine_target(target_name, return_object=True) with tvm.target.Target(target): @@ -69,7 +95,11 @@ def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K): mod = tilelang.transform.LayoutInference()(mod) # Lower high-level tile operations to low-level operations mod = tilelang.transform.LowerTileOp()(mod) - print(mod) + texts = mod.script().split('\n') + texts = texts[29:] + texts = [it[20:] for it in texts] + for i in range(len(texts)): + assert texts[i] == lower_stmt[i] def wrong_copy_1(M, @@ -101,8 +131,6 @@ def main( T.copy(C[by * block_M, ko * block_K], A_shared) elif error_type == 'A->D': T.copy(A_shared, C[by * block_M, ko * block_K]) - elif error_type == 'D->W': - T.copy(C[by * block_M, ko * block_K], B_shared) elif error_type == 'W->D': T.copy(B_shared, C[by * block_M, ko * block_K]) elif error_type == 'A->R': @@ -124,8 +152,6 @@ def main( "Unsupported copy from global to shared.asram of Sunmmio target."), (128, 128, 128, 32, 32, 32, "A->D", "Unsupported copy from shared.asram to global of Sunmmio target."), - (128, 128, 128, 32, 32, 32, "D->W", - "Unsupported copy from global to shared.wsram of Sunmmio target."), (128, 128, 128, 32, 32, 32, "W->D", "Unsupported copy from shared.wsram to global of Sunmmio target."), (128, 128, 128, 32, 32, 32, "A->R", From eb83efd9d6ea56bcb6d0ef1f8e6f1e77a413c8d3 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Thu, 5 Feb 2026 15:59:46 +0800 Subject: [PATCH 6/9] remove load and store, now use dma_copy --- src/op/builtin.cc | 5 +- src/op/builtin.h | 12 +- src/op/copy.cc | 101 ++++++----------- src/op/copy.h | 13 +-- ...test_tilelang_mesh_language_copy_to_dma.py | 105 +++--------------- 5 files changed, 56 insertions(+), 180 deletions(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index ed5f3067a..058c395a8 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -106,10 +106,7 @@ TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); - -TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr( +TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(create_tma_descriptor) diff --git a/src/op/builtin.h b/src/op/builtin.h index ca56e196a..caf7d2224 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -189,7 +189,7 @@ TVM_DLL const Op &get_mbarrier(); TVM_DLL const Op &tma_load(); /*! - * \brief Perform a DMA load operation from source memory to destination memory. + * \brief Perform a DMA copy operation from source memory to destination memory. * * This function describes a DMA-based tensor copy with explicit shape, layout, * memory scope. It is typically used to lower a high-level @@ -273,15 +273,7 @@ TVM_DLL const Op &tma_load(); * \note * Out-of-bound fill policies are currently not supported. */ -TVM_DLL const Op &dma_load(); - -/*! - * \brief Perform a DMA store operation from source memory to destination - * memory. see dma_load for details. - * - * - */ -TVM_DLL const Op &dma_store(); +TVM_DLL const Op &dma_copy(); /*! * \brief tvm intrinsics for loading image from global tensor to columns in diff --git a/src/op/copy.cc b/src/op/copy.cc index a675b24c8..be32021cd 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -565,33 +565,26 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, return {}; } - if (copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) { - // for dma load/store, we can directly apply the blockwise_zz_layout - bool is_load = copy_inst == CopyInst::kDMALoad; + if (copy_inst == CopyInst::kDMACopy) { + // for dma copy, we can directly apply the blockwise_zz_layout const auto f = ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); - if (!is_load) { - // DMA Store, only src in shared - if (level == InferLevel::kFree && !T.layout_map.count(src)) { - auto layout = Downcast((*f)(src)); - return Map({{src, layout}}); - } - return {}; - } else { - // DMA Load, src may in shared, dst in shared - auto result = Map(); - if (level == InferLevel::kFree && src.scope() != "global" && - !T.layout_map.count(src)) { + auto result = Map(); + + if (level == InferLevel::kFree && !T.layout_map.count(src)) { + if (src.scope() != "global") { auto layout = Downcast((*f)(src)); result.Set(src, layout); } - if (level == InferLevel::kFree && !T.layout_map.count(dst)) { + } + + if (level == InferLevel::kFree && !T.layout_map.count(dst)) { + if (dst.scope() != "global") { auto layout = Downcast((*f)(dst)); result.Set(dst, layout); } - return result; } - return {}; + return result; } // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy @@ -621,7 +614,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, * @return true if the copy can be implemented as a DMA Load; false * otherwise. */ -bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, +bool CopyNode::CheckDMACopy(Target target, arith::Analyzer *analyzer, bool check_last_dim) const { // 1. arch must support Sunmmio if (!TargetIsSunmmio(target)) @@ -632,53 +625,30 @@ bool CopyNode::CheckDMALoad(Target target, arith::Analyzer *analyzer, // 2.1 DRAM -> RSRAM if (src.scope() == "global" && dst.scope() == "shared.rsram") scope_check = true; - // 2.2 RSRAM -> WSRAM + // 2.2 DRAM -> WSRAM + if (src.scope() == "global" && dst.scope() == "shared.wsram") + scope_check = true; + // 2.3 DRAM -> ASRAM + if (src.scope() == "global" && dst.scope() == "shared.asram") + scope_check = true; + // 2.4 RSRAM -> WSRAM if (src.scope() == "shared.rsram" && dst.scope() == "shared.wsram") scope_check = true; - // 2.3 RSRAM -> ASRAM + // 2.5 RSRAM -> ASRAM if (src.scope() == "shared.rsram" && dst.scope() == "shared.asram") scope_check = true; - // 2.4 RSRAM <-> RSRAM + // 2.6 RSRAM <-> RSRAM if (src.scope() == "shared.rsram" && dst.scope() == "shared.rsram") scope_check = true; - // 2.5 DRAM -> WSRAM - if (src.scope() == "global" && dst.scope() == "shared.wsram") + // 2.7 RSRAM -> DRAM + if (src.scope() == "shared.rsram" && dst.scope() == "global") scope_check = true; if (!scope_check) return false; // 3. src and dst must have the same dtype if (src->dtype != dst->dtype) { - LOG(WARNING) << "src and dst must have the same dtype for tma load " - << src->name << " vs. " << dst->name << " dtype " << src->dtype - << " vs. " << dst->dtype << " will be fallback to normal copy"; - return false; - } - return true; -} - -/** - * @brief Determine if this CopyNode can be lowered to a CUDA DMA store. - * - * Checks whether the target supports DMA store, the source buffer is in shared - * memory (shared.rsram), the destination buffer is in global memory, - * and both buffers have the same element data type. If the data types differ, - * a warning is logged and false is returned. - * - * @param target Target device/architecture to check for dma store support. - * @return true if all conditions are met; false otherwise. - */ -bool CopyNode::CheckDMAStore(Target target, arith::Analyzer *analyzer, - bool check_last_dim) const { - // 1. arch must support Sunmmio - if (!TargetIsSunmmio(target)) - return false; - // 2. src and dst must be shared.dyn and local.fragment - if (src.scope() != "shared.rsram" || dst.scope() != "global") - return false; - // 3. src and dst must have the same dtype - if (src->dtype != dst->dtype) { - LOG(WARNING) << "src and dst must have the same dtype for dma store " + LOG(WARNING) << "src and dst must have the same dtype for dma copy " << src->name << " vs. " << dst->name << " dtype " << src->dtype << " vs. " << dst->dtype << " will be fallback to normal copy"; return false; @@ -962,12 +932,9 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, } else if (CheckTMemStore(target)) { return CopyInst::kTMemStore; } else if (TargetIsSunmmio(target)) { - auto is_load = CheckDMALoad(target, analyzer); - auto is_store = CheckDMAStore(target, analyzer); - if (is_load) - return CopyInst::kDMALoad; - if (is_store) - return CopyInst::kDMAStore; + auto is_copy = CheckDMACopy(target, analyzer); + if (is_copy) + return CopyInst::kDMACopy; ICHECK(0) << "Unsupported copy from " << src.scope() << " to " << dst.scope() << " of Sunmmio target."; } else { @@ -1018,11 +985,10 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return ldsm_copy; } else if (copy_inst == CopyInst::kNormal) { return LowerNormalCopy(T, analyzer); - } else if (copy_inst == CopyInst::kDMALoad || - copy_inst == CopyInst::kDMAStore) { - auto bulk_copy = LowerDMACopy(T, analyzer, copy_inst); - ICHECK(bulk_copy.defined()) << "Failed to lower dma load/store"; - return bulk_copy; + } else if (copy_inst == CopyInst::kDMACopy) { + auto dma_copy = LowerDMACopy(T, analyzer, copy_inst); + ICHECK(dma_copy.defined()) << "Failed to lower dma load/store"; + return dma_copy; } else { LOG(FATAL) << "Unsupported copy inst " << static_cast(copy_inst); } @@ -1045,10 +1011,9 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { */ Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, CopyInst copy_inst) const { - ICHECK(copy_inst == CopyInst::kDMALoad || copy_inst == CopyInst::kDMAStore) + ICHECK(copy_inst == CopyInst::kDMACopy) << "Invalid copy inst " << static_cast(copy_inst); - bool is_load = copy_inst == CopyInst::kDMALoad; Array args; // \param data_type args.push_back(to_CUtensorMapDataType(src->dtype)); @@ -1129,7 +1094,7 @@ Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, args.push_back(r->min); } - auto op = is_load ? dma_load() : dma_store(); + auto op = dma_copy(); Stmt dma_copy; dma_copy = Evaluate(Call(DataType::Handle(), op, args)); diff --git a/src/op/copy.h b/src/op/copy.h index 04fb44200..e58eb4fc8 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -28,8 +28,7 @@ enum class CopyInst : uint8_t { kTMemStore = 8, // tcgen05.st (register -> tensor memory) // dma - kDMALoad = 9, - kDMAStore = 10, + kDMACopy = 9, }; /// Descriptor for Tensor Memory Access (TMA) copy operations @@ -132,17 +131,11 @@ class CopyNode : public TileOperatorNode { InferLevel level) const override; /*! - * \brief Check if dma load is supported. + * \brief Check if dma copy is supported. */ - bool CheckDMALoad(Target target, arith::Analyzer *analyzer, + bool CheckDMACopy(Target target, arith::Analyzer *analyzer, bool check_last_dim = true) const; - /*! - * \brief Check if dma store is supported. - */ - bool CheckDMAStore(Target target, arith::Analyzer *analyzer, - bool check_last_dim = true) const; - /*! * \brief Check if bulk copy is supported. */ diff --git a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py index 67db45db1..ea93034df 100644 --- a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py +++ b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py @@ -24,13 +24,15 @@ def main(C: MyTensor): D_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # DRAM -> RSRAM + # # DRAM -> RSRAM # T.copy(C[by * block_M, ko * block_K], C_shared) # # DRAM -> WSRAM # T.copy(C[by * block_M, ko * block_K], B_shared) # # DRAM <- RSRAM # T.copy(C_shared, C[by * block_M, ko * block_K]) - # RSRAM -> ASRAM + # # DRAM -> ASRAM + # T.copy(C[by * block_M, ko * block_K], A_shared) + # # RSRAM -> ASRAM T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) # RSRAM -> WSRAM T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) @@ -56,15 +58,18 @@ def main(C: MyTensor): # DRAM <- RSRAM # T.copy(C_shared, C[by * block_M, ko * block_K]) # "1", + # DRAM -> ASRAM + # T.copy(C[by * block_M, ko * block_K], A_shared) + # "1", # RSRAM -> ASRAM # T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) - "T.dma_load(7, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 512, 1), 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 512, 2), 24, 8)", + "T.dma_copy(7, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 512, 1), 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 512, 2), 24, 8)", # RSRAM -> WSRAM # T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) - "T.dma_load(7, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.wsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 192, 1), 8, 48, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 192, 2), 40, 0)", + "T.dma_copy(7, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.wsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 192, 1), 8, 48, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 192, 2), 40, 0)", # RSRAM <-> RSRAM # T.copy(C_shared, D_shared) - "T.dma_load(7, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 0, 0, T.tvm_access_ptr(T.type_annotation(\"float32\"), D_shared.data, 0, 4096, 2), 0, 0)", + "T.dma_copy(7, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 0, 0, T.tvm_access_ptr(T.type_annotation(\"float32\"), D_shared.data, 0, 4096, 2), 0, 0)", ]), ] @@ -127,9 +132,7 @@ def main( C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - if error_type == 'D->A': - T.copy(C[by * block_M, ko * block_K], A_shared) - elif error_type == 'A->D': + if error_type == 'A->D': T.copy(A_shared, C[by * block_M, ko * block_K]) elif error_type == 'W->D': T.copy(B_shared, C[by * block_M, ko * block_K]) @@ -143,13 +146,15 @@ def main( T.copy(A_shared, A_shared_2) elif error_type == 'W<->W': T.copy(B_shared, B_shared_2) + elif error_type == 'A->W': + T.copy(A_shared, B_shared) + elif error_type == 'W->A': + T.copy(B_shared, A_shared) return tvm.IRModule({'main': main}) WRONG_TEST_CASES = [ - (128, 128, 128, 32, 32, 32, "D->A", - "Unsupported copy from global to shared.asram of Sunmmio target."), (128, 128, 128, 32, 32, 32, "A->D", "Unsupported copy from shared.asram to global of Sunmmio target."), (128, 128, 128, 32, 32, 32, "W->D", @@ -165,81 +170,6 @@ def main( "Unsupported copy from shared.asram to shared.asram of Sunmmio target."), (128, 128, 128, 32, 32, 32, "W<->W", "Unsupported copy from shared.wsram to shared.wsram of Sunmmio target."), -] - - -@pytest.mark.parametrize( - "M, N, K, block_M, block_N, block_K, error_type, error_msg", - WRONG_TEST_CASES, -) -def test_tilelang_mesh_wrong_copy_to_dma_1(M, N, K, block_M, block_N, block_K, error_type, - error_msg): - target_name = "Sunmmio" - target = determine_target(target_name, return_object=True) - with pytest.raises(tvm.error.InternalError, match=error_msg), tvm.target.Target(target): - mod = wrong_copy_1(M, N, K, block_M, block_N, block_K, error_type) - mod = tvm.tir.transform.BindTarget(target)(mod) - # Add wrapper for single buf store - mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) - # Normalize negative indices to canonical non-negative form - mod = tilelang.transform.LegalizeNegativeIndex()(mod) - # Inject assumes to speedup tvm prover - mod = tilelang.transform.InjectAssumes()(mod) - # Simplify the IR expressions - mod = tilelang.transform.Simplify()(mod) - # Infer shared memory SRAM scope - mod = tilelang.transform.InferSramScope()(mod) - # Set layouts for reducers - mod = tilelang.transform.LayoutReducer()(mod) - # Infer memory layouts for fragments and shared memory - mod = tilelang.transform.LayoutInference()(mod) - # Lower high-level tile operations to low-level operations - mod = tilelang.transform.LowerTileOp()(mod) - - -def wrong_copy_2(M, - N, - K, - block_M, - block_N, - block_K, - error_type, - dtype="float16", - accum_dtype="float16"): - - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Initialize Kernel Context - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") - B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") - C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") - D_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # DRAM -> RSRAM - T.copy(C[by * block_M, ko * block_K], C_shared) - # DRAM <- RSRAM - T.copy(C_shared, C[by * block_M, ko * block_K]) - # RSRAM -> ASRAM - T.copy(C_shared, A_shared) - # RSRAM -> WSRAM - T.copy(C_shared, B_shared) - # RSRAM <-> RSRAM - T.copy(C_shared, D_shared) - if error_type == 'A->W': - T.copy(A_shared, B_shared) - elif error_type == 'W->A': - T.copy(B_shared, A_shared) - - return tvm.IRModule({'main': main}) - - -WRONG_TEST_CASES = [ (128, 128, 128, 32, 32, 32, "A->W", "Unsupported copy from shared.asram to shared.wsram of Sunmmio target."), (128, 128, 128, 32, 32, 32, "W->A", @@ -251,12 +181,11 @@ def main( "M, N, K, block_M, block_N, block_K, error_type, error_msg", WRONG_TEST_CASES, ) -def test_tilelang_mesh_wrong_copy_to_dma_2(M, N, K, block_M, block_N, block_K, error_type, - error_msg): +def test_tilelang_mesh_wrong_copy_to_dma(M, N, K, block_M, block_N, block_K, error_type, error_msg): target_name = "Sunmmio" target = determine_target(target_name, return_object=True) with pytest.raises(tvm.error.InternalError, match=error_msg), tvm.target.Target(target): - mod = wrong_copy_2(M, N, K, block_M, block_N, block_K, error_type) + mod = wrong_copy_1(M, N, K, block_M, block_N, block_K, error_type) mod = tvm.tir.transform.BindTarget(target)(mod) # Add wrapper for single buf store mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) From 408c431e0cd276a67b7d44b710280be32b1a0564 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Fri, 6 Feb 2026 13:46:40 +0800 Subject: [PATCH 7/9] fix global layout logic --- src/op/copy.cc | 38 ++++++++----------- ...test_tilelang_mesh_language_copy_to_dma.py | 20 +++++----- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index be32021cd..71a2e8bd1 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1024,17 +1024,14 @@ Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, args.push_back(r->extent); } // \param src_input_size & \param src_forward - if (src.scope() == "global") { - } else { - ICHECK(T.layout_map.count(src)) - << "Layout of buffer " << src << " not found."; - auto layout = T.layout_map.at(src); - for (auto s : layout->InputShape()) { - args.push_back(s); - } - for (auto s : layout->GetForwardIndex()) { - args.push_back(s); - } + ICHECK(T.layout_map.count(src)) + << "Layout of buffer " << src << " not found."; + auto layout = T.layout_map.at(src); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); } // \param src_scope auto src_scope = StringImm(src.scope()); @@ -1047,17 +1044,14 @@ Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, args.push_back(r->extent); } // \param dst_input_size & \param dst_forward - if (dst.scope() == "global") { - } else { - ICHECK(T.layout_map.count(dst)) - << "Layout of buffer " << dst << " not found."; - auto layout = T.layout_map.at(dst); - for (auto s : layout->InputShape()) { - args.push_back(s); - } - for (auto s : layout->GetForwardIndex()) { - args.push_back(s); - } + ICHECK(T.layout_map.count(dst)) + << "Layout of buffer " << dst << " not found."; + layout = T.layout_map.at(dst); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); } // \param dst_scope auto dst_scope = StringImm(dst.scope()); diff --git a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py index ea93034df..e5c4ff6fa 100644 --- a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py +++ b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py @@ -107,15 +107,15 @@ def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K, lower_stmt): assert texts[i] == lower_stmt[i] -def wrong_copy_1(M, - N, - K, - block_M, - block_N, - block_K, - error_type, - dtype="float16", - accum_dtype="float16"): +def wrong_copy(M, + N, + K, + block_M, + block_N, + block_K, + error_type, + dtype="float16", + accum_dtype="float16"): @T.prim_func def main( @@ -185,7 +185,7 @@ def test_tilelang_mesh_wrong_copy_to_dma(M, N, K, block_M, block_N, block_K, err target_name = "Sunmmio" target = determine_target(target_name, return_object=True) with pytest.raises(tvm.error.InternalError, match=error_msg), tvm.target.Target(target): - mod = wrong_copy_1(M, N, K, block_M, block_N, block_K, error_type) + mod = wrong_copy(M, N, K, block_M, block_N, block_K, error_type) mod = tvm.tir.transform.BindTarget(target)(mod) # Add wrapper for single buf store mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) From 56940f06813a81d6b001e0f48443852f877c1de8 Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Fri, 6 Feb 2026 19:53:44 +0800 Subject: [PATCH 8/9] fix bug --- src/op/copy.cc | 8 ++++---- .../language/test_tilelang_mesh_language_copy_to_dma.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 71a2e8bd1..a05b67b5a 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1062,8 +1062,8 @@ Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, args.push_back(src->data); } else { PrimExpr total_elements = 1; - for (auto e : src_range) { - total_elements *= e->extent; + for (auto e : src->shape) { + total_elements *= e; } auto addr = src.access_ptr(1, DataType::Handle(), 1, 0, total_elements); args.push_back(addr); @@ -1077,8 +1077,8 @@ Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, args.push_back(dst->data); } else { PrimExpr total_elements = 1; - for (auto e : dst_range) { - total_elements *= e->extent; + for (auto e : dst->shape) { + total_elements *= e; } auto addr = dst.access_ptr(2, DataType::Handle(), 1, 0, total_elements); args.push_back(addr); diff --git a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py index e5c4ff6fa..8253a71a9 100644 --- a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py +++ b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py @@ -63,10 +63,10 @@ def main(C: MyTensor): # "1", # RSRAM -> ASRAM # T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) - "T.dma_copy(7, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 512, 1), 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 512, 2), 24, 8)", + "T.dma_copy(7, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 4096, 2), 24, 8)", # RSRAM -> WSRAM # T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) - "T.dma_copy(7, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.wsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 192, 1), 8, 48, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 192, 2), 40, 0)", + "T.dma_copy(7, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.wsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 8, 48, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 4096, 2), 40, 0)", # RSRAM <-> RSRAM # T.copy(C_shared, D_shared) "T.dma_copy(7, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 0, 0, T.tvm_access_ptr(T.type_annotation(\"float32\"), D_shared.data, 0, 4096, 2), 0, 0)", @@ -102,7 +102,7 @@ def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K, lower_stmt): mod = tilelang.transform.LowerTileOp()(mod) texts = mod.script().split('\n') texts = texts[29:] - texts = [it[20:] for it in texts] + texts = [it.lstrip() for it in texts] for i in range(len(texts)): assert texts[i] == lower_stmt[i] From e651b830b69ccc891817507c6b70a63de084b62d Mon Sep 17 00:00:00 2001 From: wanghz18 Date: Tue, 10 Feb 2026 15:40:03 +0800 Subject: [PATCH 9/9] move scope position in dma copy and merge global layout --- src/op/builtin.h | 16 ++--- src/op/copy.cc | 71 +++++++++++++------ ...test_tilelang_mesh_language_copy_to_dma.py | 34 ++++----- 3 files changed, 73 insertions(+), 48 deletions(-) diff --git a/src/op/builtin.h b/src/op/builtin.h index caf7d2224..de9d9ab0e 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -214,6 +214,14 @@ TVM_DLL const Op &tma_load(); * src_region_shape = [32, 64, 256] * coord = [32, 128, 0] * + * \param src_scope + * Memory scope of the source tensor. + * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". + * + * \param dst_scope + * Memory scope of the destination tensor. + * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". + * * \param data_type * Element data type of the tensor (e.g. float32, float16). * @@ -234,10 +242,6 @@ TVM_DLL const Op &tma_load(); * For a row-major layout of Tensor(128, 256, 512), * this is [256 * 512, 512, 1]. * - * \param src_scope - * Memory scope of the source tensor. - * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". - * * \param dst_rank * Rank (number of dimensions) of the destination tensor. * @@ -252,10 +256,6 @@ TVM_DLL const Op &tma_load(); * Forward index mapping of the destination layout, retrievable via * Layout::GetForwardIndex(). * - * \param dst_scope - * Memory scope of the destination tensor. - * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". - * * \param src_addr * Base address of the source tensor in memory . * diff --git a/src/op/copy.cc b/src/op/copy.cc index 4c9064156..66104bd87 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1015,8 +1015,15 @@ Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, << "Invalid copy inst " << static_cast(copy_inst); Array args; + // \param src_scope + auto src_scope = StringImm(src.scope()); + args.push_back(src_scope); + // \param dst_scope + auto dst_scope = StringImm(dst.scope()); + args.push_back(dst_scope); // \param data_type args.push_back(to_CUtensorMapDataType(src->dtype)); + // \param src_rank args.push_back(static_cast(src->shape.size())); // \param src_region_shape @@ -1024,38 +1031,56 @@ Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer, args.push_back(r->extent); } // \param src_input_size & \param src_forward - ICHECK(T.layout_map.count(src)) - << "Layout of buffer " << src << " not found."; - auto layout = T.layout_map.at(src); - for (auto s : layout->InputShape()) { - args.push_back(s); - } - for (auto s : layout->GetForwardIndex()) { - args.push_back(s); + if (src.scope() == "global") { + ICHECK(T.global_layout_map.count(src)) + << "Layout of buffer " << src << " not found."; + auto layout = T.global_layout_map.at(src); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } + } else { + ICHECK(T.layout_map.count(src)) + << "Layout of buffer " << src << " not found."; + auto layout = T.layout_map.at(src); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } } - // \param src_scope - auto src_scope = StringImm(src.scope()); - args.push_back(src_scope); // \param dst_rank args.push_back(static_cast(dst->shape.size())); - // \param dst_shape + // \param dst_region_shape for (auto r : dst_range) { args.push_back(r->extent); } // \param dst_input_size & \param dst_forward - ICHECK(T.layout_map.count(dst)) - << "Layout of buffer " << dst << " not found."; - layout = T.layout_map.at(dst); - for (auto s : layout->InputShape()) { - args.push_back(s); - } - for (auto s : layout->GetForwardIndex()) { - args.push_back(s); + if (dst.scope() == "global") { + ICHECK(T.global_layout_map.count(dst)) + << "Layout of buffer " << dst << " not found."; + auto layout = T.global_layout_map.at(dst); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } + } else { + ICHECK(T.layout_map.count(dst)) + << "Layout of buffer " << dst << " not found."; + auto layout = T.layout_map.at(dst); + for (auto s : layout->InputShape()) { + args.push_back(s); + } + for (auto s : layout->GetForwardIndex()) { + args.push_back(s); + } } - // \param dst_scope - auto dst_scope = StringImm(dst.scope()); - args.push_back(dst_scope); // \param src_addr if (src.scope() == "global") { diff --git a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py index 8253a71a9..a4e8eb290 100644 --- a/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py +++ b/testing/python/language/test_tilelang_mesh_language_copy_to_dma.py @@ -24,15 +24,15 @@ def main(C: MyTensor): D_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # # DRAM -> RSRAM - # T.copy(C[by * block_M, ko * block_K], C_shared) - # # DRAM -> WSRAM - # T.copy(C[by * block_M, ko * block_K], B_shared) - # # DRAM <- RSRAM - # T.copy(C_shared, C[by * block_M, ko * block_K]) - # # DRAM -> ASRAM - # T.copy(C[by * block_M, ko * block_K], A_shared) - # # RSRAM -> ASRAM + # DRAM -> RSRAM + T.copy(C[by * block_M, ko * block_K], C_shared) + # DRAM -> WSRAM + T.copy(C[by * block_M, ko * block_K], B_shared) + # DRAM <- RSRAM + T.copy(C_shared, C[by * block_M, ko * block_K]) + # DRAM -> ASRAM + T.copy(C[by * block_M, ko * block_K], A_shared) + # RSRAM -> ASRAM T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) # RSRAM -> WSRAM T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) @@ -51,25 +51,25 @@ def main(C: MyTensor): [ # DRAM -> RSRAM # T.copy(C[by * block_M, ko * block_K], C_shared) - # "1", + "T.dma_copy(\"global\", \"shared.rsram\", 7, 2, 64, 64, 32, 128, _j * 32 + _i, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, C.data, by * 64, ko * 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 2), 0, 0)", # DRAM -> WSRAM # T.copy(C[by * block_M, ko * block_K], B_shared) - # "1", + "T.dma_copy(\"global\", \"shared.wsram\", 7, 2, 64, 64, 32, 128, _j * 32 + _i, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, C.data, by * 64, ko * 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 4096, 2), 0, 0)", # DRAM <- RSRAM # T.copy(C_shared, C[by * block_M, ko * block_K]) - # "1", + "T.dma_copy(\"shared.rsram\", \"global\", 7, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, 2, 64, 64, 32, 128, _j * 32 + _i, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 0, 0, C.data, by * 64, ko * 32)", # DRAM -> ASRAM # T.copy(C[by * block_M, ko * block_K], A_shared) - # "1", + "T.dma_copy(\"global\", \"shared.asram\", 7, 2, 64, 64, 32, 128, _j * 32 + _i, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, C.data, by * 64, ko * 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 4096, 2), 0, 0)", # RSRAM -> ASRAM # T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) - "T.dma_copy(7, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.asram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 4096, 2), 24, 8)", + "T.dma_copy(\"shared.rsram\", \"shared.asram\", 7, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, 2, 16, 32, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 8, 16, T.tvm_access_ptr(T.type_annotation(\"float32\"), A_shared.data, 0, 4096, 2), 24, 8)", # RSRAM -> WSRAM # T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) - "T.dma_copy(7, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.wsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 8, 48, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 4096, 2), 40, 0)", + "T.dma_copy(\"shared.rsram\", \"shared.wsram\", 7, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, 2, 24, 8, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 8, 48, T.tvm_access_ptr(T.type_annotation(\"float32\"), B_shared.data, 0, 4096, 2), 40, 0)", # RSRAM <-> RSRAM # T.copy(C_shared, D_shared) - "T.dma_copy(7, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, \"shared.rsram\", T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 0, 0, T.tvm_access_ptr(T.type_annotation(\"float32\"), D_shared.data, 0, 4096, 2), 0, 0)", + "T.dma_copy(\"shared.rsram\", \"shared.rsram\", 7, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, 2, 64, 64, 64, 64, _i // 32 * 2048 + _j // 32 * 1024 + _i % 32 * 32 + _j % 32, T.tvm_access_ptr(T.type_annotation(\"float32\"), C_shared.data, 0, 4096, 1), 0, 0, T.tvm_access_ptr(T.type_annotation(\"float32\"), D_shared.data, 0, 4096, 2), 0, 0)", ]), ] @@ -101,7 +101,7 @@ def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K, lower_stmt): # Lower high-level tile operations to low-level operations mod = tilelang.transform.LowerTileOp()(mod) texts = mod.script().split('\n') - texts = texts[29:] + texts = texts[29:-2] texts = [it.lstrip() for it in texts] for i in range(len(texts)): assert texts[i] == lower_stmt[i]