Skip to content

Commit

Permalink
Add a comment about constant choice
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksandr-pavlyk committed Dec 18, 2024
1 parent dd9a873 commit 60f60c6
Showing 1 changed file with 21 additions and 24 deletions.
45 changes: 21 additions & 24 deletions dpctl/tensor/libtensor/include/kernels/accumulators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ inclusive_scan_base_step_blocked(sycl::queue &exec_q,

outputT wg_iscan_val;
if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
outputT>::value)
{
outputT>::value) {
wg_iscan_val = sycl::inclusive_scan_over_group(
it.get_group(), local_iscan.back(), scan_op, identity);
}
Expand Down Expand Up @@ -447,8 +446,7 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q,

outputT wg_iscan_val;
if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
outputT>::value)
{
outputT>::value) {
wg_iscan_val = sycl::inclusive_scan_over_group(
it.get_group(), local_iscan.back(), scan_op, identity);
}
Expand All @@ -472,35 +470,32 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q,
it.barrier(sycl::access::fence_space::local_space);

// convert back to blocked layout
{
{
const std::uint32_t local_offset0 = lid * n_wi;
{{const std::uint32_t local_offset0 = lid * n_wi;
#pragma unroll
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
}
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
}

it.barrier(sycl::access::fence_space::local_space);
it.barrier(sycl::access::fence_space::local_space);
}
}

{
const std::uint32_t block_offset =
sgroup_id * sgSize * n_wi + lane_id;
const std::uint32_t block_offset = sgroup_id * sgSize * n_wi + lane_id;
#pragma unroll
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
const std::uint32_t m_wi_scaled = m_wi * sgSize;
const std::size_t out_id = inp_id0 + m_wi_scaled;
if (out_id < acc_nelems) {
output[out_iter_offset + out_indexer(out_id)] =
slm_iscan_tmp[block_offset + m_wi_scaled];
}
}
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
const std::uint32_t m_wi_scaled = m_wi * sgSize;
const std::size_t out_id = inp_id0 + m_wi_scaled;
if (out_id < acc_nelems) {
output[out_iter_offset + out_indexer(out_id)] =
slm_iscan_tmp[block_offset + m_wi_scaled];
}
});
});
}
}
});
});

return inc_scan_phase1_ev;
return inc_scan_phase1_ev;
}

template <typename inputT,
Expand Down Expand Up @@ -530,6 +525,8 @@ inclusive_scan_base_step(sycl::queue &exec_q,
std::size_t &acc_groups,
const std::vector<sycl::event> &depends = {})
{
// For small stride use striped load/store.
// Threshold value chosen experimentally.
if (s1 <= 16) {
return inclusive_scan_base_step_striped<
inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
Expand Down

0 comments on commit 60f60c6

Please sign in to comment.