diff --git a/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp b/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp index 8cbc88ccf6194..9f1abb84e30ad 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/root_group.hpp @@ -42,6 +42,13 @@ template <> struct detail::PropertyToKind { template <> struct detail::IsCompileTimeProperty : std::true_type {}; +enum class execution_scope { + work_item, + sub_group, + work_group, + root_group, +}; + template class root_group { public: using id_type = id; @@ -78,6 +85,75 @@ template class root_group { bool leader() const { return get_local_id() == 0; }; + template + using checkScopeTy = std::enable_if_t<(Scope == execution_scope::work_item || + Scope == execution_scope::sub_group || + Scope == execution_scope::work_group), + RetTy>; + + template + std::enable_if_t<(Scope == execution_scope::work_item || + Scope == execution_scope::work_group), + id> + get_id() const { + if constexpr (Scope == execution_scope::work_item) + return it.get_global_id(); + else if constexpr (Scope == execution_scope::work_group) + return it.get_group().get_group_id(); + } + + template + std::enable_if_t> get_id() const { + return get_linear_id(); + } + + template + checkScopeTy get_linear_id() const { + if constexpr (Scope == execution_scope::work_item) { + return it.get_global_linear_id(); + } else if constexpr (Scope == execution_scope::sub_group) { + size_t WIId = it.get_global_linear_id(); + size_t SGSize = it.get_sub_group().get_local_linear_range(); + return WIId / SGSize; + } else if constexpr (Scope == execution_scope::work_group) { + return it.get_group().get_group_linear_id(); + } + } + + template + std::enable_if_t<(Scope == execution_scope::work_item || + Scope == execution_scope::work_group), + range> + get_range() const { + if constexpr (Scope == execution_scope::work_item) + return it.get_global_range(); + else if constexpr (Scope == execution_scope::work_group) + return it.get_group().get_group_range(); + } + + template + std::enable_if_t> + get_range() const { + return get_linear_range(); + } + + template + checkScopeTy get_linear_range() const { + if constexpr (Scope == execution_scope::work_item) { + range Range = it.get_global_range(); + size_t linRange = 1; + for (int i = 0; i < Dimensions; ++i) + linRange *= Range[i]; + return linRange; + } else if constexpr (Scope == execution_scope::sub_group) { + uint32_t NumWG = it.get_group().get_group_linear_range(); + uint32_t NumSGPerWG = it.get_sub_group().get_group_linear_range(); + return NumWG * NumSGPerWG; + } else if constexpr (Scope == execution_scope::work_group) { + return it.get_group().get_group_linear_range(); + } + } + private: friend root_group nd_item::ext_oneapi_get_root_group() const; diff --git a/sycl/test-e2e/GroupAlgorithm/root_group.cpp b/sycl/test-e2e/GroupAlgorithm/root_group.cpp index ba0c49fa68bf7..3a1e78b1d3ce6 100644 --- a/sycl/test-e2e/GroupAlgorithm/root_group.cpp +++ b/sycl/test-e2e/GroupAlgorithm/root_group.cpp @@ -75,7 +75,7 @@ void testRootGroupFunctions() { const auto props = sycl::ext::oneapi::experimental::properties{ sycl::ext::oneapi::experimental::use_root_sync}; - constexpr int testCount = 10; + constexpr int testCount = 22; bool *testResults = sycl::malloc_shared(testCount, q); const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize}; q.parallel_for( @@ -107,6 +107,57 @@ void testRootGroupFunctions() { sycl::sub_group>, "get_child_group(sycl::group) must return a sycl::sub_group"); } + + auto SG = it.get_sub_group(); + size_t SGSize = SG.get_local_linear_range(); + using execution_scope = + sycl::ext::oneapi::experimental::execution_scope; + + if (root.leader()) { + size_t NumSGPerWG = SG.get_group_linear_range(); + + testResults[10] = + root.template get_range()[0] == + maxWGs; + testResults[11] = + root.template get_range()[0] == + NumSGPerWG * maxWGs; + testResults[12] = + root.template get_range()[0] == + maxWGs * WorkGroupSize; + + testResults[13] = + root.template get_linear_range() == + maxWGs; + testResults[14] = + root.template get_linear_range() == + NumSGPerWG * maxWGs; + testResults[15] = + root.template get_linear_range() == + maxWGs * WorkGroupSize; + } + + if (root.get_local_id() == 3) { + testResults[16] = + root.template get_id() == + it.get_global_linear_id() / WorkGroupSize; + testResults[17] = + root.template get_id() == + it.get_global_linear_id() / SGSize; + testResults[18] = + root.template get_id() == + it.get_global_linear_id(); + + testResults[19] = + root.template get_linear_id() == + it.get_global_linear_id() / WorkGroupSize; + testResults[20] = + root.template get_linear_id() == + it.get_global_linear_id() / SGSize; + testResults[21] = + root.template get_linear_id() == + it.get_global_linear_id(); + } }); q.wait();