Skip to content

Commit

Permalink
Merge pull request #250 from LibRapid/test
Browse files Browse the repository at this point in the history
Bug Fixes
  • Loading branch information
Pencilcaseman authored Oct 18, 2023

Verified

This commit was signed with the committer’s verified signature.
zoedberg Zoe Faltibà
2 parents f345704 + df0bab4 commit 20d4bab
Showing 13 changed files with 455 additions and 249 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -435,6 +435,9 @@ add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/xsimd")
target_compile_definitions(fmt PUBLIC FMT_HEADER_ONLY)
target_link_libraries(${module_name} PUBLIC fmt xsimd)

# Disable "parameter passing for argument of type ... changed to match ..." for xsimd
target_compile_options(xsimd INTERFACE $<$<OR:$<CXX_COMPILER_ID:GNU>,$<CXX_COMPILER_ID:Clang>>:-Wno-psabi>)

if (${LIBRAPID_USE_MULTIPREC})
# Load MPIR
find_package(MPIR QUIET)
19 changes: 8 additions & 11 deletions librapid/include/librapid/array/arrayContainer.hpp
Original file line number Diff line number Diff line change
@@ -56,7 +56,13 @@ namespace librapid {
static constexpr bool supportsArithmetic = TypeInfo<Scalar>::supportsArithmetic;
static constexpr bool supportsLogical = TypeInfo<Scalar>::supportsLogical;
static constexpr bool supportsBinary = TypeInfo<Scalar>::supportsBinary;
static constexpr bool allowVectorisation = TypeInfo<Scalar>::packetWidth > 1;
static constexpr bool allowVectorisation = []() {
if constexpr (typetraits::HasAllowVectorisation<TypeInfo<Scalar>>::value) {
return TypeInfo<Scalar>::allowVectorisation;
} else {
return TypeInfo<Scalar>::packetWidth > 1;
}
}();

#if defined(LIBRAPID_HAS_CUDA)
static constexpr cudaDataType_t CudaType = TypeInfo<Scalar>::CudaType;
@@ -67,15 +73,6 @@ namespace librapid {
static constexpr int64_t canMemcpy = false;
};

/// Evaluates as true if the input type is an ArrayContainer instance
/// \tparam T Input type
template<typename T>
struct IsArrayContainer : std::false_type {};

template<typename ShapeType, typename StorageScalar>
struct IsArrayContainer<array::ArrayContainer<ShapeType, StorageScalar>> : std::true_type {
};

LIBRAPID_DEFINE_AS_TYPE(typename StorageScalar,
array::ArrayContainer<Shape COMMA StorageScalar>);

@@ -686,7 +683,7 @@ namespace librapid {
sizeof...(Indices),
m_shape.ndim());

int dim = 0;
int dim = 0;
int64_t index = 0;
for (int64_t i : {indices...}) {
LIBRAPID_ASSERT(
38 changes: 32 additions & 6 deletions librapid/include/librapid/array/assignOps.hpp
Original file line number Diff line number Diff line change
@@ -24,11 +24,17 @@ namespace librapid {
using Function = detail::Function<descriptor::Trivial, Functor_, Args...>;
using Scalar =
typename array::ArrayContainer<ShapeType_, Storage<StorageScalar>>::Scalar;
constexpr int64_t packetWidth = typetraits::TypeInfo<Scalar>::packetWidth;
constexpr bool allowVectorisation =
typetraits::TypeInfo<
detail::Function<descriptor::Trivial, Functor_, Args...>>::allowVectorisation &&
Function::argsAreSameType;
constexpr int64_t packetWidth = []() {
if constexpr (allowVectorisation) {
return typetraits::TypeInfo<Scalar>::packetWidth;
} else {
return 1;
}
}();

const int64_t size = function.shape().size();
const int64_t vectorSize = size - (size % packetWidth);
@@ -74,13 +80,21 @@ namespace librapid {
using Scalar =
typename array::ArrayContainer<ShapeType_,
FixedStorage<StorageScalar, StorageSize...>>::Scalar;
constexpr int64_t packetWidth = typetraits::TypeInfo<Scalar>::packetWidth;
constexpr int64_t elements = ::librapid::product<StorageSize...>();
constexpr int64_t vectorSize = elements - (elements % packetWidth);

constexpr bool allowVectorisation =
typetraits::TypeInfo<
detail::Function<descriptor::Trivial, Functor_, Args...>>::allowVectorisation &&
Function::argsAreSameType;
constexpr int64_t packetWidth = []() {
if constexpr (allowVectorisation) {
return typetraits::TypeInfo<Scalar>::packetWidth;
} else {
return 1;
}
}();

constexpr int64_t elements = ::librapid::product<StorageSize...>();
constexpr int64_t vectorSize = elements - (elements % packetWidth);

// Ensure the function can actually be assigned to the array container
static_assert(
@@ -124,12 +138,18 @@ namespace librapid {
using Function = detail::Function<descriptor::Trivial, Functor_, Args...>;
using Scalar =
typename array::ArrayContainer<ShapeType_, Storage<StorageScalar>>::Scalar;
constexpr size_t packetWidth = typetraits::TypeInfo<Scalar>::packetWidth;

constexpr bool allowVectorisation =
typetraits::TypeInfo<
detail::Function<descriptor::Trivial, Functor_, Args...>>::allowVectorisation &&
Function::argsAreSameType;
constexpr int64_t packetWidth = []() {
if constexpr (allowVectorisation) {
return typetraits::TypeInfo<Scalar>::packetWidth;
} else {
return 1;
}
}();

const size_t size = function.size();
const size_t vectorSize = size - (size % packetWidth);
@@ -180,12 +200,18 @@ namespace librapid {
using Scalar =
typename array::ArrayContainer<ShapeType_,
FixedStorage<StorageScalar, StorageSize...>>::Scalar;
constexpr int64_t packetWidth = typetraits::TypeInfo<Scalar>::packetWidth;

constexpr bool allowVectorisation =
typetraits::TypeInfo<
detail::Function<descriptor::Trivial, Functor_, Args...>>::allowVectorisation &&
Function::argsAreSameType;
constexpr int64_t packetWidth = []() {
if constexpr (allowVectorisation) {
return typetraits::TypeInfo<Scalar>::packetWidth;
} else {
return 1;
}
}();

constexpr int64_t size = ::librapid::product<StorageSize...>();
constexpr int64_t vectorSize = size - (size % packetWidth);
63 changes: 53 additions & 10 deletions librapid/include/librapid/array/function.hpp
Original file line number Diff line number Diff line change
@@ -35,13 +35,60 @@ namespace librapid {
}
}

// Normally, we want to use the scalar type of the input. This said, there are a few edge
// cases where it is necessary to use the type itself.

// Default
template<typename T>
struct ScalarTypeHelper {
using Type = typename TypeInfo<std::decay_t<T>>::Scalar;
};

// Vectors
template<typename T, uint64_t NumDims>
struct ScalarTypeHelper<Vector<T, NumDims>> {
using Type = Vector<T, NumDims>;
};

// Once we have the correct scalar types, we need to check if the result is a lazy-evaluated
// function. If so, we need to extract the actual return type from the function.

// Default
template<typename T>
struct ReturnTypeHelper {
using Type = T;
};

// Binary vector operation
template<typename LHS, typename RHS, typename Op>
struct ReturnTypeHelper<vectorDetail::BinaryVecOp<LHS, RHS, Op>> {
using IntermediateType = vectorDetail::BinaryVecOp<LHS, RHS, Op>;
using Type = decltype(std::declval<IntermediateType>().eval());
};

// Unary vector operation
template<typename Val, typename Op>
struct ReturnTypeHelper<vectorDetail::UnaryVecOp<Val, Op>> {
using IntermediateType = vectorDetail::UnaryVecOp<Val, Op>;
using Type = decltype(std::declval<IntermediateType>().eval());
};

template<typename desc, typename Functor_, typename... Args>
struct TypeInfo<::librapid::detail::Function<desc, Functor_, Args...>> {
static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction;
using Scalar = decltype(std::declval<Functor_>()(
std::declval<typename TypeInfo<std::decay_t<Args>>::Scalar>()...));
using Packet = typename TypeInfo<Scalar>::Packet;
using Backend = decltype(commonBackend<Args...>());
// using Scalar = decltype(std::declval<Functor_>()(
// std::declval<typename TypeInfo<std::decay_t<Args>>::Scalar>()...));

// using Scalar = decltype(std::declval<Functor_>()(
// std::declval<typename ScalarTypeHelper<Args>::Type>()...));

using TempScalar = decltype(std::declval<Functor_>()(
std::declval<typename ScalarTypeHelper<Args>::Type>()...));

using Scalar = typename ReturnTypeHelper<TempScalar>::Type;

using Packet = typename TypeInfo<Scalar>::Packet;
using Backend = decltype(commonBackend<Args...>());
using ShapeType =
typename detail::ShapeTypeHelper<typename TypeInfo<Args>::ShapeType...>::Type;

@@ -81,16 +128,12 @@ namespace librapid {
return Packet(obj);
}

template<typename T, typename std::enable_if_t<typetraits::TypeInfo<T>::type !=
::librapid::detail::LibRapidType::Scalar,
int> = 0>
template<typename T, typename std::enable_if_t<detail::IsArrayType<T>::val, int> = 0>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t index) {
return obj.scalar(index);
}

template<typename T, typename std::enable_if_t<typetraits::TypeInfo<T>::type ==
::librapid::detail::LibRapidType::Scalar,
int> = 0>
template<typename T, typename std::enable_if_t<!detail::IsArrayType<T>::val, int> = 0>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t) {
return obj;
}
96 changes: 80 additions & 16 deletions librapid/include/librapid/array/operations.hpp
Original file line number Diff line number Diff line change
@@ -67,20 +67,37 @@
return std::get<0>(args).shape(); \
}

//#define LIBRAPID_BINARY_SHAPE_EXTRACTOR \
// template<typename First, typename Second> \
// LIBRAPID_NODISCARD static LIBRAPID_ALWAYS_INLINE auto getShapeImpl( \
// const std::tuple<First, Second> &tup) { \
// if constexpr (TypeInfo<std::decay_t<First>>::type != detail::LibRapidType::Scalar && \
// TypeInfo<std::decay_t<Second>>::type != detail::LibRapidType::Scalar) { \
// LIBRAPID_ASSERT(std::get<0>(tup).shape() == std::get<1>(tup).shape(), \
// "Shapes must match for binary operations"); \
// return std::get<0>(tup).shape(); \
// } else if constexpr (TypeInfo<std::decay_t<First>>::type == \
// detail::LibRapidType::Scalar) { \
// return std::get<1>(tup).shape(); \
// } else { \
// return std::get<0>(tup).shape(); \
// } \
// } \

#define LIBRAPID_BINARY_SHAPE_EXTRACTOR \
template<typename First, typename Second> \
LIBRAPID_NODISCARD static LIBRAPID_ALWAYS_INLINE auto getShapeImpl( \
const std::tuple<First, Second> &tup) { \
if constexpr (TypeInfo<std::decay_t<First>>::type != detail::LibRapidType::Scalar && \
TypeInfo<std::decay_t<Second>>::type != detail::LibRapidType::Scalar) { \
LIBRAPID_ASSERT(std::get<0>(tup).shape() == std::get<1>(tup).shape(), \
"Shapes must match for binary operations"); \
if constexpr (IsArrayType<std::decay_t<First>>::value) { \
if constexpr (IsArrayType<std::decay_t<Second>>::value) { \
LIBRAPID_ASSERT(std::get<0>(tup).shape() == std::get<1>(tup).shape(), \
"Shapes must match for binary operations"); \
return std::get<0>(tup).shape(); \
} \
return std::get<0>(tup).shape(); \
} else if constexpr (TypeInfo<std::decay_t<First>>::type == \
detail::LibRapidType::Scalar) { \
} else if constexpr (IsArrayType<std::decay_t<Second>>::value) { \
return std::get<1>(tup).shape(); \
} else { \
return std::get<0>(tup).shape(); \
} \
} \
\
@@ -523,6 +540,12 @@ namespace librapid {
} // namespace typetraits

namespace detail {
template<typename T, LibRapidType... ValidTypes>
constexpr bool isType() {
constexpr LibRapidType t = typetraits::TypeInfo<std::decay_t<T>>::type;
return ((t == ValidTypes) || ...);
}

template<typename VAL>
constexpr bool isArrayOp() {
return (typetraits::IsArrayContainer<std::decay_t<VAL>>::value ||
@@ -531,18 +554,59 @@ namespace librapid {

template<typename LHS, typename RHS>
constexpr bool isArrayOpArray() {
return (typetraits::TypeInfo<std::decay_t<LHS>>::type != LibRapidType::Scalar) &&
(typetraits::TypeInfo<std::decay_t<RHS>>::type != LibRapidType::Scalar) &&
typetraits::IsLibRapidType<std::decay_t<LHS>>::value &&
typetraits::IsLibRapidType<std::decay_t<RHS>>::value;
// return (typetraits::TypeInfo<std::decay_t<LHS>>::type != LibRapidType::Scalar) &&
// (typetraits::TypeInfo<std::decay_t<RHS>>::type != LibRapidType::Scalar) &&
// typetraits::IsLibRapidType<std::decay_t<LHS>>::value &&
// typetraits::IsLibRapidType<std::decay_t<RHS>>::value;

// ArrayContainer,
// ArrayFunction,
// GeneralArrayView

constexpr bool lhsIsValid = isType<LHS,
LibRapidType::ArrayContainer,
LibRapidType::ArrayFunction,
LibRapidType::GeneralArrayView>();

constexpr bool rhsIsValid = isType<RHS,
LibRapidType::ArrayContainer,
LibRapidType::ArrayFunction,
LibRapidType::GeneralArrayView>();

return lhsIsValid && rhsIsValid;
}

template<typename LHS, typename RHS>
constexpr bool isArrayOpWithScalar() {
return (typetraits::IsLibRapidType<std::decay_t<LHS>>::value &&
typetraits::TypeInfo<std::decay_t<RHS>>::type == LibRapidType::Scalar) ||
(typetraits::TypeInfo<std::decay_t<LHS>>::type == LibRapidType::Scalar &&
typetraits::IsLibRapidType<std::decay_t<RHS>>::value);
// // We allow operations with vectors here
// return (typetraits::IsLibRapidType<std::decay_t<LHS>>::value &&
// (
// (typetraits::TypeInfo<std::decay_t<RHS>>::type == LibRapidType::Scalar) ||
// (typetraits::TypeInfo<std::decay_t<RHS>>::type == LibRapidType::Vector)
// )
// ) ||
// (
// (
// (typetraits::TypeInfo<std::decay_t<LHS>>::type == LibRapidType::Scalar) ||
// (typetraits::TypeInfo<std::decay_t<LHS>>::type == LibRapidType::Vector)
// ) &&
// typetraits::IsLibRapidType<std::decay_t<RHS>>::value);

constexpr bool lhsIsArray = isType<LHS,
LibRapidType::ArrayContainer,
LibRapidType::ArrayFunction,
LibRapidType::GeneralArrayView>();

constexpr bool rhsIsArray = isType<RHS,
LibRapidType::ArrayContainer,
LibRapidType::ArrayFunction,
LibRapidType::GeneralArrayView>();

constexpr bool lhsIsScalar = isType<LHS, LibRapidType::Scalar, LibRapidType::Vector>();

constexpr bool rhsIsScalar = isType<RHS, LibRapidType::Scalar, LibRapidType::Vector>();

return (lhsIsArray ^ rhsIsArray) && (lhsIsScalar ^ rhsIsScalar);
}
} // namespace detail

29 changes: 7 additions & 22 deletions librapid/include/librapid/array/shape.hpp
Original file line number Diff line number Diff line change
@@ -776,34 +776,19 @@ namespace librapid {
using Type = VectorShape;
};

template<>
struct ShapeTypeHelperImpl<Shape, std::false_type> {
using Type = Shape;
template<typename NonFalseType>
struct ShapeTypeHelperImpl<NonFalseType, std::false_type> {
using Type = NonFalseType;
};

template<>
struct ShapeTypeHelperImpl<std::false_type, Shape> {
template<typename NonFalseType>
struct ShapeTypeHelperImpl<std::false_type, NonFalseType> {
using Type = Shape;
};

template<>
struct ShapeTypeHelperImpl<MatrixShape, std::false_type> {
using Type = MatrixShape;
};

template<>
struct ShapeTypeHelperImpl<std::false_type, MatrixShape> {
using Type = MatrixShape;
};

template<>
struct ShapeTypeHelperImpl<VectorShape, std::false_type> {
using Type = VectorShape;
};

template<>
struct ShapeTypeHelperImpl<std::false_type, VectorShape> {
using Type = VectorShape;
struct ShapeTypeHelperImpl<std::false_type, std::false_type> {
using Type = VectorShape; // Fastest
};

template<typename... Args>
Loading
Oops, something went wrong.

0 comments on commit 20d4bab

Please sign in to comment.