diff --git a/sycl/include/sycl/ext/oneapi/experimental/user_defined_reductions.hpp b/sycl/include/sycl/ext/oneapi/experimental/user_defined_reductions.hpp index c7993175556f1..0b3a935cacbfc 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/user_defined_reductions.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/user_defined_reductions.hpp @@ -8,8 +8,10 @@ #pragma once +#include // for sycl::min #include #include +#include namespace sycl { inline namespace _V1 { @@ -19,18 +21,87 @@ template T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements, BinaryOperation binary_op) { #ifdef __SYCL_DEVICE_ONLY__ - T *Memory = reinterpret_cast(group_helper.get_memory().data()); auto g = group_helper.get_group(); - Memory[g.get_local_linear_id()] = x; - group_barrier(g); - T result = Memory[0]; - if (g.leader()) { - for (int i = 1; i < num_elements; i++) { - result = binary_op(result, Memory[i]); + // IMPORTANT: num_elements is *guaranteed* to be less than + // g.get_local_linear_id()! + + // It seems shift_group_left is overly restrictive and requires trivial types + // instead of trivially copyable. + if constexpr (sycl::detail::is_sub_group::value && + std::is_trivial_v) { + // sycl::ext::oneapi::sub_group isn't sycl::sub_group, and shift_group_left + // only accepts the latter. + auto ndi = sycl::ext::oneapi::experimental::this_nd_item< + decltype(g)::dimensions>(); + auto sg = ndi.get_sub_group(); + for (size_t offset = num_elements / 2; offset > 0; offset /= 2) { + auto y = shift_group_left(sg, x, offset); + // y is unspecified for the work items with id higher than offset. In + // theory, it might have a value that would cause unwanted side effects in + // binary_op, so only apply the operation in the work items that are using + // "live" values. + if (sg.get_local_linear_id() < offset) + x = binary_op(x, y); } + return group_broadcast(sg, x); + } else { + T *Memory = reinterpret_cast(group_helper.get_memory().data()); + if constexpr (sycl::detail::is_group::value && + std::is_trivial_v) { + // TODO: Use get_child_group from sycl_ext_oneapi_root_group extension + // once it is implemented instead of this free function. + auto ndi = sycl::ext::oneapi::experimental::this_nd_item< + decltype(g)::dimensions>(); + auto sg = ndi.get_sub_group(); + + auto lid = g.get_local_linear_id(); + auto sg_lid = sg.get_local_linear_id(); + auto sg_max_size = sg.get_max_local_range()[0]; + auto sg_size = sg.get_local_linear_range(); + + auto sg_leader_lid = group_broadcast(sg, lid); + + do { + if (num_elements <= 1) { + return group_broadcast(g, x); + } + + if (num_elements % sg_max_size != 0) { + // The way work group is split into sub-groups is implementation + // defined and handling it in generic way would complicate things. + break; + } + + auto sg_num_elements = + sg_leader_lid < num_elements + ? sycl::min(sg_size, num_elements - sg_leader_lid) + : 0; + group_barrier(g); + if (sg_num_elements > 0) { + auto reduce_sg = reduce_over_group_impl( + group_with_scratchpad{ + sg, sycl::span{}}, + x, sg_num_elements, binary_op); + Memory[sg.get_group_linear_id()] = reduce_sg; + } + group_barrier(g); + num_elements = num_elements / sg_max_size; + if (lid < num_elements) + x = Memory[lid]; + } while (true); + } + + Memory[g.get_local_linear_id()] = x; + group_barrier(g); + T result = Memory[0]; + if (g.leader()) { + for (int i = 1; i < num_elements; i++) { + result = binary_op(result, Memory[i]); + } + } + group_barrier(g); + return group_broadcast(g, result); } - group_barrier(g); - return group_broadcast(g, result); #else std::ignore = group_helper; std::ignore = x;