Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(mma_sunmmio)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(create_tma_descriptor)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
23 changes: 23 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,29 @@ TVM_DLL const Op &tma_load();
*/
TVM_DLL const Op &dma_copy();

/*!
* \brief tvm intrinsic for mma operation of Sunmmio target.
*
* \param A_region
* A tl.tileop.region PrimExpr describing A.
*
* \param B_region
* A tl.tileop.region PrimExpr describing B.
*
* \param C_region
* A tl.tileop.region PrimExpr describing C.
*
* \param trans_A
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support trans_A in Sunmmio. OK to keep it right now

* Whether to transpose A operand.
*
* \param trans_B
* Whether to transpose B operand.
*
* \param clear_accum
* Whether to clear accmulation buffer.
*/
TVM_DLL const Op &mma_sunmmio();

/*!
* \brief tvm intrinsics for loading image from global tensor to columns in
* shared memory
Expand Down
41 changes: 35 additions & 6 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ std::pair<int, int> GemmWarpPolicyNode::computeWarpPartition(
return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning
}

if (gemm_inst == GemmInst::kSunmmioMMA) {
return {1, num_warps}; // kSunmmioMMA doesn't care about warp partitioning
}

int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
int kNPerWarp = 8; // Columns processed by a single warp
Expand Down Expand Up @@ -439,6 +443,31 @@ static int GetArchInt(Target target) {
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = getGemmInst(block_size, T.target);

if (gemm_inst == GemmInst::kSunmmioMMA) {
ICHECK(a_.scope() == "shared.asram")
<< "Invalid scope of buffer " << a_ << " in SunmmioMMA.";
ICHECK(b_.scope() == "shared.wsram")
<< "Invalid scope of buffer " << b_ << " in SunmmioMMA.";
ICHECK(c_.scope() == "shared.rsram")
<< "Invalid scope of buffer " << c_ << " in SunmmioMMA.";

PrimExpr A_region =
MakeRegionExpr(aRegion_->buffer, aRegion_->region, /*access_mask=*/1);
PrimExpr B_region =
MakeRegionExpr(bRegion_->buffer, bRegion_->region, /*access_mask=*/1);
PrimExpr C_region =
MakeRegionExpr(cRegion_->buffer, cRegion_->region, /*access_mask=*/3);
Array<PrimExpr> args = {A_region, B_region, C_region,
Bool(transA_), Bool(transB_), clearAccum_};

auto op = mma_sunmmio();
Stmt mma_sunmmio;
mma_sunmmio = Evaluate(Call(DataType::Handle(), op, args));

return mma_sunmmio;
}

auto [warp_m, warp_n] =
policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);

Expand Down Expand Up @@ -810,12 +839,12 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
ICHECK(0);
}
} else if (gemm_inst == GemmInst::kSunmmioMMA) {
ICHECK((((std::string)a_.scope()).compare(0, 6, "shared") == 0))
<< "Sunmmio Gemm only supports A in shared scope, got " << a_.scope();
ICHECK((((std::string)b_.scope()).compare(0, 6, "shared") == 0))
<< "Sunmmio Gemm only supports B in shared scope, got " << b_.scope();
ICHECK((((std::string)c_.scope()).compare(0, 6, "shared") == 0))
<< "Sunmmio Gemm only supports C in shared scope, got " << c_.scope();
ICHECK(a_.scope() == "shared.asram")
<< "Invalid scope of buffer " << a_ << " in SunmmioMMA.";
ICHECK(b_.scope() == "shared.wsram")
<< "Invalid scope of buffer " << b_ << " in SunmmioMMA.";
ICHECK(c_.scope() == "shared.rsram")
<< "Invalid scope of buffer " << c_ << " in SunmmioMMA.";

const auto f =
ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout");
Expand Down
83 changes: 83 additions & 0 deletions testing/python/language/test_tilelang_mesh_language_mma_sunmmio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import tilelang
import pytest
from tilelang import tvm as tvm
from tilelang.utils.target import determine_target
import tilelang as tl
import tilelang.language as T

tilelang.env.disable_cache()


def layout_func(i, j, continuous):
return (i // 32 * (continuous // 32) + j // 32) * 32 * 32 + i % 32 * 32 + j % 32


def matmul(M, N, K, block_M, block_N, block_K, version, dtype=T.float16, accum_dtype=T.float32):

@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)

# T.clear(C_shared)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
if version == 1:
T.gemm_v1(
A_shared[0:8, 16:32],
B_shared[0:16, 8:16],
C_shared[8:24, 16:32],
transpose_A=True,
transpose_B=True)
elif version == 2:
T.gemm_v2(
A_shared[0:8, 16:32],
B_shared[0:16, 8:16],
C_shared[8:24, 16:32],
transpose_A=True,
transpose_B=True)
else:
raise ValueError(f'unsupported gemm version: {version}')

T.copy(C_shared, C[by * block_M, bx * block_N])

return tvm.IRModule({'main': main})


stmts = [
"T.mma_sunmmio(T.region(A_shared[0, 16], 1, 8, 16), T.region(B_shared[0, 8], 1, 16, 8), T.region(C_shared[8, 16], 3, 16, 16), T.bool(True), T.bool(True), T.bool(False))",
"T.mma_sunmmio(T.region(A_shared[0, 16], 1, 8, 16), T.region(B_shared[0, 8], 1, 16, 8), T.region(C_shared[8, 16], 3, 16, 16), T.bool(True), T.bool(True), T.bool(False))"
]
TEST_CASES = [
# gemm v1
(128, 128, 128, 32, 32, 32, 1, stmts[0]),
(128, 128, 128, 64, 32, 64, 1, stmts[1]),
# # gemm v2
(128, 128, 128, 32, 32, 32, 2, stmts[0]),
(128, 128, 128, 64, 32, 64, 2, stmts[1]),
]


@pytest.mark.parametrize(
"M, N, K, block_M, block_N, block_K, version, lower_stmt",
TEST_CASES,
)
def test_tilelang_gemm_sunmmio_layout(M, N, K, block_M, block_N, block_K, version, lower_stmt):
target_name = "Sunmmio"
target = determine_target(target_name, return_object=True)
with tvm.target.Target(target):
mod = matmul(M, N, K, block_M, block_N, block_K, version)
mod = tvm.tir.transform.BindTarget(target)(mod)
mod = tilelang.transform.InferSramScope()(mod)
mod = tl.transform.LayoutInference()(mod)
mod = tl.transform.LowerTileOp()(mod)
texts = mod.script().split('\n')
text = texts[-2].lstrip()
assert text == lower_stmt
6 changes: 3 additions & 3 deletions tilelang/tileop/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target):
NotImplementedError: If the instruction type is not supported
ValueError: If the instruction type is unknown
"""
if gemm_inst.is_sunmmio():
return GemmSunmmio
elif gemm_inst.is_mma():
if gemm_inst.is_mma():
if target_is_volta(target):
return GemmMMASm70
return GemmMMA
Expand All @@ -197,5 +195,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target):
return GemmMFMA
elif gemm_inst.is_tcgen5mma():
raise NotImplementedError("TCGEN5MMA is not implemented")
elif gemm_inst.is_sunmmio():
return GemmSunmmio
else:
raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}")
125 changes: 28 additions & 97 deletions tilelang/tileop/gemm/gemm_sunmmio.py
Original file line number Diff line number Diff line change
@@ -1,128 +1,59 @@
from .gemm_base import GemmBase
from tilelang.layout import make_blockwise_zz_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
from tilelang import language as T
from tilelang.utils.language import (
retrieve_shape,)
from tilelang.language.utils import (
buffer_region_to_tile_region,)


class GemmSunmmio(GemmBase):

def infer_layout(self, target: Target, thread_nums: int):
if self.is_gemm_sss():
if self.is_gemm_sunmmio_scope():
return {
self.A: make_blockwise_zz_layout(self.A),
self.B: make_blockwise_zz_layout(self.B),
self.C: make_blockwise_zz_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}"
f"Unsupported gemm combination of Sunmmio, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}"
)

def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
# TODO: lower not implemented
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
)

in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion

A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer

clear_accum = self.clear_accum

assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"

assert is_full_region(C_region), "Fragment output C must be a full region"

if self.is_gemm_sss():

@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_region,
ki,
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_region,
ki,
)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_buf, ki)
if self.is_gemm_sunmmio_scope():
A_shape = retrieve_shape(self.ARegion)
B_shape = retrieve_shape(self.BRegion)
C_shape = retrieve_shape(self.CRegion)
A_arg = buffer_region_to_tile_region(self.ARegion, "r", [r for r in A_shape])
B_arg = buffer_region_to_tile_region(self.BRegion, "r", [r for r in B_shape])
C_arg = buffer_region_to_tile_region(self.CRegion, "rw", [r for r in C_shape])

# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(B_region), "Fragment input B must be a full region"
args = [A_arg, B_arg, C_arg, self.trans_A, self.trans_B, self.clear_accum]

@T.prim_func
def _gemm_rrr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
def _gemm_sss() -> None:
tir.call_intrin(
"handle",
tir.op.Op.get("tl.mma_sunmmio"),
*args,
)

for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mma_emitter.mma(A_buf, B_buf, C_buf, ki)
return _Simplify(_gemm_sss, inline_let=True)

# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rrr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}"
f"Unsupported gemm combination of Sunmmio, A: {self.A.scope()}, B: {self.B.scope()}, C: {self.C.scope()}"
)

def is_gemm_sss(self) -> bool:
return is_shared(self.A) and is_shared(self.B) and is_shared(self.C)
def is_gemm_sunmmio_scope(self) -> bool:
a_check = self.A.scope() == 'shared.asram'
b_check = self.B.scope() == 'shared.wsram'
c_check = self.C.scope() == 'shared.rsram'
return a_check and b_check and c_check