Skip to content

Commit

Permalink
[Clang][Sema][OpenMP] Allow thread_limit to accept multiple express…
Browse files Browse the repository at this point in the history
…ions (#102715)
  • Loading branch information
shiltian authored Aug 10, 2024
1 parent ac47edd commit 1c26992
Show file tree
Hide file tree
Showing 19 changed files with 181 additions and 82 deletions.
3 changes: 2 additions & 1 deletion clang/docs/OpenMPSupport.rst
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ considered for standardization. Please post on the
| device extension | `'ompx_bare' clause on 'target teams' construct | :good:`prototyped` | #66844, #70612 |
| | <https://www.osti.gov/servlets/purl/2205717>`_ | | |
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
| device extension | Multi-dim 'num_teams' clause on 'target teams ompx_bare' construct | :good:`partial` | #99732, #101407 |
| device extension | Multi-dim 'num_teams' and 'thread_limit' clause on 'target teams ompx_bare' | :good:`partial` | #99732, #101407, #102715 |
| | construct | | |
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+

.. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35
5 changes: 3 additions & 2 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,9 @@ Improvements
^^^^^^^^^^^^
- Improve the handling of mapping array-section for struct containing nested structs with user defined mappers

- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
This allows the target region to be launched with multi-dim grid on GPUs.
- `num_teams` and `thead_limit` now accept multiple expressions when it is used
along in ``target teams ompx_bare`` construct. This allows the target region
to be launched with multi-dim grid on GPUs.

Additional Information
======================
Expand Down
81 changes: 49 additions & 32 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -6462,61 +6462,78 @@ class OMPNumTeamsClause final
/// \endcode
/// In this example directive '#pragma omp teams' has clause 'thread_limit'
/// with single expression 'n'.
class OMPThreadLimitClause : public OMPClause, public OMPClauseWithPreInit {
friend class OMPClauseReader;
///
/// When 'ompx_bare' clause exists on a 'target' directive, 'thread_limit'
/// clause can accept up to three expressions.
///
/// \code
/// #pragma omp target teams ompx_bare thread_limit(x, y, z)
/// \endcode
class OMPThreadLimitClause final
: public OMPVarListClause<OMPThreadLimitClause>,
public OMPClauseWithPreInit,
private llvm::TrailingObjects<OMPThreadLimitClause, Expr *> {
friend OMPVarListClause;
friend TrailingObjects;

/// Location of '('.
SourceLocation LParenLoc;

/// ThreadLimit number.
Stmt *ThreadLimit = nullptr;
OMPThreadLimitClause(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc,
unsigned N)
: OMPVarListClause(llvm::omp::OMPC_thread_limit, StartLoc, LParenLoc,
EndLoc, N),
OMPClauseWithPreInit(this) {}

/// Set the ThreadLimit number.
///
/// \param E ThreadLimit number.
void setThreadLimit(Expr *E) { ThreadLimit = E; }
/// Build an empty clause.
OMPThreadLimitClause(unsigned N)
: OMPVarListClause(llvm::omp::OMPC_thread_limit, SourceLocation(),
SourceLocation(), SourceLocation(), N),
OMPClauseWithPreInit(this) {}

public:
/// Build 'thread_limit' clause.
/// Creates clause with a list of variables \a VL.
///
/// \param E Expression associated with this clause.
/// \param HelperE Helper Expression associated with this clause.
/// \param CaptureRegion Innermost OpenMP region where expressions in this
/// clause must be captured.
/// \param C AST context.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
OMPThreadLimitClause(Expr *E, Stmt *HelperE,
OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc)
: OMPClause(llvm::omp::OMPC_thread_limit, StartLoc, EndLoc),
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), ThreadLimit(E) {
setPreInitStmt(HelperE, CaptureRegion);
}
/// \param VL List of references to the variables.
/// \param PreInit
static OMPThreadLimitClause *
Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);

/// Build an empty clause.
OMPThreadLimitClause()
: OMPClause(llvm::omp::OMPC_thread_limit, SourceLocation(),
SourceLocation()),
OMPClauseWithPreInit(this) {}
/// Creates an empty clause with \a N variables.
///
/// \param C AST context.
/// \param N The number of variables.
static OMPThreadLimitClause *CreateEmpty(const ASTContext &C, unsigned N);

/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }

/// Return ThreadLimit number.
Expr *getThreadLimit() { return cast<Expr>(ThreadLimit); }
/// Return ThreadLimit expressions.
ArrayRef<Expr *> getThreadLimit() { return getVarRefs(); }

/// Return ThreadLimit number.
Expr *getThreadLimit() const { return cast<Expr>(ThreadLimit); }
/// Return ThreadLimit expressions.
ArrayRef<Expr *> getThreadLimit() const {
return const_cast<OMPThreadLimitClause *>(this)->getThreadLimit();
}

child_range children() { return child_range(&ThreadLimit, &ThreadLimit + 1); }
child_range children() {
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
reinterpret_cast<Stmt **>(varlist_end()));
}

const_child_range children() const {
return const_child_range(&ThreadLimit, &ThreadLimit + 1);
auto Children = const_cast<OMPThreadLimitClause *>(this)->children();
return const_child_range(Children.begin(), Children.end());
}

child_range used_children() {
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3836,8 +3836,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPThreadLimitClause(
OMPThreadLimitClause *C) {
TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPreInit(C));
TRY_TO(TraverseStmt(C->getThreadLimit()));
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,7 @@ class SemaOpenMP : public SemaBase {
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'thread_limit' clause.
OMPClause *ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
OMPClause *ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
Expand Down
26 changes: 23 additions & 3 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,24 @@ OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
return new (Mem) OMPNumTeamsClause(N);
}

OMPThreadLimitClause *OMPThreadLimitClause::Create(
const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
ArrayRef<Expr *> VL, Stmt *PreInit) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
OMPThreadLimitClause *Clause =
new (Mem) OMPThreadLimitClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
Clause->setVarRefs(VL);
Clause->setPreInitStmt(PreInit, CaptureRegion);
return Clause;
}

OMPThreadLimitClause *OMPThreadLimitClause::CreateEmpty(const ASTContext &C,
unsigned N) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
return new (Mem) OMPThreadLimitClause(N);
}

//===----------------------------------------------------------------------===//
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2081,9 +2099,11 @@ void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
}

void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
OS << "thread_limit(";
Node->getThreadLimit()->printPretty(OS, nullptr, Policy, 0);
OS << ")";
if (!Node->varlist_empty()) {
OS << "thread_limit";
VisitOMPClauseList(Node, '(');
OS << ")";
}
}

void OMPClausePrinter::VisitOMPPriorityClause(OMPPriorityClause *Node) {
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,9 +862,8 @@ void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
}
void OMPClauseProfiler::VisitOMPThreadLimitClause(
const OMPThreadLimitClause *C) {
VisitOMPClauseList(C);
VistOMPClauseWithPreInit(C);
if (C->getThreadLimit())
Profiler->VisitStmt(C->getThreadLimit());
}
void OMPClauseProfiler::VisitOMPPriorityClause(const OMPPriorityClause *C) {
VistOMPClauseWithPreInit(C);
Expand Down
15 changes: 10 additions & 5 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6332,7 +6332,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
CGOpenMPInnerExprInfo CGInfo(CGF, *CS);
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
CodeGenFunction::LexicalScope Scope(
CGF, ThreadLimitClause->getThreadLimit()->getSourceRange());
CGF,
ThreadLimitClause->getThreadLimit().front()->getSourceRange());
if (const auto *PreInit =
cast_or_null<DeclStmt>(ThreadLimitClause->getPreInitStmt())) {
for (const auto *I : PreInit->decls()) {
Expand All @@ -6349,7 +6350,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
}
}
if (ThreadLimitClause)
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
ThreadLimitExpr);
if (const auto *Dir = dyn_cast_or_null<OMPExecutableDirective>(Child)) {
if (isOpenMPTeamsDirective(Dir->getDirectiveKind()) &&
!isOpenMPDistributeDirective(Dir->getDirectiveKind())) {
Expand All @@ -6370,7 +6372,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
ThreadLimitExpr);
}
const CapturedStmt *CS = D.getInnermostCapturedStmt();
getNumThreads(CGF, CS, NTPtr, UpperBound, UpperBoundOnly, CondVal);
Expand All @@ -6388,7 +6391,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
ThreadLimitExpr);
}
getNumThreads(CGF, D.getInnermostCapturedStmt(), NTPtr, UpperBound,
UpperBoundOnly, CondVal);
Expand Down Expand Up @@ -6424,7 +6428,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
ThreadLimitExpr);
}
if (D.hasClausesOfKind<OMPNumThreadsClause>()) {
CodeGenFunction::RunCleanupsScope NumThreadsScope(CGF);
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CodeGen/CGStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5259,7 +5259,7 @@ void CodeGenFunction::EmitOMPTargetTaskBasedDirective(
// enclosing this target region. This will indirectly set the thread_limit
// for every applicable construct within target region.
CGF.CGM.getOpenMPRuntime().emitThreadLimitClause(
CGF, TL->getThreadLimit(), S.getBeginLoc());
CGF, TL->getThreadLimit().front(), S.getBeginLoc());
}
BodyGen(CGF);
};
Expand Down Expand Up @@ -6860,7 +6860,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
if (NT || TL) {
const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;
const Expr *ThreadLimit = TL ? TL->getThreadLimit().front() : nullptr;

CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,
S.getBeginLoc());
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3175,7 +3175,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
case OMPC_num_tasks:
Expand Down Expand Up @@ -3332,6 +3331,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
: ParseOpenMPClause(CKind, WrongDirective);
break;
case OMPC_num_teams:
case OMPC_thread_limit:
if (!FirstClause) {
Diag(Tok, diag::err_omp_more_one_clause)
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
Expand Down
Loading

0 comments on commit 1c26992

Please sign in to comment.