From 936cb353959e579a946f724ccb1ba24eecd4f271 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Sat, 4 Jan 2025 16:49:55 +0800 Subject: [PATCH] better unreachable code removal in dce --- include/luisa/xir/passes/dom_tree.h | 6 +- src/backends/fallback/fallback_codegen.cpp | 2 + src/xir/passes/dce.cpp | 104 +++++++++++++++++---- src/xir/passes/dom_tree.cpp | 28 +++++- 4 files changed, 120 insertions(+), 20 deletions(-) diff --git a/include/luisa/xir/passes/dom_tree.h b/include/luisa/xir/passes/dom_tree.h index f7474c228..31c3dfdab 100644 --- a/include/luisa/xir/passes/dom_tree.h +++ b/include/luisa/xir/passes/dom_tree.h @@ -42,7 +42,11 @@ class LC_XIR_API DomTree { public: [[nodiscard]] auto root() const noexcept { return _root; } [[nodiscard]] auto &nodes() const noexcept { return _nodes; } - [[nodiscard]] const DomTreeNode *node(BasicBlock *block) const noexcept; + [[nodiscard]] auto node(BasicBlock *block) const noexcept -> const DomTreeNode *; + [[nodiscard]] bool contains(BasicBlock *block) const noexcept; + [[nodiscard]] bool dominates(BasicBlock *src, BasicBlock *dst) const noexcept; + [[nodiscard]] bool strictly_dominates(BasicBlock *src, BasicBlock *dst) const noexcept; + [[nodiscard]] auto immediate_dominator(BasicBlock *block) const noexcept -> BasicBlock *; }; [[nodiscard]] LC_XIR_API DomTree compute_dom_tree(Function *function) noexcept; diff --git a/src/backends/fallback/fallback_codegen.cpp b/src/backends/fallback/fallback_codegen.cpp index f23db1786..2b05cf356 100644 --- a/src/backends/fallback/fallback_codegen.cpp +++ b/src/backends/fallback/fallback_codegen.cpp @@ -2917,6 +2917,7 @@ class FallbackCodegen { } [[nodiscard]] llvm::BasicBlock *_find_or_create_basic_block(CurrentFunction ¤t, const xir::BasicBlock *bb) noexcept { + if (bb == nullptr) { return nullptr; } auto iter = current.value_map.try_emplace(bb, nullptr).first; if (iter->second) { return llvm::cast(iter->second); } auto llvm_bb = llvm::BasicBlock::Create(_llvm_context, _get_name_from_metadata(bb), current.func); @@ -2925,6 +2926,7 @@ class FallbackCodegen { } void _translate_instructions_in_basic_block(CurrentFunction ¤t, llvm::BasicBlock *llvm_bb, const xir::BasicBlock *bb) noexcept { + if (bb == nullptr) { return; } if (current.translated_basic_blocks.emplace(llvm_bb).second) { for (auto &inst : bb->instructions()) { IRBuilder b{llvm_bb}; diff --git a/src/xir/passes/dce.cpp b/src/xir/passes/dce.cpp index 8dc8c6eac..b4e6c1cbc 100644 --- a/src/xir/passes/dce.cpp +++ b/src/xir/passes/dce.cpp @@ -132,7 +132,74 @@ void eliminate_dead_alloca_in_function(Function *function, DCEInfo &info) noexce } } -void eliminate_unreachable_code_in_function(Function *function, DCEInfo &info) noexcept { +[[nodiscard]] static bool is_block_terminated_by_unreachable(BasicBlock *block) noexcept { + return block->terminator()->derived_instruction_tag() == DerivedInstructionTag::UNREACHABLE; +} + +void eliminate_instructions_in_unreachable_blocks(const luisa::unordered_set &blocks, DCEInfo &info) noexcept { + luisa::vector cache; + for (auto b : blocks) { + // replace the terminator with an unreachable instruction if it's not already + if (!is_block_terminated_by_unreachable(b)) { + b->terminator()->remove_self(); + xir::Builder builder; + builder.set_insertion_point(b); + builder.unreachable_(); + } + // collect all instructions in the unreachable block + for (auto &&inst : b->instructions()) { + cache.emplace_back(&inst); + } + // pop the terminator + cache.pop_back(); + // remove all instructions + for (auto &&inst : cache) { + inst->remove_self(); + info.removed_instructions.emplace(inst); + } + cache.clear(); + } +} + +void propagate_unreachable_marks_in_function(Function *function, DCEInfo &info) noexcept { + // run a backward dataflow analysis to propagate unreachable marks: + // we should mark a block as unreachable if all its successors are marked as unreachable + if (auto definition = function->definition()) { + luisa::vector postorder; + definition->traverse_basic_blocks(BasicBlockTraversalOrder::POST_ORDER, [&](BasicBlock *block) noexcept { + postorder.emplace_back(block); + }); + luisa::unordered_set unreachable; + for (;;) { + auto prev_reachable_count = unreachable.size(); + for (auto block : postorder) { + if (!unreachable.contains(block)) { + if (is_block_terminated_by_unreachable(block)) { + unreachable.emplace(block); + } else { + auto has_any_successor = false; + auto all_successors_unreachable = true; + block->traverse_successors(false, [&](BasicBlock *succ) noexcept { + has_any_successor = true; + if (succ != block && !unreachable.contains(succ) && + !is_block_terminated_by_unreachable(succ)) { + all_successors_unreachable = false; + } + }); + if (has_any_successor && all_successors_unreachable) { + unreachable.emplace(block); + } + } + } + } + if (unreachable.size() == prev_reachable_count) { break; } + } + // eliminate all instructions in unreachable blocks + eliminate_instructions_in_unreachable_blocks(unreachable, info); + } +} + +void eliminate_unreachable_blocks_in_function(Function *function, DCEInfo &info) noexcept { if (auto definition = function->definition()) { luisa::unordered_set reachable; definition->traverse_basic_blocks([&](BasicBlock *block) noexcept { @@ -155,22 +222,8 @@ void eliminate_unreachable_code_in_function(Function *function, DCEInfo &info) n } }); } - luisa::vector dead; - for (auto b : unreachable) { - dead.clear(); - // collect and remove all instructions in the unreachable block - for (auto &&inst : b->instructions()) { - dead.emplace_back(&inst); - } - for (auto &&inst : dead) { - info.removed_instructions.emplace(inst); - inst->remove_self(); - } - // replace with an unreachable instruction - xir::Builder builder; - builder.set_insertion_point(b); - builder.unreachable_(); - } + // eliminate all instructions in unreachable blocks + eliminate_instructions_in_unreachable_blocks(unreachable, info); } } @@ -213,9 +266,24 @@ void fix_phi_nodes_in_function(Function *function) noexcept { } } +void fix_control_flow_merges_in_function(Function *function) noexcept { + if (auto definition = function->definition()) { + definition->traverse_basic_blocks([&](BasicBlock *block) noexcept { + if (auto merge = block->terminator()->control_flow_merge()) { + if (auto merge_block = merge->merge_block(); + merge_block != nullptr && is_block_terminated_by_unreachable(merge_block)) { + merge->set_merge_block(nullptr); + } + } + }); + } +} + void run_dce_pass_on_function(Function *function, DCEInfo &info) noexcept { - eliminate_unreachable_code_in_function(function, info); + propagate_unreachable_marks_in_function(function, info); + eliminate_unreachable_blocks_in_function(function, info); fix_phi_nodes_in_function(function); + fix_control_flow_merges_in_function(function); for (;;) { auto prev_count = info.removed_instructions.size(); eliminate_dead_code_in_function(function, info); diff --git a/src/xir/passes/dom_tree.cpp b/src/xir/passes/dom_tree.cpp index ae1caef90..e3a783c98 100644 --- a/src/xir/passes/dom_tree.cpp +++ b/src/xir/passes/dom_tree.cpp @@ -60,12 +60,38 @@ inline void DomTree::compute_dominance_frontiers() noexcept { } } -const DomTreeNode *DomTree::node(BasicBlock *block) const noexcept { +auto DomTree::node(BasicBlock *block) const noexcept -> const DomTreeNode * { auto iter = _nodes.find(block); LUISA_ASSERT(iter != _nodes.end(), "Block not found in the dom tree."); return iter->second.get(); } +bool DomTree::contains(BasicBlock *block) const noexcept { + return _nodes.contains(block); +} + +bool DomTree::dominates(BasicBlock *src, BasicBlock *dst) const noexcept { + if (src == dst) { return true; } + auto src_node = node(src); + auto dst_node = node(dst); + if (src_node == _root) { return true; } + while (dst_node != _root) { + if (dst_node == src_node) { return true; } + dst_node = dst_node->parent(); + } + return false; +} + +bool DomTree::strictly_dominates(BasicBlock *src, BasicBlock *dst) const noexcept { + return src != dst && dominates(src, dst); +} + +auto DomTree::immediate_dominator(BasicBlock *block) const noexcept -> BasicBlock * { + auto node = this->node(block); + if (node == _root) { return nullptr; } + return node->parent()->block(); +} + // Reference: A Simple, Fast Dominance Algorithm [Cooper et al. 2001] DomTree compute_dom_tree(Function *function) noexcept { auto definition = function->definition();