diff --git a/tests/group_functions/group_scan.h b/tests/group_functions/group_scan.h index d7c81d1d5..593ace793 100644 --- a/tests/group_functions/group_scan.h +++ b/tests/group_functions/group_scan.h @@ -25,6 +25,9 @@ template class joint_scan_group_kernel; +// This should never be higher than std::numeric_limits::max() for the +// smallest type tested. Currently, the smallest type tested is +// char/int8_t, so it shouldn't be higher than 127. constexpr int init = 42; constexpr size_t test_size = 12; @@ -54,14 +57,25 @@ auto joint_exclusive_scan_helper(Group group, T* v_begin, T* v_end, op); } -template +template struct JointScanDataStruct { - JointScanDataStruct(size_t range_size) + JointScanDataStruct(size_t range_size, OpT op, bool with_init) : ref_input(range_size), res(range_size * 4, U(-1)) { std::iota(ref_input.begin(), ref_input.end(), T(1)); + if constexpr (std::is_same_v> || + std::is_same_v>) { + auto identity = sycl::known_identity_v; + auto acc = with_init ? I{init} : identity; + for (size_t i = 0; i < range_size; ++i) { + I tmp = op(I(acc), I(ref_input[i])); + if (tmp > std::numeric_limits::max()) { + ref_input[i] = identity; + } + acc = op(acc, ref_input[i]); + } + } } - template void check_results(size_t range_size, OpT op, const std::string& op_name, bool with_init) { CHECK(end[0]); @@ -128,7 +142,7 @@ template void check_scan(sycl::queue& queue, size_t size, sycl::nd_range executionRange, OpT op, const std::string& op_name, bool with_init) { - JointScanDataStruct host_data{size}; + JointScanDataStruct host_data{size, op, with_init}; { sycl::buffer ref_input_sycl = host_data.create_ref_input_buffer(); sycl::buffer res_sycl = host_data.create_res_buffer(); @@ -180,7 +194,7 @@ void check_scan(sycl::queue& queue, size_t size, .wait_and_throw(); } - host_data.template check_results(size, op, op_name, with_init); + host_data.check_results(size, op, op_name, with_init); } /** @@ -434,8 +448,6 @@ template void check_scan_over_group(sycl::queue& queue, sycl::range range, OpT op, const std::string& op_name, bool with_init) { auto range_size = range.size(); - REQUIRE(((range_size * (range_size + 1) / 2) + T(init)) <= - std::numeric_limits::max()); ScanOverGroupDataStruct host_data{range_size}; {