Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ LayoutNode::LayoutNode(Array<PrimExpr> input_size,
[&](const PrimExpr &e) { return analyzer.Simplify(e); });
}

TileLayoutNode::TileLayoutNode(Array<PrimExpr> input_shape, Array<PrimExpr> tile_size, Array<PrimExpr> dim_map) {
input_shape_ = input_shape;
tile_size_ = tile_size;
dim_map_ = dim_map;
}

Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
Map<Var, PrimExpr> vmap;
Array<PrimExpr> input_size;
Expand All @@ -69,11 +75,25 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
data_ = std::move(n);
}

TileLayout::TileLayout(Array<PrimExpr> input_shape, Array<PrimExpr> tile_size, Array<PrimExpr> dim_map) {
auto n = tvm::ffi::make_object<TileLayoutNode>(input_shape, tile_size, dim_map);
data_ = std::move(n);
}

Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}

void TileLayoutNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TileLayoutNode>()
.def_ro("input_shape", &TileLayoutNode::input_shape_)
.def_ro("tile_size", &TileLayoutNode::tile_size_)
.def_ro("dim_map", &TileLayoutNode::dim_map_)
.def("_DebugOutput", &TileLayoutNode::DebugOutput);
}

void LayoutNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LayoutNode>()
Expand Down Expand Up @@ -668,6 +688,10 @@ std::string LayoutNode::DebugOutput() const {
return ss.str();
}

std::string TileLayoutNode::DebugOutput() const {
return "";
}

std::string FragmentNode::DebugOutput() const {
std::stringstream ss;
ss << "Fragment(" << InputShape() << " -> " << OutputShape()
Expand Down Expand Up @@ -722,6 +746,12 @@ void FragmentNode::RegisterReflection() {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("tl.TileLayout",
[](PackedArgs args, Any *rv) {
*rv = TileLayout(args[0].cast<Array<PrimExpr>>(),
args[1].cast<Array<PrimExpr>>(),
args[2].cast<Array<PrimExpr>>());
})
.def_packed("tl.Layout",
[](PackedArgs args, Any *rv) {
*rv = Layout(args[0].cast<Array<IterVar>>(),
Expand Down Expand Up @@ -819,6 +849,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection();
TileLayoutNode::RegisterReflection();
FragmentNode::RegisterReflection();
}

Expand Down
24 changes: 24 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,25 @@ class LoopLayoutInjectiveException : public std::exception {
};

class Layout;
class TileLayout;
class Fragment;

class TileLayoutNode : public Object {
public:
TileLayoutNode() = default;
TileLayoutNode(Array<PrimExpr> input_shape, Array<PrimExpr> tile_size, Array<PrimExpr> dim_map);

virtual std::string DebugOutput() const;

static void RegisterReflection();
TVM_FFI_DECLARE_OBJECT_INFO("tl.TileLayout", TileLayoutNode, Object);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;

Array<PrimExpr> input_shape_;
Array<PrimExpr> tile_size_;
Array<PrimExpr> dim_map_;
};

class LayoutNode : public Object {
public:
LayoutNode() = default;
Expand Down Expand Up @@ -93,6 +110,13 @@ class LayoutNode : public Object {
Array<PrimExpr> input_size_;
};

class TileLayout : public ObjectRef {
public:
// Only keep the PrimExpr version to avoid FFI overload resolution issues
TVM_DLL TileLayout(Array<PrimExpr> input_shape, Array<PrimExpr> tile_size, Array<PrimExpr> dim_map);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileLayout, ObjectRef, TileLayoutNode);
};

/*!
* \brief Layout reference class.
*/
Expand Down
66 changes: 66 additions & 0 deletions testing/python/language/test_alloc_tile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import tilelang
import tilelang.language as T


tilelang.disable_cache()


@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared_with_tileview((block_M, block_K), dtype)
B_shared = T.alloc_shared_with_tileview((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)

T.copy(C_local, C[by * block_M, bx * block_N])

return gemm


def main():
kernel = matmul(1024, 1024, 1024, 128, 128, 32)

import torch

a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()

c = kernel(a, b)

ref_c = a @ b

print("c:")
print(c)
print("ref_c:")
print(ref_c)

torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")

# Get CUDA Source
print("CUDA Source:")
print(kernel.get_kernel_source())

# benchmark
profiler = kernel.get_profiler()
latency = profiler.do_bench(backend="cupti")
# latency = profiler.do_bench()
print(f"tilelang Latency: {latency}ms")


if __name__ == "__main__":
main()

1 change: 1 addition & 0 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
alloc_var, # noqa: F401
alloc_local, # noqa: F401
alloc_shared, # noqa: F401
alloc_shared_with_tileview, # noqa: F401
alloc_fragment, # noqa: F401
alloc_barrier, # noqa: F401
alloc_tmem, # noqa: F401
Expand Down
81 changes: 79 additions & 2 deletions tilelang/language/allocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from typing_extensions import TypeVarTuple, Unpack # type: ignore
from tilelang import tvm as tvm
from tvm.script import tir as T
from tvm.tir import PrimExpr
from tvm.tir import PrimExpr, IntImm
from tvm.script.parser.tir import block_attr
from tvm.tir.buffer import Buffer
from tvm.tir.expr import FloatImm, IntImm
from tvm.tir.expr import FloatImm
from tvm.ir import Array
from .v2.dtypes import dtype as tl_dtype
from .v2.builder import OutTensor
from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer
Expand All @@ -34,6 +35,82 @@
_DType = TypeVar('_DType')


def handle_shape_tuple_when_alloc_shared_with_tileview(shape: tuple[Unpack[_Shapes]], tile_size=None, dim_map=None) -> tuple:
"""Returns (newShape, tile_size, dim_map) after validation and default value processing"""
# check shape
if shape is None:
raise ValueError("shape map ValueError")
if len(shape) == 0:
raise ValueError("shape map ValueError")
# check tile_size None
if tile_size is None:
tile_size = ()
if len(tile_size) == 0:
if len(shape) == 1:
tile_size = (32, 1)
if len(shape) > 1:
tile_size = (32, 32)
# check tile_size
if len(tile_size) > len(shape) and len(shape) != 1:
raise ValueError("tile map ValueError")
if len(shape) == 1:
if len(tile_size) > 2:
raise ValueError("tile map ValueError")
if len(tile_size) == 2:
if tile_size [0] != 1 and tile_size[1] != 1:
raise ValueError("tile map ValueError")
for a_tile in tile_size:
if a_tile == 0:
raise ValueError("tile map ValueError")
# check dim_map None
if dim_map is None:
dim_map = ()
if len(dim_map) == 0:
if len(shape) == 1:
dim_map=(0,)
if len(shape) > 1:
dim_map = tuple(range(len(shape) - len(tile_size), len(shape)))
# check dim_map
if len(dim_map) != len(set(dim_map)):
raise ValueError("dim map ValueError")
for a_dim in dim_map:
if a_dim >= len(shape) or a_dim < 0:
raise ValueError("dim map ValueError")
if len(dim_map) != len(tile_size) and len(shape) != 1:
raise ValueError("dim map ValueError")
# len(shape) = 1
if len(shape) == 1:
if len(tile_size) == 1:
return ((T.ceildiv(shape[0], tile_size[0]), tile_size[0]), tile_size, dim_map)
if tile_size[0] == 1:
return ((T.ceildiv(shape[0], tile_size[0]), 1, tile_size[0]), tile_size, dim_map)
if tile_size[1] == 1:
return ((T.ceildiv(shape[0], tile_size[0]), tile_size[0], 1), tile_size, dim_map)
# normal
shapeList=list(shape)
for a_tile, a_dim in zip(tile_size, dim_map):
shapeList[a_dim] = T.ceildiv(shape[a_dim], a_tile)
shapeList += [a_tile]
return (tuple(shapeList), tile_size, dim_map)

from tilelang.layout import TileLayout

def alloc_shared_with_tileview(shape: tuple[Unpack[_Shapes]], dtype: _DType, tile_size=None, dim_map=None, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
newShape, tile_size, dim_map = handle_shape_tuple_when_alloc_shared_with_tileview(shape=shape, tile_size=tile_size, dim_map=dim_map)
if dtype == "bool":
scope = "shared"
buffer = T.alloc_buffer(shape, dtype, scope=scope)
# Convert tuples to TVM Arrays and ensure elements are PrimExpr for correct FFI overload resolution
# Convert int to IntImm to match Array<PrimExpr> constructor signature
newShape_expr = [e if isinstance(e, PrimExpr) else IntImm("int32", e) for e in newShape]
tile_size_expr = [e if isinstance(e, PrimExpr) else IntImm("int32", e) for e in tile_size]
dim_map_expr = [e if isinstance(e, PrimExpr) else IntImm("int32", e) for e in dim_map]
tileLayout = TileLayout(Array(newShape_expr), Array(tile_size_expr), Array(dim_map_expr))
# block_attr({"tile_view": {buffer.data: {"tile_size": tile_size, "dim_map": dim_map, "tiled_shape": newShape}}})
block_attr({"tile_view": {buffer.data: tileLayout}})
return buffer


def alloc_shared(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
Expand Down
1 change: 1 addition & 0 deletions tilelang/layout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=invalid-name, unsupported-binary-operation

from .layout import Layout # noqa: F401
from .layout import TileLayout # noqa: F401
from .fragment import Fragment # noqa: F401
from .swizzle import (
make_swizzled_layout, # noqa: F401
Expand Down
9 changes: 9 additions & 0 deletions tilelang/layout/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
from tilelang import _ffi_api


# Register the TileLayout class as a TVM object under the name "tl.TileLayout"
@tvm_ffi.register_object("tl.TileLayout")
class TileLayout(Node):
def __init__(self, input_shape, tile_size, dim_map):
# Call the FFI constructor to create the TileLayout object in C++ backend
# Parameter order matches C++: TileLayout(Array<PrimExpr> input_shape, Array<PrimExpr> tile_size, Array<PrimExpr> dim_map)
self.__init_handle_by_constructor__(_ffi_api.TileLayout, input_shape, tile_size, dim_map)


# Register the Layout class as a TVM object under the name "tl.Layout"
@tvm_ffi.register_object("tl.Layout")
class Layout(Node):
Expand Down
Loading