Skip to content

Commit

Permalink
[flang][acc] Implement type categorization for FIR types (#126964)
Browse files Browse the repository at this point in the history
The OpenACC type interfaces have been updated to require that a type
self-identify which type category it belongs to. Ensure that FIR types
are able to provide this self identification.

In addition to implementing the new API, the PointerLikeType interface
attachment was moved to FIROpenACCSupport library like MappableType to
ensure all type interfaces and their implementation are now in the same
spot.
  • Loading branch information
razvanlupusoru authored Feb 13, 2025
1 parent 9456e7f commit 7b473df
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 31 deletions.
16 changes: 16 additions & 0 deletions flang/include/flang/Optimizer/OpenACC/FIROpenACCTypeInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@

namespace fir::acc {

template <typename T>
struct OpenACCPointerLikeModel
: public mlir::acc::PointerLikeType::ExternalModel<
OpenACCPointerLikeModel<T>, T> {
mlir::Type getElementType(mlir::Type pointer) const {
return mlir::cast<T>(pointer).getElementType();
}
mlir::acc::VariableTypeCategory
getPointeeTypeCategory(mlir::Type pointer,
mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const;
};

template <typename T>
struct OpenACCMappableModel
: public mlir::acc::MappableType::ExternalModel<OpenACCMappableModel<T>,
Expand All @@ -36,6 +49,9 @@ struct OpenACCMappableModel
llvm::SmallVector<mlir::Value>
generateAccBounds(mlir::Type type, mlir::Value var,
mlir::OpBuilder &builder) const;

mlir::acc::VariableTypeCategory getTypeCategory(mlir::Type type,
mlir::Value var) const;
};

} // namespace fir::acc
Expand Down
10 changes: 0 additions & 10 deletions flang/include/flang/Tools/PointerModels.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#ifndef FORTRAN_TOOLS_POINTER_MODELS_H
#define FORTRAN_TOOLS_POINTER_MODELS_H

#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"

/// models for FIR pointer like types that already provide a `getElementType`
Expand All @@ -24,13 +23,4 @@ struct OpenMPPointerLikeModel
}
};

template <typename T>
struct OpenACCPointerLikeModel
: public mlir::acc::PointerLikeType::ExternalModel<
OpenACCPointerLikeModel<T>, T> {
mlir::Type getElementType(mlir::Type pointer) const {
return mlir::cast<T>(pointer).getElementType();
}
};

#endif // FORTRAN_TOOLS_POINTER_MODELS_H
8 changes: 4 additions & 4 deletions flang/lib/Frontend/FrontendActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ bool CodeGenAction::beginSourceFileAction() {
}

// Load the MLIR dialects required by Flang
mlir::DialectRegistry registry;
mlirCtx = std::make_unique<mlir::MLIRContext>(registry);
fir::support::registerNonCodegenDialects(registry);
fir::support::loadNonCodegenDialects(*mlirCtx);
mlirCtx = std::make_unique<mlir::MLIRContext>();
fir::support::loadDialects(*mlirCtx);
fir::support::registerLLVMTranslation(*mlirCtx);
mlir::DialectRegistry registry;
fir::acc::registerOpenACCExtensions(registry);
mlirCtx->appendDialectRegistry(registry);

const llvm::TargetMachine &targetMachine = ci.getTargetMachine();

Expand Down
11 changes: 0 additions & 11 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1370,23 +1370,12 @@ void FIROpsDialect::registerTypes() {
TypeDescType, fir::VectorType, fir::DummyScopeType>();
fir::ReferenceType::attachInterface<
OpenMPPointerLikeModel<fir::ReferenceType>>(*getContext());
fir::ReferenceType::attachInterface<
OpenACCPointerLikeModel<fir::ReferenceType>>(*getContext());

fir::PointerType::attachInterface<OpenMPPointerLikeModel<fir::PointerType>>(
*getContext());
fir::PointerType::attachInterface<OpenACCPointerLikeModel<fir::PointerType>>(
*getContext());

fir::HeapType::attachInterface<OpenMPPointerLikeModel<fir::HeapType>>(
*getContext());
fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
*getContext());

fir::LLVMPointerType::attachInterface<
OpenMPPointerLikeModel<fir::LLVMPointerType>>(*getContext());
fir::LLVMPointerType::attachInterface<
OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext());
}

std::optional<std::pair<uint64_t, unsigned short>>
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Optimizer/OpenACC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_flang_library(FIROpenACCSupport

DEPENDS
FIRBuilder
FIRCodeGen
FIRDialect
FIRDialectSupport
FIRSupport
Expand All @@ -14,6 +15,7 @@ add_flang_library(FIROpenACCSupport

LINK_LIBS
FIRBuilder
FIRCodeGen
FIRDialect
FIRDialectSupport
FIRSupport
Expand Down
143 changes: 143 additions & 0 deletions flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "flang/Optimizer/Builder/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/CodeGen/CGOps.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
Expand All @@ -24,6 +25,7 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"

namespace fir::acc {

Expand Down Expand Up @@ -224,4 +226,145 @@ OpenACCMappableModel<fir::BaseBoxType>::generateAccBounds(
return {};
}

static bool isScalarLike(mlir::Type type) {
return fir::isa_trivial(type) || fir::isa_ref_type(type);
}

static bool isArrayLike(mlir::Type type) {
return mlir::isa<fir::SequenceType>(type);
}

static bool isCompositeLike(mlir::Type type) {
return mlir::isa<fir::RecordType, fir::ClassType, mlir::TupleType>(type);
}

template <>
mlir::acc::VariableTypeCategory
OpenACCMappableModel<fir::SequenceType>::getTypeCategory(
mlir::Type type, mlir::Value var) const {
return mlir::acc::VariableTypeCategory::array;
}

template <>
mlir::acc::VariableTypeCategory
OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
mlir::Value var) const {

mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);

// If the type enclosed by the box is a mappable type, then have it
// provide the type category.
if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
return mappableTy.getTypeCategory(var);

// For all arrays, despite whether they are allocatable, pointer, assumed,
// etc, we'd like to categorize them as "array".
if (isArrayLike(eleTy))
return mlir::acc::VariableTypeCategory::array;

// We got here because we don't have an array nor a mappable type. At this
// point, we know we have a type that fits the "aggregate" definition since it
// is a type with a descriptor. Try to refine it by checking if it matches the
// "composite" definition.
if (isCompositeLike(eleTy))
return mlir::acc::VariableTypeCategory::composite;

// Even if we have a scalar type - simply because it is wrapped in a box
// we want to categorize it as "nonscalar". Anything else would've been
// non-scalar anyway.
return mlir::acc::VariableTypeCategory::nonscalar;
}

static mlir::TypedValue<mlir::acc::PointerLikeType>
getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
// If there is no defining op - the unwrapped reference is the base one.
mlir::Operation *op = varPtr.getDefiningOp();
if (!op)
return varPtr;

// Look to find if this value originates from an interior pointer
// calculation op.
mlir::Value baseRef =
llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op)
.Case<hlfir::DesignateOp>([&](auto op) {
// Get the base object.
return op.getMemref();
})
.Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp>([&](auto op) {
// Get the base array on which the coordinate is being applied.
return op.getMemref();
})
.Case<fir::CoordinateOp>([&](auto op) {
// For coordinate operation which is applied on derived type
// object, get the base object.
return op.getRef();
})
.Default([&](mlir::Operation *) { return varPtr; });

return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
}

static mlir::acc::VariableTypeCategory
categorizePointee(mlir::Type pointer,
mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) {
// FIR uses operations to compute interior pointers.
// So for example, an array element or composite field access to a float
// value would both be represented as !fir.ref<f32>. We do not want to treat
// such a reference as a scalar. Thus unwrap interior pointer calculations.
auto baseRef = getBaseRef(varPtr);
mlir::Type eleTy = baseRef.getType().getElementType();

if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
return mappableTy.getTypeCategory(varPtr);

if (isScalarLike(eleTy))
return mlir::acc::VariableTypeCategory::scalar;
if (isArrayLike(eleTy))
return mlir::acc::VariableTypeCategory::array;
if (isCompositeLike(eleTy))
return mlir::acc::VariableTypeCategory::composite;
if (mlir::isa<fir::CharacterType, mlir::FunctionType>(eleTy))
return mlir::acc::VariableTypeCategory::nonscalar;
// "pointers" - in the sense of raw address point-of-view, are considered
// scalars. However
if (mlir::isa<fir::LLVMPointerType>(eleTy))
return mlir::acc::VariableTypeCategory::scalar;

// Without further checking, this type cannot be categorized.
return mlir::acc::VariableTypeCategory::uncategorized;
}

template <>
mlir::acc::VariableTypeCategory
OpenACCPointerLikeModel<fir::ReferenceType>::getPointeeTypeCategory(
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const {
return categorizePointee(pointer, varPtr, varType);
}

template <>
mlir::acc::VariableTypeCategory
OpenACCPointerLikeModel<fir::PointerType>::getPointeeTypeCategory(
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const {
return categorizePointee(pointer, varPtr, varType);
}

template <>
mlir::acc::VariableTypeCategory
OpenACCPointerLikeModel<fir::HeapType>::getPointeeTypeCategory(
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const {
return categorizePointee(pointer, varPtr, varType);
}

template <>
mlir::acc::VariableTypeCategory
OpenACCPointerLikeModel<fir::LLVMPointerType>::getPointeeTypeCategory(
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const {
return categorizePointee(pointer, varPtr, varType);
}

} // namespace fir::acc
9 changes: 9 additions & 0 deletions flang/lib/Optimizer/OpenACC/RegisterOpenACCExtensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
fir::SequenceType::attachInterface<OpenACCMappableModel<fir::SequenceType>>(
*ctx);
fir::BoxType::attachInterface<OpenACCMappableModel<fir::BaseBoxType>>(*ctx);

fir::ReferenceType::attachInterface<
OpenACCPointerLikeModel<fir::ReferenceType>>(*ctx);
fir::PointerType::attachInterface<
OpenACCPointerLikeModel<fir::PointerType>>(*ctx);
fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
*ctx);
fir::LLVMPointerType::attachInterface<
OpenACCPointerLikeModel<fir::LLVMPointerType>>(*ctx);
});
}

Expand Down
2 changes: 2 additions & 0 deletions flang/test/Fir/OpenACC/openacc-mappable.fir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<f16 = dense<16> : vector<2xi64>,

// CHECK: Visiting: %{{.*}} = acc.copyin var(%{{.*}} : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "arr", structured = false}
// CHECK: Mappable: !fir.box<!fir.array<10xf32>>
// CHECK: Type category: array
// CHECK: Size: 40
// CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr", structured = false}
// CHECK: Mappable: !fir.array<10xf32>
// CHECK: Type category: array
// CHECK: Size: 40
49 changes: 49 additions & 0 deletions flang/test/Fir/OpenACC/openacc-type-categories.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s

program main
real :: scalar
real, allocatable :: scalaralloc
type tt
real :: field
real :: fieldarray(10)
end type tt
type(tt) :: ttvar
real :: arrayconstsize(10)
real, allocatable :: arrayalloc(:)
complex :: complexvar
character*1 :: charvar

!$acc enter data copyin(scalar, scalaralloc, ttvar, arrayconstsize, arrayalloc)
!$acc enter data copyin(complexvar, charvar, ttvar%field, ttvar%fieldarray, arrayconstsize(1))
end program

! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "scalar", structured = false}
! CHECK: Pointer-like: !fir.ref<f32>
! CHECK: Type category: scalar
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "scalaralloc", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.box<!fir.heap<f32>>>
! CHECK: Type category: nonscalar
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.type<_QFTtt{field:f32,fieldarray:!fir.array<10xf32>}>>
! CHECK: Type category: composite
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayconstsize", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
! CHECK: Type category: array
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayalloc", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
! CHECK: Type category: array
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "complexvar", structured = false}
! CHECK: Pointer-like: !fir.ref<complex<f32>>
! CHECK: Type category: scalar
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "charvar", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.char<1>>
! CHECK: Type category: nonscalar
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar%field", structured = false}
! CHECK: Pointer-like: !fir.ref<f32>
! CHECK: Type category: composite
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar%fieldarray", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
! CHECK: Type category: array
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayconstsize(1)", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
! CHECK: Type category: array
Loading

0 comments on commit 7b473df

Please sign in to comment.