@@ -726,6 +726,25 @@ class DeviceFunctionTracker {
726
726
}
727
727
};
728
728
729
+ // / This function checks whether given DeclContext contains a topmost
730
+ // / namespace with name "sycl".
731
+ static bool isDeclaredInSYCLNamespace (const Decl *D) {
732
+ const DeclContext *DC = D->getDeclContext ()->getEnclosingNamespaceContext ();
733
+ const auto *ND = dyn_cast<NamespaceDecl>(DC);
734
+ // If this is not a namespace, then we are done.
735
+ if (!ND)
736
+ return false ;
737
+
738
+ // While it is a namespace, find its parent scope.
739
+ while (const DeclContext *Parent = ND->getParent ()) {
740
+ if (!isa<NamespaceDecl>(Parent))
741
+ break ;
742
+ ND = cast<NamespaceDecl>(Parent);
743
+ }
744
+
745
+ return ND && ND->getName () == " sycl" ;
746
+ }
747
+
729
748
// This type does the heavy lifting for the management of device functions,
730
749
// recursive function detection, and attribute collection for a single
731
750
// kernel/external function. It walks the callgraph to find all functions that
@@ -770,6 +789,20 @@ class SingleDeviceFunctionTracker {
770
789
Parent.SemaRef .addFDToReachableFromSyclDevice (CurrentDecl,
771
790
CallStack.back ());
772
791
792
+ // If this is a parallel_for_work_item that is declared in the
793
+ // sycl namespace, mark it with the WorkItem scope attribute.
794
+ // Note: Here, we assume that this is called from within a
795
+ // parallel_for_work_group; it is undefined to call it otherwise.
796
+ // We deliberately do not diagnose a violation.
797
+ if (CurrentDecl->getIdentifier () &&
798
+ CurrentDecl->getIdentifier ()->getName () == " parallel_for_work_item" &&
799
+ isDeclaredInSYCLNamespace (CurrentDecl) &&
800
+ !CurrentDecl->hasAttr <SYCLScopeAttr>()) {
801
+ CurrentDecl->addAttr (
802
+ SYCLScopeAttr::CreateImplicit (Parent.SemaRef .getASTContext (),
803
+ SYCLScopeAttr::Level::WorkItem));
804
+ }
805
+
773
806
// We previously thought we could skip this function if we'd seen it before,
774
807
// but if we haven't seen it before in this call graph, we can end up
775
808
// missing a recursive call. SO, we have to revisit call-graphs we've
@@ -919,14 +952,13 @@ class MarkWIScopeFnVisitor : public RecursiveASTVisitor<MarkWIScopeFnVisitor> {
919
952
// not a member of sycl::group - continue search
920
953
return true ;
921
954
auto Name = Callee->getName ();
922
- if ((( Name != " parallel_for_work_item " ) && (Name != " wait_for" )) ||
955
+ if (Name != " wait_for" ||
923
956
Callee->hasAttr <SYCLScopeAttr>())
924
957
return true ;
925
- // it is a call to sycl::group::parallel_for_work_item/wait_for -
926
- // mark the callee
958
+ // it is a call to sycl::group::wait_for - mark the callee
927
959
Callee->addAttr (
928
960
SYCLScopeAttr::CreateImplicit (Ctx, SYCLScopeAttr::Level::WorkItem));
929
- // continue search as there can be other PFWI or wait_for calls
961
+ // continue search as there can be other wait_for calls
930
962
return true ;
931
963
}
932
964
@@ -2968,7 +3000,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
2968
3000
2969
3001
assert (CallOperator && " non callable object is passed as kernel obj" );
2970
3002
// Mark the function that it "works" in a work group scope:
2971
- // NOTE: In case of parallel_for_work_item the marker call itself is
3003
+ // NOTE: In case of wait_for the marker call itself is
2972
3004
// marked with work item scope attribute, here the '()' operator of the
2973
3005
// object passed as parameter is marked. This is an optimization -
2974
3006
// there are a lot of locals created at parallel_for_work_group
@@ -2979,7 +3011,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
2979
3011
if (!CallOperator->hasAttr <SYCLScopeAttr>()) {
2980
3012
CallOperator->addAttr (SYCLScopeAttr::CreateImplicit (
2981
3013
SemaRef.getASTContext (), SYCLScopeAttr::Level::WorkGroup));
2982
- // Search and mark parallel_for_work_item calls:
3014
+ // Search and mark wait_for calls:
2983
3015
MarkWIScopeFnVisitor MarkWIScope (SemaRef.getASTContext ());
2984
3016
MarkWIScope.TraverseDecl (CallOperator);
2985
3017
// Now mark local variables declared in the PFWG lambda with work group
0 commit comments