Skip to content

Commit 03cbd6f

Browse files
small refactoring for mlir backend
1 parent c5f2bbf commit 03cbd6f

File tree

7 files changed

+14
-200
lines changed

7 files changed

+14
-200
lines changed

nautilus/src/nautilus/compiler/backends/mlir/JITCompiler.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ std::unique_ptr<::mlir::ExecutionEngine> JITCompiler::jitCompileModule(::mlir::O
2727
LLVMInitializeNativeTarget();
2828
LLVMInitializeNativeAsmPrinter();
2929

30-
//(void) dumpHelper;
31-
// if (compilerOptions.isDumpToConsole() || compilerOptions.isDumpToFile()) {
32-
// dumpLLVMIR(mlirModule.get(), compilerOptions, dumpHelper);
33-
// }
34-
3530
// Create MLIR execution engine (wrapper around LLVM ExecutionEngine).
3631
::mlir::ExecutionEngineOptions options;
3732
options.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive;
@@ -43,7 +38,6 @@ std::unique_ptr<::mlir::ExecutionEngine> JITCompiler::jitCompileModule(::mlir::O
4338
// We register all external functions (symbols) that we do not inline.
4439
const auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
4540
auto symbolMap = llvm::orc::SymbolMap();
46-
4741
for (int i = 0; i < (int) jitProxyFunctionSymbols.size(); ++i) {
4842
auto address = jitProxyFunctionTargetAddresses.at(i);
4943
symbolMap[interner(jitProxyFunctionSymbols.at(i))] = {llvm::orc::ExecutorAddr::fromPtr(address), llvm::JITSymbolFlags::Exported};

nautilus/src/nautilus/compiler/backends/mlir/LLVMIROptimizer.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ int getOptimizationLevel(const engine::Options& options) {
1515
return options.getOptionOrDefault("mlir.optimizationLevel", 3);
1616
}
1717

18+
LLVMIROptimizer::LLVMIROptimizer() = default;
19+
LLVMIROptimizer::~LLVMIROptimizer() = default;
20+
1821
std::function<llvm::Error(llvm::Module*)> LLVMIROptimizer::getLLVMOptimizerPipeline(const engine::Options& options, const DumpHandler& handler) {
1922
// Return LLVM optimizer pipeline.
2023
return [options, handler](llvm::Module* llvmIRModule) {
@@ -23,8 +26,6 @@ std::function<llvm::Error(llvm::Module*)> LLVMIROptimizer::getLLVMOptimizerPipel
2326
constexpr int SIZE_LEVEL = 0;
2427
// Create A target-specific target machine for the host
2528
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
26-
// NES_ASSERT2_FMT(tmBuilderOrError, "Failed to create a
27-
// JITTargetMachineBuilder for the host");
2829
auto targetMachine = tmBuilderOrError->createTargetMachine();
2930
llvm::TargetMachine* targetMachinePtr = targetMachine->get();
3031
targetMachinePtr->setOptLevel(llvm::CodeGenOptLevel::Aggressive);
@@ -35,17 +36,6 @@ std::function<llvm::Error(llvm::Module*)> LLVMIROptimizer::getLLVMOptimizerPipel
3536
llvmIRModule->getFunction("execute")->addAttributeAtIndex(~0, llvm::Attribute::get(llvmIRModule->getContext(), "tune-cpu", targetMachinePtr->getTargetCPU()));
3637
llvm::SMDiagnostic Err;
3738

38-
// Load LLVM IR module from proxy inlining input path (We assert that it
39-
// exists in CompilationOptions). if (options.isProxyInlining()) {
40-
// auto proxyFunctionsIR =
41-
// llvm::parseIRFile(options.getProxyInliningInputPath(), Err,
42-
// llvmIRModule->getContext());
43-
// Link the module with our generated LLVM IR module and optimize the linked
44-
// LLVM IR module (inlining happens during optimization).
45-
// llvm::Linker::linkModules(*llvmIRModule, std::move(proxyFunctionsIR),
46-
// llvm::Linker::Flags::OverrideFromSrc);
47-
// }
48-
4939
auto optPipeline = ::mlir::makeOptimizingTransformer(getOptimizationLevel(options), SIZE_LEVEL, targetMachinePtr);
5040
auto optimizedModule = optPipeline(llvmIRModule);
5141

nautilus/src/nautilus/compiler/backends/mlir/MLIRLoweringProvider.cpp

Lines changed: 9 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ ::mlir::Type MLIRLoweringProvider::getMLIRType(Type type) {
5151
case Type::ptr:
5252
return mlir::LLVM::LLVMPointerType::get(context);
5353
}
54-
5554
throw NotImplementedException("No matching type for stamp ");
5655
}
5756

58-
std::vector<mlir::Type> MLIRLoweringProvider::getMLIRType(std::vector<ir::Operation*> types) {
57+
std::vector<mlir::Type> MLIRLoweringProvider::getMLIRType(const std::vector<ir::Operation*>& types) {
5958
std::vector<mlir::Type> resultTypes;
6059
for (auto& type : types) {
6160
resultTypes.push_back(getMLIRType(type->getStamp()));
@@ -72,7 +71,6 @@ mlir::Value MLIRLoweringProvider::getConstBool(const std::string& location, bool
7271
return builder->create<mlir::LLVM::ConstantOp>(getNameLoc(location), builder->getI1Type(), builder->getIntegerAttr(builder->getIndexType(), value));
7372
}
7473

75-
// Todo Issue #3004: Currently, we are simply adding 'Query_1' as the
7674
// FileLineLoc name. Moreover,
7775
// the provided 'name' often is not meaningful either.
7876
mlir::Location MLIRLoweringProvider::getNameLoc(const std::string& name) {
@@ -182,7 +180,7 @@ mlir::arith::CmpIPredicate convertToBooleanMLIRComparison(ir::CompareOperation::
182180
}
183181
}
184182

185-
mlir::FlatSymbolRefAttr MLIRLoweringProvider::insertExternalFunction(const std::string& name, void* functionPtr, mlir::Type resultType, std::vector<mlir::Type> argTypes, bool varArgs) {
183+
mlir::FlatSymbolRefAttr MLIRLoweringProvider::insertExternalFunction(const std::string& name, void* functionPtr, const mlir::Type& resultType, const std::vector<mlir::Type>& argTypes, bool varArgs) {
186184
// Create function arg & result types (currently only int for result).
187185
mlir::LLVM::LLVMFunctionType llvmFnType = mlir::LLVM::LLVMFunctionType::get(resultType, argTypes, varArgs);
188186

@@ -241,7 +239,6 @@ void MLIRLoweringProvider::generateMLIR(const ir::BasicBlock* basicBlock, ValueF
241239
void MLIRLoweringProvider::generateMLIR(const std::unique_ptr<ir::Operation>& operation, ValueFrame& frame) {
242240
switch (operation->getOperationType()) {
243241
case ir::Operation::OperationType::FunctionOp:
244-
// generateMLIR(as<ir::FunctionOperation>(operation), frame);
245242
break;
246243
case ir::Operation::OperationType::ConstIntOp:
247244
generateMLIR(as<ir::ConstIntOperation>(operation), frame);
@@ -336,24 +333,14 @@ void MLIRLoweringProvider::generateMLIR(ir::OrOperation* orOperation, ValueFrame
336333
auto leftInput = frame.getValue(orOperation->getLeftInput()->getIdentifier());
337334
auto rightInput = frame.getValue(orOperation->getRightInput()->getIdentifier());
338335
auto mlirOrOp = builder->create<mlir::LLVM::OrOp>(getNameLoc("binOpResult"), leftInput, rightInput);
339-
frame.setValue(orOperation->
340-
341-
getIdentifier(),
342-
mlirOrOp
343-
344-
);
336+
frame.setValue(orOperation->getIdentifier(), mlirOrOp);
345337
}
346338

347339
void MLIRLoweringProvider::generateMLIR(ir::AndOperation* andOperation, ValueFrame& frame) {
348340
auto leftInput = frame.getValue(andOperation->getLeftInput()->getIdentifier());
349341
auto rightInput = frame.getValue(andOperation->getRightInput()->getIdentifier());
350342
auto mlirAndOp = builder->create<mlir::LLVM::AndOp>(getNameLoc("binOpResult"), leftInput, rightInput);
351-
frame.setValue(andOperation->
352-
353-
getIdentifier(),
354-
mlirAndOp
355-
356-
);
343+
frame.setValue(andOperation->getIdentifier(), mlirAndOp);
357344
}
358345

359346
void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp, ValueFrame& frame) {
@@ -363,7 +350,6 @@ void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp,
363350
inputTypes.emplace_back(getMLIRType(inputArg->getStamp()));
364351
}
365352
llvm::SmallVector<mlir::Type> outputTypes(1, getMLIRType(functionOp.getOutputArg()));
366-
;
367353
auto functionInOutTypes = builder->getFunctionType(inputTypes, outputTypes);
368354
auto loc = getNameLoc("EntryPoint");
369355
auto mlirFunction = builder->create<mlir::func::FuncOp>(loc, functionOp.getName(), functionInOutTypes);
@@ -375,30 +361,16 @@ void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp,
375361
} else if (isSignedInteger(functionOp.getStamp())) {
376362
mlirFunction.setResultAttr(0, "llvm.signext", mlir::UnitAttr::get(context));
377363
}
378-
// mlirFunction.setArgAttr(0, "llvm.signext", mlir::UnitAttr::get(context));
379364

380-
mlirFunction.
381-
382-
addEntryBlock();
365+
mlirFunction.addEntryBlock();
383366

384367
// Set InsertPoint to beginning of the execute function.
385-
builder->setInsertionPointToStart(&mlirFunction
386-
.
387-
388-
getBody()
389-
390-
.
391-
392-
front()
393-
394-
);
368+
builder->setInsertionPointToStart(&mlirFunction.getBody().front());
395369

396370
// Store references to function args in the valueMap map.
397371
auto valueMapIterator = mlirFunction.args_begin();
398372
for (int i = 0; i < (int) functionOp.getFunctionBasicBlock().getArguments().size(); ++i) {
399-
frame.setValue(functionOp.getFunctionBasicBlock().getArguments().at(i)->getIdentifier(), valueMapIterator[i]
400-
401-
);
373+
frame.setValue(functionOp.getFunctionBasicBlock().getArguments().at(i)->getIdentifier(), valueMapIterator[i]);
402374
}
403375

404376
// Generate MLIR for operations in function body (BasicBlock).
@@ -408,27 +380,17 @@ void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp,
408380
}
409381

410382
void MLIRLoweringProvider::generateMLIR(ir::LoadOperation* loadOp, ValueFrame& frame) {
411-
412383
auto address = frame.getValue(loadOp->getAddress()->getIdentifier());
413-
414-
// auto bitcast = builder->create<mlir::LLVM::BitcastOp>(getNameLoc("Bitcasted
415-
// address"),
416-
// mlir::LLVM::LLVMPointerType::get(context),
417-
// address);
418384
auto mlirLoadOp = builder->create<mlir::LLVM::LoadOp>(getNameLoc("loadedValue"), getMLIRType(loadOp->getStamp()), address);
419385
frame.setValue(loadOp->getIdentifier(), mlirLoadOp);
420386
}
421387

422388
void MLIRLoweringProvider::generateMLIR(ir::ConstIntOperation* constIntOp, ValueFrame& frame) {
423-
if (!frame.contains(constIntOp->getIdentifier())) {
424-
frame.setValue(constIntOp->getIdentifier(), getConstInt("ConstantOp", constIntOp->getStamp(), constIntOp->getValue()));
425-
} else {
426-
frame.setValue(constIntOp->getIdentifier(), getConstInt("ConstantOp", constIntOp->getStamp(), constIntOp->getValue()));
427-
}
389+
frame.setValue(constIntOp->getIdentifier(), getConstInt("ConstantOp", constIntOp->getStamp(), constIntOp->getValue()));
428390
}
429391

430392
void MLIRLoweringProvider::generateMLIR(ir::ConstPtrOperation* constPtr, ValueFrame& frame) {
431-
int64_t val = (int64_t) constPtr->getValue();
393+
auto val = (int64_t) constPtr->getValue();
432394
auto constInt = builder->create<mlir::arith::ConstantOp>(getNameLoc("location"), builder->getI64Type(), builder->getIntegerAttr(builder->getI64Type(), val));
433395
auto elementAddress = builder->create<mlir::LLVM::IntToPtrOp>(getNameLoc("fieldAccess"), mlir::LLVM::LLVMPointerType::get(context), constInt);
434396
frame.setValue(constPtr->getIdentifier(), elementAddress);
@@ -451,7 +413,6 @@ void MLIRLoweringProvider::generateMLIR(ir::AddOperation* addOp, ValueFrame& fra
451413
// if we add something to a ptr we have to use a llvm getelementptr
452414
mlir::Value elementAddress = builder->create<mlir::LLVM::GEPOp>(getNameLoc("fieldAccess"), mlir::LLVM::LLVMPointerType::get(context), builder->getI8Type(), leftInput, mlir::ArrayRef<mlir::Value>({rightInput}));
453415
frame.setValue(addOp->getIdentifier(), elementAddress);
454-
455416
} else if (isFloat(addOp->getStamp())) {
456417
auto mlirAddOp = builder->create<mlir::LLVM::FAddOp>(getNameLoc("binOpResult"), leftInput.getType(), leftInput, rightInput, mlir::LLVM::FastmathFlags::fast);
457418
frame.setValue(addOp->getIdentifier(), mlirAddOp);
@@ -475,7 +436,6 @@ void MLIRLoweringProvider::generateMLIR(ir::SubOperation* subIntOp, ValueFrame&
475436
// if we add something to a ptr we have to use a llvm getelementptr
476437
mlir::Value elementAddress = builder->create<mlir::LLVM::GEPOp>(getNameLoc("fieldAccess"), mlir::LLVM::LLVMPointerType::get(context), builder->getI8Type(), leftInput, mlir::ArrayRef<mlir::Value>({rightInput}));
477438
frame.setValue(subIntOp->getIdentifier(), elementAddress);
478-
479439
} else if (isFloat(subIntOp->getStamp())) {
480440
auto mlirSubOp = builder->create<mlir::LLVM::FSubOp>(getNameLoc("binOpResult"), leftInput, rightInput, mlir::LLVM::FastmathFlagsAttr::get(context, mlir::LLVM::FastmathFlags::fast));
481441
frame.setValue(subIntOp->getIdentifier(), mlirSubOp);
@@ -576,18 +536,6 @@ void MLIRLoweringProvider::generateMLIR(ir::CompareOperation* compareOp, ValueFr
576536
if ((isInteger(leftStamp) && isFloat(rightStamp)) || ((isInteger(rightStamp) && isFloat(leftStamp)))) {
577537
// Avoid comparing integer to float
578538
throw NotImplementedException("Type missmatch: cannot compare");
579-
} else if (compareOp->getComparator() == ir::CompareOperation::EQ && compareOp->getLeftInput()->getStamp() == Type::ptr && isInteger(compareOp->getRightInput()->getStamp())) {
580-
// add null check
581-
throw NotImplementedException("Null check is not implemented");
582-
// auto null =
583-
// builder->create<mlir::LLVM::NullOp>(getNameLoc("null"),
584-
// mlir::LLVM::LLVMPointerType::get(context));
585-
// auto cmpOp =
586-
// builder->create<mlir::LLVM::ICmpOp>(getNameLoc("comparison"),
587-
// mlir::LLVM::ICmpPredicate::eq,
588-
// frame.getValue(compareOp->getLeftInput()->getIdentifier()),
589-
// null);
590-
// frame.setValue(compareOp->getIdentifier(), cmpOp);
591539
} else if (isInteger(leftStamp) && isInteger(rightStamp)) {
592540
// handle integer
593541
auto cmpOp = builder->create<mlir::arith::CmpIOp>(getNameLoc("comparison"), convertToIntMLIRComparison(compareOp->getComparator(), leftStamp), frame.getValue(compareOp->getLeftInput()->getIdentifier()),
@@ -768,7 +716,6 @@ void MLIRLoweringProvider::generateMLIR(ir::CastOperation* castOperation, MLIRLo
768716
void MLIRLoweringProvider::generateMLIR(ir::BinaryCompOperation* binaryCompOperation, nautilus::compiler::mlir::MLIRLoweringProvider::ValueFrame& frame) {
769717
auto leftInput = frame.getValue(binaryCompOperation->getLeftInput()->getIdentifier());
770718
auto rightInput = frame.getValue(binaryCompOperation->getRightInput()->getIdentifier());
771-
772719
mlir::Value op;
773720
switch (binaryCompOperation->getType()) {
774721
case ir::BinaryCompOperation::BAND:
@@ -787,7 +734,6 @@ void MLIRLoweringProvider::generateMLIR(ir::BinaryCompOperation* binaryCompOpera
787734
void MLIRLoweringProvider::generateMLIR(ir::ShiftOperation* shiftOperation, nautilus::compiler::mlir::MLIRLoweringProvider::ValueFrame& frame) {
788735
auto leftInput = frame.getValue(shiftOperation->getLeftInput()->getIdentifier());
789736
auto rightInput = frame.getValue(shiftOperation->getRightInput()->getIdentifier());
790-
791737
mlir::Value op;
792738
switch (shiftOperation->getType()) {
793739
case ir::ShiftOperation::LS:

nautilus/src/nautilus/compiler/backends/mlir/MLIRLoweringProvider.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class MLIRLoweringProvider {
147147
* @param varArgs: Include variable arguments.
148148
* @return FlatSymbolRefAttr: Reference to function used in CallOps.
149149
*/
150-
::mlir::FlatSymbolRefAttr insertExternalFunction(const std::string& name, void* functionPtr, ::mlir::Type resultType, std::vector<::mlir::Type> argTypes, bool varArgs);
150+
::mlir::FlatSymbolRefAttr insertExternalFunction(const std::string& name, void* functionPtr, const ::mlir::Type& resultType, const std::vector<::mlir::Type>& argTypes, bool varArgs);
151151

152152
/**
153153
* @brief Generates a Name(d)Loc(ation) that is attached to the operation.
@@ -167,7 +167,7 @@ class MLIRLoweringProvider {
167167
* @param types: Vector of basic types.
168168
* @return mlir::Type: Vector of MLIR types.
169169
*/
170-
std::vector<::mlir::Type> getMLIRType(std::vector<ir::Operation*> types);
170+
std::vector<::mlir::Type> getMLIRType(const std::vector<ir::Operation*>& types);
171171

172172
/**
173173
* @brief Get a constant MLIR Integer.

nautilus/src/nautilus/compiler/backends/mlir/MLIRPassManager.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@ namespace nautilus::compiler::mlir {
1010
class MLIRPassManager {
1111
public:
1212
enum class OptimizationPass : uint8_t { Inline };
13-
14-
MLIRPassManager(); // Disable default constructor
15-
~MLIRPassManager(); // Disable default destructor
16-
1713
static int lowerAndOptimizeMLIRModule(::mlir::OwningOpRef<::mlir::ModuleOp>& module, const std::vector<OptimizationPass>& optimizationPasses);
1814
};
1915
} // namespace nautilus::compiler::mlir

nautilus/src/nautilus/compiler/backends/mlir/MLIRUtility.cpp

Lines changed: 0 additions & 68 deletions
This file was deleted.

nautilus/src/nautilus/compiler/backends/mlir/MLIRUtility.hpp

Lines changed: 0 additions & 44 deletions
This file was deleted.

0 commit comments

Comments
 (0)