From cdfa7b50c409f811738202023c08f75999843610 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Tue, 7 Nov 2023 21:04:50 +0000 Subject: [PATCH] Remove more enable_if instances --- .../librapid/array/linalg/arrayMultiply.hpp | 35 +++---- .../librapid/array/linalg/transpose.hpp | 20 ++-- librapid/include/librapid/core/typetraits.hpp | 99 +++++++++---------- .../librapid/datastructures/bitset.hpp | 4 +- librapid/include/librapid/ml/activations.hpp | 30 +++--- librapid/include/librapid/simd/vecOps.hpp | 1 - 6 files changed, 88 insertions(+), 101 deletions(-) diff --git a/librapid/include/librapid/array/linalg/arrayMultiply.hpp b/librapid/include/librapid/array/linalg/arrayMultiply.hpp index 2326ad5c..6075651f 100644 --- a/librapid/include/librapid/array/linalg/arrayMultiply.hpp +++ b/librapid/include/librapid/array/linalg/arrayMultiply.hpp @@ -535,16 +535,15 @@ namespace librapid { /// \tparam T /// \param val /// \return - template::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto transposeExtractor(T &&val) { - using Scalar = typename typetraits::TypeInfo>::Scalar; - return std::make_tuple(false, Scalar(1), std::forward(val)); - } - - template::value, int> = 0> + template LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto transposeExtractor(T &&val) { - using Type = decltype(val.array()); - return std::make_tuple(true, val.alpha(), std::forward(val.array())); + if constexpr (IsTransposeType::value) { + using Type = decltype(val.array()); + return std::make_tuple(true, val.alpha(), std::forward(val.array())); + } else { + using Scalar = typename typetraits::TypeInfo>::Scalar; + return std::make_tuple(false, Scalar(1), std::forward(val)); + } } /// Evaluates to true if the type is a multiply type. @@ -561,15 +560,15 @@ namespace librapid { /// \tparam T /// \param val /// \return - template::value, int> = 0> + template + requires(!IsMultiplyType::value) LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto multiplyExtractor(T &&val) { using Scalar = typename typetraits::TypeInfo>::Scalar; return std::make_tuple(Scalar(1), std::forward(val)); } - template::type == detail::LibRapidType::Scalar, int> = 0> + template + requires(typetraits::TypeInfo::type == detail::LibRapidType::Scalar) LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto multiplyExtractor(detail::Function &&val) { using Type = decltype(std::get<0>(val.args())); @@ -577,9 +576,8 @@ namespace librapid { std::forward(std::get<0>(val.args()))); } - template::type == detail::LibRapidType::Scalar, int> = 0> + template + requires(typetraits::TypeInfo::type == detail::LibRapidType::Scalar) LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto multiplyExtractor(detail::Function &&val) { using Type = decltype(std::get<1>(val.args())); @@ -634,9 +632,8 @@ namespace librapid { /// StorageTypeB The storage type of the right input array. \param a The left input array. /// \param b The right input array. /// \return The dot product of the two input arrays. - template< - typename First, typename Second, - typename std::enable_if_t::value && IsArrayType::value, int> = 0> + template + requires(IsArrayType::value && IsArrayType::value) auto dot(First &&a, Second &&b) { using ScalarA = typename typetraits::TypeInfo>::Scalar; using ScalarB = typename typetraits::TypeInfo>::Scalar; diff --git a/librapid/include/librapid/array/linalg/transpose.hpp b/librapid/include/librapid/array/linalg/transpose.hpp index 179bea70..6a303068 100644 --- a/librapid/include/librapid/array/linalg/transpose.hpp +++ b/librapid/include/librapid/array/linalg/transpose.hpp @@ -438,11 +438,11 @@ namespace librapid { template class Transpose { public: - using ArrayType = TransposeType; - using BaseType = typename std::decay_t; - using Scalar = typename typetraits::TypeInfo::Scalar; - using ShapeType = typename BaseType::ShapeType; - using Backend = typename typetraits::TypeInfo::Backend; + using ArrayType = TransposeType; + using BaseType = typename std::decay_t; + using Scalar = typename typetraits::TypeInfo::Scalar; + using ShapeType = typename BaseType::ShapeType; + using Backend = typename typetraits::TypeInfo::Backend; static constexpr bool allowVectorisation = typetraits::TypeInfo::allowVectorisation; @@ -517,7 +517,8 @@ namespace librapid { template LIBRAPID_ALWAYS_INLINE void str(const fmt::formatter &format, char bracket, - char separator, const char (&formatString)[N], Ctx &ctx) const; + char separator, const char (&formatString)[N], + Ctx &ctx) const; private: ArrayType m_array; @@ -654,13 +655,14 @@ namespace librapid { template template void Transpose::str(const fmt::formatter &format, char bracket, - char separator, const char (&formatString)[N], Ctx &ctx) const { + char separator, const char (&formatString)[N], + Ctx &ctx) const { eval().str(format, bracket, separator, formatString, ctx); } }; // namespace array - template::value, int> = 0> + template + requires(typetraits::IsSizeType::value) auto transpose(T &&array, const ShapeType &axes = ShapeType()) { // If axes is empty, transpose the array in reverse order ShapeType newAxes = axes; diff --git a/librapid/include/librapid/core/typetraits.hpp b/librapid/include/librapid/core/typetraits.hpp index 28964131..c6a1fcb0 100644 --- a/librapid/include/librapid/core/typetraits.hpp +++ b/librapid/include/librapid/core/typetraits.hpp @@ -7,68 +7,65 @@ */ namespace librapid::typetraits { - template - using EnableIf = std::enable_if_t; - - template - constexpr bool IsSame = std::is_same::value; - - namespace impl { - /* - * These functions test for the presence of certain features of a type - * by providing two valid function overloads, but the preferred one - * (the one taking an integer) is only valid if the requested feature - * exists. The return type of both functions differ, and can be evaluated - * as "true" and "false" depending on the presence of the feature. - * - * This is really cool :) - */ - - template()[std::declval()])> - std::true_type testSubscript(int); - template - std::false_type testSubscript(float); - - template() + std::declval())> - std::true_type testAddition(int); - template - std::false_type testAddition(float); - - template() * std::declval())> - std::true_type testMultiplication(int); - template - std::false_type testMultiplication(float); - - template())> - std::true_type testCast(int); - template - std::false_type testCast(float); + template + constexpr bool IsSame = std::is_same::value; + + namespace impl { + /* + * These functions test for the presence of certain features of a type + * by providing two valid function overloads, but the preferred one + * (the one taking an integer) is only valid if the requested feature + * exists. The return type of both functions differ, and can be evaluated + * as "true" and "false" depending on the presence of the feature. + * + * This is really cool :) + */ + + template()[std::declval()])> + std::true_type testSubscript(int); + template + std::false_type testSubscript(float); + + template() + std::declval())> + std::true_type testAddition(int); + template + std::false_type testAddition(float); + + template() * std::declval())> + std::true_type testMultiplication(int); + template + std::false_type testMultiplication(float); + + template())> + std::true_type testCast(int); + template + std::false_type testCast(float); // Test for T::allowVectorisation (static constexpr bool) template std::true_type testAllowVectorisation(int); template std::false_type testAllowVectorisation(float); - } // namespace impl + } // namespace impl - template - struct HasSubscript : public decltype(impl::testSubscript(1)) {}; + template + struct HasSubscript : public decltype(impl::testSubscript(1)) {}; - template - struct HasAddition : public decltype(impl::testAddition(1)) {}; + template + struct HasAddition : public decltype(impl::testAddition(1)) {}; - template - struct HasMultiplication : public decltype(impl::testMultiplication(1)) {}; + template + struct HasMultiplication : public decltype(impl::testMultiplication(1)) {}; - template - struct CanCast : public decltype(impl::testCast(1)) {}; + template + struct CanCast : public decltype(impl::testCast(1)) {}; - // Detect whether a class can be default constructed - template - using TriviallyDefaultConstructible = std::is_trivially_default_constructible; + // Detect whether a class can be default constructed + template + using TriviallyDefaultConstructible = std::is_trivially_default_constructible; // Detect whether a class has a static constexpr bool member called allowVectorization template diff --git a/librapid/include/librapid/datastructures/bitset.hpp b/librapid/include/librapid/datastructures/bitset.hpp index e4a5fb40..422445d2 100644 --- a/librapid/include/librapid/datastructures/bitset.hpp +++ b/librapid/include/librapid/datastructures/bitset.hpp @@ -337,8 +337,8 @@ namespace librapid { const auto &data() const { return m_data; } auto &data() { return m_data; } - template, int> = 0> + template + requires(std::is_integral_v) LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int toInt() const { #if defined(LIBRAPID_DEBUG) static bool warned = false; diff --git a/librapid/include/librapid/ml/activations.hpp b/librapid/include/librapid/ml/activations.hpp index d0a3e802..7a12f7b4 100644 --- a/librapid/include/librapid/ml/activations.hpp +++ b/librapid/include/librapid/ml/activations.hpp @@ -131,10 +131,8 @@ namespace librapid::ml { } #if defined(LIBRAPID_HAS_OPENCL) - template< - typename ShapeType, typename StorageScalar, typename Src, - typename std::enable_if_t< - std::is_same_v::Backend, backend::OpenCL>, int> = 0> + template + requires(std::is_same_v::Backend, backend::OpenCL>) LIBRAPID_ALWAYS_INLINE void forward(array::ArrayContainer> &dst, const Src &src) const { @@ -145,10 +143,8 @@ namespace librapid::ml { src.storage().data()); } - template< - typename ShapeType, typename StorageScalar, typename Src, - typename std::enable_if_t< - std::is_same_v::Backend, backend::OpenCL>, int> = 0> + template + requires(std::is_same_v::Backend, backend::OpenCL>) LIBRAPID_ALWAYS_INLINE void backward(array::ArrayContainer> &dst, const Src &src) const { @@ -161,13 +157,11 @@ namespace librapid::ml { #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - template< - typename ShapeType, typename StorageScalar, typename Src, - typename std::enable_if_t< - std::is_same_v::Backend, backend::CUDA>, int> = 0> - LIBRAPID_ALWAYS_INLINE void - forward(array::ArrayContainer> &dst, - const Src &src) const { + template + require(std::is_same_v::Backend, backend::CUDA>) + LIBRAPID_ALWAYS_INLINE + void forward(array::ArrayContainer> &dst, + const Src &src) const { auto temp = evaluated(src); cuda::runKernel("activations", "sigmoidActivationForward", @@ -177,10 +171,8 @@ namespace librapid::ml { temp.storage().begin()); } - template< - typename ShapeType, typename StorageScalar, typename Src, - typename std::enable_if_t< - std::is_same_v::Backend, backend::CUDA>, int> = 0> + template + requires(std::is_same_v::Backend, backend::CUDA>) LIBRAPID_ALWAYS_INLINE void backward(array::ArrayContainer> &dst, const Src &src) const { diff --git a/librapid/include/librapid/simd/vecOps.hpp b/librapid/include/librapid/simd/vecOps.hpp index 17add3a1..d9326a32 100644 --- a/librapid/include/librapid/simd/vecOps.hpp +++ b/librapid/include/librapid/simd/vecOps.hpp @@ -16,7 +16,6 @@ namespace librapid { concept SIMD = IsSIMD::value; } // namespace typetraits -#define REQUIRE_SIMD(TYPE) typename std::enable_if_t::value, int> = 0 #define IS_FLOATING(TYPE) std::is_floating_point_v #define SIMD_OP_IMPL(OP) \