Skip to content

Commit 77a25db

Browse files
[SYCL] Generalize group_algorithm helpers (#12726)
This commit generalizes two helper functions in group_algorithm.hpp to make it so they can also handle non-uniform groups. --------- Signed-off-by: Larsen, Steffen <steffen.larsen@intel.com>
1 parent c90de3c commit 77a25db

File tree

1 file changed

+13
-30
lines changed

1 file changed

+13
-30
lines changed

sycl/include/sycl/group_algorithm.hpp

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -59,42 +59,25 @@ template <> inline id<3> linear_id_to_id(range<3> r, size_t linear_id) {
5959
}
6060

6161
// ---- get_local_linear_range
62-
template <typename Group> size_t get_local_linear_range(Group g);
63-
template <> inline size_t get_local_linear_range<group<1>>(group<1> g) {
64-
return g.get_local_range(0);
65-
}
66-
template <> inline size_t get_local_linear_range<group<2>>(group<2> g) {
67-
return g.get_local_range(0) * g.get_local_range(1);
68-
}
69-
template <> inline size_t get_local_linear_range<group<3>>(group<3> g) {
70-
return g.get_local_range(0) * g.get_local_range(1) * g.get_local_range(2);
71-
}
72-
template <>
73-
inline size_t get_local_linear_range<sycl::sub_group>(sycl::sub_group g) {
74-
return g.get_local_range()[0];
62+
template <typename Group> inline auto get_local_linear_range(Group g) {
63+
auto local_range = g.get_local_range();
64+
auto result = local_range[0];
65+
for (size_t i = 1; i < Group::dimensions; ++i)
66+
result *= local_range[i];
67+
return result;
7568
}
7669

7770
// ---- get_local_linear_id
78-
template <typename Group>
79-
inline typename Group::linear_id_type get_local_linear_id(Group g);
80-
71+
template <typename Group> inline auto get_local_linear_id(Group g) {
8172
#ifdef __SYCL_DEVICE_ONLY__
82-
#define __SYCL_GROUP_GET_LOCAL_LINEAR_ID(D) \
83-
template <> \
84-
inline group<D>::linear_id_type get_local_linear_id<group<D>>(group<D>) { \
85-
nd_item<D> it = sycl::detail::Builder::getNDItem<D>(); \
86-
return it.get_local_linear_id(); \
73+
if constexpr (std::is_same_v<Group, group<1>> ||
74+
std::is_same_v<Group, group<2>> ||
75+
std::is_same_v<Group, group<3>>) {
76+
auto it = sycl::detail::Builder::getNDItem<Group::dimensions>();
77+
return it.get_local_linear_id();
8778
}
88-
__SYCL_GROUP_GET_LOCAL_LINEAR_ID(1);
89-
__SYCL_GROUP_GET_LOCAL_LINEAR_ID(2);
90-
__SYCL_GROUP_GET_LOCAL_LINEAR_ID(3);
91-
#undef __SYCL_GROUP_GET_LOCAL_LINEAR_ID
9279
#endif // __SYCL_DEVICE_ONLY__
93-
94-
template <>
95-
inline sycl::sub_group::linear_id_type
96-
get_local_linear_id<sycl::sub_group>(sycl::sub_group g) {
97-
return g.get_local_id()[0];
80+
return g.get_local_linear_id();
9881
}
9982

10083
// ---- is_native_op

0 commit comments

Comments
 (0)