Skip to content

Commit

Permalink
better unreachable code removal in dce
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jan 4, 2025
1 parent 8b7bd38 commit 936cb35
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 20 deletions.
6 changes: 5 additions & 1 deletion include/luisa/xir/passes/dom_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ class LC_XIR_API DomTree {
public:
[[nodiscard]] auto root() const noexcept { return _root; }
[[nodiscard]] auto &nodes() const noexcept { return _nodes; }
[[nodiscard]] const DomTreeNode *node(BasicBlock *block) const noexcept;
[[nodiscard]] auto node(BasicBlock *block) const noexcept -> const DomTreeNode *;
[[nodiscard]] bool contains(BasicBlock *block) const noexcept;
[[nodiscard]] bool dominates(BasicBlock *src, BasicBlock *dst) const noexcept;
[[nodiscard]] bool strictly_dominates(BasicBlock *src, BasicBlock *dst) const noexcept;
[[nodiscard]] auto immediate_dominator(BasicBlock *block) const noexcept -> BasicBlock *;
};

[[nodiscard]] LC_XIR_API DomTree compute_dom_tree(Function *function) noexcept;
Expand Down
2 changes: 2 additions & 0 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2917,6 +2917,7 @@ class FallbackCodegen {
}

[[nodiscard]] llvm::BasicBlock *_find_or_create_basic_block(CurrentFunction &current, const xir::BasicBlock *bb) noexcept {
if (bb == nullptr) { return nullptr; }
auto iter = current.value_map.try_emplace(bb, nullptr).first;
if (iter->second) { return llvm::cast<llvm::BasicBlock>(iter->second); }
auto llvm_bb = llvm::BasicBlock::Create(_llvm_context, _get_name_from_metadata(bb), current.func);
Expand All @@ -2925,6 +2926,7 @@ class FallbackCodegen {
}

void _translate_instructions_in_basic_block(CurrentFunction &current, llvm::BasicBlock *llvm_bb, const xir::BasicBlock *bb) noexcept {
if (bb == nullptr) { return; }
if (current.translated_basic_blocks.emplace(llvm_bb).second) {
for (auto &inst : bb->instructions()) {
IRBuilder b{llvm_bb};
Expand Down
104 changes: 86 additions & 18 deletions src/xir/passes/dce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,74 @@ void eliminate_dead_alloca_in_function(Function *function, DCEInfo &info) noexce
}
}

void eliminate_unreachable_code_in_function(Function *function, DCEInfo &info) noexcept {
[[nodiscard]] static bool is_block_terminated_by_unreachable(BasicBlock *block) noexcept {
return block->terminator()->derived_instruction_tag() == DerivedInstructionTag::UNREACHABLE;
}

void eliminate_instructions_in_unreachable_blocks(const luisa::unordered_set<BasicBlock *> &blocks, DCEInfo &info) noexcept {
luisa::vector<Instruction *> cache;
for (auto b : blocks) {
// replace the terminator with an unreachable instruction if it's not already
if (!is_block_terminated_by_unreachable(b)) {
b->terminator()->remove_self();
xir::Builder builder;
builder.set_insertion_point(b);
builder.unreachable_();
}
// collect all instructions in the unreachable block
for (auto &&inst : b->instructions()) {
cache.emplace_back(&inst);
}
// pop the terminator
cache.pop_back();
// remove all instructions
for (auto &&inst : cache) {
inst->remove_self();
info.removed_instructions.emplace(inst);
}
cache.clear();
}
}

void propagate_unreachable_marks_in_function(Function *function, DCEInfo &info) noexcept {
// run a backward dataflow analysis to propagate unreachable marks:
// we should mark a block as unreachable if all its successors are marked as unreachable
if (auto definition = function->definition()) {
luisa::vector<BasicBlock *> postorder;
definition->traverse_basic_blocks(BasicBlockTraversalOrder::POST_ORDER, [&](BasicBlock *block) noexcept {
postorder.emplace_back(block);
});
luisa::unordered_set<BasicBlock *> unreachable;
for (;;) {
auto prev_reachable_count = unreachable.size();
for (auto block : postorder) {
if (!unreachable.contains(block)) {
if (is_block_terminated_by_unreachable(block)) {
unreachable.emplace(block);
} else {
auto has_any_successor = false;
auto all_successors_unreachable = true;
block->traverse_successors(false, [&](BasicBlock *succ) noexcept {
has_any_successor = true;
if (succ != block && !unreachable.contains(succ) &&
!is_block_terminated_by_unreachable(succ)) {
all_successors_unreachable = false;
}
});
if (has_any_successor && all_successors_unreachable) {
unreachable.emplace(block);
}
}
}
}
if (unreachable.size() == prev_reachable_count) { break; }
}
// eliminate all instructions in unreachable blocks
eliminate_instructions_in_unreachable_blocks(unreachable, info);
}
}

void eliminate_unreachable_blocks_in_function(Function *function, DCEInfo &info) noexcept {
if (auto definition = function->definition()) {
luisa::unordered_set<BasicBlock *> reachable;
definition->traverse_basic_blocks([&](BasicBlock *block) noexcept {
Expand All @@ -155,22 +222,8 @@ void eliminate_unreachable_code_in_function(Function *function, DCEInfo &info) n
}
});
}
luisa::vector<Instruction *> dead;
for (auto b : unreachable) {
dead.clear();
// collect and remove all instructions in the unreachable block
for (auto &&inst : b->instructions()) {
dead.emplace_back(&inst);
}
for (auto &&inst : dead) {
info.removed_instructions.emplace(inst);
inst->remove_self();
}
// replace with an unreachable instruction
xir::Builder builder;
builder.set_insertion_point(b);
builder.unreachable_();
}
// eliminate all instructions in unreachable blocks
eliminate_instructions_in_unreachable_blocks(unreachable, info);
}
}

Expand Down Expand Up @@ -213,9 +266,24 @@ void fix_phi_nodes_in_function(Function *function) noexcept {
}
}

void fix_control_flow_merges_in_function(Function *function) noexcept {
if (auto definition = function->definition()) {
definition->traverse_basic_blocks([&](BasicBlock *block) noexcept {
if (auto merge = block->terminator()->control_flow_merge()) {
if (auto merge_block = merge->merge_block();
merge_block != nullptr && is_block_terminated_by_unreachable(merge_block)) {
merge->set_merge_block(nullptr);
}
}
});
}
}

void run_dce_pass_on_function(Function *function, DCEInfo &info) noexcept {
eliminate_unreachable_code_in_function(function, info);
propagate_unreachable_marks_in_function(function, info);
eliminate_unreachable_blocks_in_function(function, info);
fix_phi_nodes_in_function(function);
fix_control_flow_merges_in_function(function);
for (;;) {
auto prev_count = info.removed_instructions.size();
eliminate_dead_code_in_function(function, info);
Expand Down
28 changes: 27 additions & 1 deletion src/xir/passes/dom_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,38 @@ inline void DomTree::compute_dominance_frontiers() noexcept {
}
}

const DomTreeNode *DomTree::node(BasicBlock *block) const noexcept {
auto DomTree::node(BasicBlock *block) const noexcept -> const DomTreeNode * {
auto iter = _nodes.find(block);
LUISA_ASSERT(iter != _nodes.end(), "Block not found in the dom tree.");
return iter->second.get();
}

bool DomTree::contains(BasicBlock *block) const noexcept {
return _nodes.contains(block);
}

bool DomTree::dominates(BasicBlock *src, BasicBlock *dst) const noexcept {
if (src == dst) { return true; }
auto src_node = node(src);
auto dst_node = node(dst);
if (src_node == _root) { return true; }
while (dst_node != _root) {
if (dst_node == src_node) { return true; }
dst_node = dst_node->parent();
}
return false;
}

bool DomTree::strictly_dominates(BasicBlock *src, BasicBlock *dst) const noexcept {
return src != dst && dominates(src, dst);
}

auto DomTree::immediate_dominator(BasicBlock *block) const noexcept -> BasicBlock * {
auto node = this->node(block);
if (node == _root) { return nullptr; }
return node->parent()->block();
}

// Reference: A Simple, Fast Dominance Algorithm [Cooper et al. 2001]
DomTree compute_dom_tree(Function *function) noexcept {
auto definition = function->definition();
Expand Down

0 comments on commit 936cb35

Please sign in to comment.