Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[group] Refine scan_over_group for sub-group #839

Merged
merged 5 commits into from
Dec 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 81 additions & 42 deletions tests/group_functions/group_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ struct ScanOverGroupDataStruct {
ScanOverGroupDataStruct(size_t range_size)
: ref_input(range_size),
res(range_size * 4, T(-1)),
local_id(range_size * 2, 0) {
local_id(range_size, 0),
sub_group_id(range_size, 0) {
std::iota(ref_input.begin(), ref_input.end(), U(1));
}

Expand All @@ -337,35 +338,70 @@ struct ScanOverGroupDataStruct {
T init_value = with_init ? T(init) : sycl::known_identity<OpT, T>::value;
// res consists of 4 series of results: two pairs of exclusive and inclusive
// scan results made over 'group' and 'sub_group' accordingly.
for (int group_i = 0; group_i < 2; group_i++) {
std::string group_name = group_i == 0 ? "group" : "sub_group";
size_t group_offset = range_size * group_i;
{
std::vector<T> reference(range_size, T(-1));
// There is only one work-group so we can scan over all the input data.
std::exclusive_scan(ref_input.begin(), ref_input.end(), reference.begin(),
init_value, op);
for (int i = 0; i < range_size; i++) {
int shift = i - local_id[i + group_offset];
auto startIter = ref_input.begin() + shift;
// Each group contains two sets of results.
size_t res_i = i + 2 * group_offset;
int res_i = i;
INFO("Check exclusive_scan_over_group on group for element " +
std::to_string(i) + " (Operator: " + op_name + ")");
INFO("Result: " + std::to_string(res[res_i]));
INFO("Expected: " + std::to_string(reference[i]));
CHECK(res[res_i] == reference[i]);
}
std::inclusive_scan(ref_input.begin(), ref_input.end(), reference.begin(),
op, init_value);
for (int i = 0; i < range_size; i++) {
int res_i = range_size + i;
INFO("Check inclusive_scan_over_group on group for element " +
std::to_string(i) + " (Operator: " + op_name + ")");
INFO("Result: " + std::to_string(res[res_i]));
INFO("Expected: " + std::to_string(reference[i]));
CHECK(res[res_i] == reference[i]);
}
}
{
// Mapping from "sub-group id" to "vector of input data (ordered by item
// linear id within the sub-group)"
std::unordered_map<size_t, std::vector<T>> ref_input_per_sub_group;
for (int i = 0; i < range_size; i++) {
size_t sgid = sub_group_id[i];
size_t lid = local_id[i];
std::vector<T>& input_vec = ref_input_per_sub_group[sgid];
// Extend input vector dynamically.
if (input_vec.size() <= lid) input_vec.resize(lid + 1);
// Place the data identified by (sgid, lid).
input_vec[lid] = ref_input[i];
}
// Compute the reference results and verify.
for (int i = 0; i < range_size; i++) {
size_t sgid = sub_group_id[i];
size_t lid = local_id[i];
const std::vector<T>& input_vec = ref_input_per_sub_group[sgid];
// Scan over the first (lid + 1) elements of input_vec to obtain the
// result identified by i.
std::vector<T> reference(lid + 1, T(-1));
std::exclusive_scan(input_vec.begin(), input_vec.begin() + lid + 1,
reference.begin(), init_value, op);
{
INFO("Check exclusive_scan_over_group on " + group_name +
" for element " + std::to_string(i) + " (Operator: " + op_name +
")");
std::vector<T> reference(i + 1, T(-1));
std::exclusive_scan(startIter, ref_input.begin() + i + 1,
reference.begin(), init_value, op);
int res_i = range_size * 2 + i;
INFO("Check exclusive_scan_over_group on sub_group for element " +
std::to_string(i) + " (Operator: " + op_name + ")");
INFO("Result: " + std::to_string(res[res_i]));
INFO("Expected: " + std::to_string(reference[i - shift]));
CHECK(res[res_i] == reference[i - shift]);
INFO("Expected: " + std::to_string(reference[lid]));
CHECK(res[res_i] == reference[lid]);
}
std::inclusive_scan(input_vec.begin(), input_vec.begin() + lid + 1,
reference.begin(), op, init_value);
{
INFO("Check inclusive_scan_over_group on " + group_name +
" for element " + std::to_string(i) + " (Operator: " + op_name +
")");
std::vector<T> reference(i + 1, T(-1));
std::inclusive_scan(startIter, ref_input.begin() + i + 1,
reference.begin(), op, init_value);
INFO("Result: " + std::to_string(res[res_i + range_size]));
INFO("Expected: " + std::to_string(reference[i - shift]));
CHECK(res[res_i + range_size] == reference[i - shift]);
int res_i = range_size * 3 + i;
INFO("Check inclusive_scan_over_group on sub_group for element " +
std::to_string(i) + " (Operator: " + op_name + ")");
INFO("Result: " + std::to_string(res[res_i]));
INFO("Expected: " + std::to_string(reference[lid]));
CHECK(res[res_i] == reference[lid]);
}
}
}
Expand All @@ -383,10 +419,15 @@ struct ScanOverGroupDataStruct {
return {local_id.data(), local_id.size()};
}

sycl::buffer<size_t, 1> create_sub_group_id_buffer() {
return {sub_group_id.data(), sub_group_id.size()};
}

std::vector<U> ref_input;
std::vector<T> res;
bool ret_type[4] = {false, false, false, false};
std::vector<size_t> local_id;
std::vector<size_t> sub_group_id;
};

template <int D, typename T, typename U = T, typename OpT>
Expand All @@ -402,6 +443,7 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
auto res_sycl = host_data.create_res_buffer();
auto ret_type_sycl = host_data.create_ret_type_buffer();
auto local_id_sycl = host_data.create_local_id_buffer();
auto sub_group_id_sycl = host_data.create_sub_group_id_buffer();

queue
.submit([&](sycl::handler& cgh) {
Expand All @@ -410,18 +452,14 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
sycl::accessor<T, 1> res_acc(res_sycl, cgh);
sycl::accessor<bool, 1> ret_type_acc(ret_type_sycl, cgh);
sycl::accessor<size_t, 1> local_id_acc(local_id_sycl, cgh);
sycl::accessor<size_t, 1> sub_group_id_acc(sub_group_id_sycl, cgh);

cgh.parallel_for<scan_over_group_kernel<D, T, U, OpT>>(
sycl::nd_range<D>(range, range), [=](sycl::nd_item<D> item) {
sycl::group<D> group = item.get_group();
sycl::sub_group sub_group = item.get_sub_group();

// Use the local id of the item in the group to place results of
// the scan operation in the order of the items.
auto g_index = group.get_group_linear_id() *
group.get_local_linear_range() +
group.get_local_linear_id();
local_id_acc[g_index] = group.get_local_linear_id();
auto g_index = item.get_global_linear_id();
steffenlarsen marked this conversation as resolved.
Show resolved Hide resolved

auto res_g_e = exclusive_scan_over_group_helper<T>(
group, ref_input_acc[g_index], op, with_init);
Expand All @@ -433,22 +471,23 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
res_acc[range_size + g_index] = res_g_i;
ret_type_acc[1] = std::is_same_v<T, decltype(res_g_i)>;

// Use the local id of the item in the sub-group to place
// results of the scan operation in the order of the items.
auto sg_index = sub_group.get_group_linear_id() *
sub_group.get_local_linear_range() +
sub_group.get_local_linear_id();
local_id_acc[range_size + sg_index] =
sub_group.get_local_linear_id();
// Input data is indexed by global linear id of item (g_index),
// however, sub-group partitioning and ordering are
// implementation-defined.
// Here we store both the sub-group id and item linear id within
// the sub-group so that we could recover the sub-group
// construction when verifying.
sub_group_id_acc[g_index] = sub_group.get_group_linear_id();
local_id_acc[g_index] = sub_group.get_local_linear_id();

auto res_sg_e = exclusive_scan_over_group_helper<T>(
sub_group, ref_input_acc[sg_index], op, with_init);
res_acc[range_size * 2 + sg_index] = res_sg_e;
sub_group, ref_input_acc[g_index], op, with_init);
res_acc[range_size * 2 + g_index] = res_sg_e;
ret_type_acc[2] = std::is_same_v<T, decltype(res_sg_e)>;

auto res_sg_i = inclusive_scan_over_group_helper<T>(
sub_group, ref_input_acc[sg_index], op, with_init);
res_acc[range_size * 3 + sg_index] = res_sg_i;
sub_group, ref_input_acc[g_index], op, with_init);
res_acc[range_size * 3 + g_index] = res_sg_i;
ret_type_acc[3] = std::is_same_v<T, decltype(res_sg_i)>;
});
})
Expand Down