diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index 65103b68b3ac..d98d09044281 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -569,9 +569,9 @@ uint64_t RecordType::getElementOffset(const ::mlir::DataLayout &dataLayout, assert(idx < getNumElements()); auto members = getMembers(); - unsigned offset = 0; + unsigned offset = 0, recordSize = 0; - for (unsigned i = 0, e = idx; i != e; ++i) { + for (unsigned i = 0, e = idx; i != e + 1; ++i) { auto ty = members[i]; // This matches LLVM since it uses the ABI instead of preferred alignment. @@ -579,10 +579,12 @@ uint64_t RecordType::getElementOffset(const ::mlir::DataLayout &dataLayout, llvm::Align(getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty)); // Add padding if necessary to align the data element properly. - offset = llvm::alignTo(offset, tyAlign); + recordSize = llvm::alignTo(recordSize, tyAlign); + if (i == idx) + offset = recordSize; // Consume space for this data item - offset += dataLayout.getTypeSize(ty); + recordSize += dataLayout.getTypeSize(ty); } // Account for padding, if necessary, for the alignment of the field whose @@ -781,8 +783,8 @@ LongDoubleType::getTypeSizeInBits(const mlir::DataLayout &dataLayout, uint64_t LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout, mlir::DataLayoutEntryListRef params) const { - return mlir::cast(getUnderlying()).getABIAlignment( - dataLayout, params); + return mlir::cast(getUnderlying()) + .getABIAlignment(dataLayout, params); } //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 5d2b4180571a..1adbbded43d7 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "LowerToMLIRHelpers.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -35,8 +36,10 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" @@ -48,19 +51,17 @@ #include "mlir/Transforms/DialectConversion.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/IR/CIRTypes.h" +#include "clang/CIR/Interfaces/CIRLoopOpInterface.h" #include "clang/CIR/LowerToLLVM.h" #include "clang/CIR/LowerToMLIR.h" #include "clang/CIR/LoweringHelpers.h" #include "clang/CIR/Passes.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/ErrorHandling.h" -#include "clang/CIR/Interfaces/CIRLoopOpInterface.h" -#include "clang/CIR/LowerToLLVM.h" -#include "clang/CIR/Passes.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Value.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/TimeProfiler.h" using namespace cir; @@ -288,17 +289,17 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern { matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type mlirType = - convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType()); + mlir::Type allocaType = adaptor.getAllocaType(); + mlir::Type mlirType = convertTypeForMemory(*getTypeConverter(), allocaType); // FIXME: Some types can not be converted yet (e.g. struct) if (!mlirType) return mlir::LogicalResult::failure(); auto memreftype = mlir::dyn_cast(mlirType); - if (memreftype && mlir::isa(adaptor.getAllocaType())) { - // if the type is an array, - // we don't need to wrap with memref. + if (memreftype && (mlir::isa(allocaType) || + mlir::isa(allocaType))) { + // Arrays and structs are already memref. No need to wrap another one. } else { memreftype = mlir::MemRefType::get({}, mlirType); } @@ -946,8 +947,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern { } else { // For scopes with results, use scf.execute_region SmallVector types; - if (mlir::failed( - getTypeConverter()->convertTypes(scopeOp->getResultTypes(), types))) + if (mlir::failed(getTypeConverter()->convertTypes( + scopeOp->getResultTypes(), types))) return mlir::failure(); auto exec = rewriter.create(scopeOp.getLoc(), types); @@ -1485,6 +1486,35 @@ class CIRPtrStrideOpLowering } }; +class CIRGetMemberOpLowering + : public mlir::OpConversionPattern { +public: + CIRGetMemberOpLowering(mlir::TypeConverter &converter, mlir::MLIRContext *ctx, + const mlir::DataLayout &layout) + : OpConversionPattern(converter, ctx), layout(layout) {} + mlir::LogicalResult + matchAndRewrite(cir::GetMemberOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto baseAddr = op.getAddr(); + auto structType = + mlir::cast(baseAddr.getType().getPointee()); + + uint64_t byteOffset = structType.getElementOffset(layout, op.getIndex()); + auto fieldType = op.getResult().getType(); + + auto resultType = mlir::cast( + getTypeConverter()->convertType(fieldType)); + mlir::Value offsetValue = + rewriter.create(op.getLoc(), byteOffset); + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getAddr(), offsetValue, mlir::ValueRange{}); + return mlir::success(); + } + +private: + const mlir::DataLayout &layout; +}; + class CIRUnreachableOpLowering : public mlir::OpConversionPattern { public: @@ -1516,37 +1546,41 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern { }; void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, - mlir::TypeConverter &converter) { + mlir::TypeConverter &converter, + mlir::DataLayout layout) { patterns.add(patterns.getContext()); - patterns - .add(converter, patterns.getContext()); + patterns.add< + CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering, + CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering, + CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering, + CIRFuncOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering, + CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering, + CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering, + CIRGetElementOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering, + CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering, CIRAbsOpLowering, + CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering, + CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering, + CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering, + CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering, + CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering, + CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering, + CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering, + CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext()); + + patterns.add(converter, patterns.getContext(), + layout); } -static mlir::TypeConverter prepareTypeConverter() { +static mlir::TypeConverter prepareTypeConverter(mlir::DataLayout layout) { mlir::TypeConverter converter; converter.addConversion([&](cir::PointerType type) -> mlir::Type { - auto ty = convertTypeForMemory(converter, type.getPointee()); + auto pointee = type.getPointee(); + auto ty = convertTypeForMemory(converter, pointee); // FIXME: The pointee type might not be converted (e.g. struct) if (!ty) return nullptr; - if (isa(type.getPointee())) + if (isa(pointee) || isa(pointee)) return ty; return mlir::MemRefType::get({}, ty); }); @@ -1598,6 +1632,13 @@ static mlir::TypeConverter prepareTypeConverter() { return nullptr; return mlir::MemRefType::get(shape, elementType); }); + converter.addConversion([&](cir::RecordType type) -> mlir::Type { + // Reinterpret structs as raw bytes. Don't use tuples as they can't be put + // in memref. + auto size = type.getTypeSize(layout, {}); + auto i8 = mlir::IntegerType::get(type.getContext(), /*width=*/8); + return mlir::MemRefType::get(size.getFixedValue(), i8); + }); converter.addConversion([&](cir::VectorType type) -> mlir::Type { auto ty = converter.convertType(type.getElementType()); return mlir::VectorType::get(type.getSize(), ty); @@ -1609,12 +1650,15 @@ void ConvertCIRToMLIRPass::runOnOperation() { mlir::MLIRContext *context = &getContext(); mlir::ModuleOp theModule = getOperation(); - auto converter = prepareTypeConverter(); - + mlir::DataLayoutAnalysis layoutAnalysis(theModule); + const mlir::DataLayout &layout = layoutAnalysis.getAtOrAbove(theModule); + + auto converter = prepareTypeConverter(layout); + mlir::RewritePatternSet patterns(&getContext()); populateCIRLoopToSCFConversionPatterns(patterns, converter); - populateCIRToMLIRConversionPatterns(patterns, converter); + populateCIRToMLIRConversionPatterns(patterns, converter, layout); mlir::ConversionTarget target(getContext()); target.addLegalOp(); @@ -1628,10 +1672,11 @@ void ConvertCIRToMLIRPass::runOnOperation() { // cir dialect, for example the `cir.continue`. If we marked cir as illegal // here, then MLIR would think any remaining `cir.continue` indicates a // failure, which is not what we want. - - patterns.add(converter, context); - if (mlir::failed(mlir::applyPartialConversion(theModule, target, + patterns.add(converter, context); + + if (mlir::failed(mlir::applyPartialConversion(theModule, target, std::move(patterns)))) { signalPassFailure(); } diff --git a/clang/test/CIR/Lowering/ThroughMLIR/struct.cir b/clang/test/CIR/Lowering/ThroughMLIR/struct.cir new file mode 100644 index 000000000000..ad36f7f7932b --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/struct.cir @@ -0,0 +1,25 @@ +// RUN: cir-opt %s -cir-to-mlir -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +!s32i = !cir.int +!u8i = !cir.int +!u32i = !cir.int +!ty_S = !cir.record + +module { + cir.func @test() { + %1 = cir.alloca !ty_S, !cir.ptr, ["x"] {alignment = 4 : i64} + %3 = cir.get_member %1[0] {name = "c"} : !cir.ptr -> !cir.ptr + %5 = cir.get_member %1[1] {name = "i"} : !cir.ptr -> !cir.ptr + cir.return + } + + // CHECK: func.func @test() { + // CHECK: %[[alloca:[a-z0-9]+]] = memref.alloca() {alignment = 4 : i64} : memref<8xi8> + // CHECK: %[[zero:[a-z0-9]+]] = arith.constant 0 : index + // CHECK: memref.view %[[alloca]][%[[zero]]][] : memref<8xi8> to memref + // CHECK: %[[four:[a-z0-9]+]] = arith.constant 4 : index + // CHECK: %view_0 = memref.view %[[alloca]][%[[four]]][] : memref<8xi8> to memref + // CHECK: return + // CHECK: } +} \ No newline at end of file