diff --git a/sycl/test-e2e/Matrix/get_coordinate_ops_impl.hpp b/sycl/test-e2e/Matrix/get_coordinate_ops_impl.hpp index a21f3daa47b1a..8cbd24b302171 100644 --- a/sycl/test-e2e/Matrix/get_coordinate_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coordinate_ops_impl.hpp @@ -11,6 +11,23 @@ template class matrix_process; +template +void reduce_and_accumulate(sub_group sg, size_t sg_size, size_t global_idy, + AccessorType &global_acc, TResult *local_sums, + size_t count) { + for (size_t i = 0; i < count; i++) { + local_sums[i] = reduce_over_group(sg, local_sums[i], sycl::plus<>()); + + // Only the subgroup leader performs the global accumulation + if (global_idy % sg_size == 0) { + sycl::atomic_ref + aref(global_acc[i]); + aref.fetch_add(local_sums[i]); + } + } +} + template void matrix_sum(big_matrix &M, @@ -32,7 +49,7 @@ void matrix_sum(big_matrix &M, {1, 1 * sg_size}), [=](nd_item<2> spmd_item) #ifdef SG_SZ - [[intel::reqd_sub_group_size(SG_SZ)]] + [[sycl::reqd_sub_group_size(SG_SZ)]] #endif { // The submatrix API has to be accessed by all the workitems in a @@ -83,29 +100,10 @@ void matrix_sum(big_matrix &M, }); } - for (int i = 0; i < NUM_ROWS; i++) { - sum_local_rows[i] = - reduce_over_group(sg, sum_local_rows[i], sycl::plus<>()); - // only Groups leader perform the global reduction - if (global_idy % sg_size == 0) { - sycl::atomic_ref - aref(v_rows[i]); - aref.fetch_add(sum_local_rows[i]); - } - } - - for (int i = 0; i < NUM_COLS; i++) { - sum_local_cols[i] = - reduce_over_group(sg, sum_local_cols[i], sycl::plus<>()); - // only Groups leader perform the global reduction - if (global_idy % sg_size == 0) { - sycl::atomic_ref - aref(v_cols[i]); - aref.fetch_add(sum_local_cols[i]); - } - } + reduce_and_accumulate(sg, sg_size, global_idy, v_rows, + sum_local_rows, NUM_ROWS); + reduce_and_accumulate(sg, sg_size, global_idy, v_cols, + sum_local_cols, NUM_COLS); }); // parallel for }).wait(); } @@ -124,11 +122,7 @@ void test_get_coord_op() { TResult sum_cols[Cols] = {0}; TResult sum_cols_ref[Cols] = {0}; - for (int i = 0; i < Rows; i++) { - for (int j = 0; j < Cols; j++) { - M[i][j] = i + j; - } - } + matrix_fill(Rows, Cols, (T *)M, [](int i, int j) { return T(i + j); }); matrix_vnni(Rows, Cols, *M, *Mvnni, VF); big_matrix MM((T *)&Mvnni);