diff --git a/src/op/builtin_dma.cc b/src/op/builtin_dma.cc new file mode 100644 index 000000000..c332a78f2 --- /dev/null +++ b/src/op/builtin_dma.cc @@ -0,0 +1,41 @@ +/* + *\file tl/op/builtin_dma.CC + *\brief DMA builtin intrinscs for SUNMMIO GPU + *\separated from the origional Tilelang's code + + */ + +#include "builtin_dma.h" + +#include +#include +#include + +#include "../target/cuda.h" +#include "../target/utils.h" + +namespace tvm { +namespace tl { + + +#define TIR_DEFINE_TL_BUILTIN(OpName) \ + const Op &OpName() { \ + static const Op &op = Op::Get("tl." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tl." #OpName) \ + .set_attr("TScriptPrinterName", #OpName) + +TIR_DEFINE_TL_BUILTIN(create_dma_descriptor) +.set_num_inputs(-1) +.set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +} // namespace tl +} // namespace tvm diff --git a/src/op/builtin_dma.h b/src/op/builtin_dma.h new file mode 100644 index 000000000..4f847e7fb --- /dev/null +++ b/src/op/builtin_dma.h @@ -0,0 +1,52 @@ +/* + *\file tl/op/builtin_dma.h + *\brief DMA builtin intrinscs for SUNMMIO GPU + *\separated from the origional Tilelang's code + + */ + +#ifndef TVM_TL_OP_BUILTIN_DMA_H_ +#define TVM_TL_OP_BUILTIN_DMA_H_ + + +#include "operator.h" +#include + + +namespace tvm { + +namespace tl { + +/*! + * \brief tvm intrinsics for DMADescriptor creation for tiled load + * + * CuTensorMap* create_dma_descriptor(data_type, rank, global_addr, + * global_shape..., global_stride..., smem_box..., smem_stride..., interleave, + * swizzle, l2_promotion, oob_fill) + * + */ +TVM_DLL const Op &create_dma_descriptor(); + + +/*! + * \brief tvm intrinsics for loading data from global tensor descriptor to + * shared memory for DMA + * + + * + */ +TVM_DLL const Op &dma_load(); + + +/*! + * \brief tvm intrinsics for storing data from shared memory to global tensor + * descriptor for DMA + * + * + */ +TVM_DLL const Op &dma_store(); + +} // namespace tl +} // namespace tvm + +#endif \ No newline at end of file diff --git a/testing/python/language/test_tilelang_mesh_language_DMA_descriptor.py b/testing/python/language/test_tilelang_mesh_language_DMA_descriptor.py new file mode 100644 index 000000000..0338761b0 --- /dev/null +++ b/testing/python/language/test_tilelang_mesh_language_DMA_descriptor.py @@ -0,0 +1,88 @@ +""" +Docstring for examples.mytest.testDMADescriptor +just to test create_dma_descriptor() intrin +to run this script successfully, you should +1. remove tilelang/language/dma.py and line "from .dma import dma_load, dma_store # noqa: F401" in tilelang/language/__init__.py +2. add `def create_dma_descriptor(*args): + return tir.call_intrin("handle", tir.op.Op.get("tl.create_dma_descriptor"), *args)` + in tilelang/builtin.py and decomment the dma_load and dma_store functions in the same file. +""" + +from tilelang import tvm as tvm +import tilelang.language as T + + +def test_multi_version_buffer(M, N, K, dtype, block_M, block_N, block_K): + + @T.prim_func + def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + v = T.launch_thread("threadIdx.x", 128) + with T.block(""): + T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) + T.writes() + A_shared = T.alloc_buffer((1, 8, 256), "float16", scope="shared.dyn") + B_shared = T.alloc_buffer((1, 4, 512), "float16", scope="shared.dyn") + C_local = T.alloc_buffer((32,), scope="local") + for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): + for vec in T.vectorized(2): + C_local[i * 2 + vec] = T.float32(0) + for k in T.serial(16, annotations={"num_stages": T.int32(3)}): + if v == 0: + T.dma_load( + T.create_dma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, + 2, 0), 0, + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2), + k * 32, by * 64) + if v == 0: + T.dma_load( + T.create_dma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, + 2, 0), 0, + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2), + bx * 64, k * 32) + T.call_extern( + "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + + return before + + +expected_result = """# from tvm.script import tir as T + +@T.prim_func +def before(A_handle: T.handle, B_handle: T.handle): + A = T.match_buffer(A_handle, (512, 512), "float16", strides=(512, 1)) + B = T.match_buffer(B_handle, (512, 512), "float16", strides=(512, 1)) + # with T.block("root"): + v = T.launch_thread("blockIdx.x", 8) + v_1 = T.launch_thread("blockIdx.y", 8) + v_2 = T.launch_thread("threadIdx.x", 128) + with T.block(""): + T.reads(A[v_1 * 64, 0:481], B[0:481, v * 64]) + T.writes() + A_shared = T.alloc_buffer((1, 8, 256), "float16", scope="shared.dyn") + B_shared = T.alloc_buffer((1, 4, 512), "float16", scope="shared.dyn") + C_local = T.alloc_buffer((32,), scope="local") + for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): + for vec in T.vectorized(2): + C_local[i * 2 + vec] = T.float32(0.0) + for k in T.serial(16, annotations={"num_stages": 3}): + if v_2 == 0: + T.dma_load(T.create_dma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), 0, T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2), k * 32, v_1 * 64) + if v_2 == 0: + T.dma_load(T.create_dma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), 0, T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2), v * 64, k * 32) + T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))""" + +M = 512 +N = 512 +K = 512 +dtype = "float16" +block_M = 64 +block_N = 64 +block_K = 32 +test_kernel = test_multi_version_buffer(M, N, K, dtype, block_M, block_N, block_K) +test_kernel.show() +print(test_kernel.script() == expected_result) diff --git a/testing/python/language/test_tilelang_mesh_language_DMA_op.py b/testing/python/language/test_tilelang_mesh_language_DMA_op.py new file mode 100644 index 000000000..19b2f150e --- /dev/null +++ b/testing/python/language/test_tilelang_mesh_language_DMA_op.py @@ -0,0 +1,241 @@ +import tilelang.language as T + +# from mesh_layout import MeshLayout, ReplicationType +from typing import Tuple + + +class MeshTensorDescriptor: + """ + MeshTensorDescriptor类,用在编译时描述MeshTensor的布局信息 + 参数: + global_shape: 多元元组,元素个数可以是1, 2, 3, 4等 + mesh_layout: 二维元组,默认值为(4, 4) + axis_partitions: 元组,元素个数可以是一个或两个 + """ + + def __init__(self, global_shape: Tuple[int, ...], mesh_layout: Tuple[int, int], + axis_partitions: Tuple[int, ...]): + self.global_shape = global_shape + self.mesh_layout = mesh_layout + self.axis_partitions = axis_partitions + self.mesh_size = mesh_layout[0] * mesh_layout[1] + + def get_local_shape(self) -> Tuple[int, ...]: + local_shape = list(self.global_shape) + if len(self.axis_partitions) == 1: + dim_idx = self.axis_partitions[0] + local_shape[dim_idx] = self.global_shape[dim_idx] // self.mesh_size + elif len(self.axis_partitions) == 2: + dim_idx_0 = self.axis_partitions[0] + dim_idx_1 = self.axis_partitions[1] + local_shape[dim_idx_0] = self.global_shape[dim_idx_0] // self.mesh_layout[0] + local_shape[dim_idx_1] = self.global_shape[dim_idx_1] // self.mesh_layout[1] + return tuple(local_shape) + + +# def make_mesh_tensor_descriptor(): + + +def get_gpu_info(): + return (4, 4) # For example, a 4x4 core mesh + + +def flashattn_fwd(batch, heads, seq_len, dim, block_M, block_N): + + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] # 32b 16h + dtype = "bfloat16" + accum_dtype = "float32" + + currunt_mesh_layout = get_gpu_info() # (4, 4) or (4, 3)... + shape = MeshTensorDescriptor( + global_shape=(batch, seq_len, heads, dim), + mesh_layout=currunt_mesh_layout, + axis_partitions=(0, 2) # partition along batch and heads dimensions + ).get_local_shape() + print(f"\n******Local shape on each core: {shape}*****\n") + batch, seq_len, heads, dim = shape + + @T.prim_func + def flash_attention( + Q: T.Tensor((batch, seq_len, heads, dim), dtype), + K: T.Tensor((batch, seq_len, heads, dim), dtype), + V: T.Tensor((batch, seq_len, heads, dim), dtype), + Output: T.Tensor((batch, seq_len, heads, dim), dtype), + ): + # Launch a specialized T.Kernel with 3D mapping: (bx, by, bz) + # bx: block index in sequence dimension + # by: block index in "heads" dimension + # bz: block index in "batch" dimension + # Assume each core is responsible for a block of size (block_M, dim) for Q and (block_N, dim) for K, V + + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch) as (bx, by, bz): + # Allocate shared memory for Q, K, V to reduce global memory accesses + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + # Allocate buffers on register + # acc_s: buffer to hold intermediate attention scores + acc_s = T.alloc_shared([block_M, block_N], accum_dtype) + # acc_s_cast: buffer for storing casted/adjusted scores + acc_s_cast = T.alloc_shared([block_M, block_N], dtype) + # acc_o: partial accumulation of output + acc_o = T.alloc_shared([block_M, dim], accum_dtype) + # Buffers to track per-row maximum score and related stats + scores_max = T.alloc_shared([block_M], accum_dtype) + scores_max_prev = T.alloc_shared([block_M], accum_dtype) + scores_scale = T.alloc_shared([block_M], accum_dtype) + scores_sum = T.alloc_shared([block_M], accum_dtype) + logsum = T.alloc_shared([block_M], accum_dtype) + + # Copy a block of Q from global memory to Q_shared + T.dma_load(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + + # Initialize accumulators + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((bx + 1) * block_M, block_N) + + # Pipeline the loop to overlap copies/gemm stages + for k in T.Pipelined(loop_range, num_stages=3): + # Copy K block into shared memory + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + # Perform the Q*K^T multiplication, Here, transpose_B=True indicates that K_shared is transposed, + # policy=T.GemmWarpPolicy.FullRow means each warp is responsible for computing an entire row + # of acc_s, and the resulting acc_s is retained in registers. + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Copy V block into shared memory + T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + for i, j in T.Parallel(block_M, dim): + acc_s[i, j] *= scale + + # Save old scores_max, then reset scores_max + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # Compute the maximum value per row on dimension 1 (block_N) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + # Compute the factor by which we need to rescale previous partial sums + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + + # Rescale the partial output accumulation to keep exponents consistent + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + # Exponentiate (scores - max) for the new block + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + + # Make a cast of acc_s to fp16 for the next GEMM + T.copy(acc_s, acc_s_cast) + + # Multiply the attention acc_s_cast by V and add to partial output (acc_o) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + # Update the "logsum" tracker with the newly accumulated sum + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + # Final step: divide each partial output by logsum (completing the softmax) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + + # Write back the final output block from acc_o to the Output buffer + # T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + T.dma_store(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + + return flash_attention + + +flashattn = flashattn_fwd(batch=32, heads=16, seq_len=4096, dim=64, block_M=64, block_N=64) +# flashattn.show() + +expected_print = """# from tvm.script import tir as T + +@T.prim_func +def flash_attention(Q_handle: T.handle, K_handle: T.handle, V_handle: T.handle, Output_handle: T.handle): + Q = T.match_buffer(Q_handle, (8, 4096, 4, 64), "bfloat16", strides=(1048576, 256, 64, 1)) + K = T.match_buffer(K_handle, (8, 4096, 4, 64), "bfloat16", strides=(1048576, 256, 64, 1)) + V = T.match_buffer(V_handle, (8, 4096, 4, 64), "bfloat16", strides=(1048576, 256, 64, 1)) + Output = T.match_buffer(Output_handle, (8, 4096, 4, 64), "bfloat16", strides=(1048576, 256, 64, 1)) + # with T.block("root"): + bx = T.launch_thread("blockIdx.x", 64) + by = T.launch_thread("blockIdx.y", 4) + bz = T.launch_thread("blockIdx.z", 8) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + loop_range = T.int32() + T.reads(Q[bz, bx * 64, by, 0], K[bz, 0:loop_range * 64 - 63, by, 0], V[bz, 0:loop_range * 64 - 63, by, 0], Output[bz, bx * 64, by, 0]) + T.writes() + Q_shared = T.alloc_buffer((64, 64), "bfloat16", scope="shared.dyn") + K_shared = T.alloc_buffer((64, 64), "bfloat16", scope="shared.dyn") + V_shared = T.alloc_buffer((64, 64), "bfloat16", scope="shared.dyn") + acc_s = T.alloc_buffer((64, 64), scope="shared.dyn") + acc_s_cast = T.alloc_buffer((64, 64), "bfloat16", scope="shared.dyn") + acc_o = T.alloc_buffer((64, 64), scope="shared.dyn") + scores_max = T.alloc_buffer((64,), scope="shared.dyn") + scores_max_prev = T.alloc_buffer((64,), scope="shared.dyn") + scores_scale = T.alloc_buffer((64,), scope="shared.dyn") + scores_sum = T.alloc_buffer((64,), scope="shared.dyn") + logsum = T.alloc_buffer((64,), scope="shared.dyn") + acc_s_frag = T.alloc_buffer((64, 64), scope="local.fragment") + scores_max_frag = T.alloc_buffer((64,), scope="local.fragment") + acc_s_frag_1 = T.alloc_buffer((64, 64), scope="local.fragment") + scores_sum_frag = T.alloc_buffer((64,), scope="local.fragment") + T.dma_load(T.region(Q[bz, bx * 64, by, 0], 1, 1, 64, 1, 64), T.region(Q_shared[0, 0], 2, 64, 64), 0) + T.fill(T.region(acc_o[0, 0], 2, 64, 64), 0) + T.fill(T.region(logsum[0], 2, 64), 0) + T.fill(T.region(scores_max[0], 2, 64), T.infinity("float32") * T.float32(-1.0)) + with T.LetStmt(((bx + 1) * 64 + 64 - 1) // 64, var=loop_range): + for k in T.serial(loop_range, annotations={"num_stages": 3}): + T.copy(T.region(K[bz, k * 64, by, 0], 1, 1, 64, 1, 64), T.region(K_shared[0, 0], 2, 64, 64), -1, T.bool(False), 0) + for i in T.parallel(64): + for j in T.parallel(64): + acc_s[i, j] = T.if_then_else(bx * 64 + i >= k * 64 + j, T.float32(0.0), T.infinity("float32") * T.float32(-1.0)) + T.gemm_py(T.region(Q_shared[0, 0], 1, 64, 64), T.region(K_shared[0, 0], 1, 64, 64), T.region(acc_s[0, 0], 3, 64, 64), T.bool(False), T.bool(True), 64, 64, 64, 1, T.bool(False), 64, 64, 0, 0, 1, 0, T.uint32(0), 0, 0) + T.copy(T.region(V[bz, k * 64, by, 0], 1, 1, 64, 1, 64), T.region(V_shared[0, 0], 2, 64, 64), -1, T.bool(False), 0) + for i in T.parallel(64): + for j in T.parallel(64): + acc_s[i, j] = acc_s[i, j] * T.float32(0.18033688) + T.copy(T.region(scores_max[0], 1, 64), T.region(scores_max_prev[0], 2, 64), -1, T.bool(False), 0) + T.fill(T.region(scores_max[0], 2, 64), T.infinity("float32") * T.float32(-1.0)) + T.copy(T.region(acc_s[0, 0], 1, 64, 64), T.region(acc_s_frag[0, 0], 2, 64, 64), -1, T.bool(False), 0) + T.reduce(acc_s_frag[0:64, 0:64], scores_max_frag[0:64], "max", 1, T.bool(False)) + T.copy(T.region(scores_max_frag[0], 1, 64), T.region(scores_max[0], 2, 64), -1, T.bool(False), 0) + for i in T.parallel(64): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.parallel(64): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + for i in T.parallel(64): + for j in T.parallel(64): + acc_o[i, j] = acc_o[i, j] * scores_scale[i] + for i in T.parallel(64): + for j in T.parallel(64): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + T.copy(T.region(acc_s[0, 0], 1, 64, 64), T.region(acc_s_cast[0, 0], 2, 64, 64), -1, T.bool(False), 0) + T.gemm_py(T.region(acc_s_cast[0, 0], 1, 64, 64), T.region(V_shared[0, 0], 1, 64, 64), T.region(acc_o[0, 0], 3, 64, 64), T.bool(False), T.bool(False), 64, 64, 64, 1, T.bool(False), 64, 64, 0, 0, 1, 0, T.uint32(0), 0, 0) + T.copy(T.region(acc_s[0, 0], 1, 64, 64), T.region(acc_s_frag_1[0, 0], 2, 64, 64), -1, T.bool(False), 0) + T.reduce(acc_s_frag_1[0:64, 0:64], scores_sum_frag[0:64], "sum", 1, T.bool(True)) + T.copy(T.region(scores_sum_frag[0], 1, 64), T.region(scores_sum[0], 2, 64), -1, T.bool(False), 0) + for i in T.parallel(64): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i in T.parallel(64): + for j in T.parallel(64): + acc_o[i, j] = acc_o[i, j] / logsum[i] + T.dma_store(T.region(acc_o[0, 0], 1, 64, 64), T.region(Output[bz, bx * 64, by, 0], 2, 1, 64, 1, 64), 0)""" + +flashattn.show() +print(flashattn.script() == expected_print) diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 9eae861eb..212ad8557 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -106,6 +106,8 @@ from .annotate_mesh_tensor import mesh_tensor_functions # noqa: F401 +from .dma import dma_load, dma_store # noqa: F401 + def import_source(source: str | None = None): # source is the source code to be imported return block_attr({"pragma_import_c": source}) if source is not None else None \ No newline at end of file diff --git a/tilelang/language/dma.py b/tilelang/language/dma.py new file mode 100644 index 000000000..dc5677535 --- /dev/null +++ b/tilelang/language/dma.py @@ -0,0 +1,102 @@ +# tilelang/lang/dma.py + +from __future__ import annotations +from typing import Literal +from tilelang import language as T +from tvm import tir +from tilelang.utils.language import ( + get_buffer_region_from_load, + legalize_pairwise_extents, +) +from tilelang.language.utils import ( + buffer_region_to_tile_region, + buffer_load_to_tile_region, +) + +def _to_region_general(data, access_type, extent): + """Convert Buffer / BufferRegion / BufferLoad into TileRegion.""" + if isinstance(data, tir.Buffer): + # Full buffer → region starting at 0 + zeros = [tir.IntImm("int32", 0) for _ in extent] + load = tir.BufferLoad(data, zeros) + return buffer_load_to_tile_region(load, access_type, extent) + + elif isinstance(data, tir.BufferRegion): + # Already a region: simply convert + return buffer_region_to_tile_region(data, access_type, extent) + + elif isinstance(data, tir.BufferLoad): + region = get_buffer_region_from_load(data) + if region is None: # scalar load + return buffer_load_to_tile_region(data, access_type, extent) + return buffer_region_to_tile_region(region, access_type, extent) + + else: + # Fallback: treat anything else as a load + return buffer_load_to_tile_region(data, access_type, extent) + + +def _get_extent(data): + """Detect extent from Buffer / BufferRegion / BufferLoad.""" + if isinstance(data, tir.Buffer): + return list(data.shape) + + elif isinstance(data, tir.BufferRegion): + return [r.extent for r in data.region] + + elif isinstance(data, tir.BufferLoad): + region = get_buffer_region_from_load(data) + if region is None: + return None + return [r.extent for r in region.region] + + return None + + +def dma_load(src, dst, eviction_policy: int = 0): + """Global -> Shared TMA Load.""" + src_extent = _get_extent(src) + dst_extent = _get_extent(dst) + + assert src_extent or dst_extent, "Can't deduce extents for dma_load()" + + src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) + dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) + + # Pairwise extent legalize (same as T.copy) + src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) + + src_region = _to_region_general(src, "r", src_extent) + dst_region = _to_region_general(dst, "w", dst_extent) + + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.dma_load"), + src_region, + dst_region, + eviction_policy, + ) + + +def dma_store(src, dst, eviction_policy: int = 0): + """Shared -> Global TMA Store.""" + src_extent = _get_extent(src) + dst_extent = _get_extent(dst) + + assert src_extent or dst_extent, "Can't deduce extents for dma_store()" + + src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) + dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) + + src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) + + src_region = _to_region_general(src, "r", src_extent) + dst_region = _to_region_general(dst, "w", dst_extent) + + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.dma_store"), + src_region, + dst_region, + eviction_policy, + )