Skip to content

Commit

Permalink
Split continues with convergent calls in some cases (#1439)
Browse files Browse the repository at this point in the history
* Adds a new transform to FixupStructuredCFG to split a continue (latch)
  into two blocks under certain circumstances:
  * The loop header is a conditional branch to the body and latch
  * The latch has two predecessors
  * The latch contains a convergent call

* This transformation prevents clspv forces (along with the
  breakConditionalHeader transform in the same pass) to prevent
  convergent operations from being placed in the loop continue contruct.
  Instead they end up as a structured selection in the body. This
  ensures reconvergence more robustly than previously. SPIRV-Cross, for
  example, inlines continues into the body under the assumption that
  reconvergence is not expected
  • Loading branch information
alan-baker authored Jan 22, 2025
1 parent 7b4ea7e commit 85146eb
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 5 deletions.
96 changes: 94 additions & 2 deletions lib/FixupStructuredCFGPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ PreservedAnalyses
clspv::FixupStructuredCFGPass::run(Function &F, FunctionAnalysisManager &FAM) {
// Assumes CFG has been structurized.
isolateContinue(F, FAM);
// Run after isolateContinue since this can invalidate loop info.
isolateConvergentLatch(F, FAM);
breakConditionalHeader(F, FAM);

removeUndefPHI(F);
Expand Down Expand Up @@ -75,6 +75,97 @@ void clspv::FixupStructuredCFGPass::removeUndefPHI(Function &F) {
}
}

void clspv::FixupStructuredCFGPass::isolateConvergentLatch(
Function &F, FunctionAnalysisManager &FAM) {
auto &LI = FAM.getResult<LoopAnalysis>(F);

std::vector<BasicBlock *> blocks;
blocks.reserve(F.size());
for (auto &BB : F) {
blocks.push_back(&BB);
}

for (auto *BB : blocks) {
if (!LI.isLoopHeader(BB))
continue;

auto *loop = LI.getLoopFor(BB);
auto *latch = loop->getLoopLatch();
// Skip single block loops.
if (!latch || latch == BB) {
continue;
}

// Latch needs two predecessors.
if (!latch->hasNPredecessors(2)) {
continue;
}

// Header is a conditional branch.
auto header_terminator = dyn_cast_or_null<BranchInst>(BB->getTerminator());
if (!header_terminator || !header_terminator->isConditional()) {
continue;
}

// One edge jumps to the continue target.
if (header_terminator->getSuccessor(0) != latch &&
header_terminator->getSuccessor(1) != latch) {
continue;
}

// The continue contains a convergent call.
bool has_convergent_call = false;
for (auto &inst : *latch) {
if (auto *call = dyn_cast<CallInst>(&inst)) {
if (call->hasFnAttr(Attribute::Convergent)) {
has_convergent_call = true;
break;
}
}
}
if (!has_convergent_call) {
continue;
}

auto *latch_terminator =
dyn_cast_or_null<BranchInst>(latch->getTerminator());
if (!latch_terminator)
continue;

// Break the latch such that it is a single-entry single-exit block.
// This will force later transforms in this fixup to break the loop header
// which puts the whole loop body as a selection.
if (latch_terminator->isConditional()) {
// Safety valve: if this is not an exiting block then the loop is not
// structured as expected.
if (!loop->isLoopExiting(latch)) {
continue;
}

// Conditional branch case: one edge back to header and one out of the
// loop. Transformed into one edge out of the loop and one edge to the new
// continue and thence to the header.
auto new_latch =
BasicBlock::Create(F.getContext(), "", &F, latch->getNextNode());
BranchInst::Create(BB, new_latch);
loop->addBlockEntry(new_latch);

const auto idx = latch_terminator->getSuccessor(0) == BB ? 0 : 1;
latch_terminator->setSuccessor(idx, new_latch);

// Update phis to use the new basic block.
for (auto iter = BB->begin(); &*iter != BB->getFirstNonPHI(); ++iter) {
PHINode *phi = cast<PHINode>(&*iter);
phi->replaceIncomingBlockWith(latch, new_latch);
}
} else {
// Simple case: just split the block.
auto new_block = latch->splitBasicBlockBefore(latch_terminator);
loop->addBlockEntry(new_block);
}
}
}

void clspv::FixupStructuredCFGPass::breakConditionalHeader(
Function &F, FunctionAnalysisManager &FAM) {
auto &LI = FAM.getResult<LoopAnalysis>(F);
Expand Down Expand Up @@ -106,7 +197,8 @@ void clspv::FixupStructuredCFGPass::breakConditionalHeader(
bool succ2_in_body = succ2 != latch && succ2 != exit;

if (succ1_in_body && succ2_in_body) {
BB->splitBasicBlockBefore(terminator);
auto new_block = BB->splitBasicBlockBefore(terminator);
loop->addBlockEntry(new_block);
}
}
}
Expand Down
29 changes: 29 additions & 0 deletions lib/FixupStructuredCFGPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,35 @@ struct FixupStructuredCFGPass : llvm::PassInfoMixin<FixupStructuredCFGPass> {
void breakConditionalHeader(llvm::Function &F, llvm::FunctionAnalysisManager &FAM);
void isolateContinue(llvm::Function &F, llvm::FunctionAnalysisManager &FAM);

/**
* Transforms a loop such as:
*
* header --\
* / \ |
* body | |
* \ / ^
* latch |
* / \ |
* exit ---/
*
* Into:
* header --------\
* / \ |
* body | |
* \ / ^
* old_latch |
* / \ |
* exit new_latch -/
*
* When the latch contains a convergent call (e.g. a barrier). This will force
* breakConditionalHeader to transform the loop also and effectively
* encapsulates body within a selection now fully contained in the body of the
* loop. This effectively moves the convergent call out of the latch where
* SPIR-V does not guarantee reconvergence (without maximal reconvergence)
* into a fully structured section where reconvergence is guaranteed.
*/
void isolateConvergentLatch(llvm::Function &F,
llvm::FunctionAnalysisManager &FAM);
};
} // namespace clspv

Expand Down
42 changes: 42 additions & 0 deletions test/FixupStructuredCFG/split_convergent_continue_branch.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: clspv-opt --passes=fixup-structured-cfg %s -o %t.ll
; RUN: FileCheck %s < %t.ll

; CHECK: entry:
; CHECK-NEXT: br label %[[new_header:[a-zA-Z0-9_.]+]]
; CHECK: [[new_header]]:
; CHECK-NEXT: br label %loop
; CHECK: loop:
; CHECK-NEXT: br i1 undef, label %then, label %[[pre_cont:[a-zA-Z0-9_.]+]]
; CHECK: then:
; CHECK-NEXT: br i1 undef, label %[[pre_cont]], label %exit
; CHECK: [[pre_cont]]:
; CHECK: call void @_Z8spirv.op.224
; CHECK-NEXT: br label %[[cont:[a-zA-Z0-9_.]+]]
; CHECK: [[cont]]:
; CHECK-NEXT: br label %[[new_header]]

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

define spir_kernel void @test() {
entry:
br label %loop

loop:
br i1 undef, label %then, label %cont

then:
br i1 undef, label %cont, label %exit

cont:
tail call void @_Z8spirv.op.224.jjj(i32 224, i32 2, i32 2, i32 264) #0
br label %loop

exit:
ret void
}

attributes #0 = { convergent }

declare void @_Z8spirv.op.224.jjj(i32, i32, i32, i32) #0

44 changes: 44 additions & 0 deletions test/FixupStructuredCFG/split_convergent_continue_cond_branch.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
; RUN: clspv-opt --passes=fixup-structured-cfg %s -o %t.ll
; RUN: FileCheck %s < %t.ll

; CHECK: entry:
; CHECK-NEXT: br label %[[new_header:[a-zA-Z0-9_.]+]]
; CHECK: [[new_header]]:
; CHECK-NEXT: phi i32 [ 0, %entry ], [ 1, %[[cont:[a-zA-Z0-9_.]+]] ]
; CHECK-NEXT: br label %loop
; CHECK: loop:
; CHECK-NEXT: br i1 undef, label %then, label %[[pre_cont:[a-zA-Z0-9_.]+]]
; CHECK: then:
; CHECK-NEXT: br label %[[pre_cont]]
; CHECK: [[pre_cont]]:
; CHECK: call void @_Z8spirv.op.224
; CHECK-NEXT: br i1 undef, label %[[cont]], label %exit
; CHECK: [[cont]]:
; CHECK-NEXT: br label %[[new_header]]

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

define spir_kernel void @test() {
entry:
br label %loop

loop:
%0 = phi i32 [ 0, %entry ], [ 1, %cont ]
br i1 undef, label %then, label %cont

then:
br label %cont

cont:
tail call void @_Z8spirv.op.224.jjj(i32 224, i32 2, i32 2, i32 264) #0
br i1 undef, label %loop, label %exit

exit:
ret void
}

attributes #0 = { convergent }

declare void @_Z8spirv.op.224.jjj(i32, i32, i32, i32) #0

4 changes: 1 addition & 3 deletions test/loop_continue_no_selection_merge.cl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// CHECK: OpBranch [[CONT]]
// CHECK: [[CONT]] = OpLabel
// CHECK-NOT: OpLabel
// CHECK: OpControlBarrier
// CHECK-NOT: OpLabel
// CHECK: OpBranchConditional {{.*}} [[MERGE]] [[LOOP]]

__kernel void
Expand All @@ -27,6 +25,7 @@ top_scan(__global uint * isums,
int last_thread = (get_local_id(0) < n &&
(get_local_id(0)+1) == n) ? 1 : 0;

#pragma unroll 0
for (int d = 0; d < 16; d++)
{
int idx = get_local_id(0);
Expand All @@ -35,7 +34,6 @@ top_scan(__global uint * isums,
{
s_seed += 42;
}
barrier(CLK_LOCAL_MEM_FENCE);
}
}

Expand Down

0 comments on commit 85146eb

Please sign in to comment.