Skip to content

Commit

Permalink
Merge pull request #838 from maarquitos14/maronas/fix_scan_overflow
Browse files Browse the repository at this point in the history
Fix overflow issues in scan tests.
  • Loading branch information
bader authored Dec 5, 2023
2 parents 3b98e3e + f7ef992 commit d3d479a
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions tests/group_functions/group_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
template <int D, typename T, typename U, typename I, typename OpT>
class joint_scan_group_kernel;

// This should never be higher than std::numeric_limits<T>::max() for the
// smallest type tested. Currently, the smallest type tested is
// char/int8_t, so it shouldn't be higher than 127.
constexpr int init = 42;
constexpr size_t test_size = 12;

Expand Down Expand Up @@ -54,14 +57,25 @@ auto joint_exclusive_scan_helper(Group group, T* v_begin, T* v_end,
op);
}

template <typename T, typename U>
template <typename T, typename U, typename I, typename OpT>
struct JointScanDataStruct {
JointScanDataStruct(size_t range_size)
JointScanDataStruct(size_t range_size, OpT op, bool with_init)
: ref_input(range_size), res(range_size * 4, U(-1)) {
std::iota(ref_input.begin(), ref_input.end(), T(1));
if constexpr (std::is_same_v<OpT, sycl::multiplies<I>> ||
std::is_same_v<OpT, sycl::plus<I>>) {
auto identity = sycl::known_identity_v<OpT, I>;
auto acc = with_init ? I{init} : identity;
for (size_t i = 0; i < range_size; ++i) {
I tmp = op(I(acc), I(ref_input[i]));
if (tmp > std::numeric_limits<U>::max()) {
ref_input[i] = identity;
}
acc = op(acc, ref_input[i]);
}
}
}

template <typename I, typename OpT>
void check_results(size_t range_size, OpT op, const std::string& op_name,
bool with_init) {
CHECK(end[0]);
Expand Down Expand Up @@ -128,7 +142,7 @@ template <int D, typename T, typename U, typename I = U, typename OpT>
void check_scan(sycl::queue& queue, size_t size,
sycl::nd_range<D> executionRange, OpT op,
const std::string& op_name, bool with_init) {
JointScanDataStruct<T, U> host_data{size};
JointScanDataStruct<T, U, I, OpT> host_data{size, op, with_init};
{
sycl::buffer<T, 1> ref_input_sycl = host_data.create_ref_input_buffer();
sycl::buffer<U, 1> res_sycl = host_data.create_res_buffer();
Expand Down Expand Up @@ -180,7 +194,7 @@ void check_scan(sycl::queue& queue, size_t size,
.wait_and_throw();
}

host_data.template check_results<I>(size, op, op_name, with_init);
host_data.check_results(size, op, op_name, with_init);
}

/**
Expand Down Expand Up @@ -434,8 +448,6 @@ template <int D, typename T, typename U = T, typename OpT>
void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
const std::string& op_name, bool with_init) {
auto range_size = range.size();
REQUIRE(((range_size * (range_size + 1) / 2) + T(init)) <=
std::numeric_limits<T>::max());

ScanOverGroupDataStruct<T, U> host_data{range_size};
{
Expand Down

0 comments on commit d3d479a

Please sign in to comment.