Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ Layout makeHierarchicalLayout(Array<Integer> hdims_arr,
namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block
constexpr const char *kLayoutMap = "layout_map";
// BlockAttr, Containing the layout for global (DRAM) buffers.
// Separated from kLayoutMap so that inference/lowering passes do not
// accidentally overwrite or transform these read-only metadata layouts.
constexpr const char *kGlobalLayoutMap = "global_layout_map";
} // namespace attr

} // namespace tl
Expand Down
3 changes: 2 additions & 1 deletion src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,8 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
Array<Range> shared_range = is_load ? dst_range : src_range;
// TMA bulk copy cannot support a non-swizzled global layout, will be fallback
// to normal copy
if (T.layout_map.count(global_tensor)) {
if (T.layout_map.count(global_tensor) ||
T.global_layout_map.count(global_tensor)) {
LOG(WARNING) << "TMA bulk copy cannot support a non-swizzled global "
"layout, fallback to normal copy.";
return LowerNormalCopy(T, analyzer);
Expand Down
2 changes: 2 additions & 0 deletions src/op/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
LayoutMap global_layout_map;
};

struct LayoutInferArgs {
Expand All @@ -48,6 +49,7 @@ struct LayoutInferArgs {
arith::Analyzer *analyzer;
bool buffer_oob = false;
Map<Buffer, Buffer> buffer_remap;
LayoutMap global_layout_map;
};

class TileOperator;
Expand Down
47 changes: 47 additions & 0 deletions src/transform/common/global_layout_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*!
* \file global_layout_utils.h
* \brief Utility functions to extract global buffer layouts from tensor_meta
* attributes for Sunmmio target.
*/

#ifndef TVM_TL_TRANSFORM_COMMON_GLOBAL_LAYOUT_UTILS_H_
#define TVM_TL_TRANSFORM_COMMON_GLOBAL_LAYOUT_UTILS_H_

#include <tvm/target/target.h>
#include <tvm/tir/function.h>

#include "../../layout/layout.h"
#include "../../target/utils.h"

namespace tvm {
namespace tl {

using LayoutMap = Map<tir::Buffer, Layout>;

/*!
* \brief Populate layout_map with global buffer layouts from tensor_meta
* attribute. Only applies when target is Sunmmio.
*
* \param f The PrimFunc containing tensor_meta attribute
* \param target The compilation target
* \param layout_map The layout map to update (in-place)
* \return true if any layouts were added, false otherwise
*/
bool PopulateGlobalBufferLayouts(const tir::PrimFunc &f, Target target,
LayoutMap *layout_map);

/*!
* \brief Parse a single buffer's hierarchical layout from tensor_meta entry
*
* \param meta_entry The metadata dict for one buffer
* \param buffer The buffer to create layout for
* \return Layout object, or nullopt if parsing fails
*/
Optional<Layout>
ParseGlobalBufferLayout(const Map<String, ObjectRef> &meta_entry,
const tir::Buffer &buffer);

} // namespace tl
} // namespace tvm

#endif // TVM_TL_TRANSFORM_COMMON_GLOBAL_LAYOUT_UTILS_H_
117 changes: 117 additions & 0 deletions src/transform/global_layout_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*!
* \file global_layout_utils.cc
* \brief Implementation of utility functions to extract global buffer layouts
* from tensor_meta attributes for Sunmmio target.
*/

#include "common/global_layout_utils.h"

#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

Optional<Layout>
ParseGlobalBufferLayout(const Map<String, ObjectRef> &meta_entry,
const Buffer &buffer) {
// Extract sharded layout info
auto hdims_obj = meta_entry.Get("sharded_hdims");
auto hstrides_obj = meta_entry.Get("sharded_hstrides");
auto hgroups_obj = meta_entry.Get("sharded_hgroups");

if (!hdims_obj || !hstrides_obj || !hgroups_obj) {
return Optional<Layout>();
}

// Convert to arrays for makeHierarchicalLayout
Array<Integer> hdims_arr, hstrides_arr, logical_shape_arr;
Array<Array<Integer>> groups_arr;

// Parse hdims - it's an Array<Integer> from Python
Array<Integer> hdims = Downcast<Array<Integer>>(hdims_obj.value());
for (const auto &dim : hdims) {
hdims_arr.push_back(dim);
}

// Parse hstrides
Array<Integer> hstrides = Downcast<Array<Integer>>(hstrides_obj.value());
for (const auto &stride : hstrides) {
hstrides_arr.push_back(stride);
}

// Parse hgroups - Array<Array<Integer>>
Array<Array<Integer>> hgroups =
Downcast<Array<Array<Integer>>>(hgroups_obj.value());
for (const auto &group : hgroups) {
groups_arr.push_back(group);
}

// Use buffer shape as logical shape
for (size_t i = 0; i < buffer->shape.size(); ++i) {
if (auto *imm = buffer->shape[i].as<IntImmNode>()) {
logical_shape_arr.push_back(Integer(imm->value));
} else {
return Optional<Layout>(); // Dynamic shape not supported
}
}

// Verify that groups_arr matches logical shape dimensions
if (groups_arr.size() != logical_shape_arr.size()) {
return Optional<Layout>();
}

return makeHierarchicalLayout(hdims_arr, hstrides_arr, groups_arr,
logical_shape_arr);
}

bool PopulateGlobalBufferLayouts(const PrimFunc &f, Target target,
LayoutMap *layout_map) {
if (!TargetIsSunmmio(target)) {
return false;
}

auto tensor_meta_opt = f->GetAttr<Map<String, ObjectRef>>("tensor_meta");
if (!tensor_meta_opt) {
return false;
}

auto tensor_meta = tensor_meta_opt.value();
bool any_added = false;

for (const auto &kv : f->buffer_map) {
const Var &var = kv.first;
const Buffer &buffer = kv.second;

if (buffer.scope() != "global") {
continue;
}

String buffer_name = buffer->name;
if (!tensor_meta.count(buffer_name)) {
continue;
}

auto meta_entry_obj = tensor_meta[buffer_name];
auto meta_entry = meta_entry_obj.as<Map<String, ObjectRef>>();
if (!meta_entry.has_value()) {
continue;
}

auto layout_opt = ParseGlobalBufferLayout(meta_entry.value(), buffer);

if (layout_opt) {
layout_map->Set(buffer, layout_opt.value());
any_added = true;
}
}

return any_added;
}

} // namespace tl
} // namespace tvm
26 changes: 21 additions & 5 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/global_layout_utils.h"
#include "common/loop_fusion_utils.h"
#include "common/loop_parallel_transform_utils.h"
#include "common/union_find.h"
Expand Down Expand Up @@ -62,6 +63,7 @@ struct LayoutInferenceResult {
Map<Buffer, Layout> layout_map;
Map<For, Fragment> for_map;
Map<For, PrimExpr> predicate_map;
Map<Buffer, Layout> global_layout_map;
};

class BufferUseDefCollector : public IRVisitorWithAnalyzer {
Expand Down Expand Up @@ -109,10 +111,14 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
"required for layout inference.";

// Run InferLayout
auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
cur_analyzer, buffer_oob},
level);
auto updates = next->InferLayout(LayoutInferArgs{target_,
thread_bounds,
layout_map,
cur_analyzer,
buffer_oob,
{},
global_layout_map_},
level);

// Process the returned updates
for (const auto &[buffer, layout] : updates) {
Expand Down Expand Up @@ -399,7 +405,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
}
}

return {layout_map, for_map, predicate_map};
return {layout_map, for_map, predicate_map, global_layout_map_};
}

void Collect(const PrimFunc &f) {
Expand All @@ -416,6 +422,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
ICHECK(target.defined())
<< "Layout_Inference: Require the target attribute";
target_ = target.value();

// Populate global buffer layouts from tensor_meta (Sunmmio only).
// Stored separately so that inference does not overwrite them.
PopulateGlobalBufferLayouts(f, target_, &global_layout_map_);

this->operator()(f->body);
}

Expand Down Expand Up @@ -830,6 +841,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
std::vector<bool> buffer_oob_vec_;
Target target_;
LayoutMap annotated_layout_map_;
LayoutMap global_layout_map_;
bool skip_thread_partition_{false};

std::vector<TileOperator> BackupInferList() {
Expand Down Expand Up @@ -1045,6 +1057,10 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
}
auto block_ptr = block.CopyOnWrite();
block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
if (!result_.global_layout_map.empty()) {
block_ptr->annotations.Set(attr::kGlobalLayoutMap,
result_.global_layout_map);
}
return block;
}

Expand Down
16 changes: 12 additions & 4 deletions src/transform/lower_tile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
layout_map_.Set(buffer, layout);
}
}
// Read global layout map separately — these are read-only metadata
// and must NOT be processed through makeBufferWithLayout/Forward.
if (op->annotations.count(attr::kGlobalLayoutMap)) {
global_layout_map_ = op->annotations.at(attr::kGlobalLayoutMap)
.as<Map<Buffer, Layout>>()
.value();
}
// Begin a new workspace collection frame for this block scope
workspace_stack_.emplace_back();

Expand Down Expand Up @@ -552,10 +559,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
thread_bounds = Range::FromMinExtent(0, 1);
}

auto lowered =
tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var,
callback, layout_map_, buffer_remap_},
analyzer_);
auto lowered = tile_op->Lower(
LowerArgs{target_, thread_bounds, thread_var_->var, callback,
layout_map_, buffer_remap_, global_layout_map_},
analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered);
}

Expand All @@ -577,6 +584,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
Map<Buffer, Layout> layout_map_;
Map<Buffer, Layout> layout_remap_;
Map<Buffer, Buffer> buffer_remap_;
Map<Buffer, Layout> global_layout_map_;
// This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop.
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
Expand Down
Loading