Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves performance of search reductions for small numbers of elements #1464

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 248 additions & 3 deletions dpctl/tensor/libtensor/include/kernels/reductions.hpp
Original file line number Diff line number Diff line change
@@ -3401,6 +3401,129 @@ struct LogSumExpOverAxis0TempsContigFactory

// Argmax and Argmin

/* Sequential search reduction */

template <typename argT,
typename outT,
typename ReductionOp,
typename IdxReductionOp,
typename InputOutputIterIndexerT,
typename InputRedIndexerT>
struct SequentialSearchReduction
{
private:
const argT *inp_ = nullptr;
outT *out_ = nullptr;
ReductionOp reduction_op_;
argT identity_;
IdxReductionOp idx_reduction_op_;
outT idx_identity_;
InputOutputIterIndexerT inp_out_iter_indexer_;
InputRedIndexerT inp_reduced_dims_indexer_;
size_t reduction_max_gid_ = 0;

public:
SequentialSearchReduction(const argT *inp,
outT *res,
ReductionOp reduction_op,
const argT &identity_val,
IdxReductionOp idx_reduction_op,
const outT &idx_identity_val,
InputOutputIterIndexerT arg_res_iter_indexer,
InputRedIndexerT arg_reduced_dims_indexer,
size_t reduction_size)
: inp_(inp), out_(res), reduction_op_(reduction_op),
identity_(identity_val), idx_reduction_op_(idx_reduction_op),
idx_identity_(idx_identity_val),
inp_out_iter_indexer_(arg_res_iter_indexer),
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
reduction_max_gid_(reduction_size)
{
}

void operator()(sycl::id<1> id) const
{

auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]);
const py::ssize_t &inp_iter_offset =
inp_out_iter_offsets_.get_first_offset();
const py::ssize_t &out_iter_offset =
inp_out_iter_offsets_.get_second_offset();

argT red_val(identity_);
outT idx_val(idx_identity_);
for (size_t m = 0; m < reduction_max_gid_; ++m) {
const py::ssize_t inp_reduction_offset =
inp_reduced_dims_indexer_(m);
const py::ssize_t inp_offset =
inp_iter_offset + inp_reduction_offset;

argT val = inp_[inp_offset];
if (val == red_val) {
idx_val = idx_reduction_op_(idx_val, static_cast<outT>(m));
}
else {
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
using dpctl::tensor::type_utils::is_complex;
if constexpr (is_complex<argT>::value) {
using dpctl::tensor::math_utils::less_complex;
// less_complex always returns false for NaNs, so check
if (less_complex<argT>(val, red_val) ||
std::isnan(std::real(val)) ||
std::isnan(std::imag(val)))
{
red_val = val;
idx_val = static_cast<outT>(m);
}
}
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val < red_val || std::isnan(val)) {
red_val = val;
idx_val = static_cast<outT>(m);
}
}
else {
if (val < red_val) {
red_val = val;
idx_val = static_cast<outT>(m);
}
}
}
else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
using dpctl::tensor::type_utils::is_complex;
if constexpr (is_complex<argT>::value) {
using dpctl::tensor::math_utils::greater_complex;
if (greater_complex<argT>(val, red_val) ||
std::isnan(std::real(val)) ||
std::isnan(std::imag(val)))
{
red_val = val;
idx_val = static_cast<outT>(m);
}
}
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val > red_val || std::isnan(val)) {
red_val = val;
idx_val = static_cast<outT>(m);
}
}
else {
if (val > red_val) {
red_val = val;
idx_val = static_cast<outT>(m);
}
}
}
}
}
out_[out_iter_offset] = idx_val;
}
};

/* = Search reduction using reduce_over_group*/

template <typename argT,
@@ -3670,7 +3793,9 @@ struct CustomSearchReduction
}
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val < local_red_val || std::isnan(val)) {
local_red_val = val;
if constexpr (!First) {
@@ -3714,7 +3839,9 @@ struct CustomSearchReduction
}
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val > local_red_val || std::isnan(val)) {
local_red_val = val;
if constexpr (!First) {
@@ -3757,7 +3884,9 @@ struct CustomSearchReduction
? local_idx
: idx_identity_;
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
// equality does not hold for NaNs, so check here
local_idx =
(red_val_over_wg == local_red_val || std::isnan(local_red_val))
@@ -3799,6 +3928,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
py::ssize_t,
const std::vector<sycl::event> &);

template <typename T1,
typename T2,
typename T3,
typename T4,
typename T5,
typename T6>
class search_seq_strided_krn;

template <typename T1,
typename T2,
typename T3,
@@ -3820,6 +3957,14 @@ template <typename T1,
bool b2>
class custom_search_over_group_temps_strided_krn;

template <typename T1,
typename T2,
typename T3,
typename T4,
typename T5,
typename T6>
class search_seq_contig_krn;

template <typename T1,
typename T2,
typename T3,
@@ -4019,6 +4164,36 @@ sycl::event search_over_group_temps_strided_impl(
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);

if (reduction_nelems < wg) {
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using InputOutputIterIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
using ReductionIndexerT =
dpctl::tensor::offset_utils::StridedIndexer;

InputOutputIterIndexerT in_out_iter_indexer{
iter_nd, iter_arg_offset, iter_res_offset,
iter_shape_and_strides};
ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
reduction_shape_stride};

cgh.parallel_for<class search_seq_strided_krn<
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
ReductionIndexerT>>(
sycl::range<1>(iter_nelems),
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
InputOutputIterIndexerT,
ReductionIndexerT>(
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
idx_identity_val, in_out_iter_indexer, reduction_indexer,
reduction_nelems));
});

return comp_ev;
}

constexpr size_t preferred_reductions_per_wi = 4;
// max_max_wg prevents running out of resources on CPU
size_t max_wg =
@@ -4419,6 +4594,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);

if (reduction_nelems < wg) {
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using InputIterIndexerT =
dpctl::tensor::offset_utils::Strided1DIndexer;
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
using InputOutputIterIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
InputIterIndexerT, NoOpIndexerT>;
using ReductionIndexerT = NoOpIndexerT;

InputOutputIterIndexerT in_out_iter_indexer{
InputIterIndexerT{0, static_cast<py::ssize_t>(iter_nelems),
static_cast<py::ssize_t>(reduction_nelems)},
NoOpIndexerT{}};
ReductionIndexerT reduction_indexer{};

cgh.parallel_for<class search_seq_contig_krn<
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
ReductionIndexerT>>(
sycl::range<1>(iter_nelems),
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
InputOutputIterIndexerT,
ReductionIndexerT>(
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
idx_identity_val, in_out_iter_indexer, reduction_indexer,
reduction_nelems));
});

return comp_ev;
}

constexpr size_t preferred_reductions_per_wi = 8;
// max_max_wg prevents running out of resources on CPU
size_t max_wg =
@@ -4801,6 +5009,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);

if (reduction_nelems < wg) {
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
using InputOutputIterIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
NoOpIndexerT, NoOpIndexerT>;
using ReductionIndexerT =
dpctl::tensor::offset_utils::Strided1DIndexer;

InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
NoOpIndexerT{}};
ReductionIndexerT reduction_indexer{
0, static_cast<py::ssize_t>(reduction_nelems),
static_cast<py::ssize_t>(iter_nelems)};

using KernelName =
class search_seq_contig_krn<argTy, resTy, ReductionOpT,
IndexOpT, InputOutputIterIndexerT,
ReductionIndexerT>;

sycl::range<1> iter_range{iter_nelems};

cgh.parallel_for<KernelName>(
iter_range,
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
InputOutputIterIndexerT,
ReductionIndexerT>(
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
idx_identity_val, in_out_iter_indexer, reduction_indexer,
reduction_nelems));
});

return comp_ev;
}

constexpr size_t preferred_reductions_per_wi = 8;
// max_max_wg prevents running out of resources on CPU
size_t max_wg =