Skip to content

Commit 19bb017

Browse files
authored
[SYCL] Mark parallel_for_work_item even when called indirectly (#12805)
Previously, we would mark parallel_for_work_item FunctionDecl only when called directly from a parallel_for_work_group region. This change marks it when called even indirectly.
1 parent ea400f7 commit 19bb017

File tree

3 files changed

+63
-6
lines changed

3 files changed

+63
-6
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,25 @@ class DeviceFunctionTracker {
726726
}
727727
};
728728

729+
/// This function checks whether given DeclContext contains a topmost
730+
/// namespace with name "sycl".
731+
static bool isDeclaredInSYCLNamespace(const Decl *D) {
732+
const DeclContext *DC = D->getDeclContext()->getEnclosingNamespaceContext();
733+
const auto *ND = dyn_cast<NamespaceDecl>(DC);
734+
// If this is not a namespace, then we are done.
735+
if (!ND)
736+
return false;
737+
738+
// While it is a namespace, find its parent scope.
739+
while (const DeclContext *Parent = ND->getParent()) {
740+
if (!isa<NamespaceDecl>(Parent))
741+
break;
742+
ND = cast<NamespaceDecl>(Parent);
743+
}
744+
745+
return ND && ND->getName() == "sycl";
746+
}
747+
729748
// This type does the heavy lifting for the management of device functions,
730749
// recursive function detection, and attribute collection for a single
731750
// kernel/external function. It walks the callgraph to find all functions that
@@ -770,6 +789,20 @@ class SingleDeviceFunctionTracker {
770789
Parent.SemaRef.addFDToReachableFromSyclDevice(CurrentDecl,
771790
CallStack.back());
772791

792+
// If this is a parallel_for_work_item that is declared in the
793+
// sycl namespace, mark it with the WorkItem scope attribute.
794+
// Note: Here, we assume that this is called from within a
795+
// parallel_for_work_group; it is undefined to call it otherwise.
796+
// We deliberately do not diagnose a violation.
797+
if (CurrentDecl->getIdentifier() &&
798+
CurrentDecl->getIdentifier()->getName() == "parallel_for_work_item" &&
799+
isDeclaredInSYCLNamespace(CurrentDecl) &&
800+
!CurrentDecl->hasAttr<SYCLScopeAttr>()) {
801+
CurrentDecl->addAttr(
802+
SYCLScopeAttr::CreateImplicit(Parent.SemaRef.getASTContext(),
803+
SYCLScopeAttr::Level::WorkItem));
804+
}
805+
773806
// We previously thought we could skip this function if we'd seen it before,
774807
// but if we haven't seen it before in this call graph, we can end up
775808
// missing a recursive call. SO, we have to revisit call-graphs we've
@@ -919,14 +952,13 @@ class MarkWIScopeFnVisitor : public RecursiveASTVisitor<MarkWIScopeFnVisitor> {
919952
// not a member of sycl::group - continue search
920953
return true;
921954
auto Name = Callee->getName();
922-
if (((Name != "parallel_for_work_item") && (Name != "wait_for")) ||
955+
if (Name != "wait_for" ||
923956
Callee->hasAttr<SYCLScopeAttr>())
924957
return true;
925-
// it is a call to sycl::group::parallel_for_work_item/wait_for -
926-
// mark the callee
958+
// it is a call to sycl::group::wait_for - mark the callee
927959
Callee->addAttr(
928960
SYCLScopeAttr::CreateImplicit(Ctx, SYCLScopeAttr::Level::WorkItem));
929-
// continue search as there can be other PFWI or wait_for calls
961+
// continue search as there can be other wait_for calls
930962
return true;
931963
}
932964

@@ -2968,7 +3000,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
29683000

29693001
assert(CallOperator && "non callable object is passed as kernel obj");
29703002
// Mark the function that it "works" in a work group scope:
2971-
// NOTE: In case of parallel_for_work_item the marker call itself is
3003+
// NOTE: In case of wait_for the marker call itself is
29723004
// marked with work item scope attribute, here the '()' operator of the
29733005
// object passed as parameter is marked. This is an optimization -
29743006
// there are a lot of locals created at parallel_for_work_group
@@ -2979,7 +3011,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
29793011
if (!CallOperator->hasAttr<SYCLScopeAttr>()) {
29803012
CallOperator->addAttr(SYCLScopeAttr::CreateImplicit(
29813013
SemaRef.getASTContext(), SYCLScopeAttr::Level::WorkGroup));
2982-
// Search and mark parallel_for_work_item calls:
3014+
// Search and mark wait_for calls:
29833015
MarkWIScopeFnVisitor MarkWIScope(SemaRef.getASTContext());
29843016
MarkWIScope.TraverseDecl(CallOperator);
29853017
// Now mark local variables declared in the PFWG lambda with work group

clang/test/CodeGenSYCL/Inputs/sycl.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ template <int dimensions = 1>
3333
class __SYCL_TYPE(group) group {
3434
public:
3535
group() = default; // fake constructor
36+
// Dummy parallel_for_work_item function to mimic calls from
37+
// parallel_for_work_group.
38+
void parallel_for_work_item() {
39+
}
3640
};
3741

3842
namespace access {
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown -internal-isystem %S/Inputs -emit-llvm %s -o - | FileCheck %s
2+
// This test checks if the parallel_for_work_item called indirecly from
3+
// parallel_for_work_group gets the work_item_scope marker on it.
4+
#include <sycl.hpp>
5+
6+
void foo(sycl::group<1> work_group) {
7+
work_group.parallel_for_work_item();
8+
}
9+
10+
int main(int argc, char **argv) {
11+
sycl::queue q;
12+
q.submit([&](sycl::handler &cgh) {
13+
cgh.parallel_for_work_group(
14+
sycl::range<1>{1}, sycl::range<1>{1024}, ([=](sycl::group<1> wGroup) {
15+
foo(wGroup);
16+
}));
17+
});
18+
return 0;
19+
}
20+
21+
// CHECK: define {{.*}} void @{{.*}}sycl{{.*}}group{{.*}}parallel_for_work_item{{.*}}(ptr addrspace(4) noundef align 1 dereferenceable_or_null(1) %this) {{.*}}!work_item_scope {{.*}}!parallel_for_work_item

0 commit comments

Comments
 (0)