Skip to content

Commit d140415

Browse files
authored
[Pipeline] Refactor buffer allocation in Inject Pipeline Pass (tile-ai#1525)
* [Feature] Introduce BufferUsageCollector for software pipelining * Added BufferUsageCollector class to identify and collect buffers used in pipeline loop bodies, enabling proper multi-versioning for software pipelining. * Updated PipelineRewriter to handle local and outer block buffer allocations more effectively, ensuring that only necessary buffers are included in the pipeline. * Enhanced buffer remapping logic to prevent conflicts when buffers from outer blocks are used in multiple pipeline loops. This update improves the efficiency and correctness of buffer management during software pipelining. * Refactor buffer allocation declarations in inject_pipeline.cc * Adjusted formatting of buffer allocation declarations for improved readability. * Ensured consistent style in the codebase by aligning variable declarations. This change enhances code clarity without altering functionality. * test fix
1 parent 3c11823 commit d140415

File tree

2 files changed

+199
-12
lines changed

2 files changed

+199
-12
lines changed

examples/gdn/test_example_gdn_compilation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
block_DK = 64
2121
block_DV = 32
2222
threads = 128
23-
num_stages = 1
23+
num_stages = 0
2424

2525

2626
def test_example_wy_fast_compilation():

src/transform/inject_pipeline.cc

Lines changed: 198 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,77 @@ struct LetWrapper {
2626
PrimExpr value;
2727
};
2828

29+
/*!
30+
* \brief Collector to find all buffers used in a statement.
31+
*
32+
* This is used to collect buffers that are actually used in the pipeline loop
33+
* body, so that we can properly multi-version them for software pipelining.
34+
*/
35+
class BufferUsageCollector : public StmtExprVisitor {
36+
public:
37+
BufferUsageCollector(
38+
const Map<Var, Buffer> &buffer_data_to_buffer,
39+
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
40+
&allocated_buffers)
41+
: buffer_data_to_buffer_(buffer_data_to_buffer),
42+
allocated_buffers_(allocated_buffers) {}
43+
44+
Array<Buffer> Collect(const Stmt &stmt) {
45+
this->VisitStmt(stmt);
46+
Array<Buffer> result;
47+
for (const auto &buffer : used_buffers_) {
48+
result.push_back(buffer);
49+
}
50+
return result;
51+
}
52+
53+
private:
54+
void VisitStmt_(const BufferStoreNode *op) final {
55+
AddBuffer(op->buffer);
56+
StmtExprVisitor::VisitStmt_(op);
57+
}
58+
59+
void VisitExpr_(const BufferLoadNode *op) final {
60+
AddBuffer(op->buffer);
61+
StmtExprVisitor::VisitExpr_(op);
62+
}
63+
64+
void VisitExpr_(const CallNode *op) final {
65+
// Handle tvm_access_ptr which also accesses buffers
66+
if (op->op.same_as(builtin::tvm_access_ptr())) {
67+
if (op->args.size() > 1) {
68+
if (const auto *var = op->args[1].as<VarNode>()) {
69+
auto it = buffer_data_to_buffer_.find(GetRef<Var>(var));
70+
if (it != buffer_data_to_buffer_.end()) {
71+
AddBuffer((*it).second);
72+
}
73+
}
74+
}
75+
}
76+
StmtExprVisitor::VisitExpr_(op);
77+
}
78+
79+
void VisitStmt_(const BlockNode *op) final {
80+
// Also collect buffers allocated in nested blocks within the pipeline body
81+
for (const auto &buffer : op->alloc_buffers) {
82+
used_buffers_.insert(buffer);
83+
}
84+
StmtExprVisitor::VisitStmt_(op);
85+
}
86+
87+
void AddBuffer(const Buffer &buffer) {
88+
// Only add buffers that are allocated (not function input/output buffers)
89+
if (allocated_buffers_.count(buffer)) {
90+
used_buffers_.insert(buffer);
91+
}
92+
}
93+
94+
const Map<Var, Buffer> &buffer_data_to_buffer_;
95+
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
96+
&allocated_buffers_;
97+
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> used_buffers_;
98+
};
99+
29100
/*!
30101
* \brief Create a block and infer the access region with the given body.
31102
*
@@ -219,13 +290,28 @@ class PipelineBodyRewriter : public StmtExprMutator {
219290
*/
220291
class PipelineRewriter : public StmtExprMutator {
221292
public:
293+
/*!
294+
* \brief Constructor of PipelineRewriter.
295+
* \param buffer_data_to_buffer The map from buffer data to buffer.
296+
* \param pipeline_allocs All buffers that need multi-versioning in the
297+
* pipeline. This includes buffers allocated in the pipeline block and
298+
* buffers allocated in outer blocks that are used in the pipeline.
299+
* \param local_allocs Buffers that are allocated in the pipeline block
300+
* itself. These buffers will be re-allocated in the rewritten block.
301+
* Buffers in pipeline_allocs but not in local_allocs are allocated in outer
302+
* blocks and should not be re-allocated.
303+
* \param pipeline_loop The original loop to be software pipelined.
304+
* \param pipeline_info The pipeline annotation information.
305+
* \param loop_var_let_wrappers Let wrappers that depend on the loop var.
306+
*/
222307
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
223308
const Array<Buffer> &pipeline_allocs,
224-
const For &pipeline_loop, const PipelineInfo &pipeline_info,
309+
const Array<Buffer> &local_allocs, const For &pipeline_loop,
310+
const PipelineInfo &pipeline_info,
225311
const std::vector<LetWrapper> &loop_var_let_wrappers)
226312
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
227-
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
228-
pipeline_info_(pipeline_info),
313+
pipeline_allocs_(pipeline_allocs), local_allocs_(local_allocs),
314+
pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info),
229315
loop_var_let_wrappers_(loop_var_let_wrappers) {}
230316

231317
Stmt BuildPipeline() {
@@ -234,7 +320,12 @@ class PipelineRewriter : public StmtExprMutator {
234320
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
235321
infos = GetBufferAccessInfo();
236322
for (const Buffer &buffer : pipeline_allocs_) {
237-
int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
323+
auto it = infos.find(buffer);
324+
if (it == infos.end()) {
325+
// Buffer is not accessed in the pipeline blocks, skip it
326+
continue;
327+
}
328+
int num_versions = ComputeBufferVersions(buffer, it->second);
238329
if (num_versions > 1) {
239330
buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
240331
}
@@ -302,8 +393,12 @@ class PipelineRewriter : public StmtExprMutator {
302393

303394
// Step 3: Make a new block that contains new buffer allocations after
304395
// pipeline rewriting.
396+
// Only include buffers that are locally allocated in the pipeline block.
397+
// Buffers from outer blocks will be handled separately.
305398
Array<Buffer> alloc_buffers;
306-
for (const auto &alloc : pipeline_allocs_) {
399+
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> local_allocs_set(
400+
local_allocs_.begin(), local_allocs_.end());
401+
for (const auto &alloc : local_allocs_) {
307402
alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
308403
buffer_data_to_buffer_.erase(alloc->data);
309404
}
@@ -312,6 +407,12 @@ class PipelineRewriter : public StmtExprMutator {
312407
return BlockRealize({}, Bool(true), block);
313408
}
314409

410+
/*!
411+
* \brief Get the buffer remapping created during pipeline rewriting.
412+
* This is used to update alloc_buffers in outer blocks.
413+
*/
414+
const Map<Buffer, Buffer> &GetBufferRemap() const { return buffer_remap_; }
415+
315416
private:
316417
/*!
317418
* \brief Analyze accesses to the buffers in the software pipeline.
@@ -804,6 +905,7 @@ class PipelineRewriter : public StmtExprMutator {
804905
arith::Analyzer analyzer_;
805906
Map<Var, Buffer> buffer_data_to_buffer_;
806907
Array<Buffer> pipeline_allocs_;
908+
Array<Buffer> local_allocs_;
807909
For pipeline_loop_;
808910
PipelineInfo pipeline_info_;
809911
int max_stage_ = -1;
@@ -923,14 +1025,17 @@ class PipelineInjector : private StmtExprMutator {
9231025
Stmt pipeline_body_root{nullptr};
9241026
bool pipeline_body_from_block = false;
9251027
Array<Buffer> pipeline_allocs;
1028+
Array<Buffer>
1029+
block_local_allocs; // buffers allocated in the pipeline block itself
9261030
if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
9271031
const auto &block = realize->block;
9281032
for (const auto &buffer : block->alloc_buffers) {
9291033
ICHECK(buffer->IsInstance<BufferNode>());
9301034
buffer_data_to_buffer_.Set(buffer->data, buffer);
1035+
allocated_buffers_.insert(buffer);
9311036
}
9321037
pipeline_body_root = block->body;
933-
pipeline_allocs = block->alloc_buffers;
1038+
block_local_allocs = block->alloc_buffers;
9341039
pipeline_body_from_block = true;
9351040
} else {
9361041
pipeline_body_root = for_node->body;
@@ -1021,13 +1126,49 @@ class PipelineInjector : private StmtExprMutator {
10211126
ICHECK(nested_pipeline_block->match_buffers
10221127
.empty()); // match_buffer should have been lowered
10231128
for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
1024-
pipeline_allocs.push_back(buffer);
10251129
buffer_data_to_buffer_.Set(buffer->data, buffer);
1130+
allocated_buffers_.insert(buffer);
10261131
}
10271132
}
10281133
f_add_child(child);
10291134
}
10301135

1136+
// Collect all buffers that are actually used in the pipeline loop body.
1137+
// This includes buffers allocated in outer blocks (like logits_smem) that
1138+
// are used inside the pipeline loop.
1139+
BufferUsageCollector collector(buffer_data_to_buffer_, allocated_buffers_);
1140+
pipeline_allocs = collector.Collect(SeqStmt(pipeline_body_seq->seq));
1141+
1142+
// Build a set of local allocs (buffers allocated in the pipeline block
1143+
// itself) for efficient lookup
1144+
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> local_allocs_set;
1145+
for (const auto &buffer : block_local_allocs) {
1146+
local_allocs_set.insert(buffer);
1147+
}
1148+
for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
1149+
const Stmt &child = pipeline_body_seq->seq[i];
1150+
const auto *nested_block_realize = child.as<BlockRealizeNode>();
1151+
if (nested_block_realize && is_one(nested_block_realize->predicate) &&
1152+
nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
1153+
for (const auto &buffer : nested_block_realize->block->alloc_buffers) {
1154+
local_allocs_set.insert(buffer);
1155+
}
1156+
}
1157+
}
1158+
1159+
// Check if any external buffer (from outer blocks) is already used in
1160+
// another pipeline. This would cause conflicts in multi-versioning.
1161+
for (const auto &buffer : pipeline_allocs) {
1162+
// Only check external buffers (not locally allocated in this pipeline)
1163+
if (local_allocs_set.count(buffer) == 0) {
1164+
CHECK(buffers_used_in_pipeline_.count(buffer) == 0)
1165+
<< "Buffer '" << buffer->name
1166+
<< "' is used in multiple software pipeline loops. "
1167+
<< "This is not supported because multi-versioning would conflict.";
1168+
buffers_used_in_pipeline_.insert(buffer);
1169+
}
1170+
}
1171+
10311172
auto pipeline_stages = Downcast<Array<Integer>>(
10321173
op->annotations.at(tir::attr::software_pipeline_stage));
10331174
auto pipeline_orders = Downcast<Array<Integer>>(
@@ -1067,10 +1208,32 @@ class PipelineInjector : private StmtExprMutator {
10671208
ValidatePipelineBody(pipeline_info, original_order);
10681209

10691210
// Step 4: Rewrite the pipeline body.
1070-
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
1071-
tvm::ffi::GetRef<For>(op), pipeline_info,
1072-
loop_var_let_wrappers)
1073-
.BuildPipeline();
1211+
// local_allocs contains buffers allocated in the pipeline block itself.
1212+
// pipeline_allocs contains all buffers that need multi-versioning,
1213+
// including buffers from outer blocks.
1214+
Array<Buffer> local_allocs = block_local_allocs;
1215+
// Add nested block allocs to local_allocs
1216+
for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
1217+
const Stmt &child = pipeline_body_seq->seq[i];
1218+
const auto *nested_block_realize = child.as<BlockRealizeNode>();
1219+
if (nested_block_realize && is_one(nested_block_realize->predicate) &&
1220+
nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
1221+
const Block &nested_pipeline_block = nested_block_realize->block;
1222+
for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
1223+
local_allocs.push_back(buffer);
1224+
}
1225+
}
1226+
}
1227+
1228+
PipelineRewriter rewriter(buffer_data_to_buffer_, pipeline_allocs,
1229+
local_allocs, tvm::ffi::GetRef<For>(op),
1230+
pipeline_info, loop_var_let_wrappers);
1231+
Stmt pipeline = rewriter.BuildPipeline();
1232+
1233+
// Store the buffer remapping for updating outer block alloc_buffers
1234+
for (const auto &kv : rewriter.GetBufferRemap()) {
1235+
pending_buffer_remap_.Set(kv.first, kv.second);
1236+
}
10741237
auto apply_wrappers = [&](Stmt stmt) {
10751238
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
10761239
stmt = (*it)(stmt);
@@ -1097,6 +1260,7 @@ class PipelineInjector : private StmtExprMutator {
10971260
const auto &block = realize->block;
10981261
for (const auto &buffer : block->alloc_buffers) {
10991262
buffer_data_to_buffer_.erase(buffer->data);
1263+
allocated_buffers_.erase(buffer);
11001264
}
11011265
}
11021266
return pipeline;
@@ -1105,18 +1269,35 @@ class PipelineInjector : private StmtExprMutator {
11051269
Stmt VisitStmt_(const BlockNode *op) final {
11061270
for (const auto &buffer : op->alloc_buffers) {
11071271
buffer_data_to_buffer_.Set(buffer->data, buffer);
1272+
allocated_buffers_.insert(buffer);
11081273
}
11091274

11101275
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
11111276

1277+
// Update alloc_buffers with any pending buffer remaps from pipeline
1278+
// rewriting. This handles buffers allocated in this block but
1279+
// multi-versioned during pipeline rewriting of inner loops.
1280+
Array<Buffer> new_alloc_buffers;
1281+
for (const auto &buffer : block->alloc_buffers) {
1282+
if (auto remapped = pending_buffer_remap_.Get(buffer)) {
1283+
new_alloc_buffers.push_back(remapped.value());
1284+
// Remove from pending after applying
1285+
pending_buffer_remap_.erase(buffer);
1286+
} else {
1287+
new_alloc_buffers.push_back(buffer);
1288+
}
1289+
}
1290+
11121291
Array<Array<BufferRegion>> access =
11131292
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
11141293
BlockNode *n = block.CopyOnWrite();
11151294
n->reads = access[0];
11161295
n->writes = access[1];
1296+
n->alloc_buffers = std::move(new_alloc_buffers);
11171297

11181298
for (const auto &buffer : op->alloc_buffers) {
11191299
buffer_data_to_buffer_.erase(buffer->data);
1300+
allocated_buffers_.erase(buffer);
11201301
}
11211302
return block;
11221303
}
@@ -1141,6 +1322,12 @@ class PipelineInjector : private StmtExprMutator {
11411322
}
11421323

11431324
Map<Var, Buffer> buffer_data_to_buffer_;
1325+
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocated_buffers_;
1326+
Map<Buffer, Buffer> pending_buffer_remap_;
1327+
// Buffers from outer blocks that have been used in a pipeline loop.
1328+
// Used to detect if the same buffer is used in multiple pipeline loops.
1329+
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
1330+
buffers_used_in_pipeline_;
11441331
Optional<String> global_symbol_;
11451332
};
11461333
} // namespace software_pipeline

0 commit comments

Comments
 (0)