Skip to content

Commit

Permalink
[OpenACC] Implement 'deviceptr' and 'attach' sema for compute constructs
Browse files Browse the repository at this point in the history
These two are very similar to the other 'var-list' variants, except they
require that the type of the variable be a pointer.  This patch
implements that restriction.
  • Loading branch information
erichkeane committed May 6, 2024
1 parent e50060f commit 48c8a57
Show file tree
Hide file tree
Showing 20 changed files with 733 additions and 18 deletions.
38 changes: 38 additions & 0 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,44 @@ class OpenACCFirstPrivateClause final
ArrayRef<Expr *> VarList, SourceLocation EndLoc);
};

class OpenACCDevicePtrClause final
: public OpenACCClauseWithVarList,
public llvm::TrailingObjects<OpenACCDevicePtrClause, Expr *> {

OpenACCDevicePtrClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<Expr *> VarList, SourceLocation EndLoc)
: OpenACCClauseWithVarList(OpenACCClauseKind::DevicePtr, BeginLoc,
LParenLoc, EndLoc) {
std::uninitialized_copy(VarList.begin(), VarList.end(),
getTrailingObjects<Expr *>());
setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), VarList.size()));
}

public:
static OpenACCDevicePtrClause *
Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<Expr *> VarList, SourceLocation EndLoc);
};

class OpenACCAttachClause final
: public OpenACCClauseWithVarList,
public llvm::TrailingObjects<OpenACCAttachClause, Expr *> {

OpenACCAttachClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<Expr *> VarList, SourceLocation EndLoc)
: OpenACCClauseWithVarList(OpenACCClauseKind::Attach, BeginLoc, LParenLoc,
EndLoc) {
std::uninitialized_copy(VarList.begin(), VarList.end(),
getTrailingObjects<Expr *>());
setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), VarList.size()));
}

public:
static OpenACCAttachClause *
Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<Expr *> VarList, SourceLocation EndLoc);
};

class OpenACCNoCreateClause final
: public OpenACCClauseWithVarList,
public llvm::TrailingObjects<OpenACCNoCreateClause, Expr *> {
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12337,4 +12337,7 @@ def warn_acc_deprecated_alias_name
: Warning<"OpenACC clause name '%0' is a deprecated clause name and is "
"now an alias for '%1'">,
InGroup<DiagGroup<"openacc-deprecated-clause-alias">>;
def err_acc_var_not_pointer_type
: Error<"expected pointer in '%0' clause, type is %1">;
def note_acc_expected_pointer_var : Note<"expected variable of pointer type">;
} // end of sema component.
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define CLAUSE_ALIAS(ALIAS_NAME, CLAUSE_NAME)
#endif

VISIT_CLAUSE(Attach)
VISIT_CLAUSE(Copy)
CLAUSE_ALIAS(PCopy, Copy)
CLAUSE_ALIAS(PresentOrCopy, Copy)
Expand All @@ -34,6 +35,7 @@ VISIT_CLAUSE(Create)
CLAUSE_ALIAS(PCreate, Create)
CLAUSE_ALIAS(PresentOrCreate, Create)
VISIT_CLAUSE(Default)
VISIT_CLAUSE(DevicePtr)
VISIT_CLAUSE(FirstPrivate)
VISIT_CLAUSE(If)
VISIT_CLAUSE(NoCreate)
Expand Down
10 changes: 10 additions & 0 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class SemaOpenACC : public SemaBase {
ClauseKind == OpenACCClauseKind::Create ||
ClauseKind == OpenACCClauseKind::PCreate ||
ClauseKind == OpenACCClauseKind::PresentOrCreate ||
ClauseKind == OpenACCClauseKind::Attach ||
ClauseKind == OpenACCClauseKind::DevicePtr ||
ClauseKind == OpenACCClauseKind::FirstPrivate) &&
"Parsed clause kind does not have a var-list");
return std::get<VarListDetails>(Details).VarList;
Expand Down Expand Up @@ -217,6 +219,8 @@ class SemaOpenACC : public SemaBase {
ClauseKind == OpenACCClauseKind::Create ||
ClauseKind == OpenACCClauseKind::PCreate ||
ClauseKind == OpenACCClauseKind::PresentOrCreate ||
ClauseKind == OpenACCClauseKind::Attach ||
ClauseKind == OpenACCClauseKind::DevicePtr ||
ClauseKind == OpenACCClauseKind::FirstPrivate) &&
"Parsed clause kind does not have a var-list");
assert((!IsReadOnly || ClauseKind == OpenACCClauseKind::CopyIn ||
Expand Down Expand Up @@ -251,6 +255,8 @@ class SemaOpenACC : public SemaBase {
ClauseKind == OpenACCClauseKind::Create ||
ClauseKind == OpenACCClauseKind::PCreate ||
ClauseKind == OpenACCClauseKind::PresentOrCreate ||
ClauseKind == OpenACCClauseKind::Attach ||
ClauseKind == OpenACCClauseKind::DevicePtr ||
ClauseKind == OpenACCClauseKind::FirstPrivate) &&
"Parsed clause kind does not have a var-list");
assert((!IsReadOnly || ClauseKind == OpenACCClauseKind::CopyIn ||
Expand Down Expand Up @@ -315,6 +321,10 @@ class SemaOpenACC : public SemaBase {
/// declaration reference to a variable of the correct type.
ExprResult ActOnVar(Expr *VarExpr);

/// Called to check the 'var' type is a variable of pointer type, necessary
/// for 'deviceptr' and 'attach' clauses. Returns true on success.
bool CheckVarIsPointerType(OpenACCClauseKind ClauseKind, Expr *VarExpr);

/// Checks and creates an Array Section used in an OpenACC construct/clause.
ExprResult ActOnArraySectionExpr(Expr *Base, SourceLocation LBLoc,
Expr *LowerBound,
Expand Down
35 changes: 35 additions & 0 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ OpenACCFirstPrivateClause *OpenACCFirstPrivateClause::Create(
OpenACCFirstPrivateClause(BeginLoc, LParenLoc, VarList, EndLoc);
}

OpenACCAttachClause *OpenACCAttachClause::Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
ArrayRef<Expr *> VarList,
SourceLocation EndLoc) {
void *Mem =
C.Allocate(OpenACCAttachClause::totalSizeToAlloc<Expr *>(VarList.size()));
return new (Mem) OpenACCAttachClause(BeginLoc, LParenLoc, VarList, EndLoc);
}

OpenACCDevicePtrClause *OpenACCDevicePtrClause::Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
ArrayRef<Expr *> VarList,
SourceLocation EndLoc) {
void *Mem = C.Allocate(
OpenACCDevicePtrClause::totalSizeToAlloc<Expr *>(VarList.size()));
return new (Mem) OpenACCDevicePtrClause(BeginLoc, LParenLoc, VarList, EndLoc);
}

OpenACCNoCreateClause *OpenACCNoCreateClause::Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expand Down Expand Up @@ -282,6 +302,21 @@ void OpenACCClausePrinter::VisitFirstPrivateClause(
OS << ")";
}

void OpenACCClausePrinter::VisitAttachClause(const OpenACCAttachClause &C) {
OS << "attach(";
llvm::interleaveComma(C.getVarList(), OS,
[&](const Expr *E) { printExpr(E); });
OS << ")";
}

void OpenACCClausePrinter::VisitDevicePtrClause(
const OpenACCDevicePtrClause &C) {
OS << "deviceptr(";
llvm::interleaveComma(C.getVarList(), OS,
[&](const Expr *E) { printExpr(E); });
OS << ")";
}

void OpenACCClausePrinter::VisitNoCreateClause(const OpenACCNoCreateClause &C) {
OS << "no_create(";
llvm::interleaveComma(C.getVarList(), OS,
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2543,6 +2543,18 @@ void OpenACCClauseProfiler::VisitFirstPrivateClause(
Profiler.VisitStmt(E);
}

void OpenACCClauseProfiler::VisitAttachClause(
const OpenACCAttachClause &Clause) {
for (auto *E : Clause.getVarList())
Profiler.VisitStmt(E);
}

void OpenACCClauseProfiler::VisitDevicePtrClause(
const OpenACCDevicePtrClause &Clause) {
for (auto *E : Clause.getVarList())
Profiler.VisitStmt(E);
}

void OpenACCClauseProfiler::VisitNoCreateClause(
const OpenACCNoCreateClause &Clause) {
for (auto *E : Clause.getVarList())
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,12 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
case OpenACCClauseKind::Default:
OS << '(' << cast<OpenACCDefaultClause>(C)->getDefaultClauseKind() << ')';
break;
case OpenACCClauseKind::Attach:
case OpenACCClauseKind::Copy:
case OpenACCClauseKind::PCopy:
case OpenACCClauseKind::PresentOrCopy:
case OpenACCClauseKind::If:
case OpenACCClauseKind::DevicePtr:
case OpenACCClauseKind::FirstPrivate:
case OpenACCClauseKind::NoCreate:
case OpenACCClauseKind::NumGangs:
Expand Down
8 changes: 6 additions & 2 deletions clang/lib/Parse/ParseOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,17 +945,21 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
// make sure we get the right differentiator.
assert(DirKind == OpenACCDirectiveKind::Update);
[[fallthrough]];
case OpenACCClauseKind::Attach:
case OpenACCClauseKind::Delete:
case OpenACCClauseKind::Detach:
case OpenACCClauseKind::Device:
case OpenACCClauseKind::DeviceResident:
case OpenACCClauseKind::DevicePtr:
case OpenACCClauseKind::Host:
case OpenACCClauseKind::Link:
case OpenACCClauseKind::UseDevice:
ParseOpenACCVarList();
break;
case OpenACCClauseKind::Attach:
case OpenACCClauseKind::DevicePtr:
// TODO: ERICH: Figure out how to limit to just ptrs?
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
/*IsReadOnly=*/false, /*IsZero=*/false);
break;
case OpenACCClauseKind::Copy:
case OpenACCClauseKind::PCopy:
case OpenACCClauseKind::PresentOrCopy:
Expand Down
100 changes: 100 additions & 0 deletions clang/lib/Sema/SemaOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,34 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
default:
return false;
}
case OpenACCClauseKind::Attach:
switch (DirectiveKind) {
case OpenACCDirectiveKind::Parallel:
case OpenACCDirectiveKind::Serial:
case OpenACCDirectiveKind::Kernels:
case OpenACCDirectiveKind::Data:
case OpenACCDirectiveKind::EnterData:
case OpenACCDirectiveKind::ParallelLoop:
case OpenACCDirectiveKind::SerialLoop:
case OpenACCDirectiveKind::KernelsLoop:
return true;
default:
return false;
}
case OpenACCClauseKind::DevicePtr:
switch (DirectiveKind) {
case OpenACCDirectiveKind::Parallel:
case OpenACCDirectiveKind::Serial:
case OpenACCDirectiveKind::Kernels:
case OpenACCDirectiveKind::Data:
case OpenACCDirectiveKind::Declare:
case OpenACCDirectiveKind::ParallelLoop:
case OpenACCDirectiveKind::SerialLoop:
case OpenACCDirectiveKind::KernelsLoop:
return true;
default:
return false;
}
default:
// Do nothing so we can go to the 'unimplemented' diagnostic instead.
return true;
Expand Down Expand Up @@ -513,6 +541,48 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
Clause.getLParenLoc(), Clause.isZero(),
Clause.getVarList(), Clause.getEndLoc());
}
case OpenACCClauseKind::Attach: {
// Restrictions only properly implemented on 'compute' constructs, and
// 'compute' constructs are the only construct that can do anything with
// this yet, so skip/treat as unimplemented in this case.
if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
break;

// ActOnVar ensured that everything is a valid variable reference, but we
// still have to make sure it is a pointer type.
llvm::SmallVector<Expr *> VarList{Clause.getVarList().begin(),
Clause.getVarList().end()};
VarList.erase(std::remove_if(VarList.begin(), VarList.end(), [&](Expr *E) {
return CheckVarIsPointerType(OpenACCClauseKind::Attach, E);
}), VarList.end());
Clause.setVarListDetails(VarList,
/*IsReadOnly=*/false, /*IsZero=*/false);

return OpenACCAttachClause::Create(getASTContext(), Clause.getBeginLoc(),
Clause.getLParenLoc(),
Clause.getVarList(), Clause.getEndLoc());
}
case OpenACCClauseKind::DevicePtr: {
// Restrictions only properly implemented on 'compute' constructs, and
// 'compute' constructs are the only construct that can do anything with
// this yet, so skip/treat as unimplemented in this case.
if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
break;

// ActOnVar ensured that everything is a valid variable reference, but we
// still have to make sure it is a pointer type.
llvm::SmallVector<Expr *> VarList{Clause.getVarList().begin(),
Clause.getVarList().end()};
VarList.erase(std::remove_if(VarList.begin(), VarList.end(), [&](Expr *E) {
return CheckVarIsPointerType(OpenACCClauseKind::DevicePtr, E);
}), VarList.end());
Clause.setVarListDetails(VarList,
/*IsReadOnly=*/false, /*IsZero=*/false);

return OpenACCDevicePtrClause::Create(
getASTContext(), Clause.getBeginLoc(), Clause.getLParenLoc(),
Clause.getVarList(), Clause.getEndLoc());
}
default:
break;
}
Expand Down Expand Up @@ -641,6 +711,36 @@ ExprResult SemaOpenACC::ActOnIntExpr(OpenACCDirectiveKind DK,
return IntExpr;
}

bool SemaOpenACC::CheckVarIsPointerType(OpenACCClauseKind ClauseKind,
Expr *VarExpr) {
// We already know that VarExpr is a proper reference to a variable, so we
// should be able to just take the type of the expression to get the type of
// the referenced variable.

// We've already seen an error, don't diagnose anything else.
if (!VarExpr || VarExpr->containsErrors())
return false;

if (isa<ArraySectionExpr>(VarExpr->IgnoreParenImpCasts()) ||
VarExpr->hasPlaceholderType(BuiltinType::ArraySection)) {
Diag(VarExpr->getExprLoc(), diag::err_array_section_use) << /*OpenACC=*/0;
Diag(VarExpr->getExprLoc(), diag::note_acc_expected_pointer_var);
return true;
}

QualType Ty = VarExpr->getType();
Ty = Ty.getNonReferenceType().getUnqualifiedType();

// Nothing we can do if this is a dependent type.
if (Ty->isDependentType())
return false;

if (!Ty->isPointerType())
return Diag(VarExpr->getExprLoc(), diag::err_acc_var_not_pointer_type)
<< ClauseKind << Ty;
return false;
}

ExprResult SemaOpenACC::ActOnVar(Expr *VarExpr) {
// We still need to retain the array subscript/subarray exprs, so work on a
// copy.
Expand Down
37 changes: 37 additions & 0 deletions clang/lib/Sema/TreeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -11318,6 +11318,43 @@ void OpenACCClauseTransform<Derived>::VisitCreateClause(
ParsedClause.isZero(), ParsedClause.getVarList(),
ParsedClause.getEndLoc());
}
template <typename Derived>
void OpenACCClauseTransform<Derived>::VisitAttachClause(
const OpenACCAttachClause &C) {
llvm::SmallVector<Expr *> VarList = VisitVarList(C.getVarList());

// Ensure each var is a pointer type.
VarList.erase(std::remove_if(VarList.begin(), VarList.end(), [&](Expr *E) {
return Self.getSema().OpenACC().CheckVarIsPointerType(
OpenACCClauseKind::Attach, E);
}), VarList.end());

ParsedClause.setVarListDetails(VarList,
/*IsReadOnly=*/false, /*IsZero=*/false);
NewClause = OpenACCAttachClause::Create(
Self.getSema().getASTContext(), ParsedClause.getBeginLoc(),
ParsedClause.getLParenLoc(), ParsedClause.getVarList(),
ParsedClause.getEndLoc());
}

template <typename Derived>
void OpenACCClauseTransform<Derived>::VisitDevicePtrClause(
const OpenACCDevicePtrClause &C) {
llvm::SmallVector<Expr *> VarList = VisitVarList(C.getVarList());

// Ensure each var is a pointer type.
VarList.erase(std::remove_if(VarList.begin(), VarList.end(), [&](Expr *E) {
return Self.getSema().OpenACC().CheckVarIsPointerType(
OpenACCClauseKind::DevicePtr, E);
}), VarList.end());

ParsedClause.setVarListDetails(VarList,
/*IsReadOnly=*/false, /*IsZero=*/false);
NewClause = OpenACCDevicePtrClause::Create(
Self.getSema().getASTContext(), ParsedClause.getBeginLoc(),
ParsedClause.getLParenLoc(), ParsedClause.getVarList(),
ParsedClause.getEndLoc());
}

template <typename Derived>
void OpenACCClauseTransform<Derived>::VisitNumWorkersClause(
Expand Down
Loading

0 comments on commit 48c8a57

Please sign in to comment.