From 0c317452ad7eceeaec68e93534db449c135cb5d4 Mon Sep 17 00:00:00 2001 From: zongzhengWei Date: Thu, 22 Jan 2026 15:42:29 +0800 Subject: [PATCH 1/4] add Tiles() in loop.py --- tilelang/language/loop.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tilelang/language/loop.py b/tilelang/language/loop.py index 3478b6cc1..e728ea4e2 100644 --- a/tilelang/language/loop.py +++ b/tilelang/language/loop.py @@ -32,6 +32,30 @@ def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None): return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member +def Tiles(shared_buf: tir.Buffer, parallel: bool = False): + """Tools to construct tiled for loop over a shared memory buffer. + + Parameters + ---------- + shared_buf : tir.Buffer + The shared memory buffer to be tiled. + + parallel : bool + Whether to generate a parallel tiled loop. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + annotations = { + "tile_level_loop": tir.IntImm("int32", 1), + "tile.parallel": tir.IntImm("int32", 1 if parallel else 0), + } + + return _ffi_api.Parallel(tuple(shared_buf.shape), annotations) # type: ignore[attr-defined] # pylint: disable=no-member + + def Persistent( domain: list[tir.PrimExpr], wave_size: tir.PrimExpr, From e80db8a9bf5a43ca2c409bcd4bdf21cb2af0fe39 Mon Sep 17 00:00:00 2001 From: zongzhengWei Date: Fri, 6 Feb 2026 11:02:28 +0800 Subject: [PATCH 2/4] WIP: tiles loop related work, added legalize_tiles_loop.cc, tiles_loop.cc, modified src/ir.cc language/__init__.py, language/loop.py, transform/__init__.py --- src/ir.cc | 27 ++++ src/transform/legalize_tiles_loop.cc | 196 +++++++++++++++++++++++ src/transform/tiles_loop.cc | 230 +++++++++++++++++++++++++++ tilelang/language/__init__.py | 1 + tilelang/language/loop.py | 10 +- tilelang/transform/__init__.py | 22 +++ 6 files changed, 480 insertions(+), 6 deletions(-) create mode 100644 src/transform/legalize_tiles_loop.cc create mode 100644 src/transform/tiles_loop.cc diff --git a/src/ir.cc b/src/ir.cc index 3d2b3ecdc..1ef1c1edd 100644 --- a/src/ir.cc +++ b/src/ir.cc @@ -78,6 +78,32 @@ ForFrame ParallelFor(const Array &extents, }; return ForFrame(n); } +ForFrame TilesFor(const Array &extents, + const Map &annotations) { + using namespace tvm::tir; + ObjectPtr n = tvm::ffi::make_object(); + n->vars.reserve(extents.size()); + n->doms.reserve(extents.size()); + for (const auto &extent : extents) { + DataType dtype = extent.dtype(); + n->vars.push_back(Var("v", extent.dtype())); + n->doms.push_back(Range(make_const(dtype, 0), extent)); + } + n->f_make_for_loop = [annotations](const Array &vars, + const Array &doms, + Stmt body) -> Stmt { + ICHECK_EQ(vars.size(), doms.size()); + int n = vars.size(); + for (int i = n - 1; i >= 0; --i) { + Range dom = doms[i]; + Var var = vars[i]; + body = For(var, dom->min, dom->extent, ForKind::kSerial, body, + /*thread_binding=*/std::nullopt, /*annotations=*/annotations); + } + return body; + }; + return ForFrame(n); +} ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, const Array &order, @@ -302,6 +328,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tl.Parallel", ParallelFor) + .def("tl.Tiles", TilesFor) .def("tl.Pipelined", PipelinedFor) .def("tl.Persistent", PersistentFor) .def("tl.KernelLaunch", KernelLaunch); diff --git a/src/transform/legalize_tiles_loop.cc b/src/transform/legalize_tiles_loop.cc new file mode 100644 index 000000000..0c606b18a --- /dev/null +++ b/src/transform/legalize_tiles_loop.cc @@ -0,0 +1,196 @@ +#include + +#include +#include +#include + +#include "../support/ffi_aliases.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +namespace attr { +// ---- block-level ---- +constexpr const char *tile_view = "tile_view"; +constexpr const char *tile_size = "tile_size"; +constexpr const char *dim_map = "dim_map"; +constexpr const char *new_shape = "new_shape"; + +// ---- loop-level ---- +constexpr const char *tile_level_loop = "tile_level_loop"; +constexpr const char *tiled_buffer = "tiled_buffer"; + +// ---- added by this pass ---- +constexpr const char *tile_execution = "tile_execution"; +constexpr const char *tile_new_shape = "tile_new_shape"; +constexpr const char *tile_tile_size = "tile.tile_size"; +constexpr const char *tile_dim_map = "tile.dim_map"; +} // namespace attr + +/* ============================================================ + * Collector + * + * Collect block-level tile_view: + * Map> + * ============================================================ */ +class TileViewCollector : public StmtExprVisitor { +public: + using TileViewMap = std::unordered_map, + ObjectPtrHash, ObjectPtrEqual>; + + /*! \brief Entry point */ + static TileViewMap Collect(const PrimFunc &f) { + TileViewCollector collector; + collector(f->body); + return std::move(collector.tile_views_); + } + +private: + /*! \brief Collect tile_view annotations from BlockNode */ + void VisitStmt_(const BlockNode *block) final { + auto it = block->annotations.find(attr::tile_view); + if (it != block->annotations.end()) { + auto tile_view = Downcast>>((*it).second); + + for (const auto &kv : tile_view) { + // kv.first : buffer.data (Var) + // kv.second : {"tile_size", "dim_map", "new_shape"} + + auto res = tile_views_.emplace(kv.first, kv.second); + ICHECK(res.second) << "Duplicate tile_view for buffer " << kv.first; + } + } + StmtExprVisitor::VisitStmt_(block); + } + +private: + TileViewMap tile_views_; +}; + +/* ============================================================ + * Rewriter + * + * Rewrite tile-level For loops: + * extent := new_shape[tile_dim] + * ============================================================ */ +class LegalizeTilesLoopRewriter : public StmtExprMutator { +public: + using TileViewMap = std::unordered_map, + ObjectPtrHash, ObjectPtrEqual>; + + /*! \brief Entry point */ + static PrimFunc Rewrite(PrimFunc f) { + LegalizeTilesLoopRewriter rewriter; + rewriter.tile_views_ = TileViewCollector::Collect(f); + LOG(INFO) << "Collected " << rewriter.tile_views_.size() << " tile_view(s)" + << " in LegalizeTilesLoopRewriter." << std::endl; + // Fast path: no tile_view, nothing to do + if (rewriter.tile_views_.empty()) { + return f; + } + + f.CopyOnWrite()->body = rewriter(f->body); + return f; + } + +private: + /*! \brief Rewrite tile-level for loops */ + Stmt VisitStmt_(const ForNode *loop) final { + // Only care about tile-level loops + if (!loop->annotations.count(attr::tile_level_loop)) { + return StmtExprMutator::VisitStmt_(loop); + } + + // Must have tiled_buffer + auto buf_it = loop->annotations.find(attr::tiled_buffer); + if (buf_it == loop->annotations.end()) { + return StmtExprMutator::VisitStmt_(loop); + } + + Var buffer_data = Downcast((*buf_it).second); + LOG(INFO) << "Legalizing tile loop for buffer " << buffer_data; + auto view_it = tile_views_.find(buffer_data); + if (view_it == tile_views_.end()) { + // No tile_view for this buffer + return StmtExprMutator::VisitStmt_(loop); + } + + // Enter tile loop (MVP assumption: nesting order == tile dim order) + int dim = tile_loop_depth_++; + Stmt new_body = VisitStmt(loop->body); + tile_loop_depth_--; + + const Map &view = view_it->second; + + ObjectRef obj = view.at(attr::new_shape); + auto arr = Downcast>(obj); + + LOG(INFO) << "new_shape raw size: " << arr.size(); + + for (int i = 0; i < arr.size(); i++) { + LOG(INFO) << " new_shape[" << i << "] type: " << arr[i]->GetTypeKey(); + } + + auto new_shape_opt = obj.as>(); + LOG(INFO) << "tile_view.new_shape must be Array, we getted " + << obj->GetTypeKey(); + + Array new_shape = new_shape_opt.value(); + + // auto new_shape = + // Downcast>(view.at(attr::new_shape)); + LOG(INFO) << " - Retrieved tile_view for buffer " << buffer_data + << " at tile dim " << dim << " with new_shape size " + << new_shape.size(); + ICHECK(dim < static_cast(new_shape.size())) + << "Tile loop depth exceeds new_shape rank"; + LOG(INFO) << " - Setting loop extent to new_shape[" << dim + << "] = " << new_shape[dim]; + // Rewrite loop + For new_for = ffi::GetRef(loop); + auto *n = new_for.CopyOnWrite(); + n->extent = new_shape[dim]; + n->body = new_body; + + // Attach normalized annotations for later passes + n->annotations.Set(attr::tile_execution, Integer(1)); + n->annotations.Set(attr::tile_new_shape, new_shape); + n->annotations.Set(attr::tile_tile_size, + Downcast>(view.at(attr::tile_size))); + n->annotations.Set(attr::tile_dim_map, + Downcast>(view.at(attr::dim_map))); + + return new_for; + } + +private: + /*! \brief Collected tile_view info */ + TileViewMap tile_views_; + + /*! \brief Tile loop nesting depth (MVP) */ + int tile_loop_depth_{0}; +}; + +using namespace tir::transform; + +tvm::transform::Pass LegalizeTilesLoop() { + auto pass_func = [](PrimFunc f, const IRModule &, + const PassContext &) -> PrimFunc { + return LegalizeTilesLoopRewriter::Rewrite(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, + /*opt_level=*/0, "tl.LegalizeTilesLoop", {}); +} + +/* ============================================================ + * FFI Registration + * ============================================================ */ +TVM_FFI_STATIC_INIT_BLOCK() { + tvm::ffi::reflection::GlobalDef().def("tl.transform.LegalizeTilesLoop", + LegalizeTilesLoop); +} + +} // namespace tl +} // namespace tvm diff --git a/src/transform/tiles_loop.cc b/src/transform/tiles_loop.cc new file mode 100644 index 000000000..20fbde477 --- /dev/null +++ b/src/transform/tiles_loop.cc @@ -0,0 +1,230 @@ +#include +#include +#include +#include + +#include "../support/ffi_aliases.h" + +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +/*! + * \brief Annotation keys used by Tiles lowering pipeline. + * + * NOTE: + * - tile_execution : marks loops originating from T.Tiles() + * - tile.tile_size : 2D tile size, attached by LegalizeTilesLoop + */ +namespace attr { +constexpr const char *tile_execution = "tile_execution"; +constexpr const char *tile_tile_size = "tile.tile_size"; +} // namespace attr + +/*! + * \brief TilesLoopRewriter + * + * This pass performs the lowering of T.Tiles() into + * serial + vectorized inner loops. + * + * Design principles: + * ------------------ + * 1. Post-order traversal: + * - Always visit loop body first. + * - Structural decisions are made only after the body is stable. + * + * 2. Scope gating via annotation: + * - Only loops marked with `tile_execution` are considered. + * - Other loops (e.g. T.parallel, normal serial loops) are ignored. + * + * 3. Structural matching (not annotation-driven semantics): + * - Actual lowering only happens when we see: + * + * for i (serial, tile_execution): + * for j (serial, tile_execution): + * BODY + * + * - i, j are assumed to be the 2D tile-execution axes. + * + * 4. Exactly-once lowering: + * - The inner `for j` is consumed during lowering. + * - Lowering happens exactly once per T.Tiles(). + */ +class TilesLoopRewriter : public StmtExprMutator { +public: + static PrimFunc Rewrite(PrimFunc f) { + LOG(INFO) << "[TilesLoop] Start rewriting PrimFunc"; + TilesLoopRewriter rewriter; + f.CopyOnWrite()->body = rewriter(f->body); + LOG(INFO) << "[TilesLoop] Finished rewriting PrimFunc"; + return f; + } + +private: + /*! \brief Check whether a loop belongs to T.Tiles() scope */ + bool IsTilesScope(const ForNode *loop) const { + return loop->annotations.count(attr::tile_execution); + } + + /*! \brief Check whether a loop is a serial loop */ + bool IsSerialFor(const ForNode *loop) const { + return loop->kind == ForKind::kSerial; + } + + /*! \brief Read tile size annotation if present */ + Optional> GetTileSize(const ForNode *loop) const { + auto it = loop->annotations.find(attr::tile_tile_size); + if (it == loop->annotations.end()) { + return std::nullopt; + } + return Downcast>((*it).second); + } + + /*! + * \brief Update loop body while preserving other fields. + * + * This helper avoids unnecessary CopyOnWrite when body is unchanged. + */ + Stmt UpdateBody(const ForNode *loop, Stmt new_body) { + if (new_body.same_as(loop->body)) { + return ffi::GetRef(loop); + } + For f = ffi::GetRef(loop); + f.CopyOnWrite()->body = new_body; + return f; + } + + /*! + * \brief Visit ForNode (post-order). + * + * Execution order: + * 1. Recursively visit body. + * 2. Gate by tile_execution. + * 3. Try to match 2D tile pattern. + * 4. Perform lowering if matched. + */ + Stmt VisitStmt_(const ForNode *loop) final { + // ------------------------------------------------------------ + // (1) Post-order: first visit the body + // ------------------------------------------------------------ + Stmt new_body = VisitStmt(loop->body); + + // ------------------------------------------------------------ + // (2) Scope gate: only care about T.Tiles() loops + // ------------------------------------------------------------ + if (!IsTilesScope(loop)) { + // Not a Tiles loop: just propagate rewritten body + return UpdateBody(loop, new_body); + } + + LOG(INFO) << "[TilesLoop] Visiting tile loop: " + << loop->loop_var->name_hint; + + // ------------------------------------------------------------ + // (3) Structural pattern matching: + // for i: + // for j: + // BODY + // ------------------------------------------------------------ + const ForNode *inner = new_body.as(); + if (!inner) { + // Body is not a loop → cannot form a 2D tile + LOG(INFO) << "[TilesLoop] Body is not a ForNode, skip lowering"; + return UpdateBody(loop, new_body); + } + + if (!IsSerialFor(loop) || !IsSerialFor(inner)) { + // Only handle serial-serial pattern + LOG(INFO) << "[TilesLoop] Non-serial loop detected, skip lowering"; + return UpdateBody(loop, new_body); + } + + if (!IsTilesScope(inner)) { + // Inner loop must also belong to Tiles scope + LOG(INFO) + << "[TilesLoop] Inner loop is not tile_execution, skip lowering"; + return UpdateBody(loop, new_body); + } + + // ------------------------------------------------------------ + // (4) Read tile size (must exist and be 2D) + // ------------------------------------------------------------ + auto tile_size_opt = GetTileSize(loop); + if (!tile_size_opt.defined()) { + LOG(INFO) << "[TilesLoop] Missing tile_tile_size, skip lowering"; + return UpdateBody(loop, new_body); + } + + Array tile_size = tile_size_opt.value(); + ICHECK_EQ(tile_size.size(), 2) << "TilesLoop expects exactly 2D tile_size"; + + LOG(INFO) << "[TilesLoop] Performing 2D tile lowering"; + + // ------------------------------------------------------------ + // (5) Perform tile lowering + // ------------------------------------------------------------ + Var ti = loop->loop_var; + Var tj = inner->loop_var; + + // Tile-inner loop variables + Var ki("ki"); + Var kj("kj"); + + // Index substitution: + // i -> i * Ts0 + ki + // j -> j * Ts1 + kj + Map vmap; + vmap.Set(ti, ti * tile_size[0] + ki); + vmap.Set(tj, tj * tile_size[1] + kj); + + // Apply substitution to the original tile body + Stmt tiled_body = Substitute(inner->body, vmap); + + // Construct inner tile loops: + // ki : serial + // kj : vectorized + tiled_body = For(kj, 0, tile_size[1], ForKind::kVectorized, tiled_body); + + tiled_body = For(ki, 0, tile_size[0], ForKind::kSerial, tiled_body); + + // Replace the original inner loop body,j loop is consumed + For new_inner = ffi::GetRef(inner); + new_inner.CopyOnWrite()->body = tiled_body; + // Replace the original outer loop body, i loop remains + For new_outer = ffi::GetRef(loop); + new_outer.CopyOnWrite()->body = new_inner; + + LOG(INFO) << "[TilesLoop] Tile lowering done at loop: " + << loop->loop_var->name_hint; + + return new_outer; + } +}; + +using namespace tir::transform; + +/*! + * \brief Create TilesLoop pass. + */ +Pass TilesLoop() { + auto pass_func = [](PrimFunc f, const IRModule &, + const PassContext &) -> PrimFunc { + return TilesLoopRewriter::Rewrite(std::move(f)); + }; + + return CreatePrimFuncPass(pass_func, + /*opt_level=*/0, "tl.TilesLoop", {}); +} + +/*! + * \brief FFI registration. + */ +TVM_FFI_STATIC_INIT_BLOCK() { + tvm::ffi::reflection::GlobalDef().def("tl.transform.TilesLoop", TilesLoop); +} + +} // namespace tl +} // namespace tvm diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 0ea385f5b..2b254f4f6 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -28,6 +28,7 @@ ) from .loop import ( Parallel, # noqa: F401 + Tiles, # noqa: F401 Persistent, # noqa: F401 Pipelined, # noqa: F401 serial, # noqa: F401 diff --git a/tilelang/language/loop.py b/tilelang/language/loop.py index e728ea4e2..3b663fd98 100644 --- a/tilelang/language/loop.py +++ b/tilelang/language/loop.py @@ -48,12 +48,10 @@ def Tiles(shared_buf: tir.Buffer, parallel: bool = False): res : frame.ForFrame The ForFrame. """ - annotations = { - "tile_level_loop": tir.IntImm("int32", 1), - "tile.parallel": tir.IntImm("int32", 1 if parallel else 0), - } - - return _ffi_api.Parallel(tuple(shared_buf.shape), annotations) # type: ignore[attr-defined] # pylint: disable=no-member + annotations = {"tile_level_loop": tir.IntImm("int32", 0), "tiled_buffer": shared_buf.data} + if parallel: + annotations.update({"tile_level_loop": tir.IntImm("int32", 1)}) + return _ffi_api.Tiles(tuple(shared_buf.shape), annotations) # type: ignore[attr-defined] # pylint: disable=no-member def Persistent( diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 0216536c2..80cd7ed81 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -271,6 +271,28 @@ def LegalizeVectorizedLoop(): return _ffi_api.LegalizeVectorizedLoop() # type: ignore +def LegalizeTilesLoop(): + """LegalizeTilesLoop + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LegalizeTilesLoop() # type: ignore + + +def TilesLoop(): + """TilesLoop + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.TilesLoop() # type: ignore + + def LegalizeSafeMemoryAccess(): """LegalizeLoopVectorize From f9d563d1369659b911c041966b586182ec5a51a8 Mon Sep 17 00:00:00 2001 From: zongzhengWei Date: Tue, 10 Feb 2026 09:57:14 +0800 Subject: [PATCH 3/4] adapted to the TileView metadata and fixed some bugs --- src/transform/legalize_tiles_loop.cc | 164 +++++++++++++-------------- src/transform/tiles_loop.cc | 24 ++-- tilelang/engine/phase.py | 4 + 3 files changed, 94 insertions(+), 98 deletions(-) diff --git a/src/transform/legalize_tiles_loop.cc b/src/transform/legalize_tiles_loop.cc index 0c606b18a..ab86801f4 100644 --- a/src/transform/legalize_tiles_loop.cc +++ b/src/transform/legalize_tiles_loop.cc @@ -5,89 +5,87 @@ #include #include "../support/ffi_aliases.h" +#include "../tileview/tileview.h" namespace tvm { namespace tl { using namespace tir; +/* ============================================================ + * Attributes + * ============================================================ */ namespace attr { -// ---- block-level ---- -constexpr const char *tile_view = "tile_view"; -constexpr const char *tile_size = "tile_size"; -constexpr const char *dim_map = "dim_map"; -constexpr const char *new_shape = "new_shape"; - -// ---- loop-level ---- +// ---- loop-level (existing) ---- constexpr const char *tile_level_loop = "tile_level_loop"; constexpr const char *tiled_buffer = "tiled_buffer"; -// ---- added by this pass ---- -constexpr const char *tile_execution = "tile_execution"; -constexpr const char *tile_new_shape = "tile_new_shape"; +// ---- added / normalized by this pass ---- +// Mark the loops corresponding to the index map(index_map=(-2, -1)) for +// subsequent passes +constexpr const char *tile_execution_loop = "tile.execution"; +constexpr const char *tile_new_shape = "tile.buffer_new_shape"; constexpr const char *tile_tile_size = "tile.tile_size"; constexpr const char *tile_dim_map = "tile.dim_map"; } // namespace attr /* ============================================================ - * Collector + * TileView Collector * - * Collect block-level tile_view: - * Map> + * Collect block-level: + * block.annotations["tileview_map"] + * : Map * ============================================================ */ class TileViewCollector : public StmtExprVisitor { public: - using TileViewMap = std::unordered_map, - ObjectPtrHash, ObjectPtrEqual>; + using TileViewMap = + std::unordered_map; - /*! \brief Entry point */ static TileViewMap Collect(const PrimFunc &f) { TileViewCollector collector; collector(f->body); - return std::move(collector.tile_views_); + return std::move(collector.tileviews_); } private: - /*! \brief Collect tile_view annotations from BlockNode */ void VisitStmt_(const BlockNode *block) final { - auto it = block->annotations.find(attr::tile_view); + auto it = block->annotations.find(attr::kTileViewMap); if (it != block->annotations.end()) { - auto tile_view = Downcast>>((*it).second); - - for (const auto &kv : tile_view) { - // kv.first : buffer.data (Var) - // kv.second : {"tile_size", "dim_map", "new_shape"} - - auto res = tile_views_.emplace(kv.first, kv.second); - ICHECK(res.second) << "Duplicate tile_view for buffer " << kv.first; + auto tv_map = Downcast>((*it).second); + for (const auto &kv : tv_map) { + auto res = tileviews_.emplace(kv.first, kv.second); + ICHECK(res.second) << "Duplicate TileView for buffer " << kv.first; } } StmtExprVisitor::VisitStmt_(block); } private: - TileViewMap tile_views_; + TileViewMap tileviews_; }; /* ============================================================ - * Rewriter + * LegalizeTilesLoopRewriter * * Rewrite tile-level For loops: - * extent := new_shape[tile_dim] + * for ... in T.Tiles(...) + * into: + * extent := TileView::TiledBufferShape()[tile_dim] + * + * Assumptions: + * - Tile loop nesting order == TileView dimension order + * - TileView already validated semantic correctness * ============================================================ */ class LegalizeTilesLoopRewriter : public StmtExprMutator { public: - using TileViewMap = std::unordered_map, - ObjectPtrHash, ObjectPtrEqual>; + using TileViewMap = + std::unordered_map; - /*! \brief Entry point */ static PrimFunc Rewrite(PrimFunc f) { LegalizeTilesLoopRewriter rewriter; - rewriter.tile_views_ = TileViewCollector::Collect(f); - LOG(INFO) << "Collected " << rewriter.tile_views_.size() << " tile_view(s)" - << " in LegalizeTilesLoopRewriter." << std::endl; - // Fast path: no tile_view, nothing to do - if (rewriter.tile_views_.empty()) { + rewriter.tileviews_ = TileViewCollector::Collect(f); + + if (rewriter.tileviews_.empty()) { return f; } @@ -96,90 +94,88 @@ class LegalizeTilesLoopRewriter : public StmtExprMutator { } private: - /*! \brief Rewrite tile-level for loops */ Stmt VisitStmt_(const ForNode *loop) final { - // Only care about tile-level loops + // Only rewrite tile-level loops if (!loop->annotations.count(attr::tile_level_loop)) { return StmtExprMutator::VisitStmt_(loop); } - // Must have tiled_buffer + // Must be associated with a tiled buffer auto buf_it = loop->annotations.find(attr::tiled_buffer); if (buf_it == loop->annotations.end()) { return StmtExprMutator::VisitStmt_(loop); } Var buffer_data = Downcast((*buf_it).second); - LOG(INFO) << "Legalizing tile loop for buffer " << buffer_data; - auto view_it = tile_views_.find(buffer_data); - if (view_it == tile_views_.end()) { - // No tile_view for this buffer + + auto tv_it = tileviews_.find(buffer_data); + if (tv_it == tileviews_.end()) { return StmtExprMutator::VisitStmt_(loop); } - // Enter tile loop (MVP assumption: nesting order == tile dim order) + const TileView &tv = tv_it->second; + + // Enter tile loop (depth == tile dimension) int dim = tile_loop_depth_++; Stmt new_body = VisitStmt(loop->body); tile_loop_depth_--; - const Map &view = view_it->second; + Array tiled_shape = tv->TiledBufferShape(); - ObjectRef obj = view.at(attr::new_shape); - auto arr = Downcast>(obj); - - LOG(INFO) << "new_shape raw size: " << arr.size(); - - for (int i = 0; i < arr.size(); i++) { - LOG(INFO) << " new_shape[" << i << "] type: " << arr[i]->GetTypeKey(); - } + ICHECK(dim < static_cast(tiled_shape.size())) + << "Tile loop depth exceeds tiled buffer rank"; - auto new_shape_opt = obj.as>(); - LOG(INFO) << "tile_view.new_shape must be Array, we getted " - << obj->GetTypeKey(); - - Array new_shape = new_shape_opt.value(); - - // auto new_shape = - // Downcast>(view.at(attr::new_shape)); - LOG(INFO) << " - Retrieved tile_view for buffer " << buffer_data - << " at tile dim " << dim << " with new_shape size " - << new_shape.size(); - ICHECK(dim < static_cast(new_shape.size())) - << "Tile loop depth exceeds new_shape rank"; - LOG(INFO) << " - Setting loop extent to new_shape[" << dim - << "] = " << new_shape[dim]; // Rewrite loop For new_for = ffi::GetRef(loop); auto *n = new_for.CopyOnWrite(); - n->extent = new_shape[dim]; + n->extent = tiled_shape[dim]; n->body = new_body; - // Attach normalized annotations for later passes - n->annotations.Set(attr::tile_execution, Integer(1)); - n->annotations.Set(attr::tile_new_shape, new_shape); - n->annotations.Set(attr::tile_tile_size, - Downcast>(view.at(attr::tile_size))); - n->annotations.Set(attr::tile_dim_map, - Downcast>(view.at(attr::dim_map))); + // Attach normalized loop annotations + n->annotations.Set(attr::tile_new_shape, tiled_shape); + n->annotations.Set(attr::tile_tile_size, tv->TileShape()); + n->annotations.Set(attr::tile_dim_map, tv->IndexMap()); + // ---- Determine whether this loop is a tile execution dimension ---- + int buf_ndim = static_cast(tv->BufferShape().size()); + bool is_tile_execution = false; + + for (const PrimExpr &pe : tv->IndexMap()) { + const auto *imm = pe.as(); + ICHECK(imm) << "index_map must contain IntImm"; + + int mapped_dim = static_cast(imm->value); + if (mapped_dim < 0) { + mapped_dim += buf_ndim; + } + + if (mapped_dim == dim) { + is_tile_execution = true; + break; + } + } + + if (is_tile_execution) { + n->annotations.Set(attr::tile_execution_loop, Integer(1)); + } return new_for; } private: - /*! \brief Collected tile_view info */ - TileViewMap tile_views_; - - /*! \brief Tile loop nesting depth (MVP) */ + TileViewMap tileviews_; int tile_loop_depth_{0}; }; +/* ============================================================ + * Pass Registration + * ============================================================ */ using namespace tir::transform; tvm::transform::Pass LegalizeTilesLoop() { - auto pass_func = [](PrimFunc f, const IRModule &, - const PassContext &) -> PrimFunc { + auto pass_func = [](PrimFunc f, const IRModule &, const PassContext &) { return LegalizeTilesLoopRewriter::Rewrite(std::move(f)); }; + return CreatePrimFuncPass(pass_func, /*opt_level=*/0, "tl.LegalizeTilesLoop", {}); } diff --git a/src/transform/tiles_loop.cc b/src/transform/tiles_loop.cc index 20fbde477..51434b4a5 100644 --- a/src/transform/tiles_loop.cc +++ b/src/transform/tiles_loop.cc @@ -16,11 +16,11 @@ using namespace tir; * \brief Annotation keys used by Tiles lowering pipeline. * * NOTE: - * - tile_execution : marks loops originating from T.Tiles() - * - tile.tile_size : 2D tile size, attached by LegalizeTilesLoop + * - tile.execution : loops corresponding to the index map(index_map=(-2, -1)) + * - tile.tile_size : 2D tile size, e.g. (32, 32) */ namespace attr { -constexpr const char *tile_execution = "tile_execution"; +constexpr const char *tile_execution = "tile.execution"; constexpr const char *tile_tile_size = "tile.tile_size"; } // namespace attr @@ -43,15 +43,12 @@ constexpr const char *tile_tile_size = "tile.tile_size"; * 3. Structural matching (not annotation-driven semantics): * - Actual lowering only happens when we see: * - * for i (serial, tile_execution): - * for j (serial, tile_execution): + * for i (serial, annotation="tile.execution..."): + * for j (serial, annotation="tile.execution..."): * BODY * * - i, j are assumed to be the 2D tile-execution axes. - * - * 4. Exactly-once lowering: - * - The inner `for j` is consumed during lowering. - * - Lowering happens exactly once per T.Tiles(). + * 4. construct and insert two new ForNodes. */ class TilesLoopRewriter : public StmtExprMutator { public: @@ -104,7 +101,7 @@ class TilesLoopRewriter : public StmtExprMutator { * 1. Recursively visit body. * 2. Gate by tile_execution. * 3. Try to match 2D tile pattern. - * 4. Perform lowering if matched. + * 4. construct two new ForNodes if matched. */ Stmt VisitStmt_(const ForNode *loop) final { // ------------------------------------------------------------ @@ -154,7 +151,6 @@ class TilesLoopRewriter : public StmtExprMutator { // ------------------------------------------------------------ auto tile_size_opt = GetTileSize(loop); if (!tile_size_opt.defined()) { - LOG(INFO) << "[TilesLoop] Missing tile_tile_size, skip lowering"; return UpdateBody(loop, new_body); } @@ -164,7 +160,7 @@ class TilesLoopRewriter : public StmtExprMutator { LOG(INFO) << "[TilesLoop] Performing 2D tile lowering"; // ------------------------------------------------------------ - // (5) Perform tile lowering + // (5) Perform tile lowering, construct and insert new ForNodes // ------------------------------------------------------------ Var ti = loop->loop_var; Var tj = inner->loop_var; @@ -190,10 +186,10 @@ class TilesLoopRewriter : public StmtExprMutator { tiled_body = For(ki, 0, tile_size[0], ForKind::kSerial, tiled_body); - // Replace the original inner loop body,j loop is consumed + // Replace the original inner loop body For new_inner = ffi::GetRef(inner); new_inner.CopyOnWrite()->body = tiled_body; - // Replace the original outer loop body, i loop remains + // Replace the original outer loop body For new_outer = ffi::GetRef(loop); new_outer.CopyOnWrite()->body = new_inner; diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index fc42f1812..67b683505 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -151,6 +151,10 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: if should_force_let_inline(): # Force-let inline whenever the pass config requests it. mod = tilelang.transform.LetInline()(mod) + # read TileView metadata and attach them to Tiles loops. + mod = tilelang.transform.LegalizeTilesLoop()(mod) + # add two for loops for tile loops + mod = tilelang.transform.TilesLoop()(mod) # Add wrapper for single buf store mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) # Normalize negative indices to canonical non-negative form From 24e9658c2ca4edb2dfddcadb8a415d96f79751bc Mon Sep 17 00:00:00 2001 From: zongzhengWei Date: Tue, 10 Feb 2026 12:26:56 +0800 Subject: [PATCH 4/4] Added two test files for legalize_tiles_loop pass and tiles_loop pass --- ...lang_mesh_transform_legalize_tiles_loop.py | 203 +++++++++++++++ ...test_tilelang_mesh_transform_tiles_loop.py | 232 ++++++++++++++++++ 2 files changed, 435 insertions(+) create mode 100644 testing/python/transform/test_tilelang_mesh_transform_legalize_tiles_loop.py create mode 100644 testing/python/transform/test_tilelang_mesh_transform_tiles_loop.py diff --git a/testing/python/transform/test_tilelang_mesh_transform_legalize_tiles_loop.py b/testing/python/transform/test_tilelang_mesh_transform_legalize_tiles_loop.py new file mode 100644 index 000000000..d8788105e --- /dev/null +++ b/testing/python/transform/test_tilelang_mesh_transform_legalize_tiles_loop.py @@ -0,0 +1,203 @@ +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing +from tvm import tir +from tvm import IRModule +from tilelang.tileview import make_tileview + + +# --------------------------------------------------------- +# Helper: 2D tiled parallel kernel +# --------------------------------------------------------- +def dot_mul_tiled_parallel_2d( + M, + N, + block_M, + block_N, + dtype="float16", + accum_dtype="float16", +): + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128, + ) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) + C_shared = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Attach tileview metadata to buffers + T.annotate_tileview({ + A_shared: make_tileview(A_shared, (32, 32), (-2, -1)), + B_shared: make_tileview(B_shared, (32, 32), (-2, -1)), + C_shared: make_tileview(C_shared, (32, 32), (-2, -1)), + }) + + T.clear(C_shared) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(B[by * block_M, bx * block_N], B_shared) + + # Tile loop (target of LegalizeTilesLoop) + for i, j in T.Tiles(A_shared, parallel=True): + C_shared[i, j] = A_shared[i, j] * B_shared[i, j] + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +# --------------------------------------------------------- +# Helper: 3D tiled parallel kernel +# --------------------------------------------------------- +def dot_mul_tiled_parallel_3d( + B, + M, + N, + block_B, + block_M, + block_N, + dtype="float16", + accum_dtype="float16", +): + + @T.prim_func + def main( + A: T.Tensor((B, M, N), dtype), + B_: T.Tensor((B, M, N), dtype), + C: T.Tensor((B, M, N), dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + T.ceildiv(B, block_B), + threads=128, + ) as (bx, by, bz): + A_shared = T.alloc_shared((block_B, block_M, block_N), dtype) + B_shared = T.alloc_shared((block_B, block_M, block_N), dtype) + C_shared = T.alloc_fragment((block_B, block_M, block_N), accum_dtype) + + T.annotate_tileview({ + A_shared: make_tileview(A_shared, (32, 32), (-2, -1)), + B_shared: make_tileview(B_shared, (32, 32), (-2, -1)), + C_shared: make_tileview(C_shared, (32, 32), (-2, -1)), + }) + + T.clear(C_shared) + + T.copy( + A[bz * block_B:(bz + 1) * block_B, by * block_M:(by + 1) * block_M, + bx * block_N:(bx + 1) * block_N], + A_shared, + ) + T.copy( + B_[bz * block_B:(bz + 1) * block_B, by * block_M:(by + 1) * block_M, + bx * block_N:(bx + 1) * block_N], + B_shared, + ) + + for b, i, j in T.Tiles(A_shared, parallel=True): + C_shared[b, i, j] = A_shared[b, i, j] * B_shared[b, i, j] + + T.copy( + C_shared, + C[bz * block_B:(bz + 1) * block_B, by * block_M:(by + 1) * block_M, + bx * block_N:(bx + 1) * block_N], + ) + + return main + + +# --------------------------------------------------------- +# Core test: LegalizeTilesLoop (2D + 3D) +# --------------------------------------------------------- +def test_legalize_tiles_loop_attach_tileview_metadata(): + """ + Test that LegalizeTilesLoop correctly: + 1) reads tileview metadata from buffers + 2) attaches tile annotations to tile loops + 3) sets tile.execution only on inner tile loops + """ + + test_cases = [ + # (prim_func, expected_tile_execution_count) + ( + dot_mul_tiled_parallel_2d( + M=512, + N=1024, + block_M=256, + block_N=128, + ), + 2, # (i, j) + ), + ( + dot_mul_tiled_parallel_3d( + B=64, + M=512, + N=1024, + block_B=16, + block_M=256, + block_N=128, + ), + 2, # (b, i, j) -> execution only on i, j + ), + ] + + for prim_func, expected_exec_count in test_cases: + mod = IRModule.from_expr(prim_func.with_attr("global_symbol", "main")) + + # Apply the pass under test + mod = tl.transform.LegalizeTilesLoop()(mod) + + main_func = mod["main"] + + # ------------------------------------------------- + # Collect loop annotations + # ------------------------------------------------- + tile_execution_count = 0 + found_tile_metadata = { + "tile.tile_size": False, + "tile.dim_map": False, + "tile.buffer_new_shape": False, + "tile_level_loop": False, + } + + def visit( + stmt, + found_tile_metadata=found_tile_metadata, + ): + nonlocal tile_execution_count + if isinstance(stmt, tir.For): + ann = stmt.annotations + if ann is not None: + for k in found_tile_metadata: + if k in ann: + found_tile_metadata[k] = True + if "tile.execution" in ann: + nonlocal tile_execution_count + tile_execution_count += 1 + + tvm.tir.stmt_functor.post_order_visit(main_func.body, visit) + + # ------------------------------------------------- + # Assertions + # ------------------------------------------------- + for k, v in found_tile_metadata.items(): + assert v, f"Expected annotation '{k}' not found in tile loops" + + assert ( + tile_execution_count == expected_exec_count + ), f"Expected {expected_exec_count} tile.execution annotations, got {tile_execution_count}" + + +if __name__ == "__main__": + # tilelang.testing.main() + test_legalize_tiles_loop_attach_tileview_metadata() diff --git a/testing/python/transform/test_tilelang_mesh_transform_tiles_loop.py b/testing/python/transform/test_tilelang_mesh_transform_tiles_loop.py new file mode 100644 index 000000000..ba3aa1cd2 --- /dev/null +++ b/testing/python/transform/test_tilelang_mesh_transform_tiles_loop.py @@ -0,0 +1,232 @@ +import pytest +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +from tilelang.tileview import make_tileview +from tvm import tir +from tvm import IRModule + +# ========================================================= +# Helpers: build kernels +# ========================================================= + + +def dot_mul_tiled_parallel_2d( + M, + N, + block_M, + block_N, + tile_size, + index_map, + dtype="float16", + accum_dtype="float16", +): + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128, + ) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) + C_shared = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.annotate_tileview({ + A_shared: make_tileview(A_shared, tile_size, index_map), + B_shared: make_tileview(B_shared, tile_size, index_map), + C_shared: make_tileview(C_shared, tile_size, index_map), + }) + + T.clear(C_shared) + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(B[by * block_M, bx * block_N], B_shared) + + for i, j in T.Tiles(A_shared, parallel=True): + C_shared[i, j] = A_shared[i, j] * B_shared[i, j] + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def dot_mul_tiled_parallel_3d( + Batch, + M, + N, + block_B, + block_M, + block_N, + tile_size, + index_map, + dtype="float16", + accum_dtype="float16", +): + + @T.prim_func + def main( + A: T.Tensor((Batch, M, N), dtype), + B: T.Tensor((Batch, M, N), dtype), + C: T.Tensor((Batch, M, N), dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + T.ceildiv(Batch, block_B), + threads=128, + ) as (bx, by, bz): + A_shared = T.alloc_shared((block_B, block_M, block_N), dtype) + B_shared = T.alloc_shared((block_B, block_M, block_N), dtype) + C_shared = T.alloc_fragment((block_B, block_M, block_N), accum_dtype) + + T.annotate_tileview({ + A_shared: make_tileview(A_shared, tile_size, index_map), + B_shared: make_tileview(B_shared, tile_size, index_map), + C_shared: make_tileview(C_shared, tile_size, index_map), + }) + + T.clear(C_shared) + T.copy( + A[bz * block_B:(bz + 1) * block_B, by * block_M:(by + 1) * block_M, + bx * block_N:(bx + 1) * block_N], + A_shared, + ) + T.copy( + B[bz * block_B:(bz + 1) * block_B, by * block_M:(by + 1) * block_M, + bx * block_N:(bx + 1) * block_N], + B_shared, + ) + + for b, i, j in T.Tiles(A_shared, parallel=True): + C_shared[b, i, j] = A_shared[b, i, j] * B_shared[b, i, j] + + T.copy( + C_shared, + C[bz * block_B:(bz + 1) * block_B, by * block_M:(by + 1) * block_M, + bx * block_N:(bx + 1) * block_N], + ) + + return main + + +# ========================================================= +# Core test: TilesLoop +# ========================================================= + + +@pytest.mark.parametrize( + "prim_func_builder", + [ + # 2D + lambda: dot_mul_tiled_parallel_2d( + M=512, + N=1024, + block_M=256, + block_N=128, + tile_size=(32, 32), + index_map=(-2, -1), + ), + # 3D + lambda: dot_mul_tiled_parallel_3d( + Batch=64, + M=512, + N=1024, + block_B=16, + block_M=256, + block_N=128, + tile_size=(32, 32), + index_map=(-2, -1), + ), + ], +) +def test_tiles_loop_insert_and_index_rewrite(prim_func_builder): + """ + TilesLoop pass contract test. + + Verifies: + 1) tile.execution loops still represent tile counts + 2) serial(tile_size[0]) and vectorized(tile_size[1]) loops + are inserted inside tile.execution loop subtrees + 3) index expressions are rewritten as: + i * tile_size[0] + k + j * tile_size[1] + l + """ + + tile_size = (32, 32) + + mod = IRModule.from_expr(prim_func_builder().with_attr("global_symbol", "main")) + + # Required pipeline + mod = tl.transform.LegalizeTilesLoop()(mod) + mod = tl.transform.TilesLoop()(mod) + + main_func = mod["main"] + + # ----------------------------------------------------- + # 1. Collect tile.execution loops + # ----------------------------------------------------- + tile_exec_loops = [] + + def collect_tile_exec(stmt, tile_exec_loops=tile_exec_loops): + if isinstance(stmt, tir.For): + ann = stmt.annotations + if ann and ann.get("tile.execution", 0) == 1: + tile_exec_loops.append(stmt) + + tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_tile_exec) + + # Only i / j loops should be execution loops + assert len(tile_exec_loops) == 2 + + # ----------------------------------------------------- + # 2. Search each tile.execution subtree for k / l loops + # ----------------------------------------------------- + for exec_loop in tile_exec_loops: + found_serial = [] + found_vectorized = [] + + def visit_subtree( + stmt, + found_serial=found_serial, + found_vectorized=found_vectorized, + ): + if isinstance(stmt, tir.For): + if (stmt.kind == tir.ForKind.SERIAL and isinstance(stmt.extent, tir.IntImm) and + stmt.extent.value == tile_size[0]): + found_serial.append(stmt) + + if (stmt.kind == tir.ForKind.VECTORIZED and isinstance(stmt.extent, tir.IntImm) and + stmt.extent.value == tile_size[1]): + found_vectorized.append(stmt) + + tvm.tir.stmt_functor.post_order_visit(exec_loop.body, visit_subtree) + + assert found_serial, ("Expected serial(tile_size[0]) loop inside tile.execution subtree") + assert found_vectorized, ( + "Expected vectorized(tile_size[1]) loop inside tile.execution subtree") + + # ----------------------------------------------------- + # 3. Pattern check: index rewrite + # ----------------------------------------------------- + index_exprs = [] + + def collect_indices(stmt, index_exprs=index_exprs): + if isinstance(stmt, tir.BufferStore): + index_exprs.extend(stmt.indices) + + tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_indices) + + def contains_mul(expr, factor): + s = str(expr) + return f"* {factor}" in s or f"*{factor}" in s + + assert any(contains_mul(e, tile_size[0]) + for e in index_exprs), "Expected i * tile_size[0] in rewritten indices" + + assert any(contains_mul(e, tile_size[1]) + for e in index_exprs), "Expected j * tile_size[1] in rewritten indices"