From b2dd2e0b55b085ed05f837e6303a9b26ddbc370e Mon Sep 17 00:00:00 2001 From: Jiaqi Guo Date: Wed, 11 Feb 2026 14:02:27 +0800 Subject: [PATCH 1/6] Built-in with region --- src/op/builtin.cc | 3 + src/op/builtin.h | 14 + src/op/copy.cc | 27 ++ src/op/copy.h | 8 + src/op/utils.cc | 19 ++ src/op/utils.h | 7 + src/transform/lower_tile_op.cc | 7 +- .../transform/test_tilelang_dma_copy.py | 254 ++++++++++++++++++ 8 files changed, 337 insertions(+), 2 deletions(-) create mode 100644 testing/python/transform/test_tilelang_dma_copy.py diff --git a/src/op/builtin.cc b/src/op/builtin.cc index ed5f3067a..6a2442887 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -112,6 +112,9 @@ TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr( TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(2).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(create_tma_descriptor) .set_num_inputs(-1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 6709f7511..97b631e17 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -279,6 +279,20 @@ TVM_DLL const Op &dma_load(); */ TVM_DLL const Op &dma_store(); +/*! + * \brief Perform a DMA copy operation preserving full buffer region semantics. + * + * This intrinsic encodes a high-level copy between two buffer regions as + * tl.dma_copy(src_region, dst_region), where each argument is a + * tl.tileop.region Call carrying the buffer, access mask, and per-axis + * extents. It is emitted by the SUNMMIO lowering path of CopyNode and + * consumed by later target-specific codegen passes. + * + * \param src_region A tl.tileop.region PrimExpr describing the source. + * \param dst_region A tl.tileop.region PrimExpr describing the destination. + */ +TVM_DLL const Op &dma_copy(); + /*! * \brief tvm intrinsics for loading image from global tensor to columns in * shared memory diff --git a/src/op/copy.cc b/src/op/copy.cc index 22cb4a074..e2cf86368 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -868,6 +868,11 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; + // SUNMMIO: emit a dma_copy intrinsic instead of GPU-style lowering + if (TargetIsSunmmio(target)) { + return LowerDmaCopy(T, analyzer); + } + using namespace tvm::transform; PassContext pass_ctx = PassContext::Current(); bool disable_tma_lower = @@ -899,6 +904,28 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } +/** + * @brief Lower the copy operator for the SUNMMIO target. + * + * Emits a `tl.dma_copy(src_region, dst_region)` intrinsic call that preserves + * full buffer region semantics (buffer identity, per-axis min/extent, and + * memory scope). This intrinsic is consumed by later SUNMMIO-specific codegen + * passes to generate actual DMA instructions. + * + * @param T Lowering context (target, layout map, etc.). + * @param analyzer Arithmetic analyzer (unused here but kept for interface + * consistency). + * @return Stmt An Evaluate wrapping the tl.dma_copy Call. + */ +Stmt CopyNode::LowerDmaCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { + // access_mask: 1=read for src, 2=write for dst + PrimExpr src_region = MakeRegionExpr(src, src_range, /*access_mask=*/1); + PrimExpr dst_region = MakeRegionExpr(dst, dst_range, /*access_mask=*/2); + return Evaluate( + Call(DataType::Handle(), dma_copy(), {src_region, dst_region})); +} + /** * @brief Lower the copy operator using the generic (non-specialized) path. * diff --git a/src/op/copy.h b/src/op/copy.h index fd3e01a40..7d8e842d6 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -217,6 +217,14 @@ class CopyNode : public TileOperatorNode { */ Stmt LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + /*! + * \brief Generate lowering for SUNMMIO DMA copy. + * + * Emits a tl.dma_copy(src_region, dst_region) intrinsic that preserves full + * buffer region semantics for later SUNMMIO codegen consumption. + */ + Stmt LowerDmaCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + /*! * \brief Generate SIMT (thread-level) loop for copying. */ diff --git a/src/op/utils.cc b/src/op/utils.cc index 7e56ae8c7..a1d53977d 100644 --- a/src/op/utils.cc +++ b/src/op/utils.cc @@ -52,6 +52,25 @@ BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) { throw; // Unreachable } +PrimExpr MakeRegionExpr(const Buffer &buffer, const Array &ranges, + int access_mask) { + // Build BufferLoad with indices = per-axis minima + Array indices; + for (const auto &r : ranges) { + indices.push_back(r->min); + } + BufferLoad load(buffer, indices); + + // Pack args: [load, access_mask, extent_0, extent_1, ...] + Array args; + args.push_back(load); + args.push_back(IntImm(DataType::Int(32), access_mask)); + for (const auto &r : ranges) { + args.push_back(r->extent); + } + return Call(DataType::Handle(), RegionOp::Get(), args); +} + PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, bool require_2d) { Buffer buf = region->buffer; diff --git a/src/op/utils.h b/src/op/utils.h index d386b1a58..1c297e334 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -21,6 +21,13 @@ using namespace tir; // Note: tvm_access_ptr is no longer supported here. TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg); +// Build a tl.tileop.region Call from a Buffer + Array. +// This is the inverse of NormalizeToBufferRegion: it packages buffer, access +// mask, and per-axis extents into a Call(RegionOp::Get(), ...) that can be +// passed as an argument to builtins like dma_copy. +TVM_DLL PrimExpr MakeRegionExpr(const Buffer &buffer, + const Array &ranges, int access_mask); + // Build a tvm_access_ptr(handle) from a BufferRegion. // - If `require_2d` is true, checks buffer ndim >= 2. // - For 1D regions (when allowed), offset=min, extent=extent. diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 4089c600a..c59b72763 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -18,6 +18,7 @@ #include "../op/gemm.h" #include "../op/gemm_sp.h" #include "../op/operator.h" +#include "../target/utils.h" #include "common/remap_buffer_rewriter.h" #include "arith/ir_mutator_with_analyzer.h" @@ -157,8 +158,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { .as>() .value(); for (auto [buffer, layout] : layout_map) { - buffer_remap_.Set(buffer, - makeBufferWithLayout(buffer, layout, var_remap_)); + if (!TargetIsSunmmio(target_)) { + buffer_remap_.Set(buffer, + makeBufferWithLayout(buffer, layout, var_remap_)); + } layout_map_.Set(buffer, layout); } } diff --git a/testing/python/transform/test_tilelang_dma_copy.py b/testing/python/transform/test_tilelang_dma_copy.py new file mode 100644 index 000000000..bff708009 --- /dev/null +++ b/testing/python/transform/test_tilelang_dma_copy.py @@ -0,0 +1,254 @@ +"""Test that SUNMMIO copy lowering emits tl.dma_copy with tl.tileop.region args, +and that each region can be normalized back to a BufferRegion with full metadata.""" + +import tilelang +import tilelang as tl +import tilelang.language as T +from tilelang import tvm as tvm +from tilelang.utils.target import SUNMMIO_TARGET_DESC +from tvm import tir +from tvm.tir import PyStmtExprVisitor +import tilelang.env as env +import pytest + +tilelang.env.disable_cache() + + +def simple_copy_kernel(M, N, block_M, block_N, dtype="float16"): + """A minimal kernel with T.copy from global to shared memory.""" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype),): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + T.copy(A[by * block_M, bx * block_N], A_shared) + + return tvm.IRModule({"main": main}) + + +def gemm_copy_kernel(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float32"): + """A kernel with T.copy + T.gemm to trigger layout inference.""" + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + + T.clear(C_shared) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return tvm.IRModule({"main": main}) + + +class RegionRange: + """Simple container for (min, extent) of a region axis.""" + + def __init__(self, min_val, extent): + self.min = min_val + self.extent = extent + + +def normalize_region(region_call): + """Decode a tl.tileop.region Call back into (buffer, extents, access_mask). + + This mirrors what NormalizeToBufferRegion does in C++: + args[0] = BufferLoad (indices are per-axis minima) + args[1] = access_mask (int) + args[2+i] = extent for axis i + + On SUNMMIO, buffer remap is disabled so buffers retain their original + N-D shape and indices always match the number of extents. + + Returns (buffer, extents, access_mask) where extents is a list of + RegionRange objects with .min and .extent attributes. + """ + assert isinstance(region_call, tir.Call) + assert region_call.op.name == "tl.tileop.region" + + load = region_call.args[0] + assert isinstance(load, tir.BufferLoad) + + access_mask = int(region_call.args[1]) + num_extents = len(region_call.args) - 2 + + assert len( + load.indices) == num_extents, (f"Expected {num_extents} indices, got {len(load.indices)}") + + ranges = [] + for i in range(num_extents): + ranges.append(RegionRange(load.indices[i], region_call.args[2 + i])) + + return load.buffer, ranges, access_mask + + +@tir.functor.visitor +class _DmaCopyVisitor(PyStmtExprVisitor): + """Walk TIR and collect tl.dma_copy calls and their region arguments.""" + + def __init__(self): + super().__init__() + self.dma_copy_calls = [] + self.layout_map = {} + + def visit_block_(self, op: tir.Block) -> None: + if "layout_map" in op.annotations: + for key, layout in op.annotations["layout_map"].items(): + self.layout_map[key.name] = layout + self.visit_stmt(op.body) + + def visit_call_(self, op: tir.Call) -> None: + if hasattr(op, "op") and hasattr(op.op, "name") and op.op.name == "tl.dma_copy": + self.dma_copy_calls.append(op) + # Visit children + for arg in op.args: + self.visit_expr(arg) + + def visit_evaluate_(self, op: tir.Evaluate) -> None: + self.visit_expr(op.value) + + +TEST_CASES = [ + # (M, N, block_M, block_N) + (128, 128, 32, 32), + (256, 256, 64, 64), + (128, 256, 32, 64), +] + + +@pytest.mark.parametrize("M, N, block_M, block_N", TEST_CASES) +def test_tilelang_dma_copy(M, N, block_M, block_N): + target = tvm.target.Target(SUNMMIO_TARGET_DESC) + mod = simple_copy_kernel(M, N, block_M, block_N) + + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tl.transform.LayoutInference()(mod) + mod = tl.transform.LowerTileOp()(mod) + print(mod) + + # Walk the lowered IR and find dma_copy calls + visitor = _DmaCopyVisitor() + func = mod["main"] + visitor.visit_stmt(func.body) + + # Verify that at least one tl.dma_copy call was emitted + assert len(visitor.dma_copy_calls) > 0, ("Expected at least one tl.dma_copy call in lowered IR") + + for call in visitor.dma_copy_calls: + # dma_copy should have exactly 2 arguments (src_region, dst_region) + assert len(call.args) == 2, (f"Expected 2 args for dma_copy, got {len(call.args)}") + + # Each argument should be a tl.tileop.region Call + for i, arg in enumerate(call.args): + assert isinstance(arg, + tir.Call), (f"dma_copy arg[{i}] should be a Call, got {type(arg)}") + assert hasattr(arg.op, "name") and arg.op.name == "tl.tileop.region", ( + f"dma_copy arg[{i}] should be tl.tileop.region, got {arg.op.name}") + + # --- Normalize regions back to buffer metadata --- + src_buf, src_ranges, src_mask = normalize_region(call.args[0]) + dst_buf, dst_ranges, dst_mask = normalize_region(call.args[1]) + + # access_mask: 1=read for src, 2=write for dst + assert src_mask == 1, f"Source access_mask should be 1 (read), got {src_mask}" + assert dst_mask == 2, f"Destination access_mask should be 2 (write), got {dst_mask}" + + # Buffer dtype + assert src_buf.dtype == "float16", f"Source dtype should be float16, got {src_buf.dtype}" + assert dst_buf.dtype == "float16", f"Dest dtype should be float16, got {dst_buf.dtype}" + + # Buffer shapes: both stay 2D (no flattening on SUNMMIO) + assert len(src_buf.shape) == 2, f"Source buffer should be 2D, got {len(src_buf.shape)}D" + assert int(src_buf.shape[0]) == M, f"Source shape[0] should be {M}, got {src_buf.shape[0]}" + assert int(src_buf.shape[1]) == N, f"Source shape[1] should be {N}, got {src_buf.shape[1]}" + + assert len(dst_buf.shape) == 2, f"Dest buffer should be 2D, got {len(dst_buf.shape)}D" + assert int(dst_buf.shape[0]) == block_M, ( + f"Dest shape[0] should be {block_M}, got {dst_buf.shape[0]}") + assert int(dst_buf.shape[1]) == block_N, ( + f"Dest shape[1] should be {block_N}, got {dst_buf.shape[1]}") + + # Buffer scope + src_scope = src_buf.scope() + assert src_scope == "" or src_scope == "global", ( + f"Source buffer should be in global scope, got '{src_scope}'") + dst_scope = dst_buf.scope() + assert "shared" in dst_scope, ( + f"Destination buffer should be in shared scope, got '{dst_scope}'") + + # Region extents match block dimensions + assert len(src_ranges) == 2 + src_extent_0 = int(src_ranges[0].extent) + src_extent_1 = int(src_ranges[1].extent) + assert src_extent_0 == block_M, ( + f"Source extent[0] should be {block_M}, got {src_extent_0}") + assert src_extent_1 == block_N, ( + f"Source extent[1] should be {block_N}, got {src_extent_1}") + + assert len(dst_ranges) == 2 + dst_extent_0 = int(dst_ranges[0].extent) + dst_extent_1 = int(dst_ranges[1].extent) + assert dst_extent_0 == block_M, (f"Dest extent[0] should be {block_M}, got {dst_extent_0}") + assert dst_extent_1 == block_N, (f"Dest extent[1] should be {block_N}, got {dst_extent_1}") + + +GEMM_TEST_CASES = [ + # (M, N, K, block_M, block_N, block_K) + (128, 128, 128, 32, 32, 32), + (128, 128, 128, 64, 64, 64), +] + + +@pytest.mark.parametrize("M, N, K, block_M, block_N, block_K", GEMM_TEST_CASES) +def test_tilelang_dma_copy_layout_query(M, N, K, block_M, block_N, block_K): + """Verify that after LayoutInference, the layout_map annotation is populated + for shared buffers, and that a downstream pass can look up layouts by buffer. + + NOTE: This test only checks layout annotations after LayoutInference. + LowerTileOp is not called here because gemm does not yet have a SUNMMIO + lowering path. The dma_copy lowering for copy ops is verified separately + in test_tilelang_dma_copy above. + """ + env.TILELANG_USE_GEMM_V1 = 0 + target = tvm.target.Target(SUNMMIO_TARGET_DESC) + mod = gemm_copy_kernel(M, N, K, block_M, block_N, block_K) + + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tl.transform.LayoutInference()(mod) + + # After LayoutInference but before LowerTileOp, the layout_map + # annotation is present on the Block node. + visitor = _DmaCopyVisitor() + visitor.visit_stmt(mod["main"].body) + assert len(visitor.layout_map) > 0, ("Expected layout_map annotation after LayoutInference") + + # The shared buffers should have layout entries + assert "A_shared" in visitor.layout_map, ( + f"Expected A_shared in layout_map, got keys: {list(visitor.layout_map.keys())}") + layout_a = visitor.layout_map["A_shared"] + input_shape_a = layout_a.input_size + assert len(input_shape_a) == 2 + assert int(input_shape_a[0]) == block_M + assert int(input_shape_a[1]) == block_K + + assert "B_shared" in visitor.layout_map + layout_b = visitor.layout_map["B_shared"] + input_shape_b = layout_b.input_size + assert len(input_shape_b) == 2 + assert int(input_shape_b[0]) == block_K + assert int(input_shape_b[1]) == block_N + + +if __name__ == "__main__": + tilelang.testing.main() From 3ca4fd6da83901627e856ab7bb8556064b968ede Mon Sep 17 00:00:00 2001 From: Jiaqi Guo Date: Wed, 11 Feb 2026 16:13:23 +0800 Subject: [PATCH 2/6] Add copy layout inference and checking --- src/op/builtin.cc | 8 +- src/op/copy.cc | 100 ++++- src/op/copy.h | 16 +- src/op/utils.cc | 22 ++ ...y => test_tilelang_sunmmio_gemm_layout.py} | 3 +- .../ops/test_tilelang_ops_sunmmio_dma_copy.py | 371 ++++++++++++++++++ .../transform/test_tilelang_dma_copy.py | 254 ------------ ...ang_transform_sunmmio_infer_sram_scope.py} | 0 tilelang/utils/language.py | 4 +- 9 files changed, 506 insertions(+), 272 deletions(-) rename testing/python/layout/{test_tilelang_gemm_sunmmio_layout.py => test_tilelang_sunmmio_gemm_layout.py} (97%) create mode 100644 testing/python/ops/test_tilelang_ops_sunmmio_dma_copy.py delete mode 100644 testing/python/transform/test_tilelang_dma_copy.py rename testing/python/transform/{test_tilelang_transform_infer_sram_scope.py => test_tilelang_transform_sunmmio_infer_sram_scope.py} (100%) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 6a2442887..058c395a8 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -106,13 +106,7 @@ TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); - -TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); - -TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(2).set_attr( +TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(create_tma_descriptor) diff --git a/src/op/copy.cc b/src/op/copy.cc index e2cf86368..a26205e63 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -564,6 +564,30 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } return {}; } + + // Sunmmio DMA Layout Inference + if (copy_inst == CopyInst::kSunmmioDMACopy) { + // for dma copy, we can directly apply the blockwise_zz_layout + const auto f = + ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout"); + auto result = Map(); + + if (level == InferLevel::kFree && !T.layout_map.count(src)) { + if (src.scope() != "global") { + auto layout = Downcast((*f)(src)); + result.Set(src, layout); + } + } + + if (level == InferLevel::kFree && !T.layout_map.count(dst)) { + if (dst.scope() != "global") { + auto layout = Downcast((*f)(dst)); + result.Set(dst, layout); + } + } + return result; + } + // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy // Use parallel op to infer the layout @@ -804,6 +828,65 @@ bool CopyNode::CheckTMemStore(Target target) const { dst.scope() == "shared.tmem"; } +/** + * @brief Determine whether this CopyNode can be lowered to a DMA Load + * instruction for Sunmmio target. + * + * The function returns true when all of the following hold: + * - the target architecture advertises DMA support; + * - the source buffer and the destination buffer are legal; + * - the source and destination have the same element data type. + * + * If the source and destination dtypes differ, a warning is logged and the + * function returns false (the caller is expected to fall back to a normal + * copy). + * + * + * @param target The compilation target to query for dma load support. + * @return true if the copy can be implemented as a DMA Load; false + * otherwise. + */ +bool CopyNode::CheckSunmmioDMACopy(Target target) const { + // 1. arch must support Sunmmio + if (!TargetIsSunmmio(target)) + return false; + + // 2. src and dst must be legal + bool scope_check = false; + // 2.1 DRAM -> RSRAM + if (src.scope() == "global" && dst.scope() == "shared.rsram") + scope_check = true; + // 2.2 DRAM -> WSRAM + if (src.scope() == "global" && dst.scope() == "shared.wsram") + scope_check = true; + // 2.3 DRAM -> ASRAM + if (src.scope() == "global" && dst.scope() == "shared.asram") + scope_check = true; + // 2.4 RSRAM -> WSRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.wsram") + scope_check = true; + // 2.5 RSRAM -> ASRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.asram") + scope_check = true; + // 2.6 RSRAM <-> RSRAM + if (src.scope() == "shared.rsram" && dst.scope() == "shared.rsram") + scope_check = true; + // 2.7 RSRAM -> DRAM + if (src.scope() == "shared.rsram" && dst.scope() == "global") + scope_check = true; + if (!scope_check) + return false; + + // 3. src and dst must have the same dtype + if (src->dtype != dst->dtype) { + LOG(WARNING) << "src and dst must have the same dtype for dma copy " + << src->name << " vs. " << dst->name << " dtype " << src->dtype + << " vs. " << dst->dtype << " will be fallback to normal copy"; + return false; + } + return true; +} + /** * @brief Selects the most specific copy instruction supported for the given * target and buffers. @@ -848,6 +931,12 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, return CopyInst::kTMemLoad; } else if (CheckTMemStore(target)) { return CopyInst::kTMemStore; + } else if (TargetIsSunmmio(target)) { + auto is_copy = CheckSunmmioDMACopy(target); + if (is_copy) + return CopyInst::kSunmmioDMACopy; + ICHECK(0) << "Unsupported copy from " << src.scope() << " to " + << dst.scope() << " of Sunmmio target."; } else { return CopyInst::kNormal; } @@ -860,6 +949,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, * determined copy instruction type: * - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions * - LDSM/STSM: Uses matrix load/store instructions for tensor cores + * - DMA Load/Store: Sunmmio specified instructions for copy * - Normal: Uses standard load/store operations with loop transformations * \param T LowerArgs containing target and layout map. * \param analyzer Arithmetic analyzer for simplification. @@ -870,7 +960,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { // SUNMMIO: emit a dma_copy intrinsic instead of GPU-style lowering if (TargetIsSunmmio(target)) { - return LowerDmaCopy(T, analyzer); + return LowerSunmmioDmaCopy(T, analyzer); } using namespace tvm::transform; @@ -899,6 +989,10 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return ldsm_copy; } else if (copy_inst == CopyInst::kNormal) { return LowerNormalCopy(T, analyzer); + } else if (copy_inst == CopyInst::kSunmmioDMACopy) { + auto dma_copy = LowerSunmmioDmaCopy(T, analyzer); + ICHECK(dma_copy.defined()) << "Failed to lower dma load/store"; + return dma_copy; } else { LOG(FATAL) << "Unsupported copy inst " << static_cast(copy_inst); } @@ -917,8 +1011,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { * consistency). * @return Stmt An Evaluate wrapping the tl.dma_copy Call. */ -Stmt CopyNode::LowerDmaCopy(const LowerArgs &T, - arith::Analyzer *analyzer) const { +Stmt CopyNode::LowerSunmmioDmaCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { // access_mask: 1=read for src, 2=write for dst PrimExpr src_region = MakeRegionExpr(src, src_range, /*access_mask=*/1); PrimExpr dst_region = MakeRegionExpr(dst, dst_range, /*access_mask=*/2); diff --git a/src/op/copy.h b/src/op/copy.h index 7d8e842d6..7566a3308 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -22,10 +22,11 @@ enum class CopyInst : uint8_t { kBulkStore = 4, // utilize tma store // we should separate the bulk load and store for 1d and multi-dim // as they have different memory access patterns - kBulkLoad1D = 5, // utilize tma load 1d - kBulkStore1D = 6, // utilize tma store 1d - kTMemLoad = 7, // tcgen05.ld (tensor memory -> register) - kTMemStore = 8, // tcgen05.st (register -> tensor memory) + kBulkLoad1D = 5, // utilize tma load 1d + kBulkStore1D = 6, // utilize tma store 1d + kTMemLoad = 7, // tcgen05.ld (tensor memory -> register) + kTMemStore = 8, // tcgen05.st (register -> tensor memory) + kSunmmioDMACopy = 9, // Sunmmio DMA }; /// Descriptor for Tensor Memory Access (TMA) copy operations @@ -180,6 +181,11 @@ class CopyNode : public TileOperatorNode { */ bool CheckTMemStore(Target target) const; + /*! + * \brief Check if Sunmmio dma copy is supported. + */ + bool CheckSunmmioDMACopy(Target target) const; + /*! * \brief Get the copy instruction type. */ @@ -223,7 +229,7 @@ class CopyNode : public TileOperatorNode { * Emits a tl.dma_copy(src_region, dst_region) intrinsic that preserves full * buffer region semantics for later SUNMMIO codegen consumption. */ - Stmt LowerDmaCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + Stmt LowerSunmmioDmaCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; /*! * \brief Generate SIMT (thread-level) loop for copying. diff --git a/src/op/utils.cc b/src/op/utils.cc index a1d53977d..17401b4d2 100644 --- a/src/op/utils.cc +++ b/src/op/utils.cc @@ -52,6 +52,28 @@ BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) { throw; // Unreachable } +/*! + * \brief Encode a Buffer + Array into a tl.tileop.region Call + * expression. + * + * This is the inverse of NormalizeToBufferRegion: it packs buffer region + * metadata into a PrimExpr so it can travel through Call arguments (where + * BufferRegion cannot appear directly). + * + * Use this when emitting intrinsic calls (e.g. tl.dma_copy) that need to + * carry full region semantics — buffer identity, per-axis min/extent, and + * access mode — as opaque PrimExpr arguments for later codegen consumption. + * + * Encoding layout: + * args[0] = BufferLoad(buffer, {range[0].min, range[1].min, ...}) + * args[1] = access_mask (1=read, 2=write, 3=read-write) + * args[2+i] = range[i].extent + * + * \param buffer The buffer this region refers to. + * \param ranges Per-axis [min, extent) ranges describing the tile. + * \param access_mask 1=read, 2=write, 3=read-write. + * \return A Call(tl.tileop.region, ...) expression. + */ PrimExpr MakeRegionExpr(const Buffer &buffer, const Array &ranges, int access_mask) { // Build BufferLoad with indices = per-axis minima diff --git a/testing/python/layout/test_tilelang_gemm_sunmmio_layout.py b/testing/python/layout/test_tilelang_sunmmio_gemm_layout.py similarity index 97% rename from testing/python/layout/test_tilelang_gemm_sunmmio_layout.py rename to testing/python/layout/test_tilelang_sunmmio_gemm_layout.py index a7feb5a57..75ed349f6 100644 --- a/testing/python/layout/test_tilelang_gemm_sunmmio_layout.py +++ b/testing/python/layout/test_tilelang_sunmmio_gemm_layout.py @@ -24,7 +24,7 @@ def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_d def main( A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -106,6 +106,7 @@ def test_tilelang_gemm_sunmmio_layout(M, N, K, block_M, block_N, block_K, versio with tvm.target.Target(target): mod = matmul(M, N, K, block_M, block_N, block_K, version) mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tl.transform.InferSramScope()(mod) mod = tl.transform.LayoutInference()(mod) LayoutVisual()(mod) diff --git a/testing/python/ops/test_tilelang_ops_sunmmio_dma_copy.py b/testing/python/ops/test_tilelang_ops_sunmmio_dma_copy.py new file mode 100644 index 000000000..7ee559f70 --- /dev/null +++ b/testing/python/ops/test_tilelang_ops_sunmmio_dma_copy.py @@ -0,0 +1,371 @@ +"""Test that SUNMMIO copy lowering emits tl.dma_copy with tl.tileop.region args, +and that each region can be normalized back to a BufferRegion with full metadata.""" + +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tilelang.utils.target import SUNMMIO_TARGET_DESC, determine_target +from tilelang.language.v2.annot import MeshShardingPolicy +from tvm import tir +from tvm.tir import PyStmtExprVisitor +import pytest + +tilelang.env.disable_cache() + + +def simple_copy_kernel(M, N, block_M, block_N, dtype="float16"): + """A minimal kernel with T.copy from global to shared memory.""" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype),): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + T.copy(A[by * block_M, bx * block_N], A_shared) + + return tvm.IRModule({"main": main}) + + +def apply_sunmmio_passes(mod, target): + """Apply the full SUNMMIO pass pipeline used for DMA copy lowering.""" + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + mod = tilelang.transform.InjectAssumes()(mod) + mod = tilelang.transform.Simplify()(mod) + mod = tilelang.transform.InferSramScope()(mod) + mod = tilelang.transform.LayoutReducer()(mod) + mod = tilelang.transform.LayoutInference()(mod) + mod = tilelang.transform.LowerTileOp()(mod) + return mod + + +class RegionRange: + """Simple container for (min, extent) of a region axis.""" + + def __init__(self, min_val, extent): + self.min = min_val + self.extent = extent + + +def normalize_region(region_call): + """Decode a tl.tileop.region Call back into (buffer, extents, access_mask). + + This mirrors what NormalizeToBufferRegion does in C++: + args[0] = BufferLoad (indices are per-axis minima) + args[1] = access_mask (int) + args[2+i] = extent for axis i + + On SUNMMIO, buffer remap is disabled so buffers retain their original + N-D shape and indices always match the number of extents. + + Returns (buffer, extents, access_mask) where extents is a list of + RegionRange objects with .min and .extent attributes. + """ + assert isinstance(region_call, tir.Call) + assert region_call.op.name == "tl.tileop.region" + + load = region_call.args[0] + assert isinstance(load, tir.BufferLoad) + + access_mask = int(region_call.args[1]) + num_extents = len(region_call.args) - 2 + + assert len( + load.indices) == num_extents, (f"Expected {num_extents} indices, got {len(load.indices)}") + + ranges = [] + for i in range(num_extents): + ranges.append(RegionRange(load.indices[i], region_call.args[2 + i])) + + return load.buffer, ranges, access_mask + + +@tir.functor.visitor +class _DmaCopyVisitor(PyStmtExprVisitor): + """Walk TIR and collect tl.dma_copy calls and their region arguments.""" + + def __init__(self): + super().__init__() + self.dma_copy_calls = [] + self.layout_map = {} + + def visit_block_(self, op: tir.Block) -> None: + if "layout_map" in op.annotations: + for key, layout in op.annotations["layout_map"].items(): + self.layout_map[key.name] = layout + self.visit_stmt(op.body) + + def visit_call_(self, op: tir.Call) -> None: + if hasattr(op, "op") and hasattr(op.op, "name") and op.op.name == "tl.dma_copy": + self.dma_copy_calls.append(op) + # Visit children + for arg in op.args: + self.visit_expr(arg) + + def visit_evaluate_(self, op: tir.Evaluate) -> None: + self.visit_expr(op.value) + + +def extract_dma_copy_lines(mod): + """Extract T.dma_copy lines from TIR script, robust to formatting changes.""" + return [line.lstrip() for line in mod.script().split('\n') if 'T.dma_copy' in line] + + +SIMPLE_COPY_CASES = [ + # (M, N, block_M, block_N) + (128, 128, 32, 32), + (256, 256, 64, 64), + (128, 256, 32, 64), +] + + +@pytest.mark.parametrize("M, N, block_M, block_N", SIMPLE_COPY_CASES) +def test_tilelang_dma_copy(M, N, block_M, block_N): + target = tvm.target.Target(SUNMMIO_TARGET_DESC) + mod = simple_copy_kernel(M, N, block_M, block_N) + + with tvm.target.Target(target): + mod = apply_sunmmio_passes(mod, target) + + # Walk the lowered IR and find dma_copy calls + visitor = _DmaCopyVisitor() + func = mod["main"] + visitor.visit_stmt(func.body) + + # Verify that exactly one tl.dma_copy call was emitted + assert len(visitor.dma_copy_calls) == 1, ( + f"Expected exactly 1 tl.dma_copy call, got {len(visitor.dma_copy_calls)}") + + call = visitor.dma_copy_calls[0] + + # dma_copy should have exactly 2 arguments (src_region, dst_region) + assert len(call.args) == 2, (f"Expected 2 args for dma_copy, got {len(call.args)}") + + # Each argument should be a tl.tileop.region Call + for i, arg in enumerate(call.args): + assert isinstance(arg, tir.Call), (f"dma_copy arg[{i}] should be a Call, got {type(arg)}") + assert hasattr(arg.op, "name") and arg.op.name == "tl.tileop.region", ( + f"dma_copy arg[{i}] should be tl.tileop.region, got {arg.op.name}") + + # --- Normalize regions back to buffer metadata --- + src_buf, src_ranges, src_mask = normalize_region(call.args[0]) + dst_buf, dst_ranges, dst_mask = normalize_region(call.args[1]) + + # access_mask: 1=read for src, 2=write for dst + assert src_mask == 1, f"Source access_mask should be 1 (read), got {src_mask}" + assert dst_mask == 2, f"Destination access_mask should be 2 (write), got {dst_mask}" + + # Buffer dtype + assert src_buf.dtype == "float16", f"Source dtype should be float16, got {src_buf.dtype}" + assert dst_buf.dtype == "float16", f"Dest dtype should be float16, got {dst_buf.dtype}" + + # Buffer shapes: both stay 2D (no flattening on SUNMMIO) + assert len(src_buf.shape) == 2, f"Source buffer should be 2D, got {len(src_buf.shape)}D" + assert int(src_buf.shape[0]) == M, f"Source shape[0] should be {M}, got {src_buf.shape[0]}" + assert int(src_buf.shape[1]) == N, f"Source shape[1] should be {N}, got {src_buf.shape[1]}" + + assert len(dst_buf.shape) == 2, f"Dest buffer should be 2D, got {len(dst_buf.shape)}D" + assert int( + dst_buf.shape[0]) == block_M, (f"Dest shape[0] should be {block_M}, got {dst_buf.shape[0]}") + assert int( + dst_buf.shape[1]) == block_N, (f"Dest shape[1] should be {block_N}, got {dst_buf.shape[1]}") + + # Buffer scope: InferSramScope assigns shared.rsram for a plain alloc_shared + src_scope = src_buf.scope() + assert src_scope == "" or src_scope == "global", ( + f"Source buffer should be in global scope, got '{src_scope}'") + dst_scope = dst_buf.scope() + assert dst_scope == "shared.rsram", ( + f"Destination buffer should be shared.rsram after InferSramScope, got '{dst_scope}'") + + # Region extents match block dimensions + assert len(src_ranges) == 2 + src_extent_0 = int(src_ranges[0].extent) + src_extent_1 = int(src_ranges[1].extent) + assert src_extent_0 == block_M, (f"Source extent[0] should be {block_M}, got {src_extent_0}") + assert src_extent_1 == block_N, (f"Source extent[1] should be {block_N}, got {src_extent_1}") + + assert len(dst_ranges) == 2 + dst_extent_0 = int(dst_ranges[0].extent) + dst_extent_1 = int(dst_ranges[1].extent) + assert dst_extent_0 == block_M, (f"Dest extent[0] should be {block_M}, got {dst_extent_0}") + assert dst_extent_1 == block_N, (f"Dest extent[1] should be {block_N}, got {dst_extent_1}") + + +def wrong_copy(M, + N, + K, + block_M, + block_N, + block_K, + error_type, + dtype="float16", + accum_dtype="float16"): + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + A_shared_2 = T.alloc_shared((block_M, block_K), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + B_shared_2 = T.alloc_shared((block_K, block_N), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + if error_type == 'A->D': + T.copy(A_shared, C[by * block_M, ko * block_K]) + elif error_type == 'W->D': + T.copy(B_shared, C[by * block_M, ko * block_K]) + elif error_type == 'A->R': + T.copy(A_shared, C_shared) + elif error_type == 'W->R': + T.copy(B_shared, C_shared) + elif error_type == 'D<->D': + T.copy(C[by * block_M, ko * block_K], B[by * block_M, ko * block_K]) + elif error_type == 'A<->A': + T.copy(A_shared, A_shared_2) + elif error_type == 'W<->W': + T.copy(B_shared, B_shared_2) + elif error_type == 'A->W': + T.copy(A_shared, B_shared) + elif error_type == 'W->A': + T.copy(B_shared, A_shared) + + return tvm.IRModule({'main': main}) + + +WRONG_TEST_CASES = [ + (128, 128, 128, 32, 32, 32, "A->D", + "Unsupported copy from shared.asram to global of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->D", + "Unsupported copy from shared.wsram to global of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "A->R", + "Unsupported copy from shared.asram to shared.rsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->R", + "Unsupported copy from shared.wsram to shared.rsram of Sunmmio target."), + # (128, 128, 128, 32, 32, 32, "D<->D", + # "Unsupported copy from global to global of Sunmmio target."), + # D<->D not work now + (128, 128, 128, 32, 32, 32, "A<->A", + "Unsupported copy from shared.asram to shared.asram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W<->W", + "Unsupported copy from shared.wsram to shared.wsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "A->W", + "Unsupported copy from shared.asram to shared.wsram of Sunmmio target."), + (128, 128, 128, 32, 32, 32, "W->A", + "Unsupported copy from shared.wsram to shared.asram of Sunmmio target."), +] + + +@pytest.mark.parametrize( + "M, N, K, block_M, block_N, block_K, error_type, error_msg", + WRONG_TEST_CASES, +) +def test_tilelang_mesh_wrong_copy_to_dma(M, N, K, block_M, block_N, block_K, error_type, error_msg): + target = tvm.target.Target(SUNMMIO_TARGET_DESC) + with pytest.raises(tvm.error.InternalError, match=error_msg), tvm.target.Target(target): + mod = wrong_copy(M, N, K, block_M, block_N, block_K, error_type) + mod = apply_sunmmio_passes(mod, target) + + +def copy(K, block_M, block_N, block_K, dtype="float32", accum_dtype="float32"): + MyTensor = T.MeshTensor((128, 128), + sharding_policy=MeshShardingPolicy(cross_mesh_dim=0), + device_mesh_config=(2, 2), + hierarchical_dims=(4, 32, 128), + hierarchical_groups=((0, 2), (2, 3)), + hierarchical_strides=(32, 1, 4096)) + + @T.prim_func + def main(C: MyTensor): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(128, block_N), T.ceildiv(128, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.asram") + B_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.wsram") + C_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + D_shared = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared.rsram") + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # DRAM -> RSRAM + T.copy(C[by * block_M, ko * block_K], C_shared) + # DRAM -> WSRAM + T.copy(C[by * block_M, ko * block_K], B_shared) + # DRAM <- RSRAM + T.copy(C_shared, C[by * block_M, ko * block_K]) + # DRAM -> ASRAM + T.copy(C[by * block_M, ko * block_K], A_shared) + # RSRAM -> ASRAM + T.copy(C_shared[8:24, 16:48], A_shared[24:40, 8:40]) + # RSRAM -> WSRAM + T.copy(C_shared[8:32, 48:56], B_shared[40:64, 0:8]) + # RSRAM <-> RSRAM + T.copy(C_shared, D_shared) + + return tvm.IRModule({'main': main}) + + +# fmt: off +MESH_COPY_CASES = [ + ( + 128, 64, 64, 32, + [ + # DRAM -> RSRAM + 'T.dma_copy(T.region(C[by * 64, ko * 32], 1, 64, 64), T.region(C_shared[0, 0], 2, 64, 64))', + # DRAM -> WSRAM + 'T.dma_copy(T.region(C[by * 64, ko * 32], 1, 64, 64), T.region(B_shared[0, 0], 2, 64, 64))', + # DRAM <- RSRAM + 'T.dma_copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(C[by * 64, ko * 32], 2, 64, 64))', + # DRAM -> ASRAM + 'T.dma_copy(T.region(C[by * 64, ko * 32], 1, 64, 64), T.region(A_shared[0, 0], 2, 64, 64))', + # RSRAM -> ASRAM + 'T.dma_copy(T.region(C_shared[8, 16], 1, 16, 32), T.region(A_shared[24, 8], 2, 16, 32))', + # RSRAM -> WSRAM + 'T.dma_copy(T.region(C_shared[8, 48], 1, 24, 8), T.region(B_shared[40, 0], 2, 24, 8))', + # RSRAM <-> RSRAM + 'T.dma_copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(D_shared[0, 0], 2, 64, 64))', + ], + ), + ( + 256, 64, 64, 64, + [ + # DRAM -> RSRAM + 'T.dma_copy(T.region(C[by * 64, ko * 64], 1, 64, 64), T.region(C_shared[0, 0], 2, 64, 64))', + # DRAM -> WSRAM + 'T.dma_copy(T.region(C[by * 64, ko * 64], 1, 64, 64), T.region(B_shared[0, 0], 2, 64, 64))', + # DRAM <- RSRAM + 'T.dma_copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(C[by * 64, ko * 64], 2, 64, 64))', + # DRAM -> ASRAM + 'T.dma_copy(T.region(C[by * 64, ko * 64], 1, 64, 64), T.region(A_shared[0, 0], 2, 64, 64))', + # RSRAM -> ASRAM + 'T.dma_copy(T.region(C_shared[8, 16], 1, 16, 32), T.region(A_shared[24, 8], 2, 16, 32))', + # RSRAM -> WSRAM + 'T.dma_copy(T.region(C_shared[8, 48], 1, 24, 8), T.region(B_shared[40, 0], 2, 24, 8))', + # RSRAM <-> RSRAM + 'T.dma_copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(D_shared[0, 0], 2, 64, 64))', + ], + ), +] +# fmt: on + + +@pytest.mark.parametrize( + "K, block_M, block_N, block_K, lower_stmt", + MESH_COPY_CASES, +) +def test_tilelang_mesh_copy_to_dma(K, block_M, block_N, block_K, lower_stmt): + target_name = "Sunmmio" + target = determine_target(target_name, return_object=True) + with tvm.target.Target(target): + mod = copy(K, block_M, block_N, block_K) + mod = apply_sunmmio_passes(mod, target) + texts = extract_dma_copy_lines(mod) + assert len(texts) == len(lower_stmt), ( + f"Expected {len(lower_stmt)} dma_copy lines, got {len(texts)}") + for i in range(len(texts)): + assert texts[i] == lower_stmt[i], ( + f"Line {i} mismatch:\n actual: {texts[i]}\n expected: {lower_stmt[i]}") diff --git a/testing/python/transform/test_tilelang_dma_copy.py b/testing/python/transform/test_tilelang_dma_copy.py deleted file mode 100644 index bff708009..000000000 --- a/testing/python/transform/test_tilelang_dma_copy.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Test that SUNMMIO copy lowering emits tl.dma_copy with tl.tileop.region args, -and that each region can be normalized back to a BufferRegion with full metadata.""" - -import tilelang -import tilelang as tl -import tilelang.language as T -from tilelang import tvm as tvm -from tilelang.utils.target import SUNMMIO_TARGET_DESC -from tvm import tir -from tvm.tir import PyStmtExprVisitor -import tilelang.env as env -import pytest - -tilelang.env.disable_cache() - - -def simple_copy_kernel(M, N, block_M, block_N, dtype="float16"): - """A minimal kernel with T.copy from global to shared memory.""" - - @T.prim_func - def main(A: T.Tensor((M, N), dtype),): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_N), dtype) - T.copy(A[by * block_M, bx * block_N], A_shared) - - return tvm.IRModule({"main": main}) - - -def gemm_copy_kernel(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float32"): - """A kernel with T.copy + T.gemm to trigger layout inference.""" - - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), accum_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_shared = T.alloc_shared((block_M, block_N), accum_dtype) - - T.clear(C_shared) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return tvm.IRModule({"main": main}) - - -class RegionRange: - """Simple container for (min, extent) of a region axis.""" - - def __init__(self, min_val, extent): - self.min = min_val - self.extent = extent - - -def normalize_region(region_call): - """Decode a tl.tileop.region Call back into (buffer, extents, access_mask). - - This mirrors what NormalizeToBufferRegion does in C++: - args[0] = BufferLoad (indices are per-axis minima) - args[1] = access_mask (int) - args[2+i] = extent for axis i - - On SUNMMIO, buffer remap is disabled so buffers retain their original - N-D shape and indices always match the number of extents. - - Returns (buffer, extents, access_mask) where extents is a list of - RegionRange objects with .min and .extent attributes. - """ - assert isinstance(region_call, tir.Call) - assert region_call.op.name == "tl.tileop.region" - - load = region_call.args[0] - assert isinstance(load, tir.BufferLoad) - - access_mask = int(region_call.args[1]) - num_extents = len(region_call.args) - 2 - - assert len( - load.indices) == num_extents, (f"Expected {num_extents} indices, got {len(load.indices)}") - - ranges = [] - for i in range(num_extents): - ranges.append(RegionRange(load.indices[i], region_call.args[2 + i])) - - return load.buffer, ranges, access_mask - - -@tir.functor.visitor -class _DmaCopyVisitor(PyStmtExprVisitor): - """Walk TIR and collect tl.dma_copy calls and their region arguments.""" - - def __init__(self): - super().__init__() - self.dma_copy_calls = [] - self.layout_map = {} - - def visit_block_(self, op: tir.Block) -> None: - if "layout_map" in op.annotations: - for key, layout in op.annotations["layout_map"].items(): - self.layout_map[key.name] = layout - self.visit_stmt(op.body) - - def visit_call_(self, op: tir.Call) -> None: - if hasattr(op, "op") and hasattr(op.op, "name") and op.op.name == "tl.dma_copy": - self.dma_copy_calls.append(op) - # Visit children - for arg in op.args: - self.visit_expr(arg) - - def visit_evaluate_(self, op: tir.Evaluate) -> None: - self.visit_expr(op.value) - - -TEST_CASES = [ - # (M, N, block_M, block_N) - (128, 128, 32, 32), - (256, 256, 64, 64), - (128, 256, 32, 64), -] - - -@pytest.mark.parametrize("M, N, block_M, block_N", TEST_CASES) -def test_tilelang_dma_copy(M, N, block_M, block_N): - target = tvm.target.Target(SUNMMIO_TARGET_DESC) - mod = simple_copy_kernel(M, N, block_M, block_N) - - with tvm.target.Target(target): - mod = tvm.tir.transform.BindTarget(target)(mod) - mod = tl.transform.LayoutInference()(mod) - mod = tl.transform.LowerTileOp()(mod) - print(mod) - - # Walk the lowered IR and find dma_copy calls - visitor = _DmaCopyVisitor() - func = mod["main"] - visitor.visit_stmt(func.body) - - # Verify that at least one tl.dma_copy call was emitted - assert len(visitor.dma_copy_calls) > 0, ("Expected at least one tl.dma_copy call in lowered IR") - - for call in visitor.dma_copy_calls: - # dma_copy should have exactly 2 arguments (src_region, dst_region) - assert len(call.args) == 2, (f"Expected 2 args for dma_copy, got {len(call.args)}") - - # Each argument should be a tl.tileop.region Call - for i, arg in enumerate(call.args): - assert isinstance(arg, - tir.Call), (f"dma_copy arg[{i}] should be a Call, got {type(arg)}") - assert hasattr(arg.op, "name") and arg.op.name == "tl.tileop.region", ( - f"dma_copy arg[{i}] should be tl.tileop.region, got {arg.op.name}") - - # --- Normalize regions back to buffer metadata --- - src_buf, src_ranges, src_mask = normalize_region(call.args[0]) - dst_buf, dst_ranges, dst_mask = normalize_region(call.args[1]) - - # access_mask: 1=read for src, 2=write for dst - assert src_mask == 1, f"Source access_mask should be 1 (read), got {src_mask}" - assert dst_mask == 2, f"Destination access_mask should be 2 (write), got {dst_mask}" - - # Buffer dtype - assert src_buf.dtype == "float16", f"Source dtype should be float16, got {src_buf.dtype}" - assert dst_buf.dtype == "float16", f"Dest dtype should be float16, got {dst_buf.dtype}" - - # Buffer shapes: both stay 2D (no flattening on SUNMMIO) - assert len(src_buf.shape) == 2, f"Source buffer should be 2D, got {len(src_buf.shape)}D" - assert int(src_buf.shape[0]) == M, f"Source shape[0] should be {M}, got {src_buf.shape[0]}" - assert int(src_buf.shape[1]) == N, f"Source shape[1] should be {N}, got {src_buf.shape[1]}" - - assert len(dst_buf.shape) == 2, f"Dest buffer should be 2D, got {len(dst_buf.shape)}D" - assert int(dst_buf.shape[0]) == block_M, ( - f"Dest shape[0] should be {block_M}, got {dst_buf.shape[0]}") - assert int(dst_buf.shape[1]) == block_N, ( - f"Dest shape[1] should be {block_N}, got {dst_buf.shape[1]}") - - # Buffer scope - src_scope = src_buf.scope() - assert src_scope == "" or src_scope == "global", ( - f"Source buffer should be in global scope, got '{src_scope}'") - dst_scope = dst_buf.scope() - assert "shared" in dst_scope, ( - f"Destination buffer should be in shared scope, got '{dst_scope}'") - - # Region extents match block dimensions - assert len(src_ranges) == 2 - src_extent_0 = int(src_ranges[0].extent) - src_extent_1 = int(src_ranges[1].extent) - assert src_extent_0 == block_M, ( - f"Source extent[0] should be {block_M}, got {src_extent_0}") - assert src_extent_1 == block_N, ( - f"Source extent[1] should be {block_N}, got {src_extent_1}") - - assert len(dst_ranges) == 2 - dst_extent_0 = int(dst_ranges[0].extent) - dst_extent_1 = int(dst_ranges[1].extent) - assert dst_extent_0 == block_M, (f"Dest extent[0] should be {block_M}, got {dst_extent_0}") - assert dst_extent_1 == block_N, (f"Dest extent[1] should be {block_N}, got {dst_extent_1}") - - -GEMM_TEST_CASES = [ - # (M, N, K, block_M, block_N, block_K) - (128, 128, 128, 32, 32, 32), - (128, 128, 128, 64, 64, 64), -] - - -@pytest.mark.parametrize("M, N, K, block_M, block_N, block_K", GEMM_TEST_CASES) -def test_tilelang_dma_copy_layout_query(M, N, K, block_M, block_N, block_K): - """Verify that after LayoutInference, the layout_map annotation is populated - for shared buffers, and that a downstream pass can look up layouts by buffer. - - NOTE: This test only checks layout annotations after LayoutInference. - LowerTileOp is not called here because gemm does not yet have a SUNMMIO - lowering path. The dma_copy lowering for copy ops is verified separately - in test_tilelang_dma_copy above. - """ - env.TILELANG_USE_GEMM_V1 = 0 - target = tvm.target.Target(SUNMMIO_TARGET_DESC) - mod = gemm_copy_kernel(M, N, K, block_M, block_N, block_K) - - with tvm.target.Target(target): - mod = tvm.tir.transform.BindTarget(target)(mod) - mod = tl.transform.LayoutInference()(mod) - - # After LayoutInference but before LowerTileOp, the layout_map - # annotation is present on the Block node. - visitor = _DmaCopyVisitor() - visitor.visit_stmt(mod["main"].body) - assert len(visitor.layout_map) > 0, ("Expected layout_map annotation after LayoutInference") - - # The shared buffers should have layout entries - assert "A_shared" in visitor.layout_map, ( - f"Expected A_shared in layout_map, got keys: {list(visitor.layout_map.keys())}") - layout_a = visitor.layout_map["A_shared"] - input_shape_a = layout_a.input_size - assert len(input_shape_a) == 2 - assert int(input_shape_a[0]) == block_M - assert int(input_shape_a[1]) == block_K - - assert "B_shared" in visitor.layout_map - layout_b = visitor.layout_map["B_shared"] - input_shape_b = layout_b.input_size - assert len(input_shape_b) == 2 - assert int(input_shape_b[0]) == block_K - assert int(input_shape_b[1]) == block_N - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_infer_sram_scope.py b/testing/python/transform/test_tilelang_transform_sunmmio_infer_sram_scope.py similarity index 100% rename from testing/python/transform/test_tilelang_transform_infer_sram_scope.py rename to testing/python/transform/test_tilelang_transform_sunmmio_infer_sram_scope.py diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 41da8ab0a..d7fe90aac 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -54,7 +54,7 @@ def is_shared(buffer: Buffer | BufferLoad | BufferRegion, allow_dynamic: bool = """ buffer = _get_buffer(buffer) conditions = [False] - conditions.append(buffer.scope() == "shared") + conditions.append(buffer.scope().startswith("shared")) if allow_dynamic: conditions.append(is_shared_dynamic(buffer)) return any(conditions) @@ -71,7 +71,7 @@ def is_shared_dynamic(buffer: Buffer | BufferLoad | BufferRegion) -> bool: bool: True if the buffer is in dynamic shared memory, False otherwise. """ buffer = _get_buffer(buffer) - return buffer.scope() == "shared.dyn" + return buffer.scope().startswith("shared") and buffer.scope().endswith(".dyn") def is_tensor_memory(buffer: Buffer | BufferLoad | BufferRegion) -> bool: From d802a39eeb0f1adcf1ed4a373d4bf970401785ec Mon Sep 17 00:00:00 2001 From: Jiaqi Guo Date: Wed, 11 Feb 2026 16:23:13 +0800 Subject: [PATCH 3/6] Remove duplicated check --- src/op/copy.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index a26205e63..a8b442b6c 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -958,11 +958,6 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; - // SUNMMIO: emit a dma_copy intrinsic instead of GPU-style lowering - if (TargetIsSunmmio(target)) { - return LowerSunmmioDmaCopy(T, analyzer); - } - using namespace tvm::transform; PassContext pass_ctx = PassContext::Current(); bool disable_tma_lower = From 1c11fe7897bfc2f3139186182201d2aec3008288 Mon Sep 17 00:00:00 2001 From: Jiaqi Guo Date: Wed, 11 Feb 2026 17:09:56 +0800 Subject: [PATCH 4/6] Update comment & remove unnecessary builtins --- src/op/builtin.h | 91 ------- src/op/copy.cc | 519 +------------------------------------- tilelang/language/copy.py | 7 + 3 files changed, 13 insertions(+), 604 deletions(-) diff --git a/src/op/builtin.h b/src/op/builtin.h index 97b631e17..3779a8562 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -188,97 +188,6 @@ TVM_DLL const Op &get_mbarrier(); */ TVM_DLL const Op &tma_load(); -/*! - * \brief Perform a DMA load operation from source memory to destination memory. - * - * This function describes a DMA-based tensor copy with explicit shape, layout, - * memory scope. It is typically used to lower a high-level - * tensor copy into a hardware-specific DMA instruction. - * - * The source and destination tensors are described in terms of: - * - data type - * - rank and logical shape - * - layout (input shape + forward index), The T.Layout type is ObjectRef, - * which is not suitable for backend parsing, so it's two members are extracted: - * input shape and forward index, which are both Array - * - memory scope - * - * A sub-region of the source tensor can be copied by specifying the coordinate - * offset (`coord`) relative to the source base address. - * - * Example: - * For a 3D tensor A: Tensor(128, 256, 512), copying - * A[32:64, 128:192, 0:256] - * then: - * src_rank = 3 - * src_shape = [128, 256, 512] - * coord = [32, 128, 0] - * - * \param data_type - * Element data type of the tensor (e.g. float32, float16). - * - * \param src_rank - * Rank (number of dimensions) of the source tensor. - * - * \param src_shape - * Logical shape of the source tensor. - * For example, Tensor(128, 256, 512) -> [128, 256, 512]. - * - * \param src_input_size - * Input shape of the source layout, retrievable via Layout::getInputShape(). - * For a row-major 3D tensor, this is identical to src_shape. - * - * \param src_forward - * Forward index mapping of the source layout, retrievable via - * Layout::GetForwardIndex(). - * For a row-major layout of Tensor(128, 256, 512), - * this is [256 * 512, 512, 1]. - * - * \param src_scope - * Memory scope of the source tensor. - * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". - * - * \param dst_rank - * Rank (number of dimensions) of the destination tensor. - * - * \param dst_shape - * Logical shape of the destination tensor. - * - * \param dst_input_size - * Input shape of the destination layout, retrievable via - * Layout::getInputShape(). - * - * \param dst_forward - * Forward index mapping of the destination layout, retrievable via - * Layout::GetForwardIndex(). - * - * \param dst_scope - * Memory scope of the destination tensor. - * Examples: "global", "shared.asram", "shared.wsram", "shared.rsram". - * - * \param src_addr - * Base address of the source tensor in memory . - * - * \param coord - * Coordinate offset specifying the starting point of the copy in the source - * tensor. Its length must equal src_rank. - * - * \param dst_addr - * Base address of the destination tensor in memory . - * - * \note - * Out-of-bound fill policies are currently not supported. - */ -TVM_DLL const Op &dma_load(); - -/*! - * \brief Perform a DMA store operation from source memory to destination - * memory. see dma_load for details. - * - * - */ -TVM_DLL const Op &dma_store(); - /*! * \brief Perform a DMA copy operation preserving full buffer region semantics. * diff --git a/src/op/copy.cc b/src/op/copy.cc index a8b442b6c..96bef70bf 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -829,8 +829,8 @@ bool CopyNode::CheckTMemStore(Target target) const { } /** - * @brief Determine whether this CopyNode can be lowered to a DMA Load - * instruction for Sunmmio target. + * @brief Determine whether this CopyNode can be lowered to a DMA Copy + * Intrinsic for Sunmmio target. * * The function returns true when all of the following hold: * - the target architecture advertises DMA support; @@ -842,8 +842,8 @@ bool CopyNode::CheckTMemStore(Target target) const { * copy). * * - * @param target The compilation target to query for dma load support. - * @return true if the copy can be implemented as a DMA Load; false + * @param target The compilation target to query for dma copy support. + * @return true if the copy can be implemented as a DMA Copy; false * otherwise. */ bool CopyNode::CheckSunmmioDMACopy(Target target) const { @@ -949,7 +949,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, * determined copy instruction type: * - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions * - LDSM/STSM: Uses matrix load/store instructions for tensor cores - * - DMA Load/Store: Sunmmio specified instructions for copy + * - DMA Copy: Sunmmio specified instructions for copy * - Normal: Uses standard load/store operations with loop transformations * \param T LowerArgs containing target and layout map. * \param analyzer Arithmetic analyzer for simplification. @@ -986,7 +986,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return LowerNormalCopy(T, analyzer); } else if (copy_inst == CopyInst::kSunmmioDMACopy) { auto dma_copy = LowerSunmmioDmaCopy(T, analyzer); - ICHECK(dma_copy.defined()) << "Failed to lower dma load/store"; + ICHECK(dma_copy.defined()) << "Failed to lower dma copy"; return dma_copy; } else { LOG(FATAL) << "Unsupported copy inst " << static_cast(copy_inst); @@ -1685,510 +1685,3 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); } else if (StructuralEqual()( shared_layout, - makeHalfBankSwizzleLayout(*stride, *continuous, - shared_tensor->dtype.bits()))) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); - } else if (StructuralEqual()( - shared_layout, - makeFullBankSwizzleLayout(*stride, *continuous, - shared_tensor->dtype.bits()))) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); - } else if (StructuralEqual()( - shared_layout, - makeGemmABLayoutPadded(*stride, *continuous, - shared_tensor->dtype.bits()))) { - LOG(WARNING) << "Bulk copy cannot support a padded layout for src: " - << src->name << ", dst: " << dst->name - << ", fallback to normal copy"; - return LowerNormalCopy(T, analyzer); - } else { - LOG(WARNING) << "Came across unsupported swizzle layout for src: " - << src->name << ", dst: " << dst->name - << ", fallback to normal copy"; - return LowerNormalCopy(T, analyzer); - } - } - - auto inner_box_dim = as_const_int(desc.smem_box[0]); - if (inner_box_dim == nullptr) { - LOG(WARNING) << "inner_box_dim " << desc.smem_box[0] - << " can only be a constant integer for TMA bulk copy, " - "fallback to normal copy"; - return LowerNormalCopy(T, analyzer); - } - int instruction_dim = *inner_box_dim; - if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_64B)) { - instruction_dim = 64 / src->dtype.bytes(); - } else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_128B)) { - instruction_dim = 128 / src->dtype.bytes(); - } - if (instruction_dim > 256) { - // smem_box dim must be in [0, 256] - // if is 512, we need to split the copy into two parts - ICHECK((*inner_box_dim) % 256 == 0) - << "inner_box_dim: " << *inner_box_dim << " is not divisible by 256"; - instruction_dim = 256; - } - ICHECK((*inner_box_dim) % instruction_dim == 0) - << "inner_box_dim: " << *inner_box_dim - << " is not divisible by instruction_dim: " << instruction_dim; - desc.smem_box.Set(0, PrimExpr(instruction_dim)); - - int inner_box_dim_ = instruction_dim * shared_tensor->dtype.bytes(); - - // Check inner_box_dim_ for each swizzle type in a cleaner way - struct SwizzleCheck { - int swizzle; - int max_dim; - }; - static const std::vector swizzle_checks = { - {static_cast(CU_TENSOR_MAP_SWIZZLE_32B), 32}, - {static_cast(CU_TENSOR_MAP_SWIZZLE_64B), 64}, - {static_cast(CU_TENSOR_MAP_SWIZZLE_128B), 128}, - }; - for (const auto &check : swizzle_checks) { - if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) { - LOG(WARNING) << "TMA bulk copy cannot support a swizzled global layout " - "with inner_box_dim_ > " - << check.max_dim << ", will be fallback to normal copy"; - return LowerNormalCopy(T, analyzer); - } - } - - Call create_descriptor = - Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); - - Array args; - args.reserve(desc.rank + 4); - args.push_back(create_descriptor); - if (is_load) - args.push_back(0); // mbarrier id placeholder - auto op = is_load ? tma_load() : tma_store(); - - Stmt tma_copy; - PrimExpr total_elements = 1; - for (auto e : desc.smem_box) - total_elements *= e; - - if ((*inner_box_dim) != instruction_dim) { - Var loop_var("i"); - int loop_extent = (*inner_box_dim) / instruction_dim; - - PrimExpr shared_addr = shared_tensor.access_ptr( - is_load ? 2 : 1, DataType::Handle(), 1, - shared_offset + total_elements * loop_var, total_elements); - args.push_back(shared_addr); - global_coords.Set(0, global_coords[0] + instruction_dim * loop_var); - for (auto coord : global_coords) - args.push_back(coord); - int need_reduce = 0; - if (!is_load) - args.push_back(need_reduce); - args.push_back(this->eviction_policy); - tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, - Evaluate(Call(DataType::Handle(), op, args))); - } else { - PrimExpr shared_addr = shared_tensor.access_ptr( - is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements); - args.push_back(shared_addr); - for (auto coord : global_coords) - args.push_back(coord); - int need_reduce = 0; - if (!is_load) - args.push_back(need_reduce); - args.push_back(this->eviction_policy); - tma_copy = Evaluate(Call(DataType::Handle(), op, args)); - } - tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); - - return tma_copy; -} - -Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, - CopyInst copy_inst) const { - ICHECK(copy_inst == CopyInst::kBulkLoad1D || - copy_inst == CopyInst::kBulkStore1D); - - // Add 1D TMA copy when the global and shared memory is contiguous - // Check if shared_tensor->name is present in T.buffer_var_gemm - // (Array) to avoid use 1D TMA copy for swizzled layout - bool is_load = copy_inst == CopyInst::kBulkLoad1D; - auto shared_range = is_load ? dst_range : src_range; - auto global_range = is_load ? src_range : dst_range; - auto shared_tensor = is_load ? dst : src; - auto global_tensor = is_load ? src : dst; - - PrimExpr shared_elements = 1; - for (size_t i = 0; i < shared_range.size(); i++) { - shared_elements *= shared_range[i]->extent; - } - - std::vector shared_strides; - PrimExpr shared_stride = 1; - for (size_t i = 0; i < shared_tensor->shape.size(); i++) { - auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; - shared_strides.insert(shared_strides.begin(), shared_stride); - shared_stride *= s; - } - - Array shared_indices; - for (auto r : shared_range) - shared_indices.push_back(r->min); - - Array global_indices; - for (auto r : global_range) { - global_indices.push_back(r->min); - } - std::vector global_strides; - PrimExpr global_stride = 1; - for (size_t i = 0; i < global_tensor->shape.size(); i++) { - auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; - global_strides.insert(global_strides.begin(), global_stride); - global_stride *= s; - } - - PrimExpr global_offset = 0; - for (size_t i = 0; i < global_indices.size(); i++) { - global_offset += global_indices[i] * global_strides[i]; - } - - PrimExpr shared_offset = 0; - for (size_t i = 0; i < shared_indices.size(); i++) { - shared_offset += shared_indices[i] * shared_strides[i]; - } - - PrimExpr elements = analyzer->Simplify(shared_elements); - PrimExpr shared_addr = shared_tensor.access_ptr( - is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); - PrimExpr global_addr = global_tensor.access_ptr( - is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements); - Stmt tma_copy; - if (is_load) { - // the zero is a placeholder for mbarrier ids - tma_copy = Evaluate( - Call(DataType::Handle(), tma_load(), - {shared_addr, global_addr, 0, - elements * shared_tensor->dtype.bytes(), this->eviction_policy})); - } else { - int need_reduce = 0; - tma_copy = Evaluate( - Call(DataType::Handle(), tma_store(), - {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(), - need_reduce, this->eviction_policy})); - } - tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); - return tma_copy; -} -/*! - * \brief Encode the TMA descriptor into an array of PrimExpr. - * This function serializes the TMA descriptor fields into a format suitable for - * passing to the create_tma_descriptor() builtin function. The encoding follows - * the expected argument order for the TMA descriptor creation. - * \return Array of PrimExpr representing the encoded TMA descriptor. - */ -Array TMADesc::EncodeCallArgs() const { - Array args; - args.reserve(rank * 4 + 7); - - args.push_back(data_type); - args.push_back(static_cast(rank)); - args.push_back(global_addr); - for (auto e : global_shape) - args.push_back(e); - for (auto e : global_stride) - args.push_back(e); - for (auto e : smem_box) - args.push_back(e); - for (auto e : smem_stride) - args.push_back(e); - args.push_back(interleave); - args.push_back(swizzle); - args.push_back(l2_promotion); - args.push_back(oob_fill); - - return args; -} - -/** - * @brief Construct a Conv2DIm2ColOp node. - * - * Initializes a Conv2DIm2ColOpNode from raw TL-call arguments and a buffer map. - * The constructor extracts source and destination Buffers from vmap and reads - * convolution parameters encoded in args: - * - args[0]: source tensor access pointer - * - args[1]: destination tensor access pointer - * - args[2]: nhw_step (PrimExpr) - * - args[3]: c_step (PrimExpr) - * - args[4]: kernel (IntImm) - * - args[5]: stride (IntImm) - * - args[6]: dilation (IntImm) - * - args[7]: padding (IntImm) - * - args[8]: eviction_policy (IntImm) - * - * The created node stores these values (src, dst, nhw_step, c_step, kernel, - * stride, dilation, padding, eviction_policy) for later lowering to TMA-based - * GPU intrinsics. - * - * @param args Array of PrimExpr TL-call arguments (see list above). - */ -Conv2DIm2ColOp::Conv2DIm2ColOp(Array args) { - ObjectPtr node = - tvm::ffi::make_object(); - node->srcRegion_ = NormalizeToBufferRegion(args[0]); - node->dstRegion_ = NormalizeToBufferRegion(args[1]); - node->src_ = node->srcRegion_->buffer; - node->dst_ = node->dstRegion_->buffer; - node->nhw_step_ = args[2]; - node->c_step_ = args[3]; - node->kernel_ = args[4].as().value()->value; - node->stride_ = args[5].as().value()->value; - node->dilation_ = args[6].as().value()->value; - node->padding_ = args[7].as().value()->value; - node->eviction_policy_ = args[8].as().value()->value; - data_ = std::move(node); -} - -/** - * @brief Create a shallow copy of this Conv2DIm2ColOpNode wrapped as a - * TileOperator. - * - * Produces a new Conv2DIm2ColOp that owns a freshly allocated - * Conv2DIm2ColOpNode initialized from this node (member-wise copy). This is - * used to duplicate the operator node for compiler passes that require - * independent operator instances. - * - * @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode. - */ -TileOperator Conv2DIm2ColOpNode::Clone() const { - auto op = tvm::ffi::make_object(*this); - return Conv2DIm2ColOp(op); -} - -/** - * @brief Lower Conv2D im2col into a TMA-backed PTX sequence for Hopper. - * - * Constructs a TMA im2col descriptor from the Conv2DIm2ColOp parameters - * (kernel, stride, dilation, padding, channel/image tiling, dtype and shapes), - * emits a call to create the im2col descriptor, and returns a statement that - * invokes the corresponding tma_load_im2col builtin guarded to a single - * thread. The lowering assumes the destination resides in shared memory and the - * source in global memory and uses the provided layout information (when - * available) to select the appropriate shared-memory swizzle. - * - * Preconditions (checked with ICHECK): - * - Target is Hopper. - * - src.scope() == "global" and dst.scope() is "shared.dyn" or "shared". - * - src->shape has rank 4 and dst->shape has rank 2. - * - src and dst have the same dtype. - * - When a shared layout is supplied it must match a recognized TMA swizzle - * pattern (32B/64B/128B) or an ICHECK will fail. - * - * @param T Lowering context (target, layout map, thread_var, thread_bounds, - * buffer remapping, etc.). Used to fetch target/layout and to emit a - * thread-guarded TMA call. - * @param analyzer Arithmetic analyzer used to prove divisibility and simplify - * expressions required by descriptor construction. - * @return Stmt A TIR statement that performs a tma_load_im2col call wrapped in - * a thread-min guard (IfThenElse). The returned statement is ready - * to be inserted into the lowered TIR. - */ -Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, - arith::Analyzer *analyzer) const { - ICHECK(TargetIsHopper(T.target)); - ICHECK(src_.scope() == "global" && - (dst_.scope() == "shared.dyn" || dst_.scope() == "shared")); - ICHECK(src_->shape.size() == 4); - ICHECK(dst_->shape.size() == 2); - ICHECK(src_->dtype == dst_->dtype); - Layout shared_layout; - if (T.layout_map.count(dst_)) { - shared_layout = T.layout_map[dst_]; - } - - TMAIm2ColDesc desc; - desc.rank = src_->shape.size(); - desc.data_type = to_CUtensorMapDataType(src_->dtype); - desc.global_addr = src_->data; - desc.global_shape = ReverseArray(src_->shape); - - if (!src_->strides.empty()) { - desc.global_stride = ReverseArray(src_->strides); - } else { - // Create stride from shape - PrimExpr stride = 1; - desc.global_stride.reserve(desc.rank); - for (size_t i = 0; i < desc.rank; i++) { - desc.global_stride.push_back(stride); - stride *= desc.global_shape[i]; - } - } - // The first stride element should be 1 - ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; - // Make global stride in bytes - desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { - return cast(DataType::Int(64), e) * src_->dtype.bytes(); - }); - desc.elem_stride = {1, stride_, stride_, 1}; - desc.lower_corner = {-padding_, -padding_}; - desc.upper_corner = {-padding_, -padding_}; - desc.smem_box_pixel = Downcast(dst_->shape[0])->value; - desc.smem_box_channel = Downcast(dst_->shape[1])->value; - desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); - desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); - if (!shared_layout.defined()) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); - } else { - ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout."; - auto stride = as_const_int(shared_layout->InputShape()[0]); - auto continuous = as_const_int(shared_layout->InputShape()[1]); - ICHECK(stride != nullptr && continuous != nullptr); - - if (StructuralEqual()(shared_layout, - makeQuarterBankSwizzleLayout(*stride, *continuous, - dst_->dtype.bits()))) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); - } else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout( - *stride, *continuous, - dst_->dtype.bits()))) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); - } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout( - *stride, *continuous, - dst_->dtype.bits()))) { - desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); - } else { - ICHECK(0) << "Cannot detect TMA layout."; - } - } - - Call create_desc = Call(DataType::Handle(), create_tma_im2col_descriptor(), - desc.EncodeCallArgs()); - - Array global_coords; // c, w, h, n - Array image_offset; // w, h - global_coords.reserve(desc.rank); - - ICHECK(analyzer->CanProveEqual( - FloorMod(desc.global_shape[0], desc.smem_box_channel), 0)) - << "Currently can only support divisible channel case"; - - global_coords.push_back( - FloorMod(c_step_ * desc.smem_box_channel, desc.global_shape[0])); - image_offset.push_back( - dilation_ * - FloorMod(FloorDiv(c_step_ * desc.smem_box_channel, desc.global_shape[0]), - kernel_)); - image_offset.push_back(dilation_ * FloorDiv(c_step_ * desc.smem_box_channel, - desc.global_shape[0] * kernel_)); - - PrimExpr h_dim = - FloorDiv(src_->shape[1] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1, - stride_) + - 1; - PrimExpr w_dim = - FloorDiv(src_->shape[2] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1, - stride_) + - 1; - global_coords.push_back( - stride_ * FloorMod(nhw_step_ * desc.smem_box_pixel, w_dim) - padding_); - global_coords.push_back( - stride_ * - FloorMod(FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim), h_dim) - - padding_); - global_coords.push_back( - FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim * h_dim)); - - Array args; - args.reserve(desc.rank * 2 + 2); - args.push_back(create_desc); - args.push_back(0); // mbar placeholder - auto dst_buffer = T.buffer_remap.count(dst_) ? T.buffer_remap[dst_] : dst_; - auto shared_addr = dst_buffer.access_ptr(2); - args.push_back(shared_addr); - for (auto coord : global_coords) - args.push_back(coord); - for (auto offset : image_offset) - args.push_back(offset); - args.push_back(this->eviction_policy_); - Stmt tma_copy = - IfThenElse(EQ(T.thread_var, T.thread_bounds->min), - Evaluate(Call(DataType::Handle(), tma_load_im2col(), args))); - return tma_copy; -} - -/*! - * \brief Encode the TMA im2col descriptor into an array of PrimExpr. - * This function serializes the TMA im2col descriptor fields for passing to the - * create_tma_im2col_descriptor() builtin function. It includes - * convolution-specific parameters like kernel size, stride, padding, and - * dilation in addition to standard tensor descriptor fields. \return Array of - * PrimExpr representing the encoded TMA im2col descriptor. - */ -Array TMAIm2ColDesc::EncodeCallArgs() const { - Array args; - args.reserve(rank * 5 + 5); - - args.push_back(data_type); - args.push_back(static_cast(rank)); - args.push_back(global_addr); - for (auto e : global_shape) - args.push_back(e); - for (auto e : global_stride) - args.push_back(e); - for (auto e : elem_stride) - args.push_back(e); - for (auto e : lower_corner) - args.push_back(e); - for (auto e : upper_corner) - args.push_back(e); - args.push_back(smem_box_pixel); - args.push_back(smem_box_channel); - args.push_back(interleave); - args.push_back(swizzle); - args.push_back(l2_promotion); - args.push_back(oob_fill); - - return args; -} - -// Register the Copy operation with TVM's TIR system -// This makes the copy operation available for use in TVM programs -// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, -// eviction_policy -// - Marked as opaque since it has side effects (memory writes) -TIR_REGISTER_TL_TILE_OP(Copy, copy) - .set_num_inputs(5) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - -/** - * @brief Layout inference hook for Conv2DIm2ColOpNode. - * - * This operator does not provide any layout inference; the function - * intentionally returns an empty LayoutMap to indicate no layout suggestions. - * - * @param T Context for layout inference (ignored). - * @param level Inference level (ignored). - * @return LayoutMap An empty map. - */ -LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { - return {}; -} - -// Register the Conv2DIm2Col operation with TVM's TIR system -// This operation performs im2col transformation for 2D convolutions using TMA -// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride, -// dilation, padding, eviction_policy -// - Marked as opaque since it has side effects (memory writes) -TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col) - .set_num_inputs(9) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - -TVM_FFI_STATIC_INIT_BLOCK() { - CopyNode::RegisterReflection(); - Conv2DIm2ColOpNode::RegisterReflection(); -} -} // namespace tl -} // namespace tvm diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index cabc4a3e4..ae2c46322 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -67,6 +67,13 @@ def get_extent(data): # Combine the nested if statements into a single if statement as suggested by SIM102 if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and isinstance(dst, tir.BufferLoad)): + # FIXME + # For Sunmmio an invalid D<->D copy operation will enter here, for example: + # T.copy(C[by * block_M, ko * block_K], B[by * block_M, ko * block_K]) -> + # for ko in T.serial(4, annotations={"num_stages": 3}): + # B[by * 32, ko * 32] = C[by * 32, ko * 32] + # which causes an exception can't be caught. + # # check if the case is like this: # copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes # In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i] From 49314ae47f8956f73b55f705baf1a1bfd5d8c742 Mon Sep 17 00:00:00 2001 From: Jiaqi Guo Date: Wed, 11 Feb 2026 17:11:39 +0800 Subject: [PATCH 5/6] Fix copy.cc --- src/op/copy.cc | 509 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 508 insertions(+), 1 deletion(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 96bef70bf..76504f2de 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -949,7 +949,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, * determined copy instruction type: * - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions * - LDSM/STSM: Uses matrix load/store instructions for tensor cores - * - DMA Copy: Sunmmio specified instructions for copy + * - DMA Load/Store: Sunmmio specified instructions for copy * - Normal: Uses standard load/store operations with loop transformations * \param T LowerArgs containing target and layout map. * \param analyzer Arithmetic analyzer for simplification. @@ -1685,3 +1685,510 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); } else if (StructuralEqual()( shared_layout, + makeHalfBankSwizzleLayout(*stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); + } else if (StructuralEqual()( + shared_layout, + makeFullBankSwizzleLayout(*stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); + } else if (StructuralEqual()( + shared_layout, + makeGemmABLayoutPadded(*stride, *continuous, + shared_tensor->dtype.bits()))) { + LOG(WARNING) << "Bulk copy cannot support a padded layout for src: " + << src->name << ", dst: " << dst->name + << ", fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } else { + LOG(WARNING) << "Came across unsupported swizzle layout for src: " + << src->name << ", dst: " << dst->name + << ", fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } + } + + auto inner_box_dim = as_const_int(desc.smem_box[0]); + if (inner_box_dim == nullptr) { + LOG(WARNING) << "inner_box_dim " << desc.smem_box[0] + << " can only be a constant integer for TMA bulk copy, " + "fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } + int instruction_dim = *inner_box_dim; + if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_64B)) { + instruction_dim = 64 / src->dtype.bytes(); + } else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_128B)) { + instruction_dim = 128 / src->dtype.bytes(); + } + if (instruction_dim > 256) { + // smem_box dim must be in [0, 256] + // if is 512, we need to split the copy into two parts + ICHECK((*inner_box_dim) % 256 == 0) + << "inner_box_dim: " << *inner_box_dim << " is not divisible by 256"; + instruction_dim = 256; + } + ICHECK((*inner_box_dim) % instruction_dim == 0) + << "inner_box_dim: " << *inner_box_dim + << " is not divisible by instruction_dim: " << instruction_dim; + desc.smem_box.Set(0, PrimExpr(instruction_dim)); + + int inner_box_dim_ = instruction_dim * shared_tensor->dtype.bytes(); + + // Check inner_box_dim_ for each swizzle type in a cleaner way + struct SwizzleCheck { + int swizzle; + int max_dim; + }; + static const std::vector swizzle_checks = { + {static_cast(CU_TENSOR_MAP_SWIZZLE_32B), 32}, + {static_cast(CU_TENSOR_MAP_SWIZZLE_64B), 64}, + {static_cast(CU_TENSOR_MAP_SWIZZLE_128B), 128}, + }; + for (const auto &check : swizzle_checks) { + if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) { + LOG(WARNING) << "TMA bulk copy cannot support a swizzled global layout " + "with inner_box_dim_ > " + << check.max_dim << ", will be fallback to normal copy"; + return LowerNormalCopy(T, analyzer); + } + } + + Call create_descriptor = + Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); + + Array args; + args.reserve(desc.rank + 4); + args.push_back(create_descriptor); + if (is_load) + args.push_back(0); // mbarrier id placeholder + auto op = is_load ? tma_load() : tma_store(); + + Stmt tma_copy; + PrimExpr total_elements = 1; + for (auto e : desc.smem_box) + total_elements *= e; + + if ((*inner_box_dim) != instruction_dim) { + Var loop_var("i"); + int loop_extent = (*inner_box_dim) / instruction_dim; + + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, + shared_offset + total_elements * loop_var, total_elements); + args.push_back(shared_addr); + global_coords.Set(0, global_coords[0] + instruction_dim * loop_var); + for (auto coord : global_coords) + args.push_back(coord); + int need_reduce = 0; + if (!is_load) + args.push_back(need_reduce); + args.push_back(this->eviction_policy); + tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, + Evaluate(Call(DataType::Handle(), op, args))); + } else { + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements); + args.push_back(shared_addr); + for (auto coord : global_coords) + args.push_back(coord); + int need_reduce = 0; + if (!is_load) + args.push_back(need_reduce); + args.push_back(this->eviction_policy); + tma_copy = Evaluate(Call(DataType::Handle(), op, args)); + } + tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); + + return tma_copy; +} + +Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const { + ICHECK(copy_inst == CopyInst::kBulkLoad1D || + copy_inst == CopyInst::kBulkStore1D); + + // Add 1D TMA copy when the global and shared memory is contiguous + // Check if shared_tensor->name is present in T.buffer_var_gemm + // (Array) to avoid use 1D TMA copy for swizzled layout + bool is_load = copy_inst == CopyInst::kBulkLoad1D; + auto shared_range = is_load ? dst_range : src_range; + auto global_range = is_load ? src_range : dst_range; + auto shared_tensor = is_load ? dst : src; + auto global_tensor = is_load ? src : dst; + + PrimExpr shared_elements = 1; + for (size_t i = 0; i < shared_range.size(); i++) { + shared_elements *= shared_range[i]->extent; + } + + std::vector shared_strides; + PrimExpr shared_stride = 1; + for (size_t i = 0; i < shared_tensor->shape.size(); i++) { + auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; + shared_strides.insert(shared_strides.begin(), shared_stride); + shared_stride *= s; + } + + Array shared_indices; + for (auto r : shared_range) + shared_indices.push_back(r->min); + + Array global_indices; + for (auto r : global_range) { + global_indices.push_back(r->min); + } + std::vector global_strides; + PrimExpr global_stride = 1; + for (size_t i = 0; i < global_tensor->shape.size(); i++) { + auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; + global_strides.insert(global_strides.begin(), global_stride); + global_stride *= s; + } + + PrimExpr global_offset = 0; + for (size_t i = 0; i < global_indices.size(); i++) { + global_offset += global_indices[i] * global_strides[i]; + } + + PrimExpr shared_offset = 0; + for (size_t i = 0; i < shared_indices.size(); i++) { + shared_offset += shared_indices[i] * shared_strides[i]; + } + + PrimExpr elements = analyzer->Simplify(shared_elements); + PrimExpr shared_addr = shared_tensor.access_ptr( + is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, elements); + PrimExpr global_addr = global_tensor.access_ptr( + is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements); + Stmt tma_copy; + if (is_load) { + // the zero is a placeholder for mbarrier ids + tma_copy = Evaluate( + Call(DataType::Handle(), tma_load(), + {shared_addr, global_addr, 0, + elements * shared_tensor->dtype.bytes(), this->eviction_policy})); + } else { + int need_reduce = 0; + tma_copy = Evaluate( + Call(DataType::Handle(), tma_store(), + {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(), + need_reduce, this->eviction_policy})); + } + tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); + return tma_copy; +} +/*! + * \brief Encode the TMA descriptor into an array of PrimExpr. + * This function serializes the TMA descriptor fields into a format suitable for + * passing to the create_tma_descriptor() builtin function. The encoding follows + * the expected argument order for the TMA descriptor creation. + * \return Array of PrimExpr representing the encoded TMA descriptor. + */ +Array TMADesc::EncodeCallArgs() const { + Array args; + args.reserve(rank * 4 + 7); + + args.push_back(data_type); + args.push_back(static_cast(rank)); + args.push_back(global_addr); + for (auto e : global_shape) + args.push_back(e); + for (auto e : global_stride) + args.push_back(e); + for (auto e : smem_box) + args.push_back(e); + for (auto e : smem_stride) + args.push_back(e); + args.push_back(interleave); + args.push_back(swizzle); + args.push_back(l2_promotion); + args.push_back(oob_fill); + + return args; +} + +/** + * @brief Construct a Conv2DIm2ColOp node. + * + * Initializes a Conv2DIm2ColOpNode from raw TL-call arguments and a buffer map. + * The constructor extracts source and destination Buffers from vmap and reads + * convolution parameters encoded in args: + * - args[0]: source tensor access pointer + * - args[1]: destination tensor access pointer + * - args[2]: nhw_step (PrimExpr) + * - args[3]: c_step (PrimExpr) + * - args[4]: kernel (IntImm) + * - args[5]: stride (IntImm) + * - args[6]: dilation (IntImm) + * - args[7]: padding (IntImm) + * - args[8]: eviction_policy (IntImm) + * + * The created node stores these values (src, dst, nhw_step, c_step, kernel, + * stride, dilation, padding, eviction_policy) for later lowering to TMA-based + * GPU intrinsics. + * + * @param args Array of PrimExpr TL-call arguments (see list above). + */ +Conv2DIm2ColOp::Conv2DIm2ColOp(Array args) { + ObjectPtr node = + tvm::ffi::make_object(); + node->srcRegion_ = NormalizeToBufferRegion(args[0]); + node->dstRegion_ = NormalizeToBufferRegion(args[1]); + node->src_ = node->srcRegion_->buffer; + node->dst_ = node->dstRegion_->buffer; + node->nhw_step_ = args[2]; + node->c_step_ = args[3]; + node->kernel_ = args[4].as().value()->value; + node->stride_ = args[5].as().value()->value; + node->dilation_ = args[6].as().value()->value; + node->padding_ = args[7].as().value()->value; + node->eviction_policy_ = args[8].as().value()->value; + data_ = std::move(node); +} + +/** + * @brief Create a shallow copy of this Conv2DIm2ColOpNode wrapped as a + * TileOperator. + * + * Produces a new Conv2DIm2ColOp that owns a freshly allocated + * Conv2DIm2ColOpNode initialized from this node (member-wise copy). This is + * used to duplicate the operator node for compiler passes that require + * independent operator instances. + * + * @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode. + */ +TileOperator Conv2DIm2ColOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return Conv2DIm2ColOp(op); +} + +/** + * @brief Lower Conv2D im2col into a TMA-backed PTX sequence for Hopper. + * + * Constructs a TMA im2col descriptor from the Conv2DIm2ColOp parameters + * (kernel, stride, dilation, padding, channel/image tiling, dtype and shapes), + * emits a call to create the im2col descriptor, and returns a statement that + * invokes the corresponding tma_load_im2col builtin guarded to a single + * thread. The lowering assumes the destination resides in shared memory and the + * source in global memory and uses the provided layout information (when + * available) to select the appropriate shared-memory swizzle. + * + * Preconditions (checked with ICHECK): + * - Target is Hopper. + * - src.scope() == "global" and dst.scope() is "shared.dyn" or "shared". + * - src->shape has rank 4 and dst->shape has rank 2. + * - src and dst have the same dtype. + * - When a shared layout is supplied it must match a recognized TMA swizzle + * pattern (32B/64B/128B) or an ICHECK will fail. + * + * @param T Lowering context (target, layout map, thread_var, thread_bounds, + * buffer remapping, etc.). Used to fetch target/layout and to emit a + * thread-guarded TMA call. + * @param analyzer Arithmetic analyzer used to prove divisibility and simplify + * expressions required by descriptor construction. + * @return Stmt A TIR statement that performs a tma_load_im2col call wrapped in + * a thread-min guard (IfThenElse). The returned statement is ready + * to be inserted into the lowered TIR. + */ +Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + ICHECK(TargetIsHopper(T.target)); + ICHECK(src_.scope() == "global" && + (dst_.scope() == "shared.dyn" || dst_.scope() == "shared")); + ICHECK(src_->shape.size() == 4); + ICHECK(dst_->shape.size() == 2); + ICHECK(src_->dtype == dst_->dtype); + Layout shared_layout; + if (T.layout_map.count(dst_)) { + shared_layout = T.layout_map[dst_]; + } + + TMAIm2ColDesc desc; + desc.rank = src_->shape.size(); + desc.data_type = to_CUtensorMapDataType(src_->dtype); + desc.global_addr = src_->data; + desc.global_shape = ReverseArray(src_->shape); + + if (!src_->strides.empty()) { + desc.global_stride = ReverseArray(src_->strides); + } else { + // Create stride from shape + PrimExpr stride = 1; + desc.global_stride.reserve(desc.rank); + for (size_t i = 0; i < desc.rank; i++) { + desc.global_stride.push_back(stride); + stride *= desc.global_shape[i]; + } + } + // The first stride element should be 1 + ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; + // Make global stride in bytes + desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { + return cast(DataType::Int(64), e) * src_->dtype.bytes(); + }); + desc.elem_stride = {1, stride_, stride_, 1}; + desc.lower_corner = {-padding_, -padding_}; + desc.upper_corner = {-padding_, -padding_}; + desc.smem_box_pixel = Downcast(dst_->shape[0])->value; + desc.smem_box_channel = Downcast(dst_->shape[1])->value; + desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); + desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); + if (!shared_layout.defined()) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } else { + ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout."; + auto stride = as_const_int(shared_layout->InputShape()[0]); + auto continuous = as_const_int(shared_layout->InputShape()[1]); + ICHECK(stride != nullptr && continuous != nullptr); + + if (StructuralEqual()(shared_layout, + makeQuarterBankSwizzleLayout(*stride, *continuous, + dst_->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); + } else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout( + *stride, *continuous, + dst_->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); + } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout( + *stride, *continuous, + dst_->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); + } else { + ICHECK(0) << "Cannot detect TMA layout."; + } + } + + Call create_desc = Call(DataType::Handle(), create_tma_im2col_descriptor(), + desc.EncodeCallArgs()); + + Array global_coords; // c, w, h, n + Array image_offset; // w, h + global_coords.reserve(desc.rank); + + ICHECK(analyzer->CanProveEqual( + FloorMod(desc.global_shape[0], desc.smem_box_channel), 0)) + << "Currently can only support divisible channel case"; + + global_coords.push_back( + FloorMod(c_step_ * desc.smem_box_channel, desc.global_shape[0])); + image_offset.push_back( + dilation_ * + FloorMod(FloorDiv(c_step_ * desc.smem_box_channel, desc.global_shape[0]), + kernel_)); + image_offset.push_back(dilation_ * FloorDiv(c_step_ * desc.smem_box_channel, + desc.global_shape[0] * kernel_)); + + PrimExpr h_dim = + FloorDiv(src_->shape[1] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1, + stride_) + + 1; + PrimExpr w_dim = + FloorDiv(src_->shape[2] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1, + stride_) + + 1; + global_coords.push_back( + stride_ * FloorMod(nhw_step_ * desc.smem_box_pixel, w_dim) - padding_); + global_coords.push_back( + stride_ * + FloorMod(FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim), h_dim) - + padding_); + global_coords.push_back( + FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim * h_dim)); + + Array args; + args.reserve(desc.rank * 2 + 2); + args.push_back(create_desc); + args.push_back(0); // mbar placeholder + auto dst_buffer = T.buffer_remap.count(dst_) ? T.buffer_remap[dst_] : dst_; + auto shared_addr = dst_buffer.access_ptr(2); + args.push_back(shared_addr); + for (auto coord : global_coords) + args.push_back(coord); + for (auto offset : image_offset) + args.push_back(offset); + args.push_back(this->eviction_policy_); + Stmt tma_copy = + IfThenElse(EQ(T.thread_var, T.thread_bounds->min), + Evaluate(Call(DataType::Handle(), tma_load_im2col(), args))); + return tma_copy; +} + +/*! + * \brief Encode the TMA im2col descriptor into an array of PrimExpr. + * This function serializes the TMA im2col descriptor fields for passing to the + * create_tma_im2col_descriptor() builtin function. It includes + * convolution-specific parameters like kernel size, stride, padding, and + * dilation in addition to standard tensor descriptor fields. \return Array of + * PrimExpr representing the encoded TMA im2col descriptor. + */ +Array TMAIm2ColDesc::EncodeCallArgs() const { + Array args; + args.reserve(rank * 5 + 5); + + args.push_back(data_type); + args.push_back(static_cast(rank)); + args.push_back(global_addr); + for (auto e : global_shape) + args.push_back(e); + for (auto e : global_stride) + args.push_back(e); + for (auto e : elem_stride) + args.push_back(e); + for (auto e : lower_corner) + args.push_back(e); + for (auto e : upper_corner) + args.push_back(e); + args.push_back(smem_box_pixel); + args.push_back(smem_box_channel); + args.push_back(interleave); + args.push_back(swizzle); + args.push_back(l2_promotion); + args.push_back(oob_fill); + + return args; +} + +// Register the Copy operation with TVM's TIR system +// This makes the copy operation available for use in TVM programs +// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, +// eviction_policy +// - Marked as opaque since it has side effects (memory writes) +TIR_REGISTER_TL_TILE_OP(Copy, copy) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +/** + * @brief Layout inference hook for Conv2DIm2ColOpNode. + * + * This operator does not provide any layout inference; the function + * intentionally returns an empty LayoutMap to indicate no layout suggestions. + * + * @param T Context for layout inference (ignored). + * @param level Inference level (ignored). + * @return LayoutMap An empty map. + */ +LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + return {}; +} + +// Register the Conv2DIm2Col operation with TVM's TIR system +// This operation performs im2col transformation for 2D convolutions using TMA +// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride, +// dilation, padding, eviction_policy +// - Marked as opaque since it has side effects (memory writes) +TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col) + .set_num_inputs(9) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { + CopyNode::RegisterReflection(); + Conv2DIm2ColOpNode::RegisterReflection(); +} +} // namespace tl +} // namespace tvm From c03c8ccff8b92213297f6abaf6caf975424a1202 Mon Sep 17 00:00:00 2001 From: Jiaqi Guo Date: Wed, 11 Feb 2026 17:12:45 +0800 Subject: [PATCH 6/6] Improve comment consistency --- src/op/copy.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 76504f2de..f09b8a05f 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -949,7 +949,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, * determined copy instruction type: * - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions * - LDSM/STSM: Uses matrix load/store instructions for tensor cores - * - DMA Load/Store: Sunmmio specified instructions for copy + * - DMA copy: Sunmmio specified instructions for copy * - Normal: Uses standard load/store operations with loop transformations * \param T LowerArgs containing target and layout map. * \param analyzer Arithmetic analyzer for simplification.