Skip to content

Commit

Permalink
[SYCL] Optimize user-defined reductions on trivial types
Browse files Browse the repository at this point in the history
Technically, it should be trivially copyable types but it seems
group_shift_left is overly restrictive.
  • Loading branch information
aelovikov-intel committed Jan 20, 2024
1 parent 7b62154 commit 3e3b384
Showing 1 changed file with 80 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

#pragma once

#include <sycl/builtins.hpp> // for sycl::min
#include <sycl/detail/defines.hpp>
#include <sycl/group_algorithm.hpp>
#include <sycl/sycl_span.hpp>

namespace sycl {
inline namespace _V1 {
Expand All @@ -19,18 +21,87 @@ template <typename GroupHelper, typename T, typename BinaryOperation>
T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements,
BinaryOperation binary_op) {
#ifdef __SYCL_DEVICE_ONLY__
T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data());
auto g = group_helper.get_group();
Memory[g.get_local_linear_id()] = x;
group_barrier(g);
T result = Memory[0];
if (g.leader()) {
for (int i = 1; i < num_elements; i++) {
result = binary_op(result, Memory[i]);
// IMPORTANT: num_elements is *guaranteed* to be less than
// g.get_local_linear_id()!

// It seems shift_group_left is overly restrictive and requires trivial types
// instead of trivially copyable.
if constexpr (sycl::detail::is_sub_group<decltype(g)>::value &&
std::is_trivial_v<T>) {
// sycl::ext::oneapi::sub_group isn't sycl::sub_group, and shift_group_left
// only accepts the latter.
auto ndi = sycl::ext::oneapi::experimental::this_nd_item<
decltype(g)::dimensions>();
auto sg = ndi.get_sub_group();
for (size_t offset = num_elements / 2; offset > 0; offset /= 2) {
auto y = shift_group_left(sg, x, offset);
// y is unspecified for the work items with id higher than offset. In
// theory, it might have a value that would cause unwanted side effects in
// binary_op, so only apply the operation in the work items that are using
// "live" values.
if (sg.get_local_linear_id() < offset)
x = binary_op(x, y);
}
return group_broadcast(sg, x);
} else {
T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data());
if constexpr (sycl::detail::is_group<decltype(g)>::value &&
std::is_trivial_v<T>) {
// TODO: Use get_child_group from sycl_ext_oneapi_root_group extension
// once it is implemented instead of this free function.
auto ndi = sycl::ext::oneapi::experimental::this_nd_item<
decltype(g)::dimensions>();
auto sg = ndi.get_sub_group();

auto lid = g.get_local_linear_id();
auto sg_lid = sg.get_local_linear_id();
auto sg_max_size = sg.get_max_local_range()[0];
auto sg_size = sg.get_local_linear_range();

auto sg_leader_lid = group_broadcast(sg, lid);

do {
if (num_elements <= 1) {
return group_broadcast(g, x);
}

if (num_elements % sg_max_size != 0) {
// The way work group is split into sub-groups is implementation
// defined and handling it in generic way would complicate things.
break;
}

auto sg_num_elements =
sg_leader_lid < num_elements
? sycl::min<size_t>(sg_size, num_elements - sg_leader_lid)
: 0;
group_barrier(g);
if (sg_num_elements > 0) {
auto reduce_sg = reduce_over_group_impl(
group_with_scratchpad<sycl::sub_group, 0>{
sg, sycl::span<std::byte, 0>{}},
x, sg_num_elements, binary_op);
Memory[sg.get_group_linear_id()] = reduce_sg;
}
group_barrier(g);
num_elements = num_elements / sg_max_size;
if (lid < num_elements)
x = Memory[lid];
} while (true);
}

Memory[g.get_local_linear_id()] = x;
group_barrier(g);
T result = Memory[0];
if (g.leader()) {
for (int i = 1; i < num_elements; i++) {
result = binary_op(result, Memory[i]);
}
}
group_barrier(g);
return group_broadcast(g, result);
}
group_barrier(g);
return group_broadcast(g, result);
#else
std::ignore = group_helper;
std::ignore = x;
Expand Down

0 comments on commit 3e3b384

Please sign in to comment.