Skip to content

Commit

Permalink
Merge pull request #1958 from IntelPython/resolve-gh-1944
Browse files Browse the repository at this point in the history
Add dedicated reduction kernels for sums and products of boolean arrays
  • Loading branch information
ndgrigorian authored Jan 10, 2025
2 parents a7ca491 + f9f4d6c commit 3146183
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
15 changes: 11 additions & 4 deletions dpctl/tensor/libtensor/source/reductions/prod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct TypePairSupportDataForProductReductionTemps
{

static constexpr bool is_defined = std::disjunction<
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int8_t>,
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint8_t>,
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int16_t>,
Expand Down Expand Up @@ -224,7 +225,7 @@ struct TypePairSupportDataForProductReductionTemps
outTy,
std::complex<double>>,

// fall-throug
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};

Expand Down Expand Up @@ -255,7 +256,9 @@ struct ProductOverAxisTempsStridedFactory
if constexpr (TypePairSupportDataForProductReductionTemps<
srcTy, dstTy>::is_defined)
{
using ReductionOpT = sycl::multiplies<dstTy>;
using ReductionOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
sycl::logical_and<dstTy>,
sycl::multiplies<dstTy>>;
return dpctl::tensor::kernels::
reduction_over_group_temps_strided_impl<srcTy, dstTy,
ReductionOpT>;
Expand Down Expand Up @@ -312,7 +315,9 @@ struct ProductOverAxis1TempsContigFactory
if constexpr (TypePairSupportDataForProductReductionTemps<
srcTy, dstTy>::is_defined)
{
using ReductionOpT = sycl::multiplies<dstTy>;
using ReductionOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
sycl::logical_and<dstTy>,
sycl::multiplies<dstTy>>;
return dpctl::tensor::kernels::
reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
ReductionOpT>;
Expand All @@ -331,7 +336,9 @@ struct ProductOverAxis0TempsContigFactory
if constexpr (TypePairSupportDataForProductReductionTemps<
srcTy, dstTy>::is_defined)
{
using ReductionOpT = sycl::multiplies<dstTy>;
using ReductionOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
sycl::logical_and<dstTy>,
sycl::multiplies<dstTy>>;
return dpctl::tensor::kernels::
reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
ReductionOpT>;
Expand Down
15 changes: 11 additions & 4 deletions dpctl/tensor/libtensor/source/reductions/sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct TypePairSupportDataForSumReductionTemps
{

static constexpr bool is_defined = std::disjunction<
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int8_t>,
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint8_t>,
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int16_t>,
Expand Down Expand Up @@ -224,7 +225,7 @@ struct TypePairSupportDataForSumReductionTemps
outTy,
std::complex<double>>,

// fall-throug
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};

Expand Down Expand Up @@ -255,7 +256,9 @@ struct SumOverAxisTempsStridedFactory
if constexpr (TypePairSupportDataForSumReductionTemps<
srcTy, dstTy>::is_defined)
{
using ReductionOpT = sycl::plus<dstTy>;
using ReductionOpT =
std::conditional_t<std::is_same_v<dstTy, bool>,
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
return dpctl::tensor::kernels::
reduction_over_group_temps_strided_impl<srcTy, dstTy,
ReductionOpT>;
Expand Down Expand Up @@ -312,7 +315,9 @@ struct SumOverAxis1TempsContigFactory
if constexpr (TypePairSupportDataForSumReductionTemps<
srcTy, dstTy>::is_defined)
{
using ReductionOpT = sycl::plus<dstTy>;
using ReductionOpT =
std::conditional_t<std::is_same_v<dstTy, bool>,
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
return dpctl::tensor::kernels::
reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
ReductionOpT>;
Expand All @@ -331,7 +336,9 @@ struct SumOverAxis0TempsContigFactory
if constexpr (TypePairSupportDataForSumReductionTemps<
srcTy, dstTy>::is_defined)
{
using ReductionOpT = sycl::plus<dstTy>;
using ReductionOpT =
std::conditional_t<std::is_same_v<dstTy, bool>,
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
return dpctl::tensor::kernels::
reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
ReductionOpT>;
Expand Down
14 changes: 14 additions & 0 deletions dpctl/tests/test_tensor_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,17 @@ def test_gh_1468():
a = dpt.full((2, 3, 4), 123456789, dtype=dpt.int32)
t = dpt.sum(a, dtype="f4")
assert t > 0


@pytest.mark.parametrize(
"dt", ["i1", "i2", "i4", "i8", "f2", "f4", "f8", "c8", "c16"]
)
def test_gh_1944(dt):
"See https://github.com/IntelPython/dpctl/issues/1944"
q = get_queue_or_skip()
skip_if_dtype_not_supported(dt, q)
x = dpt.asarray([-1, 1], dtype=dpt.dtype(dt), sycl_queue=q)
r = dpt.sum(x, dtype="?")
# reduction must be performed in the requested dtype
# if performed in the input type, result is False
assert r

0 comments on commit 3146183

Please sign in to comment.