Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Optimize user-defined reductions on trivial types #12314

Draft
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

#pragma once

#include <sycl/builtins.hpp> // for sycl::min
#include <sycl/detail/defines.hpp>
#include <sycl/group_algorithm.hpp>
#include <sycl/sycl_span.hpp>

namespace sycl {
inline namespace _V1 {
Expand All @@ -19,18 +21,87 @@ template <typename GroupHelper, typename T, typename BinaryOperation>
T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements,
BinaryOperation binary_op) {
#ifdef __SYCL_DEVICE_ONLY__
T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data());
auto g = group_helper.get_group();
Memory[g.get_local_linear_id()] = x;
group_barrier(g);
T result = Memory[0];
if (g.leader()) {
for (int i = 1; i < num_elements; i++) {
result = binary_op(result, Memory[i]);
// IMPORTANT: num_elements is *guaranteed* to be less than
// g.get_local_linear_id()!

// It seems shift_group_left is overly restrictive and requires trivial types
// instead of trivially copyable.
if constexpr (sycl::detail::is_sub_group<decltype(g)>::value &&
std::is_trivial_v<T>) {
// sycl::ext::oneapi::sub_group isn't sycl::sub_group, and shift_group_left
// only accepts the latter.
auto ndi = sycl::ext::oneapi::experimental::this_nd_item<
decltype(g)::dimensions>();
auto sg = ndi.get_sub_group();
for (size_t offset = num_elements / 2; offset > 0; offset /= 2) {
auto y = shift_group_left(sg, x, offset);
// y is unspecified for the work items with id higher than offset. In
// theory, it might have a value that would cause unwanted side effects in
// binary_op, so only apply the operation in the work items that are using
// "live" values.
if (sg.get_local_linear_id() < offset)
x = binary_op(x, y);
}
return group_broadcast(sg, x);
} else {
T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data());
if constexpr (sycl::detail::is_group<decltype(g)>::value &&
std::is_trivial_v<T>) {
// TODO: Use get_child_group from sycl_ext_oneapi_root_group extension
// once it is implemented instead of this free function.
auto ndi = sycl::ext::oneapi::experimental::this_nd_item<
decltype(g)::dimensions>();
auto sg = ndi.get_sub_group();

auto lid = g.get_local_linear_id();
auto sg_lid = sg.get_local_linear_id();
auto sg_max_size = sg.get_max_local_range()[0];
auto sg_size = sg.get_local_linear_range();

auto sg_leader_lid = group_broadcast(sg, lid);

do {
if (num_elements <= 1) {
return group_broadcast(g, x);
}

if (num_elements % sg_max_size != 0) {
// The way work group is split into sub-groups is implementation
// defined and handling it in generic way would complicate things.
break;
}

auto sg_num_elements =
sg_leader_lid < num_elements
? sycl::min<size_t>(sg_size, num_elements - sg_leader_lid)
: 0;
group_barrier(g);
if (sg_num_elements > 0) {
auto reduce_sg = reduce_over_group_impl(
group_with_scratchpad<sycl::sub_group, 0>{
sg, sycl::span<std::byte, 0>{}},
x, sg_num_elements, binary_op);
Memory[sg.get_group_linear_id()] = reduce_sg;
}
group_barrier(g);
num_elements = num_elements / sg_max_size;
if (lid < num_elements)
x = Memory[lid];
} while (true);
}

Memory[g.get_local_linear_id()] = x;
group_barrier(g);
T result = Memory[0];
if (g.leader()) {
for (int i = 1; i < num_elements; i++) {
result = binary_op(result, Memory[i]);
}
}
group_barrier(g);
return group_broadcast(g, result);
}
group_barrier(g);
return group_broadcast(g, result);
#else
std::ignore = group_helper;
std::ignore = x;
Expand Down
Loading