diff --git a/sycl/test-e2e/Matrix/get_coordinate_ops_impl.hpp b/sycl/test-e2e/Matrix/get_coordinate_ops_impl.hpp index a21f3daa47b1a..8cbd24b302171 100644 --- a/sycl/test-e2e/Matrix/get_coordinate_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coordinate_ops_impl.hpp @@ -11,6 +11,23 @@ template <typename T, size_t Rows, size_t Cols, layout Layout, use Use> class matrix_process; +template <typename TResult, typename AccessorType> +void reduce_and_accumulate(sub_group sg, size_t sg_size, size_t global_idy, + AccessorType &global_acc, TResult *local_sums, + size_t count) { + for (size_t i = 0; i < count; i++) { + local_sums[i] = reduce_over_group(sg, local_sums[i], sycl::plus<>()); + + // Only the subgroup leader performs the global accumulation + if (global_idy % sg_size == 0) { + sycl::atomic_ref<TResult, sycl::memory_order::relaxed, + sycl::memory_scope::device> + aref(global_acc[i]); + aref.fetch_add(local_sums[i]); + } + } +} + template <typename T, typename TResult, size_t NUM_ROWS, size_t NUM_COLS, size_t SROWS, size_t SCOLS, use Use, layout Layout, size_t VF> void matrix_sum(big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> &M, @@ -32,7 +49,7 @@ void matrix_sum(big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> &M, {1, 1 * sg_size}), [=](nd_item<2> spmd_item) #ifdef SG_SZ - [[intel::reqd_sub_group_size(SG_SZ)]] + [[sycl::reqd_sub_group_size(SG_SZ)]] #endif { // The submatrix API has to be accessed by all the workitems in a @@ -83,29 +100,10 @@ void matrix_sum(big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> &M, }); } - for (int i = 0; i < NUM_ROWS; i++) { - sum_local_rows[i] = - reduce_over_group(sg, sum_local_rows[i], sycl::plus<>()); - // only Groups leader perform the global reduction - if (global_idy % sg_size == 0) { - sycl::atomic_ref<TResult, sycl::memory_order::relaxed, - sycl::memory_scope::device> - aref(v_rows[i]); - aref.fetch_add(sum_local_rows[i]); - } - } - - for (int i = 0; i < NUM_COLS; i++) { - sum_local_cols[i] = - reduce_over_group(sg, sum_local_cols[i], sycl::plus<>()); - // only Groups leader perform the global reduction - if (global_idy % sg_size == 0) { - sycl::atomic_ref<TResult, sycl::memory_order::relaxed, - sycl::memory_scope::device> - aref(v_cols[i]); - aref.fetch_add(sum_local_cols[i]); - } - } + reduce_and_accumulate(sg, sg_size, global_idy, v_rows, + sum_local_rows, NUM_ROWS); + reduce_and_accumulate(sg, sg_size, global_idy, v_cols, + sum_local_cols, NUM_COLS); }); // parallel for }).wait(); } @@ -124,11 +122,7 @@ void test_get_coord_op() { TResult sum_cols[Cols] = {0}; TResult sum_cols_ref[Cols] = {0}; - for (int i = 0; i < Rows; i++) { - for (int j = 0; j < Cols; j++) { - M[i][j] = i + j; - } - } + matrix_fill(Rows, Cols, (T *)M, [](int i, int j) { return T(i + j); }); matrix_vnni<T>(Rows, Cols, *M, *Mvnni, VF); big_matrix<T, Rows / VF, Cols * VF> MM((T *)&Mvnni);