diff --git a/src/layout/layout.cc b/src/layout/layout.cc index c3f99f307..586cd11d7 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -55,6 +55,12 @@ LayoutNode::LayoutNode(Array input_size, [&](const PrimExpr &e) { return analyzer.Simplify(e); }); } +TileLayoutNode::TileLayoutNode(Array input_shape, Array tile_size, Array dim_map) { + input_shape_ = input_shape; + tile_size_ = tile_size; + dim_map_ = dim_map; +} + Layout::Layout(Array forward_var, Array forward_index) { Map vmap; Array input_size; @@ -69,11 +75,25 @@ Layout::Layout(Array forward_var, Array forward_index) { data_ = std::move(n); } +TileLayout::TileLayout(Array input_shape, Array tile_size, Array dim_map) { + auto n = tvm::ffi::make_object(input_shape, tile_size, dim_map); + data_ = std::move(n); +} + Layout::Layout(Array input_size, Array forward_index) { auto n = tvm::ffi::make_object(input_size, forward_index); data_ = std::move(n); } +void TileLayoutNode::RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("input_shape", &TileLayoutNode::input_shape_) + .def_ro("tile_size", &TileLayoutNode::tile_size_) + .def_ro("dim_map", &TileLayoutNode::dim_map_) + .def("_DebugOutput", &TileLayoutNode::DebugOutput); +} + void LayoutNode::RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() @@ -668,6 +688,10 @@ std::string LayoutNode::DebugOutput() const { return ss.str(); } +std::string TileLayoutNode::DebugOutput() const { + return ""; +} + std::string FragmentNode::DebugOutput() const { std::stringstream ss; ss << "Fragment(" << InputShape() << " -> " << OutputShape() @@ -722,6 +746,12 @@ void FragmentNode::RegisterReflection() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() + .def_packed("tl.TileLayout", + [](PackedArgs args, Any *rv) { + *rv = TileLayout(args[0].cast>(), + args[1].cast>(), + args[2].cast>()); + }) .def_packed("tl.Layout", [](PackedArgs args, Any *rv) { *rv = Layout(args[0].cast>(), @@ -819,6 +849,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; LayoutNode::RegisterReflection(); + TileLayoutNode::RegisterReflection(); FragmentNode::RegisterReflection(); } diff --git a/src/layout/layout.h b/src/layout/layout.h index 369df4f2e..ba8b0ac8f 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -39,8 +39,25 @@ class LoopLayoutInjectiveException : public std::exception { }; class Layout; +class TileLayout; class Fragment; +class TileLayoutNode : public Object { + public: + TileLayoutNode() = default; + TileLayoutNode(Array input_shape, Array tile_size, Array dim_map); + + virtual std::string DebugOutput() const; + + static void RegisterReflection(); + TVM_FFI_DECLARE_OBJECT_INFO("tl.TileLayout", TileLayoutNode, Object); + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + Array input_shape_; + Array tile_size_; + Array dim_map_; +}; + class LayoutNode : public Object { public: LayoutNode() = default; @@ -93,6 +110,13 @@ class LayoutNode : public Object { Array input_size_; }; +class TileLayout : public ObjectRef { +public: + // Only keep the PrimExpr version to avoid FFI overload resolution issues + TVM_DLL TileLayout(Array input_shape, Array tile_size, Array dim_map); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileLayout, ObjectRef, TileLayoutNode); +}; + /*! * \brief Layout reference class. */ diff --git a/testing/python/language/test_alloc_tile.py b/testing/python/language/test_alloc_tile.py new file mode 100644 index 000000000..281abceda --- /dev/null +++ b/testing/python/language/test_alloc_tile.py @@ -0,0 +1,66 @@ +import tilelang +import tilelang.language as T + + +tilelang.disable_cache() + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared_with_tileview((block_M, block_K), dtype) + B_shared = T.alloc_shared_with_tileview((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + print("c:") + print(c) + print("ref_c:") + print(ref_c) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # Get CUDA Source + print("CUDA Source:") + print(kernel.get_kernel_source()) + + # benchmark + profiler = kernel.get_profiler() + latency = profiler.do_bench(backend="cupti") + # latency = profiler.do_bench() + print(f"tilelang Latency: {latency}ms") + + +if __name__ == "__main__": + main() + diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 9eae861eb..20fdd3c6c 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -47,6 +47,7 @@ alloc_var, # noqa: F401 alloc_local, # noqa: F401 alloc_shared, # noqa: F401 + alloc_shared_with_tileview, # noqa: F401 alloc_fragment, # noqa: F401 alloc_barrier, # noqa: F401 alloc_tmem, # noqa: F401 diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 73377822b..b3ed3e028 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -22,10 +22,11 @@ from typing_extensions import TypeVarTuple, Unpack # type: ignore from tilelang import tvm as tvm from tvm.script import tir as T -from tvm.tir import PrimExpr +from tvm.tir import PrimExpr, IntImm from tvm.script.parser.tir import block_attr from tvm.tir.buffer import Buffer -from tvm.tir.expr import FloatImm, IntImm +from tvm.tir.expr import FloatImm +from tvm.ir import Array from .v2.dtypes import dtype as tl_dtype from .v2.builder import OutTensor from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer @@ -34,6 +35,82 @@ _DType = TypeVar('_DType') +def handle_shape_tuple_when_alloc_shared_with_tileview(shape: tuple[Unpack[_Shapes]], tile_size=None, dim_map=None) -> tuple: + """Returns (newShape, tile_size, dim_map) after validation and default value processing""" + # check shape + if shape is None: + raise ValueError("shape map ValueError") + if len(shape) == 0: + raise ValueError("shape map ValueError") + # check tile_size None + if tile_size is None: + tile_size = () + if len(tile_size) == 0: + if len(shape) == 1: + tile_size = (32, 1) + if len(shape) > 1: + tile_size = (32, 32) + # check tile_size + if len(tile_size) > len(shape) and len(shape) != 1: + raise ValueError("tile map ValueError") + if len(shape) == 1: + if len(tile_size) > 2: + raise ValueError("tile map ValueError") + if len(tile_size) == 2: + if tile_size [0] != 1 and tile_size[1] != 1: + raise ValueError("tile map ValueError") + for a_tile in tile_size: + if a_tile == 0: + raise ValueError("tile map ValueError") + # check dim_map None + if dim_map is None: + dim_map = () + if len(dim_map) == 0: + if len(shape) == 1: + dim_map=(0,) + if len(shape) > 1: + dim_map = tuple(range(len(shape) - len(tile_size), len(shape))) + # check dim_map + if len(dim_map) != len(set(dim_map)): + raise ValueError("dim map ValueError") + for a_dim in dim_map: + if a_dim >= len(shape) or a_dim < 0: + raise ValueError("dim map ValueError") + if len(dim_map) != len(tile_size) and len(shape) != 1: + raise ValueError("dim map ValueError") + # len(shape) = 1 + if len(shape) == 1: + if len(tile_size) == 1: + return ((T.ceildiv(shape[0], tile_size[0]), tile_size[0]), tile_size, dim_map) + if tile_size[0] == 1: + return ((T.ceildiv(shape[0], tile_size[0]), 1, tile_size[0]), tile_size, dim_map) + if tile_size[1] == 1: + return ((T.ceildiv(shape[0], tile_size[0]), tile_size[0], 1), tile_size, dim_map) + # normal + shapeList=list(shape) + for a_tile, a_dim in zip(tile_size, dim_map): + shapeList[a_dim] = T.ceildiv(shape[a_dim], a_tile) + shapeList += [a_tile] + return (tuple(shapeList), tile_size, dim_map) + +from tilelang.layout import TileLayout + +def alloc_shared_with_tileview(shape: tuple[Unpack[_Shapes]], dtype: _DType, tile_size=None, dim_map=None, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: + newShape, tile_size, dim_map = handle_shape_tuple_when_alloc_shared_with_tileview(shape=shape, tile_size=tile_size, dim_map=dim_map) + if dtype == "bool": + scope = "shared" + buffer = T.alloc_buffer(shape, dtype, scope=scope) + # Convert tuples to TVM Arrays and ensure elements are PrimExpr for correct FFI overload resolution + # Convert int to IntImm to match Array constructor signature + newShape_expr = [e if isinstance(e, PrimExpr) else IntImm("int32", e) for e in newShape] + tile_size_expr = [e if isinstance(e, PrimExpr) else IntImm("int32", e) for e in tile_size] + dim_map_expr = [e if isinstance(e, PrimExpr) else IntImm("int32", e) for e in dim_map] + tileLayout = TileLayout(Array(newShape_expr), Array(tile_size_expr), Array(dim_map_expr)) + # block_attr({"tile_view": {buffer.data: {"tile_size": tile_size, "dim_map": dim_map, "tiled_shape": newShape}}}) + block_attr({"tile_view": {buffer.data: tileLayout}}) + return buffer + + def alloc_shared(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index 777802d2c..b311c5fe7 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -2,6 +2,7 @@ # pylint: disable=invalid-name, unsupported-binary-operation from .layout import Layout # noqa: F401 +from .layout import TileLayout # noqa: F401 from .fragment import Fragment # noqa: F401 from .swizzle import ( make_swizzled_layout, # noqa: F401 diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py index 10e0357e6..198e008e4 100644 --- a/tilelang/layout/layout.py +++ b/tilelang/layout/layout.py @@ -6,6 +6,15 @@ from tilelang import _ffi_api +# Register the TileLayout class as a TVM object under the name "tl.TileLayout" +@tvm_ffi.register_object("tl.TileLayout") +class TileLayout(Node): + def __init__(self, input_shape, tile_size, dim_map): + # Call the FFI constructor to create the TileLayout object in C++ backend + # Parameter order matches C++: TileLayout(Array input_shape, Array tile_size, Array dim_map) + self.__init_handle_by_constructor__(_ffi_api.TileLayout, input_shape, tile_size, dim_map) + + # Register the Layout class as a TVM object under the name "tl.Layout" @tvm_ffi.register_object("tl.Layout") class Layout(Node):