Skip to content

Commit

Permalink
Remove more enable_if instances
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Nov 7, 2023
1 parent 1a5469e commit cdfa7b5
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 101 deletions.
35 changes: 16 additions & 19 deletions librapid/include/librapid/array/linalg/arrayMultiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,16 +535,15 @@ namespace librapid {
/// \tparam T
/// \param val
/// \return
template<typename T, typename std::enable_if_t<!IsTransposeType<T>::value, int> = 0>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto transposeExtractor(T &&val) {
using Scalar = typename typetraits::TypeInfo<std::decay_t<T>>::Scalar;
return std::make_tuple(false, Scalar(1), std::forward<T>(val));
}

template<typename T, typename std::enable_if_t<IsTransposeType<T>::value, int> = 0>
template<typename T>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto transposeExtractor(T &&val) {
using Type = decltype(val.array());
return std::make_tuple(true, val.alpha(), std::forward<Type>(val.array()));
if constexpr (IsTransposeType<T>::value) {
using Type = decltype(val.array());
return std::make_tuple(true, val.alpha(), std::forward<Type>(val.array()));
} else {
using Scalar = typename typetraits::TypeInfo<std::decay_t<T>>::Scalar;
return std::make_tuple(false, Scalar(1), std::forward<T>(val));
}
}

/// Evaluates to true if the type is a multiply type.
Expand All @@ -561,25 +560,24 @@ namespace librapid {
/// \tparam T
/// \param val
/// \return
template<typename T, typename std::enable_if_t<!IsMultiplyType<T>::value, int> = 0>
template<typename T>
requires(!IsMultiplyType<T>::value)
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto multiplyExtractor(T &&val) {
using Scalar = typename typetraits::TypeInfo<std::decay_t<T>>::Scalar;
return std::make_tuple(Scalar(1), std::forward<T>(val));
}

template<typename Descriptor, typename Arr, typename Scalar,
typename std::enable_if_t<
typetraits::TypeInfo<Scalar>::type == detail::LibRapidType::Scalar, int> = 0>
template<typename Descriptor, typename Arr, typename Scalar>
requires(typetraits::TypeInfo<Scalar>::type == detail::LibRapidType::Scalar)
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto
multiplyExtractor(detail::Function<Descriptor, detail::Multiply, Arr, Scalar> &&val) {
using Type = decltype(std::get<0>(val.args()));
return std::make_tuple(std::get<1>(val.args()),
std::forward<Type>(std::get<0>(val.args())));
}

template<typename Descriptor, typename Arr, typename Scalar,
typename std::enable_if_t<
typetraits::TypeInfo<Scalar>::type == detail::LibRapidType::Scalar, int> = 0>
template<typename Descriptor, typename Arr, typename Scalar>
requires(typetraits::TypeInfo<Scalar>::type == detail::LibRapidType::Scalar)
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto
multiplyExtractor(detail::Function<Descriptor, detail::Multiply, Scalar, Arr> &&val) {
using Type = decltype(std::get<1>(val.args()));
Expand Down Expand Up @@ -634,9 +632,8 @@ namespace librapid {
/// StorageTypeB The storage type of the right input array. \param a The left input array.
/// \param b The right input array.
/// \return The dot product of the two input arrays.
template<
typename First, typename Second,
typename std::enable_if_t<IsArrayType<First>::value && IsArrayType<Second>::value, int> = 0>
template<typename First, typename Second>
requires(IsArrayType<First>::value && IsArrayType<Second>::value)
auto dot(First &&a, Second &&b) {
using ScalarA = typename typetraits::TypeInfo<std::decay_t<First>>::Scalar;
using ScalarB = typename typetraits::TypeInfo<std::decay_t<Second>>::Scalar;
Expand Down
20 changes: 11 additions & 9 deletions librapid/include/librapid/array/linalg/transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,11 @@ namespace librapid {
template<typename TransposeType>
class Transpose {
public:
using ArrayType = TransposeType;
using BaseType = typename std::decay_t<TransposeType>;
using Scalar = typename typetraits::TypeInfo<BaseType>::Scalar;
using ShapeType = typename BaseType::ShapeType;
using Backend = typename typetraits::TypeInfo<BaseType>::Backend;
using ArrayType = TransposeType;
using BaseType = typename std::decay_t<TransposeType>;
using Scalar = typename typetraits::TypeInfo<BaseType>::Scalar;
using ShapeType = typename BaseType::ShapeType;
using Backend = typename typetraits::TypeInfo<BaseType>::Backend;

static constexpr bool allowVectorisation =
typetraits::TypeInfo<Scalar>::allowVectorisation;
Expand Down Expand Up @@ -517,7 +517,8 @@ namespace librapid {

template<typename T, typename Char, size_t N, typename Ctx>
LIBRAPID_ALWAYS_INLINE void str(const fmt::formatter<T, Char> &format, char bracket,
char separator, const char (&formatString)[N], Ctx &ctx) const;
char separator, const char (&formatString)[N],
Ctx &ctx) const;

private:
ArrayType m_array;
Expand Down Expand Up @@ -654,13 +655,14 @@ namespace librapid {
template<typename TransposeType>
template<typename T, typename Char, size_t N, typename Ctx>
void Transpose<TransposeType>::str(const fmt::formatter<T, Char> &format, char bracket,
char separator, const char (&formatString)[N], Ctx &ctx) const {
char separator, const char (&formatString)[N],
Ctx &ctx) const {
eval().str(format, bracket, separator, formatString, ctx);
}
}; // namespace array

template<typename T, typename ShapeType = MatrixShape,
typename std::enable_if_t<typetraits::IsSizeType<ShapeType>::value, int> = 0>
template<typename T, typename ShapeType = MatrixShape>
requires(typetraits::IsSizeType<ShapeType>::value)
auto transpose(T &&array, const ShapeType &axes = ShapeType()) {
// If axes is empty, transpose the array in reverse order
ShapeType newAxes = axes;
Expand Down
99 changes: 48 additions & 51 deletions librapid/include/librapid/core/typetraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,68 +7,65 @@
*/

namespace librapid::typetraits {
template<bool Cond, typename T = int>
using EnableIf = std::enable_if_t<Cond, T>;

template<typename A, typename B>
constexpr bool IsSame = std::is_same<A, B>::value;

namespace impl {
/*
* These functions test for the presence of certain features of a type
* by providing two valid function overloads, but the preferred one
* (the one taking an integer) is only valid if the requested feature
* exists. The return type of both functions differ, and can be evaluated
* as "true" and "false" depending on the presence of the feature.
*
* This is really cool :)
*/

template<typename T, typename Index,
typename = decltype(std::declval<T &>()[std::declval<Index>()])>
std::true_type testSubscript(int);
template<typename T, typename Index>
std::false_type testSubscript(float);

template<typename T, typename V,
typename = decltype(std::declval<T &>() + std::declval<V &>())>
std::true_type testAddition(int);
template<typename T, typename V>
std::false_type testAddition(float);

template<typename T, typename V,
typename = decltype(std::declval<T &>() * std::declval<V &>())>
std::true_type testMultiplication(int);
template<typename T, typename V>
std::false_type testMultiplication(float);

template<typename From, typename To, typename = decltype((From)std::declval<From &>())>
std::true_type testCast(int);
template<typename From, typename To>
std::false_type testCast(float);
template<typename A, typename B>
constexpr bool IsSame = std::is_same<A, B>::value;

namespace impl {
/*
* These functions test for the presence of certain features of a type
* by providing two valid function overloads, but the preferred one
* (the one taking an integer) is only valid if the requested feature
* exists. The return type of both functions differ, and can be evaluated
* as "true" and "false" depending on the presence of the feature.
*
* This is really cool :)
*/

template<typename T, typename Index,
typename = decltype(std::declval<T &>()[std::declval<Index>()])>
std::true_type testSubscript(int);
template<typename T, typename Index>
std::false_type testSubscript(float);

template<typename T, typename V,
typename = decltype(std::declval<T &>() + std::declval<V &>())>
std::true_type testAddition(int);
template<typename T, typename V>
std::false_type testAddition(float);

template<typename T, typename V,
typename = decltype(std::declval<T &>() * std::declval<V &>())>
std::true_type testMultiplication(int);
template<typename T, typename V>
std::false_type testMultiplication(float);

template<typename From, typename To, typename = decltype((From)std::declval<From &>())>
std::true_type testCast(int);
template<typename From, typename To>
std::false_type testCast(float);

// Test for T::allowVectorisation (static constexpr bool)
template<typename T, typename = decltype(T::allowVectorisation)>
std::true_type testAllowVectorisation(int);
template<typename T>
std::false_type testAllowVectorisation(float);
} // namespace impl
} // namespace impl

template<typename T, typename Index = int64_t>
struct HasSubscript : public decltype(impl::testSubscript<T, Index>(1)) {};
template<typename T, typename Index = int64_t>
struct HasSubscript : public decltype(impl::testSubscript<T, Index>(1)) {};

template<typename T, typename V = T>
struct HasAddition : public decltype(impl::testAddition<T, V>(1)) {};
template<typename T, typename V = T>
struct HasAddition : public decltype(impl::testAddition<T, V>(1)) {};

template<typename T, typename V = T>
struct HasMultiplication : public decltype(impl::testMultiplication<T, V>(1)) {};
template<typename T, typename V = T>
struct HasMultiplication : public decltype(impl::testMultiplication<T, V>(1)) {};

template<typename From, typename To>
struct CanCast : public decltype(impl::testCast<From, To>(1)) {};
template<typename From, typename To>
struct CanCast : public decltype(impl::testCast<From, To>(1)) {};

// Detect whether a class can be default constructed
template<class T>
using TriviallyDefaultConstructible = std::is_trivially_default_constructible<T>;
// Detect whether a class can be default constructed
template<class T>
using TriviallyDefaultConstructible = std::is_trivially_default_constructible<T>;

// Detect whether a class has a static constexpr bool member called allowVectorization
template<typename T>
Expand Down
4 changes: 2 additions & 2 deletions librapid/include/librapid/datastructures/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ namespace librapid {
const auto &data() const { return m_data; }
auto &data() { return m_data; }

template<typename Integer = ElementType,
typename std::enable_if_t<std::is_integral_v<Integer>, int> = 0>
template<typename Integer = ElementType>
requires(std::is_integral_v<Integer>)
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int toInt() const {
#if defined(LIBRAPID_DEBUG)
static bool warned = false;
Expand Down
30 changes: 11 additions & 19 deletions librapid/include/librapid/ml/activations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,8 @@ namespace librapid::ml {
}

#if defined(LIBRAPID_HAS_OPENCL)
template<
typename ShapeType, typename StorageScalar, typename Src,
typename std::enable_if_t<
std::is_same_v<typename typetraits::TypeInfo<Src>::Backend, backend::OpenCL>, int> = 0>
template<typename ShapeType, typename StorageScalar, typename Src>
requires(std::is_same_v<typename typetraits::TypeInfo<Src>::Backend, backend::OpenCL>)
LIBRAPID_ALWAYS_INLINE void
forward(array::ArrayContainer<ShapeType, OpenCLStorage<StorageScalar>> &dst,
const Src &src) const {
Expand All @@ -145,10 +143,8 @@ namespace librapid::ml {
src.storage().data());
}

template<
typename ShapeType, typename StorageScalar, typename Src,
typename std::enable_if_t<
std::is_same_v<typename typetraits::TypeInfo<Src>::Backend, backend::OpenCL>, int> = 0>
template<typename ShapeType, typename StorageScalar, typename Src>
requires(std::is_same_v<typename typetraits::TypeInfo<Src>::Backend, backend::OpenCL>)
LIBRAPID_ALWAYS_INLINE void
backward(array::ArrayContainer<ShapeType, OpenCLStorage<StorageScalar>> &dst,
const Src &src) const {
Expand All @@ -161,13 +157,11 @@ namespace librapid::ml {
#endif // LIBRAPID_HAS_OPENCL

#if defined(LIBRAPID_HAS_CUDA)
template<
typename ShapeType, typename StorageScalar, typename Src,
typename std::enable_if_t<
std::is_same_v<typename typetraits::TypeInfo<Src>::Backend, backend::CUDA>, int> = 0>
LIBRAPID_ALWAYS_INLINE void
forward(array::ArrayContainer<ShapeType, CudaStorage<StorageScalar>> &dst,
const Src &src) const {
template<typename ShapeType, typename StorageScalar, typename Src>
require(std::is_same_v<typename typetraits::TypeInfo<Src>::Backend, backend::CUDA>)
LIBRAPID_ALWAYS_INLINE
void forward(array::ArrayContainer<ShapeType, CudaStorage<StorageScalar>> &dst,
const Src &src) const {
auto temp = evaluated(src);
cuda::runKernel<StorageScalar, StorageScalar>("activations",
"sigmoidActivationForward",
Expand All @@ -177,10 +171,8 @@ namespace librapid::ml {
temp.storage().begin());
}

template<
typename ShapeType, typename StorageScalar, typename Src,
typename std::enable_if_t<
std::is_same_v<typename typetraits::TypeInfo<Src>::Backend, backend::CUDA>, int> = 0>
template<typename ShapeType, typename StorageScalar, typename Src>
requires(std::is_same_v<typename typetraits::TypeInfo<Src>::Backend, backend::CUDA>)
LIBRAPID_ALWAYS_INLINE void
backward(array::ArrayContainer<ShapeType, CudaStorage<StorageScalar>> &dst,
const Src &src) const {
Expand Down
1 change: 0 additions & 1 deletion librapid/include/librapid/simd/vecOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace librapid {
concept SIMD = IsSIMD<T>::value;
} // namespace typetraits

#define REQUIRE_SIMD(TYPE) typename std::enable_if_t<typetraits::IsSIMD<TYPE>::value, int> = 0
#define IS_FLOATING(TYPE) std::is_floating_point_v<TYPE>

#define SIMD_OP_IMPL(OP) \
Expand Down

0 comments on commit cdfa7b5

Please sign in to comment.