Skip to content

Commit

Permalink
Fix pointer mismatches
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanradanov committed Sep 14, 2023
1 parent 510a866 commit ae7fcb2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
15 changes: 7 additions & 8 deletions lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ struct Memref2PointerOpLowering
auto space0 = op.getSource().getType().getMemorySpaceAsInt();
if (transformed.getSource().getType().isa<LLVM::LLVMPointerType>()) {
mlir::Value ptr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(LPT.getElementType(), space0),
loc, LLVM::LLVMPointerType::get(op.getContext(), space0),
transformed.getSource());
if (space0 != LPT.getAddressSpace())
ptr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, LPT, ptr);
Expand All @@ -262,7 +262,7 @@ struct Memref2PointerOpLowering
Value idxs[] = {baseOffset};
ptr = rewriter.create<LLVM::GEPOp>(loc, ptr.getType(), ptr, idxs);
ptr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(LPT.getElementType(), space0), ptr);
loc, LLVM::LLVMPointerType::get(op.getContext(), space0), ptr);
if (space0 != LPT.getAddressSpace())
ptr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, LPT, ptr);

Expand Down Expand Up @@ -988,10 +988,9 @@ struct CAllocOpLowering : public AllocLikeOpLowering<memref::AllocOp> {
innerSizes));
}
Value null = rewriter.create<LLVM::NullOp>(loc, convertedType);
auto next =
rewriter.create<LLVM::GEPOp>(loc, convertedType, null, LLVM::GEPArg(1));
Value elementSize =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), next);
Value elementSize = rewriter.create<polygeist::TypeSizeOp>(
loc, rewriter.getIndexType(),
mlir::TypeAttr::get(originalType.getElementType()));
Value size = rewriter.create<LLVM::MulOp>(loc, totalSize, elementSize);

if (auto F = module.lookupSymbol<mlir::func::FuncOp>("malloc")) {
Expand Down Expand Up @@ -2748,7 +2747,7 @@ struct ConvertPolygeistToLLVMPass
return Type();

if (type.getRank() == 0) {
return LLVM::LLVMPointerType::get(converted,
return LLVM::LLVMPointerType::get(type.getContext(),
type.getMemorySpaceAsInt());
}

Expand All @@ -2766,7 +2765,7 @@ struct ConvertPolygeistToLLVMPass
for (int64_t size : llvm::reverse(type.getShape().drop_front()))
converted = LLVM::LLVMArrayType::get(converted, size);
}
return LLVM::LLVMPointerType::get(converted,
return LLVM::LLVMPointerType::get(type.getContext(),
type.getMemorySpaceAsInt());
});
}
Expand Down
13 changes: 9 additions & 4 deletions tools/cgeist/Lib/CGCall.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ mlir::Value MLIRScanner::getLLVM(Expr *E, bool isRef) {
if (auto mt = val.getType().dyn_cast<MemRefType>()) {
val =
builder.create<polygeist::Memref2PointerOp>(loc, getOpaquePtr(), val);
} else if (auto pt = val.getType().dyn_cast<LLVM::LLVMPointerType>()) {
if (!pt.isOpaque())
val = builder.create<LLVM::BitcastOp>(loc, getOpaquePtr(), val);
}
return val;
}
Expand Down Expand Up @@ -437,8 +440,10 @@ mlir::Value MLIRScanner::getLLVM(Expr *E, bool isRef) {
ct = Glob.CGM.getContext().getLValueReferenceType(E->getType());
}
if (auto mt = val.getType().dyn_cast<MemRefType>()) {
val = builder.create<polygeist::Memref2PointerOp>(
loc, LLVM::LLVMPointerType::get(builder.getContext()), val);
val = builder.create<polygeist::Memref2PointerOp>(loc, getOpaquePtr(), val);
} else if (auto pt = val.getType().dyn_cast<LLVM::LLVMPointerType>()) {
if (!pt.isOpaque())
val = builder.create<LLVM::BitcastOp>(loc, getOpaquePtr(), val);
}
return val;
}
Expand Down Expand Up @@ -507,8 +512,8 @@ MLIRScanner::EmitClangBuiltinCallExpr(clang::CallExpr *expr) {
if (toDelete.getType().isa<mlir::MemRefType>()) {
builder.create<mlir::memref::DeallocOp>(loc, toDelete);
} else {
mlir::Value args[1] = {builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(builder.getI8Type()), toDelete)};
mlir::Value args[1] = {
builder.create<LLVM::BitcastOp>(loc, getOpaquePtr(), toDelete)};
builder.create<mlir::LLVM::CallOp>(loc, Glob.GetOrCreateFreeFunction(),
args);
}
Expand Down
8 changes: 4 additions & 4 deletions tools/cgeist/Lib/clang-mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1373,8 +1373,8 @@ ValueCategory MLIRScanner::VisitCXXDeleteExpr(clang::CXXDeleteExpr *expr) {
if (toDelete.getType().isa<mlir::MemRefType>()) {
builder.create<mlir::memref::DeallocOp>(loc, toDelete);
} else {
mlir::Value args[1] = {builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(builder.getI8Type()), toDelete)};
mlir::Value args[1] = {
builder.create<LLVM::BitcastOp>(loc, getOpaquePtr(), toDelete)};
builder.create<mlir::LLVM::CallOp>(loc, Glob.GetOrCreateFreeFunction(),
args);
}
Expand Down Expand Up @@ -1907,8 +1907,8 @@ MLIRScanner::EmitGPUCallExpr(clang::CallExpr *expr) {
if (arg.getType().isa<mlir::LLVM::LLVMPointerType>()) {
auto callee = EmitCallee(expr->getCallee());
auto strcmpF = Glob.GetOrCreateLLVMFunction(callee);
mlir::Value args[] = {builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(builder.getIntegerType(8)), arg)};
mlir::Value args[] = {
builder.create<LLVM::BitcastOp>(loc, getOpaquePtr(), arg)};
builder.create<mlir::LLVM::CallOp>(loc, strcmpF, args);
} else {
assert(arg.getType().isa<MemRefType>());
Expand Down

0 comments on commit ae7fcb2

Please sign in to comment.