diff --git a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp index e8edd827c2..76d51cf3b6 100644 --- a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp +++ b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp @@ -273,8 +273,7 @@ inclusive_scan_base_step_blocked(sycl::queue &exec_q, outputT wg_iscan_val; if constexpr (can_use_inclusive_scan_over_group::value) - { + outputT>::value) { wg_iscan_val = sycl::inclusive_scan_over_group( it.get_group(), local_iscan.back(), scan_op, identity); } @@ -447,8 +446,7 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q, outputT wg_iscan_val; if constexpr (can_use_inclusive_scan_over_group::value) - { + outputT>::value) { wg_iscan_val = sycl::inclusive_scan_over_group( it.get_group(), local_iscan.back(), scan_op, identity); } @@ -472,35 +470,32 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q, it.barrier(sycl::access::fence_space::local_space); // convert back to blocked layout - { - { - const std::uint32_t local_offset0 = lid * n_wi; + {{const std::uint32_t local_offset0 = lid * n_wi; #pragma unroll - for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { - slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi]; - } + for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { + slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi]; + } - it.barrier(sycl::access::fence_space::local_space); + it.barrier(sycl::access::fence_space::local_space); } } { - const std::uint32_t block_offset = - sgroup_id * sgSize * n_wi + lane_id; + const std::uint32_t block_offset = sgroup_id * sgSize * n_wi + lane_id; #pragma unroll - for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { - const std::uint32_t m_wi_scaled = m_wi * sgSize; - const std::size_t out_id = inp_id0 + m_wi_scaled; - if (out_id < acc_nelems) { - output[out_iter_offset + out_indexer(out_id)] = - slm_iscan_tmp[block_offset + m_wi_scaled]; - } - } + for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { + const std::uint32_t m_wi_scaled = m_wi * sgSize; + const std::size_t out_id = inp_id0 + m_wi_scaled; + if (out_id < acc_nelems) { + output[out_iter_offset + out_indexer(out_id)] = + slm_iscan_tmp[block_offset + m_wi_scaled]; } - }); - }); + } + } +}); +}); - return inc_scan_phase1_ev; +return inc_scan_phase1_ev; } template &depends = {}) { + // For small stride use striped load/store. + // Threshold value chosen experimentally. if (s1 <= 16) { return inclusive_scan_base_step_striped< inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,