Skip to content

Commit ce5914f

Browse files
Simplify write-out kernels in topk implementation (avoid recomputing gid)
1 parent d4f5aa4 commit ce5914f

File tree

1 file changed

+14
-14
lines changed
  • dpctl/tensor/libtensor/include/kernels/sorting

1 file changed

+14
-14
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -417,17 +417,17 @@ sycl::event topk_merge_impl(
417417
topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
418418

419419
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
420-
std::size_t gid = id[0];
420+
const std::size_t gid = id[0];
421421

422-
std::size_t iter_gid = gid / k;
423-
std::size_t axis_gid = gid - (iter_gid * k);
422+
const std::size_t iter_gid = gid / k;
423+
const std::size_t axis_gid = gid - (iter_gid * k);
424424

425-
std::size_t src_idx = iter_gid * alloc_len + axis_gid;
426-
std::size_t dst_idx = iter_gid * k + axis_gid;
425+
const std::size_t src_idx = iter_gid * alloc_len + axis_gid;
426+
const std::size_t dst_idx = gid;
427427

428-
auto res_ind = index_data[src_idx];
428+
const auto res_ind = index_data[src_idx];
429429
vals_tp[dst_idx] = arg_tp[res_ind];
430-
inds_tp[dst_idx] = res_ind % axis_nelems;
430+
inds_tp[dst_idx] = (res_ind % axis_nelems);
431431
});
432432
});
433433

@@ -529,17 +529,17 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
529529
using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;
530530

531531
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
532-
std::size_t gid = id[0];
532+
const std::size_t gid = id[0];
533533

534-
std::size_t iter_gid = gid / k;
535-
std::size_t axis_gid = gid - (iter_gid * k);
534+
const std::size_t iter_gid = gid / k;
535+
const std::size_t axis_gid = gid - (iter_gid * k);
536536

537-
std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
538-
std::size_t dst_idx = iter_gid * k + axis_gid;
537+
const std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
538+
const std::size_t dst_idx = gid;
539539

540-
IndexTy res_ind = tmp_tp[src_idx];
540+
const IndexTy res_ind = tmp_tp[src_idx];
541541
vals_tp[dst_idx] = arg_tp[res_ind];
542-
inds_tp[dst_idx] = res_ind % axis_nelems;
542+
inds_tp[dst_idx] = (res_ind % axis_nelems);
543543
});
544544
});
545545

0 commit comments

Comments
 (0)