From 1ffd9a224b768af11e085988e17b7ccd8f7ec9e7 Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Sun, 3 Dec 2023 17:18:33 +0800 Subject: [PATCH] Add support for num_threads clause in OpenMP parallel pragma --- tools/cgeist/Lib/CGStmt.cc | 45 ++++++++++++------- .../Test/Verification/ompParallelNumThreads.c | 12 +++++ 2 files changed, 40 insertions(+), 17 deletions(-) create mode 100644 tools/cgeist/Test/Verification/ompParallelNumThreads.c diff --git a/tools/cgeist/Lib/CGStmt.cc b/tools/cgeist/Lib/CGStmt.cc index e8cdfbdb5e8c..acaad6499238 100644 --- a/tools/cgeist/Lib/CGStmt.cc +++ b/tools/cgeist/Lib/CGStmt.cc @@ -537,24 +537,8 @@ MLIRScanner::VisitOMPParallelDirective(clang::OMPParallelDirective *par) { IfScope scope(*this); auto loc = getMLIRLocation(par->getBeginLoc()); - auto affineOp = builder.create(loc); - - auto oldpoint = builder.getInsertionPoint(); - auto *oldblock = builder.getInsertionBlock(); - - affineOp.getRegion().push_back(new Block()); - builder.setInsertionPointToStart(&affineOp.getRegion().front()); - - auto executeRegion = - builder.create(loc, ArrayRef()); - executeRegion.getRegion().push_back(new Block()); - builder.create(loc); - builder.setInsertionPointToStart(&executeRegion.getRegion().back()); - - auto *oldScope = allocationScope; - allocationScope = &executeRegion.getRegion().back(); - std::map prevInduction; + Value numThreads; for (auto *f : par->clauses()) { switch (f->getClauseKind()) { case llvm::omp::OMPC_private: @@ -582,11 +566,38 @@ MLIRScanner::VisitOMPParallelDirective(clang::OMPParallelDirective *par) { params[name].store(loc, builder, prevInduction[name], isArray); } break; + case llvm::omp::OMPC_num_threads: { + auto *numThreadsClause = cast(f); + numThreadsClause->getNumThreads(); + numThreads = + Visit(numThreadsClause->getNumThreads()).getValue(loc, builder); + break; + } default: llvm::errs() << "may not handle omp clause " << (int)f->getClauseKind() << "\n"; } } + auto affineOp = builder.create( + loc, /*if_expr_var*/ Value{}, numThreads, /*allocate_vars*/ ValueRange{}, + /*allocators_vars*/ ValueRange{}, /*reduction_vars*/ ValueRange{}, + /*reductions*/ ArrayAttr{}, + /*proc_bind_val*/ omp::ClauseProcBindKindAttr{}); + + auto oldpoint = builder.getInsertionPoint(); + auto *oldblock = builder.getInsertionBlock(); + + affineOp.getRegion().push_back(new Block()); + builder.setInsertionPointToStart(&affineOp.getRegion().front()); + + auto executeRegion = + builder.create(loc, ArrayRef()); + executeRegion.getRegion().push_back(new Block()); + builder.create(loc); + builder.setInsertionPointToStart(&executeRegion.getRegion().back()); + + auto *oldScope = allocationScope; + allocationScope = &executeRegion.getRegion().back(); Visit(cast(par->getAssociatedStmt()) ->getCapturedDecl() diff --git a/tools/cgeist/Test/Verification/ompParallelNumThreads.c b/tools/cgeist/Test/Verification/ompParallelNumThreads.c new file mode 100644 index 000000000000..608035901910 --- /dev/null +++ b/tools/cgeist/Test/Verification/ompParallelNumThreads.c @@ -0,0 +1,12 @@ +// RUN: cgeist %s --function=* -fopenmp -S | FileCheck %s +#include + +void test_parallel_num_threads(double* x, int sinc) { + // CHECK: %[[c32:.+]] = arith.constant 32 : i32 + // CHECK: omp.parallel num_threads(%[[c32]] : i32) { + #pragma omp parallel num_threads(32) + { + int tid = omp_get_thread_num(); + x[tid] = 1; + } +}