Skip to content

Commit

Permalink
Adds more supported types to arithmetic reductions
Browse files Browse the repository at this point in the history
Permits `float` accumulation type with 64 bit integer and unsigned integer inouts
to prevent unnecessary copies on devices that don't support double precision
  • Loading branch information
ndgrigorian committed Nov 8, 2023
1 parent ddccf5c commit 375bbde
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions dpctl/tensor/libtensor/include/kernels/reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2806,10 +2806,12 @@ struct TypePairSupportDataForSumReductionTemps

// input int64_t
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, float>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,

// input uint64_t
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, float>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,

// input half
Expand Down Expand Up @@ -3077,10 +3079,12 @@ struct TypePairSupportDataForProductReductionTemps

// input int64_t
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, float>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,

// input uint32_t
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, float>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,

// input half
Expand Down

0 comments on commit 375bbde

Please sign in to comment.