Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/op/builtin_dma.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
*\file tl/op/builtin_dma.CC
*\brief DMA builtin intrinscs for SUNMMIO GPU
*\separated from the origional Tilelang's code

*/

#include "builtin_dma.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/cuda.h"
#include "../target/utils.h"

namespace tvm {
namespace tl {


#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)

TIR_DEFINE_TL_BUILTIN(create_dma_descriptor)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm
52 changes: 52 additions & 0 deletions src/op/builtin_dma.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
*\file tl/op/builtin_dma.h
*\brief DMA builtin intrinscs for SUNMMIO GPU
*\separated from the origional Tilelang's code

*/

#ifndef TVM_TL_OP_BUILTIN_DMA_H_
#define TVM_TL_OP_BUILTIN_DMA_H_


#include "operator.h"
#include <tvm/ir/transform.h>


namespace tvm {

namespace tl {

/*!
* \brief tvm intrinsics for DMADescriptor creation for tiled load
*
* CuTensorMap* create_dma_descriptor(data_type, rank, global_addr,
* global_shape..., global_stride..., smem_box..., smem_stride..., interleave,
* swizzle, l2_promotion, oob_fill)
*
*/
TVM_DLL const Op &create_dma_descriptor();


/*!
* \brief tvm intrinsics for loading data from global tensor descriptor to
* shared memory for DMA
*

*
*/
TVM_DLL const Op &dma_load();


/*!
* \brief tvm intrinsics for storing data from shared memory to global tensor
* descriptor for DMA
*
*
*/
TVM_DLL const Op &dma_store();

} // namespace tl
} // namespace tvm

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Docstring for examples.mytest.testDMADescriptor
just to test create_dma_descriptor() intrin
to run this script successfully, you should
1. remove tilelang/language/dma.py and line "from .dma import dma_load, dma_store # noqa: F401" in tilelang/language/__init__.py
2. add `def create_dma_descriptor(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.create_dma_descriptor"), *args)`
in tilelang/builtin.py and decomment the dma_load and dma_store functions in the same file.
"""

from tilelang import tvm as tvm
import tilelang.language as T


def test_multi_version_buffer(M, N, K, dtype, block_M, block_N, block_K):

@T.prim_func
def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 128)
with T.block(""):
T.reads(A[by * 64, 0:481], B[0:481, bx * 64])
T.writes()
A_shared = T.alloc_buffer((1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((1, 4, 512), "float16", scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local")
for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(2):
C_local[i * 2 + vec] = T.float32(0)
for k in T.serial(16, annotations={"num_stages": T.int32(3)}):
if v == 0:
T.dma_load(
T.create_dma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2),
k * 32, by * 64)
if v == 0:
T.dma_load(
T.create_dma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2),
bx * 64, k * 32)
T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))

return before


expected_result = """# from tvm.script import tir as T

@T.prim_func
def before(A_handle: T.handle, B_handle: T.handle):
A = T.match_buffer(A_handle, (512, 512), "float16", strides=(512, 1))
B = T.match_buffer(B_handle, (512, 512), "float16", strides=(512, 1))
# with T.block("root"):
v = T.launch_thread("blockIdx.x", 8)
v_1 = T.launch_thread("blockIdx.y", 8)
v_2 = T.launch_thread("threadIdx.x", 128)
with T.block(""):
T.reads(A[v_1 * 64, 0:481], B[0:481, v * 64])
T.writes()
A_shared = T.alloc_buffer((1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((1, 4, 512), "float16", scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local")
for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(2):
C_local[i * 2 + vec] = T.float32(0.0)
for k in T.serial(16, annotations={"num_stages": 3}):
if v_2 == 0:
T.dma_load(T.create_dma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), 0, T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2), k * 32, v_1 * 64)
if v_2 == 0:
T.dma_load(T.create_dma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), 0, T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2), v * 64, k * 32)
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))"""

M = 512
N = 512
K = 512
dtype = "float16"
block_M = 64
block_N = 64
block_K = 32
test_kernel = test_multi_version_buffer(M, N, K, dtype, block_M, block_N, block_K)
test_kernel.show()
print(test_kernel.script() == expected_result)
Loading
Loading