Skip to content

Commit

Permalink
Integrate polymer's pluto-opt into polygeist pipeline (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanradanov authored Jun 19, 2024
1 parent 5bb6194 commit 62f8635
Show file tree
Hide file tree
Showing 46 changed files with 847 additions and 245 deletions.
5 changes: 5 additions & 0 deletions include/polygeist/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#include "polygeist/Dialect.h"
#include <memory>

#ifdef POLYGEIST_ENABLE_POLYMER
#include "polymer/Transforms/PlutoTransform.h"
#endif

enum PolygeistAlternativesMode { PAM_Static, PAM_PGO_Profile, PAM_PGO_Opt };
enum PolygeistGPUStructureMode {
PGSM_Discard,
Expand All @@ -35,6 +39,7 @@ std::unique_ptr<Pass> createCPUifyPass(StringRef method = "");
std::unique_ptr<Pass> createBarrierRemovalContinuation();
std::unique_ptr<Pass> detectReductionPass();
std::unique_ptr<Pass> createRemoveTrivialUsePass();
std::unique_ptr<Pass> createPolyhedralOptPass();
std::unique_ptr<Pass> createParallelLowerPass(
bool wrapParallelOps = false,
PolygeistGPUStructureMode gpuKernelStructureMode = PGSM_Discard);
Expand Down
10 changes: 10 additions & 0 deletions include/polygeist/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
include "mlir/Pass/PassBase.td"
include "mlir/Rewrite/PassUtil.td"

def PolyhedralOpt : Pass<"polyhedral-opt"> {
let summary = "Optimize affine regions with pluto";
let constructor = "mlir::polygeist::createPolyhedralOptPass()";
let dependentDialects = [
"scf::SCFDialect",
"arith::ArithDialect",
"memref::MemRefDialect",
];
}

def AffineCFG : Pass<"affine-cfg"> {
let summary = "Replace scf.if and similar with affine.if";
let constructor = "mlir::polygeist::replaceAffineCFGPass()";
Expand Down
4 changes: 2 additions & 2 deletions lib/polygeist/Passes/AffineCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,9 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
if ((t = fix(t, false))) {
assert(isValidSymbolInt(t, /*recur*/ false));
} else
assert(0 && "cannot move");
llvm_unreachable("cannot move");
} else
assert(0 && "cannot move2");
llvm_unreachable("cannot move2");
}
if (i < numDims) {
// b. The mathematical composition of AffineMap composes dims.
Expand Down
2 changes: 1 addition & 1 deletion lib/polygeist/Passes/AffineReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ struct AffineForReductionIter : public OpRewritePattern<affine::AffineForOp> {
load->getResult(0).replaceAllUsesWith(store.getOperand(0));
} else {

assert(0 && "illegal behavior");
llvm_unreachable("illegal behavior");
}
}

Expand Down
113 changes: 113 additions & 0 deletions lib/polygeist/Passes/AlwaysInliner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#ifndef _POLYGEIST_PASSES_ALWAYSINLINER_H_
#define _POLYGEIST_PASSES_ALWAYSINLINER_H_

#include "PassDetails.h"
#include "mlir/Analysis/CallGraph.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
#include "polygeist/Ops.h"
#include "polygeist/Passes/Passes.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringRef.h"

struct AlwaysInlinerInterface : public mlir::InlinerInterface {
using InlinerInterface::InlinerInterface;

//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//

/// All call operations within standard ops can be inlined.
bool isLegalToInline(mlir::Operation *call, mlir::Operation *callable,
bool wouldBeCloned) const final {
return true;
}

/// All operations within standard ops can be inlined.
bool isLegalToInline(mlir::Region *, mlir::Region *, bool,
mlir::IRMapping &) const final {
return true;
}

/// All operations within standard ops can be inlined.
bool isLegalToInline(mlir::Operation *, mlir::Region *, bool,
mlir::IRMapping &) const final {
return true;
}

//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//

/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(mlir::Operation *op, mlir::Block *newDest) const final {
// Only "std.return" needs to be handled here.
auto returnOp = mlir::dyn_cast<mlir::func::ReturnOp>(op);
if (!returnOp)
return;

// Replace the return with a branch to the dest.
mlir::OpBuilder builder(op);
builder.create<mlir::cf::BranchOp>(op->getLoc(), newDest,
returnOp.getOperands());
op->erase();
}

/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(mlir::Operation *op,
mlir::ArrayRef<mlir::Value> valuesToRepl) const final {
// Only "std.return" needs to be handled here.
auto returnOp = mlir::cast<mlir::func::ReturnOp>(op);

// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
};

[[maybe_unused]] static void alwaysInlineCall(mlir::func::CallOp caller) {
// Build the inliner interface.
AlwaysInlinerInterface interface(caller.getContext());

auto callable = caller.getCallableForCallee();
mlir::CallableOpInterface callableOp;
if (mlir::SymbolRefAttr symRef =
mlir::dyn_cast<mlir::SymbolRefAttr>(callable)) {
auto *symbolOp =
caller->getParentOfType<mlir::ModuleOp>().lookupSymbol(symRef);
callableOp = mlir::dyn_cast_or_null<mlir::CallableOpInterface>(symbolOp);
} else {
return;
}
mlir::Region *targetRegion = callableOp.getCallableRegion();
if (!targetRegion)
return;
if (targetRegion->empty())
return;
if (inlineCall(interface, caller, callableOp, targetRegion,
/*shouldCloneInlinedRegion=*/true)
.succeeded()) {
caller.erase();
}
};

#endif // _POLYGEIST_PASSES_ALWAYSINLINER_H_
13 changes: 13 additions & 0 deletions lib/polygeist/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRPolygeistTransforms
ConvertToOpaquePtr.cpp
AffineCFG.cpp
PolyhedralOpt.cpp
AffineReduction.cpp
CanonicalizeFor.cpp
LoopRestructure.cpp
Expand Down Expand Up @@ -132,3 +133,15 @@ if(POLYGEIST_ENABLE_ROCM)
)

endif()
if(POLYGEIST_ENABLE_POLYMER)
target_include_directories(obj.MLIRPolygeistTransforms PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/../../../tools/polymer/include"
)
target_link_libraries(obj.MLIRPolygeistTransforms PRIVATE
PolymerTransforms
)
target_compile_definitions(obj.MLIRPolygeistTransforms
PRIVATE
POLYGEIST_ENABLE_POLYMER=1
)
endif()
23 changes: 12 additions & 11 deletions lib/polygeist/Passes/ConvertParallelToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ struct ParallelizeBlockOps : public OpRewritePattern<scf::ParallelOp> {
Operation &op = *it;
Operation *newOp;
if (isa<scf::ParallelOp>(&op)) {
assert(0 && "Unhandled case");
llvm_unreachable("Unhandled case");
break;
} else if (isa<scf::YieldOp>(&op)) {
continue;
Expand All @@ -908,9 +908,9 @@ struct ParallelizeBlockOps : public OpRewritePattern<scf::ParallelOp> {
collectEffects(&op, effects, /*ignoreBarriers*/ false);
if (effects.empty()) {
} else if (hasEffect<MemoryEffects::Allocate>(effects)) {
assert(0 && "??");
llvm_unreachable("??");
} else if (hasEffect<MemoryEffects::Free>(effects)) {
assert(0 && "??");
llvm_unreachable("??");
} else if (hasEffect<MemoryEffects::Write>(effects)) {
getIf();
assert(ifOp);
Expand Down Expand Up @@ -947,15 +947,15 @@ struct ParallelizeBlockOps : public OpRewritePattern<scf::ParallelOp> {
for (; it != outerBlock->end(); ++it) {
Operation &op = *it;
if (isa<scf::ParallelOp>(&op)) {
assert(0 && "Unhandled case");
llvm_unreachable("Unhandled case");
break;
} else if (isa<scf::YieldOp>(&op)) {
continue;
} else if (auto alloca = dyn_cast<memref::AllocaOp>(&op)) {
assert(0 && "Unhandled case");
llvm_unreachable("Unhandled case");
break;
} else if (auto alloca = dyn_cast<LLVM::AllocaOp>(&op)) {
assert(0 && "Unhandled case");
llvm_unreachable("Unhandled case");
break;
} else {
rewriter.clone(op, mapping);
Expand Down Expand Up @@ -1126,9 +1126,9 @@ struct HandleWrapperRootOps : public OpRewritePattern<polygeist::GPUWrapperOp> {
} else if (hasEffect<MemoryEffects::Allocate>(effects)) {
// I think this can actually happen if we lower a kernel with a barrier
// and shared memory with gridDim = 1 TODO handle
assert(0 && "what?");
llvm_unreachable("what?");
} else if (hasEffect<MemoryEffects::Free>(effects)) {
assert(0 && "what?");
llvm_unreachable("what?");
} else if (write) {
rewriter.setInsertionPoint(newWrapper.getBody()->getTerminator());
cloned = rewriter.clone(*op, splitMapping)->getResults();
Expand Down Expand Up @@ -1210,7 +1210,7 @@ struct HandleWrapperRootOps : public OpRewritePattern<polygeist::GPUWrapperOp> {
}
}
} else {
assert(0 && "are there other effects?");
llvm_unreachable("are there other effects?");
}
rewriter.replaceOpWithIf(op, cloned, [&](OpOperand &use) {
Operation *owner = use.getOwner();
Expand Down Expand Up @@ -1339,7 +1339,8 @@ struct RemovePolygeistNoopOp : public OpRewritePattern<polygeist::NoopOp> {
"must belong to the same parallel op");
threadIndices.push_back(blockArg.getArgNumber());
} else {
assert(0 && "noop block arg operands must be scf parallel op args");
llvm_unreachable(
"noop block arg operands must be scf parallel op args");
}
} else {
auto cst = getConstantInteger(operand);
Expand Down Expand Up @@ -1583,7 +1584,7 @@ struct ParallelToGPULaunch : public OpRewritePattern<polygeist::GPUWrapperOp> {
return gpu::Dimension::y;
if (index == 2)
return gpu::Dimension::z;
assert(0 && "Invalid index");
llvm_unreachable("Invalid index");
return gpu::Dimension::z;
};

Expand Down
2 changes: 1 addition & 1 deletion lib/polygeist/Passes/LowerAlternatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ struct LowerAlternativesPass
funcName = funcOp.getName();
funcName += ".func";
} else {
assert(0 && "How?");
llvm_unreachable("How?");
}
if (num.count(funcName) == 0)
num[funcName] = 0;
Expand Down
4 changes: 2 additions & 2 deletions lib/polygeist/Passes/ParallelLoopDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,12 @@ static bool hasNestedBarrier(Operation *op, SmallVector<BlockArgument> &vals) {
if (parallel->isAncestor(op))
vals.push_back(ba);
} else {
assert(0 && "unknown barrier arg\n");
llvm_unreachable("unknown barrier arg\n");
}
} else if (arg.getDefiningOp<ConstantIndexOp>())
continue;
else {
assert(0 && "unknown barrier arg\n");
llvm_unreachable("unknown barrier arg\n");
}
}
});
Expand Down
4 changes: 2 additions & 2 deletions lib/polygeist/Passes/ParallelLoopUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ LogicalResult mlir::polygeist::scfParallelUnrollByFactor(
int64_t ubCst = ubCstOp.value();
int64_t stepCst = stepCstOp.value();
if (!(lbCst == 0 && ubCst >= 0 && stepCst == 1)) {
assert(0 && "expected positive loop bounds and step");
llvm_unreachable("expected positive loop bounds and step");
return failure();
}
int64_t upperBoundRem = mlir::mod(ubCst, unrollFactor);
Expand All @@ -355,7 +355,7 @@ LogicalResult mlir::polygeist::scfParallelUnrollByFactor(
int64_t lbCst = lbCstOp.value();
int64_t stepCst = stepCstOp.value();
if (!(lbCst == 0 && stepCst == 1)) {
assert(0 && "expected positive loop bounds and step");
llvm_unreachable("expected positive loop bounds and step");
return failure();
}
// auto lowerBound = pop.getLowerBound()[dim];
Expand Down
Loading

0 comments on commit 62f8635

Please sign in to comment.