From e64ba4c7fd53d8a54bb0fd6e90433bb49780ef50 Mon Sep 17 00:00:00 2001 From: Jiaqi Guo Date: Wed, 4 Feb 2026 20:34:48 +0800 Subject: [PATCH 1/2] Add layout of global buffer --- src/transform/common/global_layout_utils.h | 47 +++ src/transform/global_layout_utils.cc | 117 +++++++ src/transform/layout_inference.cc | 5 + ..._tilelang_transform_global_layout_utils.py | 299 ++++++++++++++++++ tilelang/language/v2/annot.py | 18 +- 5 files changed, 485 insertions(+), 1 deletion(-) create mode 100644 src/transform/common/global_layout_utils.h create mode 100644 src/transform/global_layout_utils.cc create mode 100644 testing/python/transform/test_tilelang_transform_global_layout_utils.py diff --git a/src/transform/common/global_layout_utils.h b/src/transform/common/global_layout_utils.h new file mode 100644 index 000000000..bd80055b0 --- /dev/null +++ b/src/transform/common/global_layout_utils.h @@ -0,0 +1,47 @@ +/*! + * \file global_layout_utils.h + * \brief Utility functions to extract global buffer layouts from tensor_meta + * attributes for Sunmmio target. + */ + +#ifndef TVM_TL_TRANSFORM_COMMON_GLOBAL_LAYOUT_UTILS_H_ +#define TVM_TL_TRANSFORM_COMMON_GLOBAL_LAYOUT_UTILS_H_ + +#include +#include + +#include "../../layout/layout.h" +#include "../../target/utils.h" + +namespace tvm { +namespace tl { + +using LayoutMap = Map; + +/*! + * \brief Populate layout_map with global buffer layouts from tensor_meta + * attribute. Only applies when target is Sunmmio. + * + * \param f The PrimFunc containing tensor_meta attribute + * \param target The compilation target + * \param layout_map The layout map to update (in-place) + * \return true if any layouts were added, false otherwise + */ +bool PopulateGlobalBufferLayouts(const tir::PrimFunc &f, Target target, + LayoutMap *layout_map); + +/*! + * \brief Parse a single buffer's hierarchical layout from tensor_meta entry + * + * \param meta_entry The metadata dict for one buffer + * \param buffer The buffer to create layout for + * \return Layout object, or nullopt if parsing fails + */ +Optional +ParseGlobalBufferLayout(const Map &meta_entry, + const tir::Buffer &buffer); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_COMMON_GLOBAL_LAYOUT_UTILS_H_ diff --git a/src/transform/global_layout_utils.cc b/src/transform/global_layout_utils.cc new file mode 100644 index 000000000..69d719ca6 --- /dev/null +++ b/src/transform/global_layout_utils.cc @@ -0,0 +1,117 @@ +/*! + * \file global_layout_utils.cc + * \brief Implementation of utility functions to extract global buffer layouts + * from tensor_meta attributes for Sunmmio target. + */ + +#include "common/global_layout_utils.h" + +#include + +#include "../layout/layout.h" +#include "../target/utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +Optional +ParseGlobalBufferLayout(const Map &meta_entry, + const Buffer &buffer) { + // Extract sharded layout info + auto hdims_obj = meta_entry.Get("sharded_hdims"); + auto hstrides_obj = meta_entry.Get("sharded_hstrides"); + auto hgroups_obj = meta_entry.Get("sharded_hgroups"); + + if (!hdims_obj || !hstrides_obj || !hgroups_obj) { + return Optional(); + } + + // Convert to arrays for makeHierarchicalLayout + Array hdims_arr, hstrides_arr, logical_shape_arr; + Array> groups_arr; + + // Parse hdims - it's an Array from Python + Array hdims = Downcast>(hdims_obj.value()); + for (const auto &dim : hdims) { + hdims_arr.push_back(dim); + } + + // Parse hstrides + Array hstrides = Downcast>(hstrides_obj.value()); + for (const auto &stride : hstrides) { + hstrides_arr.push_back(stride); + } + + // Parse hgroups - Array> + Array> hgroups = + Downcast>>(hgroups_obj.value()); + for (const auto &group : hgroups) { + groups_arr.push_back(group); + } + + // Use buffer shape as logical shape + for (size_t i = 0; i < buffer->shape.size(); ++i) { + if (auto *imm = buffer->shape[i].as()) { + logical_shape_arr.push_back(Integer(imm->value)); + } else { + return Optional(); // Dynamic shape not supported + } + } + + // Verify that groups_arr matches logical shape dimensions + if (groups_arr.size() != logical_shape_arr.size()) { + return Optional(); + } + + return makeHierarchicalLayout(hdims_arr, hstrides_arr, groups_arr, + logical_shape_arr); +} + +bool PopulateGlobalBufferLayouts(const PrimFunc &f, Target target, + LayoutMap *layout_map) { + if (!TargetIsSunmmio(target)) { + return false; + } + + auto tensor_meta_opt = f->GetAttr>("tensor_meta"); + if (!tensor_meta_opt) { + return false; + } + + auto tensor_meta = tensor_meta_opt.value(); + bool any_added = false; + + for (const auto &kv : f->buffer_map) { + const Var &var = kv.first; + const Buffer &buffer = kv.second; + + if (buffer.scope() != "global") { + continue; + } + + String buffer_name = buffer->name; + if (!tensor_meta.count(buffer_name)) { + continue; + } + + auto meta_entry_obj = tensor_meta[buffer_name]; + auto meta_entry = meta_entry_obj.as>(); + if (!meta_entry.has_value()) { + continue; + } + + auto layout_opt = ParseGlobalBufferLayout(meta_entry.value(), buffer); + + if (layout_opt) { + layout_map->Set(buffer, layout_opt.value()); + any_added = true; + } + } + + return any_added; +} + +} // namespace tl +} // namespace tvm diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index e505bc6ea..ae89f01c5 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -23,6 +23,7 @@ #include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h" +#include "common/global_layout_utils.h" #include "common/loop_fusion_utils.h" #include "common/loop_parallel_transform_utils.h" #include "common/union_find.h" @@ -416,6 +417,10 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { ICHECK(target.defined()) << "Layout_Inference: Require the target attribute"; target_ = target.value(); + + // Populate global buffer layouts from tensor_meta (Sunmmio only) + PopulateGlobalBufferLayouts(f, target_, &annotated_layout_map_); + this->operator()(f->body); } diff --git a/testing/python/transform/test_tilelang_transform_global_layout_utils.py b/testing/python/transform/test_tilelang_transform_global_layout_utils.py new file mode 100644 index 000000000..280425e28 --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_global_layout_utils.py @@ -0,0 +1,299 @@ +""" +Test global buffer layout extraction from tensor_meta for Sunmmio target. + +This tests the C++ implementation in: +- src/transform/global_layout_utils.cc +- Integration in src/transform/layout_inference.cc +""" + +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target, SUNMMIO_TARGET_DESC, target_is_sunmmio +import tilelang as tl +import tilelang.language as T +from tilelang.language.v2.annot import MeshShardingPolicy, MeshReplicationType +from tvm import tir +from tvm.tir import PyStmtExprVisitor +from tvm.tir.transform import prim_func_pass +from tvm.target import Target + +# Global dict to collect layout_map from block annotations +collected_layout_map = {} + + +@tir.functor.visitor +class _LayoutMapCollector(PyStmtExprVisitor): + """Visitor to extract layout_map from block annotations after LayoutInference.""" + + def __init__(self): + super().__init__() + + def visit_block_(self, op: tir.Block) -> None: + if "layout_map" in op.annotations: + layout_map = op.annotations["layout_map"] + collected_layout_map.clear() + for key, layout in layout_map.items(): + # key is a Buffer, use its name as dict key + collected_layout_map[key.name] = layout + + +def CollectLayoutMap(): + """TIR pass to collect layout_map from block annotations.""" + + def pass_fn(func: tir.PrimFunc, mod, ctx): + _LayoutMapCollector().visit_stmt(func.body) + return func + + return prim_func_pass(pass_fn, opt_level=0) + + +def test_sunmmio_target_detection(): + """Verify Sunmmio target detection works correctly.""" + target = Target(SUNMMIO_TARGET_DESC) + assert target_is_sunmmio(target), "Should detect Sunmmio target" + + cuda_target = Target("cuda") + assert not target_is_sunmmio(cuda_target), "CUDA should not be detected as Sunmmio" + + +def test_global_buffer_layout_populated_for_sunmmio(): + """ + Test that global buffer layouts from tensor_meta are populated into layout_map + during LayoutInference pass for Sunmmio target. + + Uses a GEMM-style kernel that triggers proper layout inference for shared buffers. + """ + policy = MeshShardingPolicy(y=0, x=1, replicate=MeshReplicationType.NONE) + device_mesh = (2, 2) + + # Use shapes that match the test requirements + M, N, K = 64, 64, 64 + block_M, block_N, block_K = 32, 32, 32 + + # Simple row-major hierarchical layout + A_hdims = (64, 64) + A_hgroups = ((0, 1), (1, 2)) + A_hstrides = (64, 1) + + A_tensor = T.MeshTensor( + (M, K), + policy, + device_mesh, + dtype="float16", + hierarchical_dims=A_hdims, + hierarchical_groups=A_hgroups, + hierarchical_strides=A_hstrides, + ) + + B_tensor = T.MeshTensor( + (K, N), + policy, + device_mesh, + dtype="float16", + hierarchical_dims=(64, 64), + hierarchical_groups=((0, 1), (1, 2)), + hierarchical_strides=(64, 1), + ) + + C_tensor = T.MeshTensor( + (M, N), + policy, + device_mesh, + dtype="float32", + hierarchical_dims=(64, 64), + hierarchical_groups=((0, 1), (1, 2)), + hierarchical_strides=(64, 1), + ) + + @T.prim_func + def kernel( + A: A_tensor, + B: B_tensor, + C: C_tensor, + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), "float16") + B_shared = T.alloc_shared((block_K, block_N), "float16") + C_shared = T.alloc_shared((block_M, block_N), "float32") + + T.clear(C_shared) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + # Verify tensor_meta exists before compilation + assert "tensor_meta" in kernel.attrs, "Kernel should have tensor_meta attribute" + + # Get Sunmmio target + target = determine_target("Sunmmio", return_object=True) + + # Create IR module and run passes + mod = tvm.IRModule({'main': kernel}) + + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tl.transform.LayoutInference()(mod) + CollectLayoutMap()(mod) + + # Verify that global buffer 'A' has a layout in the layout_map + assert "A" in collected_layout_map, \ + f"Global buffer 'A' should be in layout_map after LayoutInference. Got: {list(collected_layout_map.keys())}" + + # Verify the layout is a Layout object (hierarchical layout) + a_layout = collected_layout_map["A"] + assert a_layout is not None, "Layout for 'A' should not be None" + + # The layout should have 2 input dimensions matching the buffer shape + assert len(a_layout.input_size) == 2, f"Expected 2 input dims, got {len(a_layout.input_size)}" + + +def test_global_buffer_layout_not_populated_for_cuda(): + """ + Test that global buffer layouts are NOT populated for non-Sunmmio targets (CUDA). + The PopulateGlobalBufferLayouts function should return early for CUDA. + + Uses a simple copy kernel to avoid GEMM instruction selection issues. + """ + M, N = 64, 64 + block_M, block_N = 32, 32 + + @T.prim_func + def kernel( + A: T.Tensor((M, N), "float16"), + B: T.Tensor((M, N), "float16"), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), "float16") + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(A_shared, B[by * block_M, bx * block_N]) + + # Use CUDA target + target = Target("cuda") + + mod = tvm.IRModule({'main': kernel}) + + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tl.transform.LayoutInference()(mod) + CollectLayoutMap()(mod) + + # For CUDA without MeshTensor, global buffer 'A' should NOT be in layout_map + # (only fragment/shared buffers get layouts inferred) + # This verifies that our code path for Sunmmio is not triggered for CUDA + assert "A" not in collected_layout_map, \ + "Global buffer 'A' should NOT be in layout_map for CUDA target" + + +def test_hierarchical_layout_values(): + """ + Test that the hierarchical layout created from tensor_meta produces + correct forward index mapping. + """ + policy = MeshShardingPolicy(y=0, x=1, replicate=MeshReplicationType.NONE) + device_mesh = (2, 2) + + M, N, K = 64, 64, 64 + block_M, block_N, block_K = 32, 32, 32 + + # Simple row-major style hierarchical layout + # After sharding by 2x2: sharded_shape = (32, 32) + # sharded_hdims = (32, 32), sharded_hstrides = (32, 1) + hdims = (64, 64) + hgroups = ((0, 1), (1, 2)) + hstrides = (64, 1) # row-major + + A_tensor = T.MeshTensor( + (M, K), + policy, + device_mesh, + dtype="float16", + hierarchical_dims=hdims, + hierarchical_groups=hgroups, + hierarchical_strides=hstrides, + ) + + B_tensor = T.MeshTensor( + (K, N), + policy, + device_mesh, + dtype="float16", + hierarchical_dims=(64, 64), + hierarchical_groups=((0, 1), (1, 2)), + hierarchical_strides=(64, 1), + ) + + C_tensor = T.MeshTensor( + (M, N), + policy, + device_mesh, + dtype="float32", + hierarchical_dims=(64, 64), + hierarchical_groups=((0, 1), (1, 2)), + hierarchical_strides=(64, 1), + ) + + @T.prim_func + def kernel( + A: A_tensor, + B: B_tensor, + C: C_tensor, + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), "float16") + B_shared = T.alloc_shared((block_K, block_N), "float16") + C_shared = T.alloc_shared((block_M, block_N), "float32") + + T.clear(C_shared) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + target = determine_target("Sunmmio", return_object=True) + mod = tvm.IRModule({'main': kernel}) + + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tl.transform.LayoutInference()(mod) + CollectLayoutMap()(mod) + + assert "A" in collected_layout_map, "Global buffer 'A' should be in layout_map" + + a_layout = collected_layout_map["A"] + + # Verify the layout computes correct physical offsets + # For row-major with sharded_strides (32, 1): + # offset(i,j) = i * 32 + j * 1 + + # Test a few index mappings + offset_0_0 = a_layout.map_forward_index([0, 0]) + offset_0_1 = a_layout.map_forward_index([0, 1]) + offset_1_0 = a_layout.map_forward_index([1, 0]) + + # offset(0,0) = 0 + # offset(0,1) = 1 + # offset(1,0) = 32 + assert offset_0_0[0] == 0, f"Expected offset(0,0)=0, got {offset_0_0[0]}" + assert offset_0_1[0] == 1, f"Expected offset(0,1)=1, got {offset_0_1[0]}" + assert offset_1_0[0] == 32, f"Expected offset(1,0)=32, got {offset_1_0[0]}" + + +if __name__ == "__main__": + test_sunmmio_target_detection() + print("PASSED: test_sunmmio_target_detection") + + test_global_buffer_layout_populated_for_sunmmio() + print("PASSED: test_global_buffer_layout_populated_for_sunmmio") + + test_global_buffer_layout_not_populated_for_cuda() + print("PASSED: test_global_buffer_layout_not_populated_for_cuda") + + test_hierarchical_layout_values() + print("PASSED: test_hierarchical_layout_values") + + print("\nAll tests passed!") diff --git a/tilelang/language/v2/annot.py b/tilelang/language/v2/annot.py index f466fc7f6..b53458425 100644 --- a/tilelang/language/v2/annot.py +++ b/tilelang/language/v2/annot.py @@ -938,10 +938,26 @@ def from_sig_annots(cls, sig: inspect.Signature, func_annots: dict[str, Any]) -> return FuncAnnot(sig, arg_names, annots, arg_parser, ker_arg_names) def get_metadata(self): + """Get metadata dict with values converted to TIR types for C++ FFI compatibility.""" + + def convert_to_tir(value): + """Convert Python values to TIR types recursively.""" + if isinstance(value, int): + # Use int32 to match buffer shape dtype used in makeHierarchicalLayout + return tir.IntImm("int32", value) + elif isinstance(value, tuple): + return tuple(convert_to_tir(v) for v in value) + elif isinstance(value, list): + return [convert_to_tir(v) for v in value] + elif isinstance(value, dict): + return {k: convert_to_tir(v) for k, v in value.items()} + return value + meta = {} for name, annot in self.annots.items(): if isinstance(annot, TensorWithMetaAnnot): - meta[name] = annot.data.meta_data + # Convert the metadata to TIR types + meta[name] = convert_to_tir(annot.data.meta_data) return meta def parse_key(self, *args, **kws): From c2b04d566ed44204a167e2b9b95bc0eee0abe83a Mon Sep 17 00:00:00 2001 From: Jiaqi Guo Date: Tue, 10 Feb 2026 13:05:24 +0800 Subject: [PATCH 2/2] Define global_layout_map --- src/layout/layout.h | 4 +++ src/op/copy.cc | 3 ++- src/op/operator.h | 2 ++ src/transform/layout_inference.cc | 25 +++++++++++++------ src/transform/lower_tile_op.cc | 16 +++++++++--- ..._tilelang_transform_global_layout_utils.py | 4 +++ 6 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/layout/layout.h b/src/layout/layout.h index 23c3f7445..c7680da94 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -249,6 +249,10 @@ Layout makeHierarchicalLayout(Array hdims_arr, namespace attr { // BlockAttr, Containing the layout for all the buffers in the block constexpr const char *kLayoutMap = "layout_map"; +// BlockAttr, Containing the layout for global (DRAM) buffers. +// Separated from kLayoutMap so that inference/lowering passes do not +// accidentally overwrite or transform these read-only metadata layouts. +constexpr const char *kGlobalLayoutMap = "global_layout_map"; } // namespace attr } // namespace tl diff --git a/src/op/copy.cc b/src/op/copy.cc index 72e73e162..22cb4a074 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1407,7 +1407,8 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, Array shared_range = is_load ? dst_range : src_range; // TMA bulk copy cannot support a non-swizzled global layout, will be fallback // to normal copy - if (T.layout_map.count(global_tensor)) { + if (T.layout_map.count(global_tensor) || + T.global_layout_map.count(global_tensor)) { LOG(WARNING) << "TMA bulk copy cannot support a non-swizzled global " "layout, fallback to normal copy."; return LowerNormalCopy(T, analyzer); diff --git a/src/op/operator.h b/src/op/operator.h index 1453f9c1e..3490eede6 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -39,6 +39,7 @@ struct LowerArgs { AddWorkspaceCallback AddWorkspace; LayoutMap layout_map; Map buffer_remap; + LayoutMap global_layout_map; }; struct LayoutInferArgs { @@ -48,6 +49,7 @@ struct LayoutInferArgs { arith::Analyzer *analyzer; bool buffer_oob = false; Map buffer_remap; + LayoutMap global_layout_map; }; class TileOperator; diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index ae89f01c5..b67aad852 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -63,6 +63,7 @@ struct LayoutInferenceResult { Map layout_map; Map for_map; Map predicate_map; + Map global_layout_map; }; class BufferUseDefCollector : public IRVisitorWithAnalyzer { @@ -110,10 +111,14 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { "required for layout inference."; // Run InferLayout - auto updates = - next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, - cur_analyzer, buffer_oob}, - level); + auto updates = next->InferLayout(LayoutInferArgs{target_, + thread_bounds, + layout_map, + cur_analyzer, + buffer_oob, + {}, + global_layout_map_}, + level); // Process the returned updates for (const auto &[buffer, layout] : updates) { @@ -400,7 +405,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } } - return {layout_map, for_map, predicate_map}; + return {layout_map, for_map, predicate_map, global_layout_map_}; } void Collect(const PrimFunc &f) { @@ -418,8 +423,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { << "Layout_Inference: Require the target attribute"; target_ = target.value(); - // Populate global buffer layouts from tensor_meta (Sunmmio only) - PopulateGlobalBufferLayouts(f, target_, &annotated_layout_map_); + // Populate global buffer layouts from tensor_meta (Sunmmio only). + // Stored separately so that inference does not overwrite them. + PopulateGlobalBufferLayouts(f, target_, &global_layout_map_); this->operator()(f->body); } @@ -835,6 +841,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { std::vector buffer_oob_vec_; Target target_; LayoutMap annotated_layout_map_; + LayoutMap global_layout_map_; bool skip_thread_partition_{false}; std::vector BackupInferList() { @@ -1050,6 +1057,10 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { } auto block_ptr = block.CopyOnWrite(); block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map); + if (!result_.global_layout_map.empty()) { + block_ptr->annotations.Set(attr::kGlobalLayoutMap, + result_.global_layout_map); + } return block; } diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index e66fb106b..4089c600a 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -162,6 +162,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { layout_map_.Set(buffer, layout); } } + // Read global layout map separately — these are read-only metadata + // and must NOT be processed through makeBufferWithLayout/Forward. + if (op->annotations.count(attr::kGlobalLayoutMap)) { + global_layout_map_ = op->annotations.at(attr::kGlobalLayoutMap) + .as>() + .value(); + } // Begin a new workspace collection frame for this block scope workspace_stack_.emplace_back(); @@ -552,10 +559,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { thread_bounds = Range::FromMinExtent(0, 1); } - auto lowered = - tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var, - callback, layout_map_, buffer_remap_}, - analyzer_); + auto lowered = tile_op->Lower( + LowerArgs{target_, thread_bounds, thread_var_->var, callback, + layout_map_, buffer_remap_, global_layout_map_}, + analyzer_); return IRMutatorWithAnalyzer::VisitStmt(lowered); } @@ -577,6 +584,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { Map layout_map_; Map layout_remap_; Map buffer_remap_; + Map global_layout_map_; // This is a workaround for cpu backend, // we need to define a thread_var for the serial loop. IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), diff --git a/testing/python/transform/test_tilelang_transform_global_layout_utils.py b/testing/python/transform/test_tilelang_transform_global_layout_utils.py index 280425e28..923d6cabd 100644 --- a/testing/python/transform/test_tilelang_transform_global_layout_utils.py +++ b/testing/python/transform/test_tilelang_transform_global_layout_utils.py @@ -34,6 +34,10 @@ def visit_block_(self, op: tir.Block) -> None: for key, layout in layout_map.items(): # key is a Buffer, use its name as dict key collected_layout_map[key.name] = layout + if "global_layout_map" in op.annotations: + global_layout_map = op.annotations["global_layout_map"] + for key, layout in global_layout_map.items(): + collected_layout_map[key.name] = layout def CollectLayoutMap():