From e1d898ed6c52f34f0cd8ecaf5b5068f622be9bdd Mon Sep 17 00:00:00 2001 From: xudemin Date: Fri, 9 Jan 2026 15:47:50 +0800 Subject: [PATCH 1/5] tile_based_on_share --- examples/gemm/example_gemm.py | 7 +++- src/layout/layout.cc | 53 +++++++++++++++++++++++++++ src/layout/layout.h | 16 ++++++++ tilelang/language/__init__.py | 1 + tilelang/language/allocate.py | 69 +++++++++++++++++++++++++++++++++++ tilelang/layout/__init__.py | 1 + tilelang/layout/layout.py | 11 ++++++ 7 files changed, 156 insertions(+), 2 deletions(-) diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index f18cd388a..4327cd016 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -1,6 +1,7 @@ import tilelang import tilelang.language as T +tilelang.disable_cache() @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @@ -12,8 +13,10 @@ def gemm( C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) + # A_shared = T.alloc_shared((block_M, block_K), dtype) + # B_shared = T.alloc_shared((block_K, block_N), dtype) + A_shared = T.alloc_shared_with_tileview((block_M, block_K), dtype) + B_shared = T.alloc_shared_with_tileview((block_K, block_N), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) diff --git a/src/layout/layout.cc b/src/layout/layout.cc index c3f99f307..6204824fd 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -55,6 +55,34 @@ LayoutNode::LayoutNode(Array input_size, [&](const PrimExpr &e) { return analyzer.Simplify(e); }); } +TileLayoutNode::TileLayoutNode(Array input_shape, Array tile_size, Array dim_map) { + input_shape_ = input_shape; + tile_size_ = tile_size; + dim_map_ = dim_map; +} + +TileLayout::TileLayout(Array input_shape_Iter, Array tile_size_Iter, Array dim_map_Iter){ + Map vmap; + Array input_shape, tile_size, dim_map; + for (size_t i = 0; i < input_shape_Iter.size(); i++) { + vmap.Set(input_shape_Iter[i]->var, InputPlaceholder(i)); + CHECK(is_zero(input_shape_Iter[i]->dom->min)); + input_shape.push_back(input_shape_Iter[i]->dom->extent); + } + for (size_t i = 0; i < tile_size_Iter.size(); i++) { + vmap.Set(tile_size_Iter[i]->var, InputPlaceholder(i)); + CHECK(is_zero(tile_size_Iter[i]->dom->min)); + tile_size.push_back(tile_size_Iter[i]->dom->extent); + } + for (size_t i = 0; i < dim_map_Iter.size(); i++) { + vmap.Set(dim_map_Iter[i]->var, InputPlaceholder(i)); + CHECK(is_zero(dim_map_Iter[i]->dom->min)); + dim_map.push_back(dim_map_Iter[i]->dom->extent); + } + auto n = tvm::ffi::make_object(input_size, tile_size, dim_map); + data_ = std::move(n); +} + Layout::Layout(Array forward_var, Array forward_index) { Map vmap; Array input_size; @@ -69,11 +97,25 @@ Layout::Layout(Array forward_var, Array forward_index) { data_ = std::move(n); } +TileLayout::TileLayout(Array input_shape, Array tile_size, Array dim_map) { + auto n = tvm::ffi::make_object(input_shape, tile_size, dim_map); + data_ = std::move(n); +} + Layout::Layout(Array input_size, Array forward_index) { auto n = tvm::ffi::make_object(input_size, forward_index); data_ = std::move(n); } +void TileLayoutNode::RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("input_shape", &TileLayoutNode::input_shape_) + .def_ro("tile_size", &TileLayoutNode::tile_size_) + .def_ro("dim_map", &TileLayoutNode::dim_map_) + .def("_DebugOutput", &TileLayoutNode::DebugOutput); +} + void LayoutNode::RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() @@ -668,6 +710,10 @@ std::string LayoutNode::DebugOutput() const { return ss.str(); } +std::string TileLayoutNode::DebugOutput() const { + return ""; +} + std::string FragmentNode::DebugOutput() const { std::stringstream ss; ss << "Fragment(" << InputShape() << " -> " << OutputShape() @@ -722,6 +768,12 @@ void FragmentNode::RegisterReflection() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() + .def_packed("tl.TileLayout", + [](PackedArgs args, Any *rv) { + *rv = TileLayout(args[0].cast>(), + args[1].cast>(), + args[2].cast>()); + }) .def_packed("tl.Layout", [](PackedArgs args, Any *rv) { *rv = Layout(args[0].cast>(), @@ -819,6 +871,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; LayoutNode::RegisterReflection(); + TileLayoutNode::RegisterReflection(); FragmentNode::RegisterReflection(); } diff --git a/src/layout/layout.h b/src/layout/layout.h index 369df4f2e..510e885c6 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -39,8 +39,17 @@ class LoopLayoutInjectiveException : public std::exception { }; class Layout; +class TileLayout; class Fragment; +class TileLayoutNode : public Object { + TileLayoutNode() = default; + TileLayoutNode(Array input_shape, Array tile_size, Array dim_map); + static void RegisterReflection(); + TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", TileLayoutNode, Object); + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; +}; + class LayoutNode : public Object { public: LayoutNode() = default; @@ -93,6 +102,13 @@ class LayoutNode : public Object { Array input_size_; }; +class TileLayout : public ObjectRef { +public: + TVM_DLL TileLayout(Array input_shape_Iter, Array tile_size_Iter, Array dim_map_Iter); + TVM_DLL TileLayout(Array input_shape, Array tile_size, Array dim_map); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileLayout, ObjectRef, TileLayoutNode); +}; + /*! * \brief Layout reference class. */ diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 9eae861eb..20fdd3c6c 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -47,6 +47,7 @@ alloc_var, # noqa: F401 alloc_local, # noqa: F401 alloc_shared, # noqa: F401 + alloc_shared_with_tileview, # noqa: F401 alloc_fragment, # noqa: F401 alloc_barrier, # noqa: F401 alloc_tmem, # noqa: F401 diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 73377822b..3c4d0140c 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -34,6 +34,75 @@ _DType = TypeVar('_DType') +def handle_shape_tuple_when_alloc_shared_with_tileview(shape: tuple[Unpack[_Shapes]], tile_size=None, dim_map=None) -> tuple[Unpack[_Shapes]]: + # check shape + if shape is None: + raise ValueError("shape map ValueError") + if len(shape) == 0: + raise ValueError("shape map ValueError") + # check tile_size None + if tile_size is None: + tile_size = () + if len(tile_size) == 0: + if len(shape) == 1: + tile_size = (32, 1) + if len(shape) > 1: + tile_size = (32, 32) + # check tile_size + if len(tile_size) > len(shape) and len(shape) != 1: + raise ValueError("tile map ValueError") + if len(shape) == 1: + if len(tile_size) > 2: + raise ValueError("tile map ValueError") + if len(tile_size) == 2: + if tile_size [0] != 1 and tile_size[1] != 1: + raise ValueError("tile map ValueError") + for a_tile in tile_size: + if a_tile == 0: + raise ValueError("tile map ValueError") + # check dim_map None + if dim_map is None: + dim_map = () + if len(dim_map) == 0: + if len(shape) == 1: + dim_map=(0,) + if len(shape) > 1: + dim_map = tuple(range(len(shape) - len(tile_size), len(shape))) + # check dim_map + if len(dim_map) != len(set(dim_map)): + raise ValueError("dim map ValueError") + for a_dim in dim_map: + if a_dim >= len(shape) or a_dim < 0: + raise ValueError("dim map ValueError") + if len(dim_map) != len(tile_size) and len(shape) != 1: + raise ValueError("dim map ValueError") + # len(shape) = 1 + if len(shape) == 1: + if len(tile_size) == 1: + return (T.ceildiv(shape[0], tile_size[0]), tile_size[0]) + if tile_size[0] == 1: + return (T.ceildiv(shape[0], tile_size[0]), 1, tile_size[0]) + if tile_size[1] == 1: + return (T.ceildiv(shape[0], tile_size[0]), tile_size[0], 1) + # normal + shapeList=list(shape) + for a_tile, a_dim in zip(tile_size, dim_map): + shapeList[a_dim] = T.ceildiv(shape[a_dim], a_tile) + shapeList += [a_tile] + return tuple(shapeList) + +# from tilelang.layout import TileLayout +def alloc_shared_with_tileview(shape: tuple[Unpack[_Shapes]], dtype: _DType, tile_size=None, dim_map=None, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: + newShape = handle_shape_tuple_when_alloc_shared_with_tileview(shape=shape, tile_size=tile_size, dim_map=dim_map) + if dtype == "bool": + scope = "shared" + buffer = T.alloc_buffer(shape, dtype, scope=scope) + tileLayout = TileLayout(newShape, dim_map, tile_size) + # block_attr({"tile_view": {buffer.data: {"tile_size": tile_size, "dim_map": dim_map, "tiled_shape": newShape}}}) + block_attr({"tile_view": {buffer.data: tileLayout}}) + return buffer + + def alloc_shared(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index 777802d2c..b311c5fe7 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -2,6 +2,7 @@ # pylint: disable=invalid-name, unsupported-binary-operation from .layout import Layout # noqa: F401 +from .layout import TileLayout # noqa: F401 from .fragment import Fragment # noqa: F401 from .swizzle import ( make_swizzled_layout, # noqa: F401 diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py index 10e0357e6..94d79febf 100644 --- a/tilelang/layout/layout.py +++ b/tilelang/layout/layout.py @@ -6,6 +6,17 @@ from tilelang import _ffi_api +# Register the TileLayout class as a TVM object under the name "tl.TileLayout" +@tvm_ffi.register_object("tl.TileLayout") +class TileLayout(Node): + def __init__(self, tile_shape, dim_map, tile_size): + self.tile_shape = tile_shape + self.dim_map = dim_map + self.tile_size = tile_size + # Call the FFI constructor to create the TileLayout object in C++ backend + self.__init_handle_by_constructor__(_ffi_api.TileLayout, tile_shape, dim_map, tile_size) + + # Register the Layout class as a TVM object under the name "tl.Layout" @tvm_ffi.register_object("tl.Layout") class Layout(Node): From 0266f67440aa255ba6475e45271884349d8af417 Mon Sep 17 00:00:00 2001 From: xudemin Date: Mon, 12 Jan 2026 20:13:30 +0800 Subject: [PATCH 2/5] tile_based_on_share --- src/layout/layout.cc | 2 +- src/layout/layout.h | 10 +++++++++- tilelang/language/allocate.py | 28 ++++++++++++++++++---------- tilelang/layout/layout.py | 8 +++----- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 6204824fd..501db6840 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -79,7 +79,7 @@ TileLayout::TileLayout(Array input_shape_Iter, Array tile_size CHECK(is_zero(dim_map_Iter[i]->dom->min)); dim_map.push_back(dim_map_Iter[i]->dom->extent); } - auto n = tvm::ffi::make_object(input_size, tile_size, dim_map); + auto n = tvm::ffi::make_object(input_shape, tile_size, dim_map); data_ = std::move(n); } diff --git a/src/layout/layout.h b/src/layout/layout.h index 510e885c6..c0c33a3cb 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -43,11 +43,19 @@ class TileLayout; class Fragment; class TileLayoutNode : public Object { + public: TileLayoutNode() = default; TileLayoutNode(Array input_shape, Array tile_size, Array dim_map); + + virtual std::string DebugOutput() const; + static void RegisterReflection(); - TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", TileLayoutNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("tl.TileLayout", TileLayoutNode, Object); static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + Array input_shape_; + Array tile_size_; + Array dim_map_; }; class LayoutNode : public Object { diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 3c4d0140c..b3ed3e028 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -22,10 +22,11 @@ from typing_extensions import TypeVarTuple, Unpack # type: ignore from tilelang import tvm as tvm from tvm.script import tir as T -from tvm.tir import PrimExpr +from tvm.tir import PrimExpr, IntImm from tvm.script.parser.tir import block_attr from tvm.tir.buffer import Buffer -from tvm.tir.expr import FloatImm, IntImm +from tvm.tir.expr import FloatImm +from tvm.ir import Array from .v2.dtypes import dtype as tl_dtype from .v2.builder import OutTensor from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer @@ -34,7 +35,8 @@ _DType = TypeVar('_DType') -def handle_shape_tuple_when_alloc_shared_with_tileview(shape: tuple[Unpack[_Shapes]], tile_size=None, dim_map=None) -> tuple[Unpack[_Shapes]]: +def handle_shape_tuple_when_alloc_shared_with_tileview(shape: tuple[Unpack[_Shapes]], tile_size=None, dim_map=None) -> tuple: + """Returns (newShape, tile_size, dim_map) after validation and default value processing""" # check shape if shape is None: raise ValueError("shape map ValueError") @@ -79,25 +81,31 @@ def handle_shape_tuple_when_alloc_shared_with_tileview(shape: tuple[Unpack[_Shap # len(shape) = 1 if len(shape) == 1: if len(tile_size) == 1: - return (T.ceildiv(shape[0], tile_size[0]), tile_size[0]) + return ((T.ceildiv(shape[0], tile_size[0]), tile_size[0]), tile_size, dim_map) if tile_size[0] == 1: - return (T.ceildiv(shape[0], tile_size[0]), 1, tile_size[0]) + return ((T.ceildiv(shape[0], tile_size[0]), 1, tile_size[0]), tile_size, dim_map) if tile_size[1] == 1: - return (T.ceildiv(shape[0], tile_size[0]), tile_size[0], 1) + return ((T.ceildiv(shape[0], tile_size[0]), tile_size[0], 1), tile_size, dim_map) # normal shapeList=list(shape) for a_tile, a_dim in zip(tile_size, dim_map): shapeList[a_dim] = T.ceildiv(shape[a_dim], a_tile) shapeList += [a_tile] - return tuple(shapeList) + return (tuple(shapeList), tile_size, dim_map) + +from tilelang.layout import TileLayout -# from tilelang.layout import TileLayout def alloc_shared_with_tileview(shape: tuple[Unpack[_Shapes]], dtype: _DType, tile_size=None, dim_map=None, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]: - newShape = handle_shape_tuple_when_alloc_shared_with_tileview(shape=shape, tile_size=tile_size, dim_map=dim_map) + newShape, tile_size, dim_map = handle_shape_tuple_when_alloc_shared_with_tileview(shape=shape, tile_size=tile_size, dim_map=dim_map) if dtype == "bool": scope = "shared" buffer = T.alloc_buffer(shape, dtype, scope=scope) - tileLayout = TileLayout(newShape, dim_map, tile_size) + # Convert tuples to TVM Arrays and ensure elements are PrimExpr for correct FFI overload resolution + # Convert int to IntImm to match Array constructor signature + newShape_expr = [e if isinstance(e, PrimExpr) else IntImm("int32", e) for e in newShape] + tile_size_expr = [e if isinstance(e, PrimExpr) else IntImm("int32", e) for e in tile_size] + dim_map_expr = [e if isinstance(e, PrimExpr) else IntImm("int32", e) for e in dim_map] + tileLayout = TileLayout(Array(newShape_expr), Array(tile_size_expr), Array(dim_map_expr)) # block_attr({"tile_view": {buffer.data: {"tile_size": tile_size, "dim_map": dim_map, "tiled_shape": newShape}}}) block_attr({"tile_view": {buffer.data: tileLayout}}) return buffer diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py index 94d79febf..198e008e4 100644 --- a/tilelang/layout/layout.py +++ b/tilelang/layout/layout.py @@ -9,12 +9,10 @@ # Register the TileLayout class as a TVM object under the name "tl.TileLayout" @tvm_ffi.register_object("tl.TileLayout") class TileLayout(Node): - def __init__(self, tile_shape, dim_map, tile_size): - self.tile_shape = tile_shape - self.dim_map = dim_map - self.tile_size = tile_size + def __init__(self, input_shape, tile_size, dim_map): # Call the FFI constructor to create the TileLayout object in C++ backend - self.__init_handle_by_constructor__(_ffi_api.TileLayout, tile_shape, dim_map, tile_size) + # Parameter order matches C++: TileLayout(Array input_shape, Array tile_size, Array dim_map) + self.__init_handle_by_constructor__(_ffi_api.TileLayout, input_shape, tile_size, dim_map) # Register the Layout class as a TVM object under the name "tl.Layout" From 92bf7f3051d13fc647e997aa19bda9dae1b082d6 Mon Sep 17 00:00:00 2001 From: xudemin Date: Tue, 13 Jan 2026 11:00:03 +0800 Subject: [PATCH 3/5] tile_based_on_share --- src/layout/layout.cc | 28 +++------------------------- src/layout/layout.h | 2 +- 2 files changed, 4 insertions(+), 26 deletions(-) diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 501db6840..586cd11d7 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -61,28 +61,6 @@ TileLayoutNode::TileLayoutNode(Array input_shape, Array tile dim_map_ = dim_map; } -TileLayout::TileLayout(Array input_shape_Iter, Array tile_size_Iter, Array dim_map_Iter){ - Map vmap; - Array input_shape, tile_size, dim_map; - for (size_t i = 0; i < input_shape_Iter.size(); i++) { - vmap.Set(input_shape_Iter[i]->var, InputPlaceholder(i)); - CHECK(is_zero(input_shape_Iter[i]->dom->min)); - input_shape.push_back(input_shape_Iter[i]->dom->extent); - } - for (size_t i = 0; i < tile_size_Iter.size(); i++) { - vmap.Set(tile_size_Iter[i]->var, InputPlaceholder(i)); - CHECK(is_zero(tile_size_Iter[i]->dom->min)); - tile_size.push_back(tile_size_Iter[i]->dom->extent); - } - for (size_t i = 0; i < dim_map_Iter.size(); i++) { - vmap.Set(dim_map_Iter[i]->var, InputPlaceholder(i)); - CHECK(is_zero(dim_map_Iter[i]->dom->min)); - dim_map.push_back(dim_map_Iter[i]->dom->extent); - } - auto n = tvm::ffi::make_object(input_shape, tile_size, dim_map); - data_ = std::move(n); -} - Layout::Layout(Array forward_var, Array forward_index) { Map vmap; Array input_size; @@ -770,9 +748,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def_packed("tl.TileLayout", [](PackedArgs args, Any *rv) { - *rv = TileLayout(args[0].cast>(), - args[1].cast>(), - args[2].cast>()); + *rv = TileLayout(args[0].cast>(), + args[1].cast>(), + args[2].cast>()); }) .def_packed("tl.Layout", [](PackedArgs args, Any *rv) { diff --git a/src/layout/layout.h b/src/layout/layout.h index c0c33a3cb..ba8b0ac8f 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -112,7 +112,7 @@ class LayoutNode : public Object { class TileLayout : public ObjectRef { public: - TVM_DLL TileLayout(Array input_shape_Iter, Array tile_size_Iter, Array dim_map_Iter); + // Only keep the PrimExpr version to avoid FFI overload resolution issues TVM_DLL TileLayout(Array input_shape, Array tile_size, Array dim_map); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileLayout, ObjectRef, TileLayoutNode); }; From e853ae0cb90b3dc8b12b41b0aa452ea40d097b7f Mon Sep 17 00:00:00 2001 From: xudemin Date: Tue, 13 Jan 2026 11:03:12 +0800 Subject: [PATCH 4/5] tile_based_on_share --- examples/gemm/example_gemm.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index 4327cd016..f18cd388a 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -1,7 +1,6 @@ import tilelang import tilelang.language as T -tilelang.disable_cache() @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @@ -13,10 +12,8 @@ def gemm( C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - # A_shared = T.alloc_shared((block_M, block_K), dtype) - # B_shared = T.alloc_shared((block_K, block_N), dtype) - A_shared = T.alloc_shared_with_tileview((block_M, block_K), dtype) - B_shared = T.alloc_shared_with_tileview((block_K, block_N), dtype) + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) From 1f2b37df40a10c2169e19752bd4309a52761d648 Mon Sep 17 00:00:00 2001 From: xudemin Date: Tue, 13 Jan 2026 11:29:06 +0800 Subject: [PATCH 5/5] tile_based_on_share --- testing/python/language/test_alloc_tile.py | 66 ++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 testing/python/language/test_alloc_tile.py diff --git a/testing/python/language/test_alloc_tile.py b/testing/python/language/test_alloc_tile.py new file mode 100644 index 000000000..281abceda --- /dev/null +++ b/testing/python/language/test_alloc_tile.py @@ -0,0 +1,66 @@ +import tilelang +import tilelang.language as T + + +tilelang.disable_cache() + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared_with_tileview((block_M, block_K), dtype) + B_shared = T.alloc_shared_with_tileview((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + print("c:") + print(c) + print("ref_c:") + print(ref_c) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # Get CUDA Source + print("CUDA Source:") + print(kernel.get_kernel_source()) + + # benchmark + profiler = kernel.get_profiler() + latency = profiler.do_bench(backend="cupti") + # latency = profiler.do_bench() + print(f"tilelang Latency: {latency}ms") + + +if __name__ == "__main__": + main() +