Skip to content

Commit

Permalink
[flang][openacc/mp] Do not read bounds on absent box (llvm#75252)
Browse files Browse the repository at this point in the history
Make sure we only load box and read its bounds when it is present.
- Add `AddrAndBoundInfo` struct to be able to carry around the `addr`
and `isPresent` values. This is likely to grow so we can make all the
access in a single `fir.if` operation.
  • Loading branch information
clementval authored Dec 15, 2023
1 parent 809ee6c commit 22426d9
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 98 deletions.
183 changes: 142 additions & 41 deletions flang/lib/Lower/DirectivesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@
namespace Fortran {
namespace lower {

/// Information gathered to generate bounds operation and data entry/exit
/// operations.
struct AddrAndBoundsInfo {
explicit AddrAndBoundsInfo() {}
explicit AddrAndBoundsInfo(mlir::Value addr) : addr(addr) {}
explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value isPresent)
: addr(addr), isPresent(isPresent) {}
mlir::Value addr = nullptr;
mlir::Value isPresent = nullptr;
};

/// Checks if the assignment statement has a single variable on the RHS.
static inline bool checkForSingleVariableOnRHS(
const Fortran::parser::AssignmentStmt &assignmentStmt) {
Expand Down Expand Up @@ -598,7 +609,7 @@ void createEmptyRegionBlocks(
}
}

inline mlir::Value
inline AddrAndBoundsInfo
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &builder,
Fortran::lower::SymbolRef sym, mlir::Location loc) {
Expand All @@ -620,25 +631,42 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,

// Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a
// `fir.ref<fir.class<T>>` type.
if (symAddr.getType().isa<fir::ReferenceType>())
return builder.create<fir::LoadOp>(loc, symAddr);
if (symAddr.getType().isa<fir::ReferenceType>()) {
if (Fortran::semantics::IsOptional(sym)) {
mlir::Value isPresent =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), symAddr);
mlir::Value addr =
builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
.genThen([&]() {
mlir::Value load = builder.create<fir::LoadOp>(loc, symAddr);
builder.create<fir::ResultOp>(loc, mlir::ValueRange{load});
})
.genElse([&] {
mlir::Value absent =
builder.create<fir::AbsentOp>(loc, boxTy);
builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
})
.getResults()[0];
return AddrAndBoundsInfo(addr, isPresent);
}
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
return AddrAndBoundsInfo(addr);
;
}
}
return symAddr;
return AddrAndBoundsInfo(symAddr);
}

/// Generate the bounds operation from the descriptor information.
template <typename BoundsOp, typename BoundsType>
llvm::SmallVector<mlir::Value>
genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::AbstractConverter &converter,
fir::ExtendedValue dataExv, mlir::Value box) {
llvm::SmallVector<mlir::Value> bounds;
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
fir::ExtendedValue dataExv, mlir::Value box,
bool collectValuesOnly = false) {
llvm::SmallVector<mlir::Value> values;
mlir::Value byteStride;
mlir::Type idxTy = builder.getIndexType();
mlir::Type boundTy = builder.getType<BoundsType>();
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
assert(box.getType().isa<fir::BaseBoxType>() &&
"expect fir.box or fir.class");
mlir::Value byteStride;
for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim);
mlir::Value baseLb =
Expand All @@ -650,12 +678,79 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
builder.create<mlir::arith::SubIOp>(loc, dimInfo.getExtent(), one);
if (dim == 0) // First stride is the element size.
byteStride = dimInfo.getByteStride();
mlir::Value bound = builder.create<BoundsOp>(
loc, boundTy, lb, ub, dimInfo.getExtent(), byteStride, true, baseLb);
if (collectValuesOnly) {
values.push_back(lb);
values.push_back(ub);
values.push_back(dimInfo.getExtent());
values.push_back(byteStride);
values.push_back(baseLb);
} else {
mlir::Value bound = builder.create<BoundsOp>(
loc, boundTy, lb, ub, dimInfo.getExtent(), byteStride, true, baseLb);
values.push_back(bound);
}
// Compute the stride for the next dimension.
byteStride = builder.create<mlir::arith::MulIOp>(loc, byteStride,
dimInfo.getExtent());
bounds.push_back(bound);
}
return values;
}

/// Generate the bounds operation from the descriptor information.
template <typename BoundsOp, typename BoundsType>
llvm::SmallVector<mlir::Value>
genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::AbstractConverter &converter,
fir::ExtendedValue dataExv,
Fortran::lower::AddrAndBoundsInfo &info) {
llvm::SmallVector<mlir::Value> bounds;
mlir::Type idxTy = builder.getIndexType();
mlir::Type boundTy = builder.getType<BoundsType>();

assert(info.addr.getType().isa<fir::BaseBoxType>() &&
"expect fir.box or fir.class");

if (info.isPresent) {
llvm::SmallVector<mlir::Type> resTypes;
constexpr unsigned nbValuesPerBound = 5;
for (unsigned dim = 0; dim < dataExv.rank() * nbValuesPerBound; ++dim)
resTypes.push_back(idxTy);

mlir::Operation::result_range ifRes =
builder.genIfOp(loc, resTypes, info.isPresent, /*withElseRegion=*/true)
.genThen([&]() {
llvm::SmallVector<mlir::Value> boundValues =
gatherBoundsOrBoundValues<BoundsOp, BoundsType>(
builder, loc, dataExv, info.addr,
/*collectValuesOnly=*/true);
builder.create<fir::ResultOp>(loc, boundValues);
})
.genElse([&] {
// Box is not present. Populate bound values with default values.
llvm::SmallVector<mlir::Value> boundValues;
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
mlir::Value mOne = builder.createIntegerConstant(loc, idxTy, -1);
for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
boundValues.push_back(zero); // lb
boundValues.push_back(mOne); // ub
boundValues.push_back(zero); // extent
boundValues.push_back(zero); // byteStride
boundValues.push_back(zero); // baseLb
}
builder.create<fir::ResultOp>(loc, boundValues);
})
.getResults();
// Create the bound operations outside the if-then-else with the if op
// results.
for (unsigned i = 0; i < ifRes.size(); i += nbValuesPerBound) {
mlir::Value bound = builder.create<BoundsOp>(
loc, boundTy, ifRes[i], ifRes[i + 1], ifRes[i + 2], ifRes[i + 3],
true, ifRes[i + 4]);
bounds.push_back(bound);
}
} else {
bounds = gatherBoundsOrBoundValues<BoundsOp, BoundsType>(
builder, loc, dataExv, info.addr);
}
return bounds;
}
Expand Down Expand Up @@ -843,14 +938,13 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
}

template <typename ObjectType, typename BoundsOp, typename BoundsType>
mlir::Value gatherDataOperandAddrAndBounds(
AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx, const ObjectType &object,
mlir::Location operandLocation, std::stringstream &asFortran,
llvm::SmallVector<mlir::Value> &bounds, bool treatIndexAsSection = false) {
mlir::Value baseAddr;

AddrAndBoundsInfo info;
std::visit(
Fortran::common::visitors{
[&](const Fortran::parser::Designator &designator) {
Expand All @@ -872,13 +966,13 @@ mlir::Value gatherDataOperandAddrAndBounds(
semanticsContext, arrayElement->base);
dataExv = converter.genExprAddr(operandLocation, *exprBase,
stmtCtx);
baseAddr = fir::getBase(dataExv);
info.addr = fir::getBase(dataExv);
asFortran << (*exprBase).AsFortran();
} else {
const Fortran::parser::Name &name =
Fortran::parser::GetLastName(*dataRef);
baseAddr = getDataOperandBaseAddr(
converter, builder, *name.symbol, operandLocation);
info = getDataOperandBaseAddr(converter, builder,
*name.symbol, operandLocation);
dataExv = converter.getSymbolExtendedValue(*name.symbol);
asFortran << name.ToString();
}
Expand All @@ -887,38 +981,44 @@ mlir::Value gatherDataOperandAddrAndBounds(
asFortran << '(';
bounds = genBoundsOps<BoundsOp, BoundsType>(
builder, operandLocation, converter, stmtCtx,
arrayElement->subscripts, asFortran, dataExv, baseAddr,
arrayElement->subscripts, asFortran, dataExv, info.addr,
treatIndexAsSection);
}
asFortran << ')';
} else if (Fortran::parser::Unwrap<
} else if (auto structComp = Fortran::parser::Unwrap<
Fortran::parser::StructureComponent>(designator)) {
fir::ExtendedValue compExv =
converter.genExprAddr(operandLocation, *expr, stmtCtx);
baseAddr = fir::getBase(compExv);
if (fir::unwrapRefType(baseAddr.getType())
info.addr = fir::getBase(compExv);
if (fir::unwrapRefType(info.addr.getType())
.isa<fir::SequenceType>())
bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
builder, operandLocation, converter, compExv);
asFortran << (*expr).AsFortran();

bool isOptional = Fortran::semantics::IsOptional(
*Fortran::parser::GetLastName(*structComp).symbol);
if (isOptional)
info.isPresent = builder.create<fir::IsPresentOp>(
operandLocation, builder.getI1Type(), info.addr);

if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(
baseAddr.getDefiningOp())) {
info.addr.getDefiningOp())) {
if (fir::isAllocatableType(loadOp.getType()) ||
fir::isPointerType(loadOp.getType()))
baseAddr = builder.create<fir::BoxAddrOp>(operandLocation,
baseAddr);
info.addr = builder.create<fir::BoxAddrOp>(operandLocation,
info.addr);
}

// If the component is an allocatable or pointer the result of
// genExprAddr will be the result of a fir.box_addr operation or
// a fir.box_addr has been inserted just before.
// Retrieve the box so we handle it like other descriptor.
if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
baseAddr.getDefiningOp())) {
baseAddr = boxAddrOp.getVal();
info.addr.getDefiningOp())) {
info.addr = boxAddrOp.getVal();
bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
builder, operandLocation, converter, compExv, baseAddr);
builder, operandLocation, converter, compExv, info);
}
} else {
if (Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
Expand All @@ -930,7 +1030,7 @@ mlir::Value gatherDataOperandAddrAndBounds(
(void)arrayElement;
fir::ExtendedValue compExv =
converter.genExprAddr(operandLocation, *expr, stmtCtx);
baseAddr = fir::getBase(compExv);
info.addr = fir::getBase(compExv);
asFortran << (*expr).AsFortran();
} else if (const auto *dataRef{
std::get_if<Fortran::parser::DataRef>(
Expand All @@ -940,13 +1040,14 @@ mlir::Value gatherDataOperandAddrAndBounds(
Fortran::parser::GetLastName(*dataRef);
fir::ExtendedValue dataExv =
converter.getSymbolExtendedValue(*name.symbol);
baseAddr = getDataOperandBaseAddr(
converter, builder, *name.symbol, operandLocation);
if (fir::unwrapRefType(baseAddr.getType())
.isa<fir::BaseBoxType>())
info = getDataOperandBaseAddr(converter, builder,
*name.symbol, operandLocation);
if (fir::unwrapRefType(info.addr.getType())
.isa<fir::BaseBoxType>()) {
bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
builder, operandLocation, converter, dataExv, baseAddr);
if (fir::unwrapRefType(baseAddr.getType())
builder, operandLocation, converter, dataExv, info);
}
if (fir::unwrapRefType(info.addr.getType())
.isa<fir::SequenceType>())
bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
builder, operandLocation, converter, dataExv);
Expand All @@ -959,12 +1060,12 @@ mlir::Value gatherDataOperandAddrAndBounds(
}
},
[&](const Fortran::parser::Name &name) {
baseAddr = getDataOperandBaseAddr(converter, builder, *name.symbol,
operandLocation);
info = getDataOperandBaseAddr(converter, builder, *name.symbol,
operandLocation);
asFortran << name.ToString();
}},
object.u);
return baseAddr;
return info;
}

} // namespace lower
Expand Down
Loading

0 comments on commit 22426d9

Please sign in to comment.