@@ -51,11 +51,10 @@ ::mlir::Type MLIRLoweringProvider::getMLIRType(Type type) {
51
51
case Type::ptr:
52
52
return mlir::LLVM::LLVMPointerType::get (context);
53
53
}
54
-
55
54
throw NotImplementedException (" No matching type for stamp " );
56
55
}
57
56
58
- std::vector<mlir::Type> MLIRLoweringProvider::getMLIRType (std::vector<ir::Operation*> types) {
57
+ std::vector<mlir::Type> MLIRLoweringProvider::getMLIRType (const std::vector<ir::Operation*>& types) {
59
58
std::vector<mlir::Type> resultTypes;
60
59
for (auto & type : types) {
61
60
resultTypes.push_back (getMLIRType (type->getStamp ()));
@@ -72,7 +71,6 @@ mlir::Value MLIRLoweringProvider::getConstBool(const std::string& location, bool
72
71
return builder->create <mlir::LLVM::ConstantOp>(getNameLoc (location), builder->getI1Type (), builder->getIntegerAttr (builder->getIndexType (), value));
73
72
}
74
73
75
- // Todo Issue #3004: Currently, we are simply adding 'Query_1' as the
76
74
// FileLineLoc name. Moreover,
77
75
// the provided 'name' often is not meaningful either.
78
76
mlir::Location MLIRLoweringProvider::getNameLoc (const std::string& name) {
@@ -182,7 +180,7 @@ mlir::arith::CmpIPredicate convertToBooleanMLIRComparison(ir::CompareOperation::
182
180
}
183
181
}
184
182
185
- mlir::FlatSymbolRefAttr MLIRLoweringProvider::insertExternalFunction (const std::string& name, void * functionPtr, mlir::Type resultType, std::vector<mlir::Type> argTypes, bool varArgs) {
183
+ mlir::FlatSymbolRefAttr MLIRLoweringProvider::insertExternalFunction (const std::string& name, void * functionPtr, const mlir::Type& resultType, const std::vector<mlir::Type>& argTypes, bool varArgs) {
186
184
// Create function arg & result types (currently only int for result).
187
185
mlir::LLVM::LLVMFunctionType llvmFnType = mlir::LLVM::LLVMFunctionType::get (resultType, argTypes, varArgs);
188
186
@@ -241,7 +239,6 @@ void MLIRLoweringProvider::generateMLIR(const ir::BasicBlock* basicBlock, ValueF
241
239
void MLIRLoweringProvider::generateMLIR (const std::unique_ptr<ir::Operation>& operation, ValueFrame& frame) {
242
240
switch (operation->getOperationType ()) {
243
241
case ir::Operation::OperationType::FunctionOp:
244
- // generateMLIR(as<ir::FunctionOperation>(operation), frame);
245
242
break ;
246
243
case ir::Operation::OperationType::ConstIntOp:
247
244
generateMLIR (as<ir::ConstIntOperation>(operation), frame);
@@ -336,24 +333,14 @@ void MLIRLoweringProvider::generateMLIR(ir::OrOperation* orOperation, ValueFrame
336
333
auto leftInput = frame.getValue (orOperation->getLeftInput ()->getIdentifier ());
337
334
auto rightInput = frame.getValue (orOperation->getRightInput ()->getIdentifier ());
338
335
auto mlirOrOp = builder->create <mlir::LLVM::OrOp>(getNameLoc (" binOpResult" ), leftInput, rightInput);
339
- frame.setValue (orOperation->
340
-
341
- getIdentifier (),
342
- mlirOrOp
343
-
344
- );
336
+ frame.setValue (orOperation->getIdentifier (), mlirOrOp);
345
337
}
346
338
347
339
void MLIRLoweringProvider::generateMLIR (ir::AndOperation* andOperation, ValueFrame& frame) {
348
340
auto leftInput = frame.getValue (andOperation->getLeftInput ()->getIdentifier ());
349
341
auto rightInput = frame.getValue (andOperation->getRightInput ()->getIdentifier ());
350
342
auto mlirAndOp = builder->create <mlir::LLVM::AndOp>(getNameLoc (" binOpResult" ), leftInput, rightInput);
351
- frame.setValue (andOperation->
352
-
353
- getIdentifier (),
354
- mlirAndOp
355
-
356
- );
343
+ frame.setValue (andOperation->getIdentifier (), mlirAndOp);
357
344
}
358
345
359
346
void MLIRLoweringProvider::generateMLIR (const ir::FunctionOperation& functionOp, ValueFrame& frame) {
@@ -363,7 +350,6 @@ void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp,
363
350
inputTypes.emplace_back (getMLIRType (inputArg->getStamp ()));
364
351
}
365
352
llvm::SmallVector<mlir::Type> outputTypes (1 , getMLIRType (functionOp.getOutputArg ()));
366
- ;
367
353
auto functionInOutTypes = builder->getFunctionType (inputTypes, outputTypes);
368
354
auto loc = getNameLoc (" EntryPoint" );
369
355
auto mlirFunction = builder->create <mlir::func::FuncOp>(loc, functionOp.getName (), functionInOutTypes);
@@ -375,30 +361,16 @@ void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp,
375
361
} else if (isSignedInteger (functionOp.getStamp ())) {
376
362
mlirFunction.setResultAttr (0 , " llvm.signext" , mlir::UnitAttr::get (context));
377
363
}
378
- // mlirFunction.setArgAttr(0, "llvm.signext", mlir::UnitAttr::get(context));
379
364
380
- mlirFunction.
381
-
382
- addEntryBlock ();
365
+ mlirFunction.addEntryBlock ();
383
366
384
367
// Set InsertPoint to beginning of the execute function.
385
- builder->setInsertionPointToStart (&mlirFunction
386
- .
387
-
388
- getBody ()
389
-
390
- .
391
-
392
- front ()
393
-
394
- );
368
+ builder->setInsertionPointToStart (&mlirFunction.getBody ().front ());
395
369
396
370
// Store references to function args in the valueMap map.
397
371
auto valueMapIterator = mlirFunction.args_begin ();
398
372
for (int i = 0 ; i < (int ) functionOp.getFunctionBasicBlock ().getArguments ().size (); ++i) {
399
- frame.setValue (functionOp.getFunctionBasicBlock ().getArguments ().at (i)->getIdentifier (), valueMapIterator[i]
400
-
401
- );
373
+ frame.setValue (functionOp.getFunctionBasicBlock ().getArguments ().at (i)->getIdentifier (), valueMapIterator[i]);
402
374
}
403
375
404
376
// Generate MLIR for operations in function body (BasicBlock).
@@ -408,27 +380,17 @@ void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp,
408
380
}
409
381
410
382
void MLIRLoweringProvider::generateMLIR (ir::LoadOperation* loadOp, ValueFrame& frame) {
411
-
412
383
auto address = frame.getValue (loadOp->getAddress ()->getIdentifier ());
413
-
414
- // auto bitcast = builder->create<mlir::LLVM::BitcastOp>(getNameLoc("Bitcasted
415
- // address"),
416
- // mlir::LLVM::LLVMPointerType::get(context),
417
- // address);
418
384
auto mlirLoadOp = builder->create <mlir::LLVM::LoadOp>(getNameLoc (" loadedValue" ), getMLIRType (loadOp->getStamp ()), address);
419
385
frame.setValue (loadOp->getIdentifier (), mlirLoadOp);
420
386
}
421
387
422
388
void MLIRLoweringProvider::generateMLIR (ir::ConstIntOperation* constIntOp, ValueFrame& frame) {
423
- if (!frame.contains (constIntOp->getIdentifier ())) {
424
- frame.setValue (constIntOp->getIdentifier (), getConstInt (" ConstantOp" , constIntOp->getStamp (), constIntOp->getValue ()));
425
- } else {
426
- frame.setValue (constIntOp->getIdentifier (), getConstInt (" ConstantOp" , constIntOp->getStamp (), constIntOp->getValue ()));
427
- }
389
+ frame.setValue (constIntOp->getIdentifier (), getConstInt (" ConstantOp" , constIntOp->getStamp (), constIntOp->getValue ()));
428
390
}
429
391
430
392
void MLIRLoweringProvider::generateMLIR (ir::ConstPtrOperation* constPtr, ValueFrame& frame) {
431
- int64_t val = (int64_t ) constPtr->getValue ();
393
+ auto val = (int64_t ) constPtr->getValue ();
432
394
auto constInt = builder->create <mlir::arith::ConstantOp>(getNameLoc (" location" ), builder->getI64Type (), builder->getIntegerAttr (builder->getI64Type (), val));
433
395
auto elementAddress = builder->create <mlir::LLVM::IntToPtrOp>(getNameLoc (" fieldAccess" ), mlir::LLVM::LLVMPointerType::get (context), constInt);
434
396
frame.setValue (constPtr->getIdentifier (), elementAddress);
@@ -451,7 +413,6 @@ void MLIRLoweringProvider::generateMLIR(ir::AddOperation* addOp, ValueFrame& fra
451
413
// if we add something to a ptr we have to use a llvm getelementptr
452
414
mlir::Value elementAddress = builder->create <mlir::LLVM::GEPOp>(getNameLoc (" fieldAccess" ), mlir::LLVM::LLVMPointerType::get (context), builder->getI8Type (), leftInput, mlir::ArrayRef<mlir::Value>({rightInput}));
453
415
frame.setValue (addOp->getIdentifier (), elementAddress);
454
-
455
416
} else if (isFloat (addOp->getStamp ())) {
456
417
auto mlirAddOp = builder->create <mlir::LLVM::FAddOp>(getNameLoc (" binOpResult" ), leftInput.getType (), leftInput, rightInput, mlir::LLVM::FastmathFlags::fast);
457
418
frame.setValue (addOp->getIdentifier (), mlirAddOp);
@@ -475,7 +436,6 @@ void MLIRLoweringProvider::generateMLIR(ir::SubOperation* subIntOp, ValueFrame&
475
436
// if we add something to a ptr we have to use a llvm getelementptr
476
437
mlir::Value elementAddress = builder->create <mlir::LLVM::GEPOp>(getNameLoc (" fieldAccess" ), mlir::LLVM::LLVMPointerType::get (context), builder->getI8Type (), leftInput, mlir::ArrayRef<mlir::Value>({rightInput}));
477
438
frame.setValue (subIntOp->getIdentifier (), elementAddress);
478
-
479
439
} else if (isFloat (subIntOp->getStamp ())) {
480
440
auto mlirSubOp = builder->create <mlir::LLVM::FSubOp>(getNameLoc (" binOpResult" ), leftInput, rightInput, mlir::LLVM::FastmathFlagsAttr::get (context, mlir::LLVM::FastmathFlags::fast));
481
441
frame.setValue (subIntOp->getIdentifier (), mlirSubOp);
@@ -576,18 +536,6 @@ void MLIRLoweringProvider::generateMLIR(ir::CompareOperation* compareOp, ValueFr
576
536
if ((isInteger (leftStamp) && isFloat (rightStamp)) || ((isInteger (rightStamp) && isFloat (leftStamp)))) {
577
537
// Avoid comparing integer to float
578
538
throw NotImplementedException (" Type missmatch: cannot compare" );
579
- } else if (compareOp->getComparator () == ir::CompareOperation::EQ && compareOp->getLeftInput ()->getStamp () == Type::ptr && isInteger (compareOp->getRightInput ()->getStamp ())) {
580
- // add null check
581
- throw NotImplementedException (" Null check is not implemented" );
582
- // auto null =
583
- // builder->create<mlir::LLVM::NullOp>(getNameLoc("null"),
584
- // mlir::LLVM::LLVMPointerType::get(context));
585
- // auto cmpOp =
586
- // builder->create<mlir::LLVM::ICmpOp>(getNameLoc("comparison"),
587
- // mlir::LLVM::ICmpPredicate::eq,
588
- // frame.getValue(compareOp->getLeftInput()->getIdentifier()),
589
- // null);
590
- // frame.setValue(compareOp->getIdentifier(), cmpOp);
591
539
} else if (isInteger (leftStamp) && isInteger (rightStamp)) {
592
540
// handle integer
593
541
auto cmpOp = builder->create <mlir::arith::CmpIOp>(getNameLoc (" comparison" ), convertToIntMLIRComparison (compareOp->getComparator (), leftStamp), frame.getValue (compareOp->getLeftInput ()->getIdentifier ()),
@@ -768,7 +716,6 @@ void MLIRLoweringProvider::generateMLIR(ir::CastOperation* castOperation, MLIRLo
768
716
void MLIRLoweringProvider::generateMLIR (ir::BinaryCompOperation* binaryCompOperation, nautilus::compiler::mlir::MLIRLoweringProvider::ValueFrame& frame) {
769
717
auto leftInput = frame.getValue (binaryCompOperation->getLeftInput ()->getIdentifier ());
770
718
auto rightInput = frame.getValue (binaryCompOperation->getRightInput ()->getIdentifier ());
771
-
772
719
mlir::Value op;
773
720
switch (binaryCompOperation->getType ()) {
774
721
case ir::BinaryCompOperation::BAND:
@@ -787,7 +734,6 @@ void MLIRLoweringProvider::generateMLIR(ir::BinaryCompOperation* binaryCompOpera
787
734
void MLIRLoweringProvider::generateMLIR (ir::ShiftOperation* shiftOperation, nautilus::compiler::mlir::MLIRLoweringProvider::ValueFrame& frame) {
788
735
auto leftInput = frame.getValue (shiftOperation->getLeftInput ()->getIdentifier ());
789
736
auto rightInput = frame.getValue (shiftOperation->getRightInput ()->getIdentifier ());
790
-
791
737
mlir::Value op;
792
738
switch (shiftOperation->getType ()) {
793
739
case ir::ShiftOperation::LS:
0 commit comments