Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add iota kernel #1946

Merged
merged 15 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "kernels/dpctl_tensor_types.hpp"
#include "kernels/sorting/search_sorted_detail.hpp"
#include "kernels/sorting/sort_utils.hpp"

namespace dpctl
{
Expand Down Expand Up @@ -811,20 +812,27 @@ sycl::event stable_argsort_axis1_contig_impl(

const size_t total_nelems = iter_nelems * sort_nelems;

using dpctl::tensor::kernels::sort_utils_detail::iota_impl;

using IotaKernelName = populate_index_data_krn<argTy, IndexTy, ValueComp>;

#if 1
sycl::event populate_indexed_data_ev = iota_impl<IotaKernelName, IndexTy>(
exec_q, res_tp, total_nelems, depends);

#else
sycl::event populate_indexed_data_ev =
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

const sycl::range<1> range{total_nelems};

using KernelName =
populate_index_data_krn<argTy, IndexTy, ValueComp>;

cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
cgh.parallel_for<IotaKernelName>(range, [=](sycl::id<1> id) {
size_t i = id[0];
res_tp[i] = static_cast<IndexTy>(i);
});
});
#endif

// Sort segments of the array
sycl::event base_sort_ev =
Expand Down
117 changes: 44 additions & 73 deletions dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <sycl/sycl.hpp>

#include "kernels/dpctl_tensor_types.hpp"
#include "kernels/sorting/sort_utils.hpp"
#include "utils/sycl_alloc_utils.hpp"

namespace dpctl
Expand Down Expand Up @@ -1256,9 +1257,7 @@ struct subgroup_radix_sort
const uint16_t id = wi * block_size + i;
if (id < n)
values[i] = std::move(
this_input_arr[iter_val_offset +
static_cast<std::size_t>(
id)]);
this_input_arr[iter_val_offset + id]);
}

while (true) {
Expand All @@ -1272,8 +1271,7 @@ struct subgroup_radix_sort
// counting phase
auto pcounter =
get_accessor_pointer(counter_acc) +
static_cast<std::size_t>(wi) +
iter_counter_offset;
(wi + iter_counter_offset);

// initialize counters
#pragma unroll
Expand Down Expand Up @@ -1348,19 +1346,15 @@ struct subgroup_radix_sort

// scan contiguous numbers
uint16_t bin_sum[bin_count];
bin_sum[0] =
counter_acc[iter_counter_offset +
static_cast<std::size_t>(
wi * bin_count)];
const std::size_t counter_offset0 =
iter_counter_offset + wi * bin_count;
bin_sum[0] = counter_acc[counter_offset0];

#pragma unroll
for (uint16_t i = 1; i < bin_count; ++i)
bin_sum[i] =
bin_sum[i - 1] +
counter_acc
[iter_counter_offset +
static_cast<std::size_t>(
wi * bin_count + i)];
counter_acc[counter_offset0 + i];

sycl::group_barrier(ndit.get_group());

Expand All @@ -1374,10 +1368,7 @@ struct subgroup_radix_sort
// add to local sum, generate exclusive scan result
#pragma unroll
for (uint16_t i = 0; i < bin_count; ++i)
counter_acc[iter_counter_offset +
static_cast<std::size_t>(
wi * bin_count + i +
1)] =
counter_acc[counter_offset0 + i + 1] =
sum_scan + bin_sum[i];

if (wi == 0)
Expand Down Expand Up @@ -1407,10 +1398,8 @@ struct subgroup_radix_sort
if (r < n) {
// move the values to source range and
// destroy the values
this_output_arr
[iter_val_offset +
static_cast<std::size_t>(r)] =
std::move(values[i]);
this_output_arr[iter_val_offset + r] =
std::move(values[i]);
}
}

Expand All @@ -1422,8 +1411,7 @@ struct subgroup_radix_sort
for (uint16_t i = 0; i < block_size; ++i) {
const uint16_t r = indices[i];
if (r < n)
exchange_acc[iter_exchange_offset +
static_cast<std::size_t>(r)] =
exchange_acc[iter_exchange_offset + r] =
std::move(values[i]);
}

Expand All @@ -1435,8 +1423,7 @@ struct subgroup_radix_sort
if (id < n)
values[i] = std::move(
exchange_acc[iter_exchange_offset +
static_cast<std::size_t>(
id)]);
id]);
}

sycl::group_barrier(ndit.get_group());
Expand Down Expand Up @@ -1601,11 +1588,11 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
using CountT = std::uint32_t;

// memory for storing count and offset values
CountT *count_ptr =
sycl::malloc_device<CountT>(n_iters * n_counts, exec_q);
if (nullptr == count_ptr) {
throw std::runtime_error("Could not allocate USM-device memory");
}
auto count_owner =
dpctl::tensor::alloc_utils::smart_malloc_device<CountT>(
n_iters * n_counts, exec_q);

CountT *count_ptr = count_owner.get();

constexpr std::uint32_t zero_radix_iter{0};

Expand All @@ -1618,25 +1605,17 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
n_counts, count_ptr, proj_op,
is_ascending, depends);

sort_ev = exec_q.submit([=](sycl::handler &cgh) {
cgh.depends_on(sort_ev);
const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task(
[ctx, count_ptr]() { sycl_free_noexcept(count_ptr, ctx); });
});
sort_ev = dpctl::tensor::alloc_utils::async_smart_free(
exec_q, {sort_ev}, count_owner);

return sort_ev;
}

ValueT *tmp_arr =
sycl::malloc_device<ValueT>(n_iters * n_to_sort, exec_q);
if (nullptr == tmp_arr) {
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
sycl_free_noexcept(count_ptr, exec_q);
throw std::runtime_error("Could not allocate USM-device memory");
}
auto tmp_arr_owner =
dpctl::tensor::alloc_utils::smart_malloc_device<ValueT>(
n_iters * n_to_sort, exec_q);

ValueT *tmp_arr = tmp_arr_owner.get();

// iterations per each bucket
assert("Number of iterations must be even" && radix_iters % 2 == 0);
Expand Down Expand Up @@ -1670,17 +1649,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
}
}

sort_ev = exec_q.submit([=](sycl::handler &cgh) {
cgh.depends_on(sort_ev);

const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task([ctx, count_ptr, tmp_arr]() {
sycl_free_noexcept(tmp_arr, ctx);
sycl_free_noexcept(count_ptr, ctx);
});
});
sort_ev = dpctl::tensor::alloc_utils::async_smart_free(
exec_q, {sort_ev}, tmp_arr_owner, count_owner);
}

return sort_ev;
Expand Down Expand Up @@ -1782,31 +1752,38 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
reinterpret_cast<IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;

const std::size_t total_nelems = iter_nelems * sort_nelems;
const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64;
IndexTy *workspace = sycl::malloc_device<IndexTy>(
padded_total_nelems + total_nelems, exec_q);
auto workspace_owner =
dpctl::tensor::alloc_utils::smart_malloc_device<IndexTy>(total_nelems,
exec_q);

if (nullptr == workspace) {
throw std::runtime_error("Could not allocate workspace on device");
}
// get raw USM pointer
IndexTy *workspace = workspace_owner.get();

using IdentityProjT = radix_sort_details::IdentityProj;
using IndexedProjT =
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
const IndexedProjT proj_op{arg_tp, IdentityProjT{}};

using IotaKernelName = radix_argsort_iota_krn<argTy, IndexTy>;

#if 1
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;

sycl::event iota_ev = iota_impl<IotaKernelName, IndexTy>(
exec_q, workspace, total_nelems, depends);
#else

sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using KernelName = radix_argsort_iota_krn<argTy, IndexTy>;

cgh.parallel_for<KernelName>(
cgh.parallel_for<IotaKernelName>(
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
size_t i = id[0];
IndexTy sort_id = static_cast<IndexTy>(i);
workspace[i] = sort_id;
});
});
#endif

sycl::event radix_sort_ev =
radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
Expand All @@ -1825,14 +1802,8 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
});
});

sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(map_back_ev);

const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); });
});
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
exec_q, {map_back_ev}, workspace_owner);

return cleanup_ev;
}
Expand Down
103 changes: 103 additions & 0 deletions dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//=== sorting.hpp - Implementation of sorting kernels ---*-C++-*--/===//
//
// Data Parallel Control (dpctl)
//
// Copyright 2020-2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file defines kernels for tensor sort/argsort operations.
//===----------------------------------------------------------------------===//

#pragma once

#include <cstddef>
#include <cstdint>
#include <vector>

#include <sycl/sycl.hpp>

namespace dpctl
{
namespace tensor
{
namespace kernels
{
namespace sort_utils_detail
{

namespace syclexp = sycl::ext::oneapi::experimental;

template <class KernelName, typename T>
sycl::event iota_impl(sycl::queue &exec_q,
T *data,
std::size_t nelems,
const std::vector<sycl::event> &dependent_events)
{
constexpr std::uint32_t lws = 256;
constexpr std::uint32_t n_wi = 4;
const std::size_t n_groups = (nelems + n_wi * lws - 1) / (n_wi * lws);

sycl::range<1> gRange{n_groups * lws};
sycl::range<1> lRange{lws};
sycl::nd_range<1> ndRange{gRange, lRange};

sycl::event e = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(dependent_events);
cgh.parallel_for<KernelName>(ndRange, [=](sycl::nd_item<1> it) {
const std::size_t gid = it.get_global_id();
const auto &sg = it.get_sub_group();
const std::uint32_t lane_id = sg.get_local_id()[0];

const std::size_t offset = (gid - lane_id) * n_wi;
const std::uint32_t max_sgSize = sg.get_max_local_range()[0];

std::array<T, n_wi> stripe{};
#pragma unroll
for (std::uint32_t i = 0; i < n_wi; ++i) {
stripe[i] = T(offset + lane_id + i * max_sgSize);
}

if (offset + n_wi * max_sgSize < nelems) {
constexpr auto group_ls_props = syclexp::properties{
syclexp::data_placement_striped
// , syclexp::full_group
};

auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&data[offset]);

syclexp::group_store(sg, sycl::span<T, n_wi>{&stripe[0], n_wi},
out_multi_ptr, group_ls_props);
}
else {
for (std::size_t idx = offset + lane_id; idx < nelems;
idx += max_sgSize)
{
data[idx] = T(idx);
}
}
});
});

return e;
}

} // end of namespace sort_utils_detail
} // end of namespace kernels
} // end of namespace tensor
} // end of namespace dpctl
Loading
Loading