diff --git a/src/op/builtin.cc b/src/op/builtin.cc index ed5f3067a..058c395a8 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -106,10 +106,7 @@ TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); - -TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr( +TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(create_tma_descriptor) diff --git a/src/op/builtin.h b/src/op/builtin.h index 6709f7511..3779a8562 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -189,95 +189,18 @@ TVM_DLL const Op &get_mbarrier(); TVM_DLL const Op &tma_load(); /*! - * \brief Perform a DMA load operation from source memory to destination memory. - * - * This function describes a DMA-based tensor copy with explicit shape, layout, - * memory scope. It is typically used to lower a high-level - * tensor copy into a hardware-specific DMA instruction. - * - * The source and destination tensors are described in terms of: - * - data type - * - rank and logical shape - * - layout (input shape + forward index), The T.Layout type is ObjectRef, - * which is not suitable for backend parsing, so it's two members are extracted: - * input shape and forward index, which are both Array - * - memory scope - * - * A sub-region of the source tensor can be copied by specifying the coordinate - * offset (`coord`) relative to the source base address. - * - * Example: - * For a 3D tensor A: Tensor(128, 256, 512), copying - * A[32:64, 128:192, 0:256] - * then: - * src_rank = 3 - * src_shape = [128, 256, 512] - * coord = [32, 128, 0] - * - * \param data_type - * Element data type of the tensor (e.g. float32, float16). - * - * \param src_rank - * Rank (number of dimensions) of the source tensor. - * - * \param src_shape - * Logical shape of the source tensor. - * For example, Tensor(128, 256, 512) -> [128, 256, 512]. - * - * \param src_input_size - * Input shape of the source layout, retrievable via Layout::getInputShape(). - * For a row-major 3D tensor, this is identical to src_shape. - * - * \param src_forward - * Forward index mapping of the source layout, retrievable via - * Layout::GetForwardIndex(). - * For a row-major layout of Tensor(128, 256, 512), - * this is [256 * 512, 512, 1]. - * - * \param src_scope - * Memory scope of the source tensor. - * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". - * - * \param dst_rank - * Rank (number of dimensions) of the destination tensor. - * - * \param dst_shape - * Logical shape of the destination tensor. - * - * \param dst_input_size - * Input shape of the destination layout, retrievable via - * Layout::getInputShape(). - * - * \param dst_forward - * Forward index mapping of the destination layout, retrievable via - * Layout::GetForwardIndex(). - * - * \param dst_scope - * Memory scope of the destination tensor. - * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". - * - * \param src_addr - * Base address of the source tensor in memory . - * - * \param coord - * Coordinate offset specifying the starting point of the copy in the source - * tensor. Its length must equal src_rank. - * - * \param dst_addr - * Base address of the destination tensor in memory . - * - * \note - * Out-of-bound fill policies are currently not supported. - */ -TVM_DLL const Op &dma_load(); - -/*! - * \brief Perform a DMA store operation from source memory to destination - * memory. see dma_load for details. + * \brief Perform a DMA copy operation preserving full buffer region semantics. * + * This intrinsic encodes a high-level copy between two buffer regions as + * tl.dma_copy(src_region, dst_region), where each argument is a + * tl.tileop.region Call carrying the buffer, access mask, and per-axis + * extents. It is emitted by the SUNMMIO lowering path of CopyNode and + * consumed by later target-specific codegen passes. * + * \param src_region A tl.tileop.region PrimExpr describing the source. + * \param dst_region A tl.tileop.region PrimExpr describing the destination. */ -TVM_DLL const Op &dma_store(); +TVM_DLL const Op &dma_copy(); /*! * \brief tvm intrinsics for loading image from global tensor to columns in diff --git a/src/op/copy.cc b/src/op/copy.cc index 22cb4a074..f09b8a05f 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -564,6 +564,30 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } return {}; } + + // Sunmmio DMA Layout Inference + if (copy_inst == CopyInst::kSunmmioDMACopy) { + // for dma copy, we can directly apply the blockwise_zz_layout + const auto f = + ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); + auto result = Map(); + + if (level == InferLevel::kFree && !T.layout_map.count(src)) { + if (src.scope() != "global") { + auto layout = Downcast((*f)(src)); + result.Set(src, layout); + } + } + + if (level == InferLevel::kFree && !T.layout_map.count(dst)) { + if (dst.scope() != "global") { + auto layout = Downcast((*f)(dst)); + result.Set(dst, layout); + } + } + return result; + } + // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy // Use parallel op to infer the layout @@ -804,6 +828,65 @@ bool CopyNode::CheckTMemStore(Target target) const { dst.scope() == "shared.tmem"; } +/** + * @brief Determine whether this CopyNode can be lowered to a DMA Copy + * Intrinsic for Sunmmio target. + * + * The function returns true when all of the following hold: + * - the target architecture advertises DMA support; + * - the source buffer and the destination buffer are legal; + * - the source and destination have the same element data type. + * + * If the source and destination dtypes differ, a warning is logged and the + * function returns false (the caller is expected to fall back to a normal + * copy). + * + * + * @param target The compilation target to query for dma copy support. + * @return true if the copy can be implemented as a DMA Copy; false + * otherwise. + */ +bool CopyNode::CheckSunmmioDMACopy(Target target) const { + // 1. arch must support Sunmmio + if (!TargetIsSunmmio(target)) + return false; + + // 2. src and dst must be legal + bool scope_check = false; + // 2.1 DRAM -> RSRAM + if (src.scope() == "global" && dst.scope() == "shared.rsram") + scope_check = true; + // 2.2 DRAM -> WSRAM + if (src.scope() == "global" && dst.scope() == "shared.wsram") + scope_check = true; + // 2.3 DRAM -> ASRAM + if (src.scope() == "global" && dst.scope() == "shared.asram") + scope_check = true; + // 2.4 RSRAM -> WSRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.wsram") + scope_check = true; + // 2.5 RSRAM -> ASRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.asram") + scope_check = true; + // 2.6 RSRAM <-> RSRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.rsram") + scope_check = true; + // 2.7 RSRAM -> DRAM + if (src.scope() == "shared.rsram" && dst.scope() == "global") + scope_check = true; + if (!scope_check) + return false; + + // 3. src and dst must have the same dtype + if (src->dtype != dst->dtype) { + LOG(WARNING) << "src and dst must have the same dtype for dma copy " + << src->name << " vs. " << dst->name << " dtype " << src->dtype + << " vs. " << dst->dtype << " will be fallback to normal copy"; + return false; + } + return true; +} + /** * @brief Selects the most specific copy instruction supported for the given * target and buffers. @@ -848,6 +931,12 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, return CopyInst::kTMemLoad; } else if (CheckTMemStore(target)) { return CopyInst::kTMemStore; + } else if (TargetIsSunmmio(target)) { + auto is_copy = CheckSunmmioDMACopy(target); + if (is_copy) + return CopyInst::kSunmmioDMACopy; + ICHECK(0) << "Unsupported copy from " << src.scope() << " to " + << dst.scope() << " of Sunmmio target."; } else { return CopyInst::kNormal; } @@ -860,6 +949,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, * determined copy instruction type: * - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions * - LDSM/STSM: Uses matrix load/store instructions for tensor cores + * - DMA copy: Sunmmio specified instructions for copy * - Normal: Uses standard load/store operations with loop transformations * \param T LowerArgs containing target and layout map. * \param analyzer Arithmetic analyzer for simplification. @@ -894,11 +984,37 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return ldsm_copy; } else if (copy_inst == CopyInst::kNormal) { return LowerNormalCopy(T, analyzer); + } else if (copy_inst == CopyInst::kSunmmioDMACopy) { + auto dma_copy = LowerSunmmioDmaCopy(T, analyzer); + ICHECK(dma_copy.defined()) << "Failed to lower dma copy"; + return dma_copy; } else { LOG(FATAL) << "Unsupported copy inst " << static_cast(copy_inst); } } +/** + * @brief Lower the copy operator for the SUNMMIO target. + * + * Emits a `tl.dma_copy(src_region, dst_region)` intrinsic call that preserves + * full buffer region semantics (buffer identity, per-axis min/extent, and + * memory scope). This intrinsic is consumed by later SUNMMIO-specific codegen + * passes to generate actual DMA instructions. + * + * @param T Lowering context (target, layout map, etc.). + * @param analyzer Arithmetic analyzer (unused here but kept for interface + * consistency). + * @return Stmt An Evaluate wrapping the tl.dma_copy Call. + */ +Stmt CopyNode::LowerSunmmioDmaCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { + // access_mask: 1=read for src, 2=write for dst + PrimExpr src_region = MakeRegionExpr(src, src_range, /*access_mask=*/1); + PrimExpr dst_region = MakeRegionExpr(dst, dst_range, /*access_mask=*/2); + return Evaluate( + Call(DataType::Handle(), dma_copy(), {src_region, dst_region})); +} + /** * @brief Lower the copy operator using the generic (non-specialized) path. * diff --git a/src/op/copy.h b/src/op/copy.h index fd3e01a40..7566a3308 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -22,10 +22,11 @@ enum class CopyInst : uint8_t { kBulkStore = 4, // utilize tma store // we should separate the bulk load and store for 1d and multi-dim // as they have different memory access patterns - kBulkLoad1D = 5, // utilize tma load 1d - kBulkStore1D = 6, // utilize tma store 1d - kTMemLoad = 7, // tcgen05.ld (tensor memory -> register) - kTMemStore = 8, // tcgen05.st (register -> tensor memory) + kBulkLoad1D = 5, // utilize tma load 1d + kBulkStore1D = 6, // utilize tma store 1d + kTMemLoad = 7, // tcgen05.ld (tensor memory -> register) + kTMemStore = 8, // tcgen05.st (register -> tensor memory) + kSunmmioDMACopy = 9, // Sunmmio DMA }; /// Descriptor for Tensor Memory Access (TMA) copy operations @@ -180,6 +181,11 @@ class CopyNode : public TileOperatorNode { */ bool CheckTMemStore(Target target) const; + /*! + * \brief Check if Sunmmio dma copy is supported. + */ + bool CheckSunmmioDMACopy(Target target) const; + /*! * \brief Get the copy instruction type. */ @@ -217,6 +223,14 @@ class CopyNode : public TileOperatorNode { */ Stmt LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + /*! + * \brief Generate lowering for SUNMMIO DMA copy. + * + * Emits a tl.dma_copy(src_region, dst_region) intrinsic that preserves full + * buffer region semantics for later SUNMMIO codegen consumption. + */ + Stmt LowerSunmmioDmaCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + /*! * \brief Generate SIMT (thread-level) loop for copying. */ diff --git a/src/op/utils.cc b/src/op/utils.cc index 7e56ae8c7..17401b4d2 100644 --- a/src/op/utils.cc +++ b/src/op/utils.cc @@ -52,6 +52,47 @@ BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) { throw; // Unreachable } +/*! + * \brief Encode a Buffer + Array into a tl.tileop.region Call + * expression. + * + * This is the inverse of NormalizeToBufferRegion: it packs buffer region + * metadata into a PrimExpr so it can travel through Call arguments (where + * BufferRegion cannot appear directly). + * + * Use this when emitting intrinsic calls (e.g. tl.dma_copy) that need to + * carry full region semantics — buffer identity, per-axis min/extent, and + * access mode — as opaque PrimExpr arguments for later codegen consumption. + * + * Encoding layout: + * args[0] = BufferLoad(buffer, {range[0].min, range[1].min, ...}) + * args[1] = access_mask (1=read, 2=write, 3=read-write) + * args[2+i] = range[i].extent + * + * \param buffer The buffer this region refers to. + * \param ranges Per-axis [min, extent) ranges describing the tile. + * \param access_mask 1=read, 2=write, 3=read-write. + * \return A Call(tl.tileop.region, ...) expression. + */ +PrimExpr MakeRegionExpr(const Buffer &buffer, const Array &ranges, + int access_mask) { + // Build BufferLoad with indices = per-axis minima + Array indices; + for (const auto &r : ranges) { + indices.push_back(r->min); + } + BufferLoad load(buffer, indices); + + // Pack args: [load, access_mask, extent_0, extent_1, ...] + Array args; + args.push_back(load); + args.push_back(IntImm(DataType::Int(32), access_mask)); + for (const auto &r : ranges) { + args.push_back(r->extent); + } + return Call(DataType::Handle(), RegionOp::Get(), args); +} + PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, bool require_2d) { Buffer buf = region->buffer; diff --git a/src/op/utils.h b/src/op/utils.h index d386b1a58..1c297e334 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -21,6 +21,13 @@ using namespace tir; // Note: tvm_access_ptr is no longer supported here. TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg); +// Build a tl.tileop.region Call from a Buffer + Array. +// This is the inverse of NormalizeToBufferRegion: it packages buffer, access +// mask, and per-axis extents into a Call(RegionOp::Get(), ...) that can be +// passed as an argument to builtins like dma_copy. +TVM_DLL PrimExpr MakeRegionExpr(const Buffer &buffer, + const Array &ranges, int access_mask); + // Build a tvm_access_ptr(handle) from a BufferRegion. // - If `require_2d` is true, checks buffer ndim >= 2. // - For 1D regions (when allowed), offset=min, extent=extent. diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 4089c600a..c59b72763 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -18,6 +18,7 @@ #include "../op/gemm.h" #include "../op/gemm_sp.h" #include "../op/operator.h" +#include "../target/utils.h" #include "common/remap_buffer_rewriter.h" #include "arith/ir_mutator_with_analyzer.h" @@ -157,8 +158,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { .as>() .value(); for (auto [buffer, layout] : layout_map) { - buffer_remap_.Set(buffer, - makeBufferWithLayout(buffer, layout, var_remap_)); + if (!TargetIsSunmmio(target_)) { + buffer_remap_.Set(buffer, + makeBufferWithLayout(buffer, layout, var_remap_)); + } layout_map_.Set(buffer, layout); } } diff --git a/testing/python/layout/test_tilelang_gemm_sunmmio_layout.py b/testing/python/layout/test_tilelang_sunmmio_gemm_layout.py similarity index 97% rename from testing/python/layout/test_tilelang_gemm_sunmmio_layout.py rename to testing/python/layout/test_tilelang_sunmmio_gemm_layout.py index a7feb5a57..75ed349f6 100644 --- a/testing/python/layout/test_tilelang_gemm_sunmmio_layout.py +++ b/testing/python/layout/test_tilelang_sunmmio_gemm_layout.py @@ -24,7 +24,7 @@ def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_d def main( A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -106,6 +106,7 @@ def test_tilelang_gemm_sunmmio_layout(M, N, K, block_M, block_N, block_K, versio with tvm.target.Target(target): mod = matmul(M, N, K, block_M, block_N, block_K, version) mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tl.transform.InferSramScope()(mod) mod = tl.transform.LayoutInference()(mod) LayoutVisual()(mod) diff --git a/testing/python/ops/test_tilelang_ops_sunmmio_dma_copy.py b/testing/python/ops/test_tilelang_ops_sunmmio_dma_copy.py new file mode 100644 index 000000000..7ee559f70 --- /dev/null +++ b/testing/python/ops/test_tilelang_ops_sunmmio_dma_copy.py @@ -0,0 +1,371 @@ +"""Test that SUNMMIO copy lowering emits tl.dma_copy with tl.tileop.region args, +and that each region can be normalized back to a BufferRegion with full metadata.""" + +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tilelang.utils.target import SUNMMIO_TARGET_DESC, determine_target +from tilelang.language.v2.annot import MeshShardingPolicy +from tvm import tir +from tvm.tir import PyStmtExprVisitor +import pytest + +tilelang.env.disable_cache() + + +def simple_copy_kernel(M, N, block_M, block_N, dtype="float16"): + """A minimal kernel with T.copy from global to shared memory.""" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype),): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + T.copy(A[by * block_M, bx * block_N], A_shared) + + return tvm.IRModule({"main": main}) + + +def apply_sunmmio_passes(mod, target): + """Apply the full SUNMMIO pass pipeline used for DMA copy lowering.""" + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + mod = tilelang.transform.InjectAssumes()(mod) + mod = tilelang.transform.Simplify()(mod) + mod = tilelang.transform.InferSramScope()(mod) + mod = tilelang.transform.LayoutReducer()(mod) + mod = tilelang.transform.LayoutInference()(mod) + mod = tilelang.transform.LowerTileOp()(mod) + return mod + + +class RegionRange: + """Simple container for (min, extent) of a region axis.""" + + def __init__(self, min_val, extent): + self.min = min_val + self.extent = extent + + +def normalize_region(region_call): + """Decode a tl.tileop.region Call back into (buffer, extents, access_mask). + + This mirrors what NormalizeToBufferRegion does in C++: + args[0] = BufferLoad (indices are per-axis minima) + args[1] = access_mask (int) + args[2+i] = extent for axis i + + On SUNMMIO, buffer remap is disabled so buffers retain their original + N-D shape and indices always match the number of extents. + + Returns (buffer, extents, access_mask) where extents is a list of + RegionRange objects with .min and .extent attributes. + """ + assert isinstance(region_call, tir.Call) + assert region_call.op.name == "tl.tileop.region" + + load = region_call.args[0] + assert isinstance(load, tir.BufferLoad) + + access_mask = int(region_call.args[1]) + num_extents = len(region_call.args) - 2 + + assert len( + load.indices) == num_extents, (f"Expected {num_extents} indices, got {len(load.indices)}") + + ranges = [] + for i in range(num_extents): + ranges.append(RegionRange(load.indices[i], region_call.args[2 + i])) + + return load.buffer, ranges, access_mask + + +@tir.functor.visitor +class _DmaCopyVisitor(PyStmtExprVisitor): + """Walk TIR and collect tl.dma_copy calls and their region arguments.""" + + def __init__(self): + super().__init__() + self.dma_copy_calls = [] + self.layout_map = {} + + def visit_block_(self, op: tir.Block) -> None: + if "layout_map" in op.annotations: + for key, layout in op.annotations["layout_map"].items(): + self.layout_map[key.name] = layout + self.visit_stmt(op.body) + + def visit_call_(self, op: tir.Call) -> None: + if hasattr(op, "op") and hasattr(op.op, "name") and op.op.name == "tl.dma_copy": + self.dma_copy_calls.append(op) + # Visit children + for arg in op.args: + self.visit_expr(arg) + + def visit_evaluate_(self, op: tir.Evaluate) -> None: + self.visit_expr(op.value) + + +def extract_dma_copy_lines(mod): + """Extract T.dma_copy lines from TIR script, robust to formatting changes.""" + return [line.lstrip() for line in mod.script().split('\n') if 'T.dma_copy' in line] + + +SIMPLE_COPY_CASES = [ + # (M, N, block_M, block_N) + (128, 128, 32, 32), + (256, 256, 64, 64), + (128, 256, 32, 64), +] + + +@pytest.mark.parametrize("M, N, block_M, block_N", SIMPLE_COPY_CASES) +def test_tilelang_dma_copy(M, N, block_M, block_N): + target = tvm.target.Target(SUNMMIO_TARGET_DESC) + mod = simple_copy_kernel(M, N, block_M, block_N) + + with tvm.target.Target(target): + mod = apply_sunmmio_passes(mod, target) + + # Walk the lowered IR and find dma_copy calls + visitor = _DmaCopyVisitor() + func = mod["main"] + visitor.visit_stmt(func.body) + + # Verify that exactly one tl.dma_copy call was emitted + assert len(visitor.dma_copy_calls) == 1, ( + f"Expected exactly 1 tl.dma_copy call, got {len(visitor.dma_copy_calls)}") + + call = visitor.dma_copy_calls[0] + + # dma_copy should have exactly 2 arguments (src_region, dst_region) + assert len(call.args) == 2, (f"Expected 2 args for dma_copy, got {len(call.args)}") + + # Each argument should be a tl.tileop.region Call + for i, arg in enumerate(call.args): + assert isinstance(arg, tir.Call), (f"dma_copy arg[{i}] should be a Call, got {type(arg)}") + assert hasattr(arg.op, "name") and arg.op.name == "tl.tileop.region", ( + f"dma_copy arg[{i}] should be tl.tileop.region, got {arg.op.name}") + + # --- Normalize regions back to buffer metadata --- + src_buf, src_ranges, src_mask = normalize_region(call.args[0]) + dst_buf, dst_ranges, dst_mask = normalize_region(call.args[1]) + + # access_mask: 1=read for src, 2=write for dst + assert src_mask == 1, f"Source access_mask should be 1 (read), got {src_mask}" + assert dst_mask == 2, f"Destination access_mask should be 2 (write), got {dst_mask}" + + # Buffer dtype + assert src_buf.dtype == "float16", f"Source dtype should be float16, got {src_buf.dtype}" + assert dst_buf.dtype == "float16", f"Dest dtype should be float16, got {dst_buf.dtype}" + + # Buffer shapes: both stay 2D (no flattening on SUNMMIO) + assert len(src_buf.shape) == 2, f"Source buffer should be 2D, got {len(src_buf.shape)}D" + assert int(src_buf.shape[0]) == M, f"Source shape[0] should be {M}, got {src_buf.shape[0]}" + assert int(src_buf.shape[1]) == N, f"Source shape[1] should be {N}, got {src_buf.shape[1]}" + + assert len(dst_buf.shape) == 2, f"Dest buffer should be 2D, got {len(dst_buf.shape)}D" + assert int( + dst_buf.shape[0]) == block_M, (f"Dest shape[0] should be {block_M}, got {dst_buf.shape[0]}") + assert int( + dst_buf.shape[1]) == block_N, (f"Dest shape[1] should be {block_N}, got {dst_buf.shape[1]}") + + # Buffer scope: InferSramScope assigns shared.rsram for a plain alloc_shared + src_scope = src_buf.scope() + assert src_scope == "" or src_scope == "global", ( + f"Source buffer should be in global scope, got '{src_scope}'") + dst_scope = dst_buf.scope() + assert dst_scope == "shared.rsram", ( + f"Destination buffer should be shared.rsram after InferSramScope, got '{dst_scope}'") + + # Region extents match block dimensions + assert len(src_ranges) == 2 + src_extent_0 = int(src_ranges[0].extent) + src_extent_1 = int(src_ranges[1].extent) + assert src_extent_0 == block_M, (f"Source extent[0] should be {block_M}, got {src_extent_0}") + assert src_extent_1 == block_N, (f"Source extent[1] should be {block_N}, got {src_extent_1}") + + assert len(dst_ranges) == 2 + dst_extent_0 = int(dst_ranges[0].extent) + dst_extent_1 = int(dst_ranges[1].extent) + assert dst_extent_0 == block_M, (f"Dest extent[0] should be {block_M}, got {dst_extent_0}") + assert dst_extent_1 == block_N, (f"Dest extent[1] should be {block_N}, got {dst_extent_1}") + + +def wrong_copy(M, + N, + K, + block_M, + block_N, + block_K, + error_type, + dtype="float16", + accum_dtype="float16"): + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + A_shared_2 = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + B_shared_2 = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + if error_type == 'A->D': + T.copy(A_shared, C[by * block_M, ko * block_K]) + elif error_type == 'W->D': + T.copy(B_shared, C[by * block_M, ko * block_K]) + elif error_type == 'A->R': + T.copy(A_shared, C_shared) + elif error_type == 'W->R': + T.copy(B_shared, C_shared) + elif error_type == 'D<->D': + T.copy(C[by * block_M, ko * block_K], B[by * block_M, ko * block_K]) + elif error_type == 'A<->A': + T.copy(A_shared, A_shared_2) + elif error_type == 'W<->W': + T.copy(B_shared, B_shared_2) + elif error_type == 'A->W': + T.copy(A_shared, B_shared) + elif error_type == 'W->A': + T.copy(B_shared, A_shared) + + return tvm.IRModule({'main': main}) + + +WRONG_TEST_CASES = [ + (128, 128, 128, 32, 32, 32, "A->D", + "Unsupported copy from shared.asram to global of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->D", + "Unsupported copy from shared.wsram to global of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "A->R", + "Unsupported copy from shared.asram to shared.rsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->R", + "Unsupported copy from shared.wsram to shared.rsram of Sunmmio target."), + # (128, 128, 128, 32, 32, 32, "D<->D", + # "Unsupported copy from global to global of Sunmmio target."), + # D<->D not work now + (128, 128, 128, 32, 32, 32, "A<->A", + "Unsupported copy from shared.asram to shared.asram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W<->W", + "Unsupported copy from shared.wsram to shared.wsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "A->W", + "Unsupported copy from shared.asram to shared.wsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->A", + "Unsupported copy from shared.wsram to shared.asram of Sunmmio target."), +] + + +@pytest.mark.parametrize( + "M, N, K, block_M, block_N, block_K, error_type, error_msg", + WRONG_TEST_CASES, +) +def test_tilelang_mesh_wrong_copy_to_dma(M, N, K, block_M, block_N, block_K, error_type, error_msg): + target = tvm.target.Target(SUNMMIO_TARGET_DESC) + with pytest.raises(tvm.error.InternalError, match=error_msg), tvm.target.Target(target): + mod = wrong_copy(M, N, K, block_M, block_N, block_K, error_type) + mod = apply_sunmmio_passes(mod, target) + + +def copy(K, block_M, block_N, block_K, dtype="float32", accum_dtype="float32"): + MyTensor = T.MeshTensor((128, 128), + sharding_policy=MeshShardingPolicy(cross_mesh_dim=0), + device_mesh_config=(2, 2), + hierarchical_dims=(4, 32, 128), + hierarchical_groups=((0, 2), (2, 3)), + hierarchical_strides=(32, 1, 4096)) + + @T.prim_func + def main(C: MyTensor): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(128, block_N), T.ceildiv(128, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + D_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # DRAM -> RSRAM + T.copy(C[by * block_M, ko * block_K], C_shared) + # DRAM -> WSRAM + T.copy(C[by * block_M, ko * block_K], B_shared) + # DRAM <- RSRAM + T.copy(C_shared, C[by * block_M, ko * block_K]) + # DRAM -> ASRAM + T.copy(C[by * block_M, ko * block_K], A_shared) + # RSRAM -> ASRAM + T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) + # RSRAM -> WSRAM + T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) + # RSRAM <-> RSRAM + T.copy(C_shared, D_shared) + + return tvm.IRModule({'main': main}) + + +# fmt: off +MESH_COPY_CASES = [ + ( + 128, 64, 64, 32, + [ + # DRAM -> RSRAM + 'T.dma_copy(T.region(C[by * 64, ko * 32], 1, 64, 64), T.region(C_shared[0, 0], 2, 64, 64))', + # DRAM -> WSRAM + 'T.dma_copy(T.region(C[by * 64, ko * 32], 1, 64, 64), T.region(B_shared[0, 0], 2, 64, 64))', + # DRAM <- RSRAM + 'T.dma_copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(C[by * 64, ko * 32], 2, 64, 64))', + # DRAM -> ASRAM + 'T.dma_copy(T.region(C[by * 64, ko * 32], 1, 64, 64), T.region(A_shared[0, 0], 2, 64, 64))', + # RSRAM -> ASRAM + 'T.dma_copy(T.region(C_shared[8, 16], 1, 16, 32), T.region(A_shared[24, 8], 2, 16, 32))', + # RSRAM -> WSRAM + 'T.dma_copy(T.region(C_shared[8, 48], 1, 24, 8), T.region(B_shared[40, 0], 2, 24, 8))', + # RSRAM <-> RSRAM + 'T.dma_copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(D_shared[0, 0], 2, 64, 64))', + ], + ), + ( + 256, 64, 64, 64, + [ + # DRAM -> RSRAM + 'T.dma_copy(T.region(C[by * 64, ko * 64], 1, 64, 64), T.region(C_shared[0, 0], 2, 64, 64))', + # DRAM -> WSRAM + 'T.dma_copy(T.region(C[by * 64, ko * 64], 1, 64, 64), T.region(B_shared[0, 0], 2, 64, 64))', + # DRAM <- RSRAM + 'T.dma_copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(C[by * 64, ko * 64], 2, 64, 64))', + # DRAM -> ASRAM + 'T.dma_copy(T.region(C[by * 64, ko * 64], 1, 64, 64), T.region(A_shared[0, 0], 2, 64, 64))', + # RSRAM -> ASRAM + 'T.dma_copy(T.region(C_shared[8, 16], 1, 16, 32), T.region(A_shared[24, 8], 2, 16, 32))', + # RSRAM -> WSRAM + 'T.dma_copy(T.region(C_shared[8, 48], 1, 24, 8), T.region(B_shared[40, 0], 2, 24, 8))', + # RSRAM <-> RSRAM + 'T.dma_copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(D_shared[0, 0], 2, 64, 64))', + ], + ), +] +# fmt: on + + +@pytest.mark.parametrize( + "K, block_M, block_N, block_K, lower_stmt", + MESH_COPY_CASES, +) +def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K, lower_stmt): + target_name = "Sunmmio" + target = determine_target(target_name, return_object=True) + with tvm.target.Target(target): + mod = copy(K, block_M, block_N, block_K) + mod = apply_sunmmio_passes(mod, target) + texts = extract_dma_copy_lines(mod) + assert len(texts) == len(lower_stmt), ( + f"Expected {len(lower_stmt)} dma_copy lines, got {len(texts)}") + for i in range(len(texts)): + assert texts[i] == lower_stmt[i], ( + f"Line {i} mismatch:\n actual: {texts[i]}\n expected: {lower_stmt[i]}") diff --git a/testing/python/transform/test_tilelang_transform_infer_sram_scope.py b/testing/python/transform/test_tilelang_transform_sunmmio_infer_sram_scope.py similarity index 100% rename from testing/python/transform/test_tilelang_transform_infer_sram_scope.py rename to testing/python/transform/test_tilelang_transform_sunmmio_infer_sram_scope.py diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index cabc4a3e4..ae2c46322 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -67,6 +67,13 @@ def get_extent(data): # Combine the nested if statements into a single if statement as suggested by SIM102 if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and isinstance(dst, tir.BufferLoad)): + # FIXME + # For Sunmmio an invalid D<->D copy operation will enter here, for example: + # T.copy(C[by * block_M, ko * block_K], B[by * block_M, ko * block_K]) -> + # for ko in T.serial(4, annotations={"num_stages": 3}): + # B[by * 32, ko * 32] = C[by * 32, ko * 32] + # which causes an exception can't be caught. + # # check if the case is like this: # copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes # In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i] diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 41da8ab0a..d7fe90aac 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -54,7 +54,7 @@ def is_shared(buffer: Buffer | BufferLoad | BufferRegion, allow_dynamic: bool = """ buffer = _get_buffer(buffer) conditions = [False] - conditions.append(buffer.scope() == "shared") + conditions.append(buffer.scope().startswith("shared")) if allow_dynamic: conditions.append(is_shared_dynamic(buffer)) return any(conditions) @@ -71,7 +71,7 @@ def is_shared_dynamic(buffer: Buffer | BufferLoad | BufferRegion) -> bool: bool: True if the buffer is in dynamic shared memory, False otherwise. """ buffer = _get_buffer(buffer) - return buffer.scope() == "shared.dyn" + return buffer.scope().startswith("shared") and buffer.scope().endswith(".dyn") def is_tensor_memory(buffer: Buffer | BufferLoad | BufferRegion) -> bool: