@@ -26,6 +26,77 @@ struct LetWrapper {
2626 PrimExpr value;
2727};
2828
29+ /* !
30+ * \brief Collector to find all buffers used in a statement.
31+ *
32+ * This is used to collect buffers that are actually used in the pipeline loop
33+ * body, so that we can properly multi-version them for software pipelining.
34+ */
35+ class BufferUsageCollector : public StmtExprVisitor {
36+ public:
37+ BufferUsageCollector (
38+ const Map<Var, Buffer> &buffer_data_to_buffer,
39+ const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
40+ &allocated_buffers)
41+ : buffer_data_to_buffer_(buffer_data_to_buffer),
42+ allocated_buffers_ (allocated_buffers) {}
43+
44+ Array<Buffer> Collect (const Stmt &stmt) {
45+ this ->VisitStmt (stmt);
46+ Array<Buffer> result;
47+ for (const auto &buffer : used_buffers_) {
48+ result.push_back (buffer);
49+ }
50+ return result;
51+ }
52+
53+ private:
54+ void VisitStmt_ (const BufferStoreNode *op) final {
55+ AddBuffer (op->buffer );
56+ StmtExprVisitor::VisitStmt_ (op);
57+ }
58+
59+ void VisitExpr_ (const BufferLoadNode *op) final {
60+ AddBuffer (op->buffer );
61+ StmtExprVisitor::VisitExpr_ (op);
62+ }
63+
64+ void VisitExpr_ (const CallNode *op) final {
65+ // Handle tvm_access_ptr which also accesses buffers
66+ if (op->op .same_as (builtin::tvm_access_ptr ())) {
67+ if (op->args .size () > 1 ) {
68+ if (const auto *var = op->args [1 ].as <VarNode>()) {
69+ auto it = buffer_data_to_buffer_.find (GetRef<Var>(var));
70+ if (it != buffer_data_to_buffer_.end ()) {
71+ AddBuffer ((*it).second );
72+ }
73+ }
74+ }
75+ }
76+ StmtExprVisitor::VisitExpr_ (op);
77+ }
78+
79+ void VisitStmt_ (const BlockNode *op) final {
80+ // Also collect buffers allocated in nested blocks within the pipeline body
81+ for (const auto &buffer : op->alloc_buffers ) {
82+ used_buffers_.insert (buffer);
83+ }
84+ StmtExprVisitor::VisitStmt_ (op);
85+ }
86+
87+ void AddBuffer (const Buffer &buffer) {
88+ // Only add buffers that are allocated (not function input/output buffers)
89+ if (allocated_buffers_.count (buffer)) {
90+ used_buffers_.insert (buffer);
91+ }
92+ }
93+
94+ const Map<Var, Buffer> &buffer_data_to_buffer_;
95+ const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
96+ &allocated_buffers_;
97+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> used_buffers_;
98+ };
99+
29100/* !
30101 * \brief Create a block and infer the access region with the given body.
31102 *
@@ -219,13 +290,28 @@ class PipelineBodyRewriter : public StmtExprMutator {
219290 */
220291class PipelineRewriter : public StmtExprMutator {
221292public:
293+ /* !
294+ * \brief Constructor of PipelineRewriter.
295+ * \param buffer_data_to_buffer The map from buffer data to buffer.
296+ * \param pipeline_allocs All buffers that need multi-versioning in the
297+ * pipeline. This includes buffers allocated in the pipeline block and
298+ * buffers allocated in outer blocks that are used in the pipeline.
299+ * \param local_allocs Buffers that are allocated in the pipeline block
300+ * itself. These buffers will be re-allocated in the rewritten block.
301+ * Buffers in pipeline_allocs but not in local_allocs are allocated in outer
302+ * blocks and should not be re-allocated.
303+ * \param pipeline_loop The original loop to be software pipelined.
304+ * \param pipeline_info The pipeline annotation information.
305+ * \param loop_var_let_wrappers Let wrappers that depend on the loop var.
306+ */
222307 PipelineRewriter (Map<Var, Buffer> buffer_data_to_buffer,
223308 const Array<Buffer> &pipeline_allocs,
224- const For &pipeline_loop, const PipelineInfo &pipeline_info,
309+ const Array<Buffer> &local_allocs, const For &pipeline_loop,
310+ const PipelineInfo &pipeline_info,
225311 const std::vector<LetWrapper> &loop_var_let_wrappers)
226312 : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
227- pipeline_allocs_ (pipeline_allocs), pipeline_loop_(pipeline_loop ),
228- pipeline_info_(pipeline_info),
313+ pipeline_allocs_ (pipeline_allocs), local_allocs_(local_allocs ),
314+ pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info),
229315 loop_var_let_wrappers_(loop_var_let_wrappers) {}
230316
231317 Stmt BuildPipeline () {
@@ -234,7 +320,12 @@ class PipelineRewriter : public StmtExprMutator {
234320 std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
235321 infos = GetBufferAccessInfo ();
236322 for (const Buffer &buffer : pipeline_allocs_) {
237- int num_versions = ComputeBufferVersions (buffer, infos.at (buffer));
323+ auto it = infos.find (buffer);
324+ if (it == infos.end ()) {
325+ // Buffer is not accessed in the pipeline blocks, skip it
326+ continue ;
327+ }
328+ int num_versions = ComputeBufferVersions (buffer, it->second );
238329 if (num_versions > 1 ) {
239330 buffer_remap_.Set (buffer, RewriteAllocBuffer (buffer, num_versions));
240331 }
@@ -302,8 +393,12 @@ class PipelineRewriter : public StmtExprMutator {
302393
303394 // Step 3: Make a new block that contains new buffer allocations after
304395 // pipeline rewriting.
396+ // Only include buffers that are locally allocated in the pipeline block.
397+ // Buffers from outer blocks will be handled separately.
305398 Array<Buffer> alloc_buffers;
306- for (const auto &alloc : pipeline_allocs_) {
399+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> local_allocs_set (
400+ local_allocs_.begin (), local_allocs_.end ());
401+ for (const auto &alloc : local_allocs_) {
307402 alloc_buffers.push_back (buffer_remap_.Get (alloc).value_or (alloc));
308403 buffer_data_to_buffer_.erase (alloc->data );
309404 }
@@ -312,6 +407,12 @@ class PipelineRewriter : public StmtExprMutator {
312407 return BlockRealize ({}, Bool (true ), block);
313408 }
314409
410+ /* !
411+ * \brief Get the buffer remapping created during pipeline rewriting.
412+ * This is used to update alloc_buffers in outer blocks.
413+ */
414+ const Map<Buffer, Buffer> &GetBufferRemap () const { return buffer_remap_; }
415+
315416private:
316417 /* !
317418 * \brief Analyze accesses to the buffers in the software pipeline.
@@ -804,6 +905,7 @@ class PipelineRewriter : public StmtExprMutator {
804905 arith::Analyzer analyzer_;
805906 Map<Var, Buffer> buffer_data_to_buffer_;
806907 Array<Buffer> pipeline_allocs_;
908+ Array<Buffer> local_allocs_;
807909 For pipeline_loop_;
808910 PipelineInfo pipeline_info_;
809911 int max_stage_ = -1 ;
@@ -923,14 +1025,17 @@ class PipelineInjector : private StmtExprMutator {
9231025 Stmt pipeline_body_root{nullptr };
9241026 bool pipeline_body_from_block = false ;
9251027 Array<Buffer> pipeline_allocs;
1028+ Array<Buffer>
1029+ block_local_allocs; // buffers allocated in the pipeline block itself
9261030 if (const auto *realize = for_node->body .as <BlockRealizeNode>()) {
9271031 const auto &block = realize->block ;
9281032 for (const auto &buffer : block->alloc_buffers ) {
9291033 ICHECK (buffer->IsInstance <BufferNode>());
9301034 buffer_data_to_buffer_.Set (buffer->data , buffer);
1035+ allocated_buffers_.insert (buffer);
9311036 }
9321037 pipeline_body_root = block->body ;
933- pipeline_allocs = block->alloc_buffers ;
1038+ block_local_allocs = block->alloc_buffers ;
9341039 pipeline_body_from_block = true ;
9351040 } else {
9361041 pipeline_body_root = for_node->body ;
@@ -1021,13 +1126,49 @@ class PipelineInjector : private StmtExprMutator {
10211126 ICHECK (nested_pipeline_block->match_buffers
10221127 .empty ()); // match_buffer should have been lowered
10231128 for (const auto &buffer : nested_pipeline_block->alloc_buffers ) {
1024- pipeline_allocs.push_back (buffer);
10251129 buffer_data_to_buffer_.Set (buffer->data , buffer);
1130+ allocated_buffers_.insert (buffer);
10261131 }
10271132 }
10281133 f_add_child (child);
10291134 }
10301135
1136+ // Collect all buffers that are actually used in the pipeline loop body.
1137+ // This includes buffers allocated in outer blocks (like logits_smem) that
1138+ // are used inside the pipeline loop.
1139+ BufferUsageCollector collector (buffer_data_to_buffer_, allocated_buffers_);
1140+ pipeline_allocs = collector.Collect (SeqStmt (pipeline_body_seq->seq ));
1141+
1142+ // Build a set of local allocs (buffers allocated in the pipeline block
1143+ // itself) for efficient lookup
1144+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> local_allocs_set;
1145+ for (const auto &buffer : block_local_allocs) {
1146+ local_allocs_set.insert (buffer);
1147+ }
1148+ for (size_t i = 0 ; i < pipeline_body_seq->seq .size (); i++) {
1149+ const Stmt &child = pipeline_body_seq->seq [i];
1150+ const auto *nested_block_realize = child.as <BlockRealizeNode>();
1151+ if (nested_block_realize && is_one (nested_block_realize->predicate ) &&
1152+ nested_block_realize->block ->body ->IsInstance <SeqStmtNode>()) {
1153+ for (const auto &buffer : nested_block_realize->block ->alloc_buffers ) {
1154+ local_allocs_set.insert (buffer);
1155+ }
1156+ }
1157+ }
1158+
1159+ // Check if any external buffer (from outer blocks) is already used in
1160+ // another pipeline. This would cause conflicts in multi-versioning.
1161+ for (const auto &buffer : pipeline_allocs) {
1162+ // Only check external buffers (not locally allocated in this pipeline)
1163+ if (local_allocs_set.count (buffer) == 0 ) {
1164+ CHECK (buffers_used_in_pipeline_.count (buffer) == 0 )
1165+ << " Buffer '" << buffer->name
1166+ << " ' is used in multiple software pipeline loops. "
1167+ << " This is not supported because multi-versioning would conflict." ;
1168+ buffers_used_in_pipeline_.insert (buffer);
1169+ }
1170+ }
1171+
10311172 auto pipeline_stages = Downcast<Array<Integer>>(
10321173 op->annotations .at (tir::attr::software_pipeline_stage));
10331174 auto pipeline_orders = Downcast<Array<Integer>>(
@@ -1067,10 +1208,32 @@ class PipelineInjector : private StmtExprMutator {
10671208 ValidatePipelineBody (pipeline_info, original_order);
10681209
10691210 // Step 4: Rewrite the pipeline body.
1070- Stmt pipeline = PipelineRewriter (buffer_data_to_buffer_, pipeline_allocs,
1071- tvm::ffi::GetRef<For>(op), pipeline_info,
1072- loop_var_let_wrappers)
1073- .BuildPipeline ();
1211+ // local_allocs contains buffers allocated in the pipeline block itself.
1212+ // pipeline_allocs contains all buffers that need multi-versioning,
1213+ // including buffers from outer blocks.
1214+ Array<Buffer> local_allocs = block_local_allocs;
1215+ // Add nested block allocs to local_allocs
1216+ for (size_t i = 0 ; i < pipeline_body_seq->seq .size (); i++) {
1217+ const Stmt &child = pipeline_body_seq->seq [i];
1218+ const auto *nested_block_realize = child.as <BlockRealizeNode>();
1219+ if (nested_block_realize && is_one (nested_block_realize->predicate ) &&
1220+ nested_block_realize->block ->body ->IsInstance <SeqStmtNode>()) {
1221+ const Block &nested_pipeline_block = nested_block_realize->block ;
1222+ for (const auto &buffer : nested_pipeline_block->alloc_buffers ) {
1223+ local_allocs.push_back (buffer);
1224+ }
1225+ }
1226+ }
1227+
1228+ PipelineRewriter rewriter (buffer_data_to_buffer_, pipeline_allocs,
1229+ local_allocs, tvm::ffi::GetRef<For>(op),
1230+ pipeline_info, loop_var_let_wrappers);
1231+ Stmt pipeline = rewriter.BuildPipeline ();
1232+
1233+ // Store the buffer remapping for updating outer block alloc_buffers
1234+ for (const auto &kv : rewriter.GetBufferRemap ()) {
1235+ pending_buffer_remap_.Set (kv.first , kv.second );
1236+ }
10741237 auto apply_wrappers = [&](Stmt stmt) {
10751238 for (auto it = rewrap_fns.rbegin (); it != rewrap_fns.rend (); ++it) {
10761239 stmt = (*it)(stmt);
@@ -1097,6 +1260,7 @@ class PipelineInjector : private StmtExprMutator {
10971260 const auto &block = realize->block ;
10981261 for (const auto &buffer : block->alloc_buffers ) {
10991262 buffer_data_to_buffer_.erase (buffer->data );
1263+ allocated_buffers_.erase (buffer);
11001264 }
11011265 }
11021266 return pipeline;
@@ -1105,18 +1269,35 @@ class PipelineInjector : private StmtExprMutator {
11051269 Stmt VisitStmt_ (const BlockNode *op) final {
11061270 for (const auto &buffer : op->alloc_buffers ) {
11071271 buffer_data_to_buffer_.Set (buffer->data , buffer);
1272+ allocated_buffers_.insert (buffer);
11081273 }
11091274
11101275 Block block = Downcast<Block>(StmtExprMutator::VisitStmt_ (op));
11111276
1277+ // Update alloc_buffers with any pending buffer remaps from pipeline
1278+ // rewriting. This handles buffers allocated in this block but
1279+ // multi-versioned during pipeline rewriting of inner loops.
1280+ Array<Buffer> new_alloc_buffers;
1281+ for (const auto &buffer : block->alloc_buffers ) {
1282+ if (auto remapped = pending_buffer_remap_.Get (buffer)) {
1283+ new_alloc_buffers.push_back (remapped.value ());
1284+ // Remove from pending after applying
1285+ pending_buffer_remap_.erase (buffer);
1286+ } else {
1287+ new_alloc_buffers.push_back (buffer);
1288+ }
1289+ }
1290+
11121291 Array<Array<BufferRegion>> access =
11131292 GetBlockReadWriteRegion (block, buffer_data_to_buffer_);
11141293 BlockNode *n = block.CopyOnWrite ();
11151294 n->reads = access[0 ];
11161295 n->writes = access[1 ];
1296+ n->alloc_buffers = std::move (new_alloc_buffers);
11171297
11181298 for (const auto &buffer : op->alloc_buffers ) {
11191299 buffer_data_to_buffer_.erase (buffer->data );
1300+ allocated_buffers_.erase (buffer);
11201301 }
11211302 return block;
11221303 }
@@ -1141,6 +1322,12 @@ class PipelineInjector : private StmtExprMutator {
11411322 }
11421323
11431324 Map<Var, Buffer> buffer_data_to_buffer_;
1325+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocated_buffers_;
1326+ Map<Buffer, Buffer> pending_buffer_remap_;
1327+ // Buffers from outer blocks that have been used in a pipeline loop.
1328+ // Used to detect if the same buffer is used in multiple pipeline loops.
1329+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
1330+ buffers_used_in_pipeline_;
11441331 Optional<String> global_symbol_;
11451332};
11461333} // namespace software_pipeline
0 commit comments