Skip to content

Commit 5459084

Browse files
committed
Fixes for array op vector:
1 parent e168bc9 commit 5459084

File tree

10 files changed

+263
-122
lines changed

10 files changed

+263
-122
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,9 @@ add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/xsimd")
435435
target_compile_definitions(fmt PUBLIC FMT_HEADER_ONLY)
436436
target_link_libraries(${module_name} PUBLIC fmt xsimd)
437437

438+
# Disable "parameter passing for argument of type ... changed to match ..." for xsimd
439+
target_compile_options(xsimd INTERFACE -Wno-psabi)
440+
438441
if (${LIBRAPID_USE_MULTIPREC})
439442
# Load MPIR
440443
find_package(MPIR QUIET)

librapid/include/librapid/array/arrayContainer.hpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,13 @@ namespace librapid {
5656
static constexpr bool supportsArithmetic = TypeInfo<Scalar>::supportsArithmetic;
5757
static constexpr bool supportsLogical = TypeInfo<Scalar>::supportsLogical;
5858
static constexpr bool supportsBinary = TypeInfo<Scalar>::supportsBinary;
59-
static constexpr bool allowVectorisation = TypeInfo<Scalar>::packetWidth > 1;
59+
static constexpr bool allowVectorisation = []() {
60+
if constexpr (typetraits::HasAllowVectorisation<TypeInfo<Scalar>>::value) {
61+
return TypeInfo<Scalar>::allowVectorisation;
62+
} else {
63+
return TypeInfo<Scalar>::packetWidth > 1;
64+
}
65+
}();
6066

6167
#if defined(LIBRAPID_HAS_CUDA)
6268
static constexpr cudaDataType_t CudaType = TypeInfo<Scalar>::CudaType;
@@ -67,15 +73,6 @@ namespace librapid {
6773
static constexpr int64_t canMemcpy = false;
6874
};
6975

70-
/// Evaluates as true if the input type is an ArrayContainer instance
71-
/// \tparam T Input type
72-
template<typename T>
73-
struct IsArrayContainer : std::false_type {};
74-
75-
template<typename ShapeType, typename StorageScalar>
76-
struct IsArrayContainer<array::ArrayContainer<ShapeType, StorageScalar>> : std::true_type {
77-
};
78-
7976
LIBRAPID_DEFINE_AS_TYPE(typename StorageScalar,
8077
array::ArrayContainer<Shape COMMA StorageScalar>);
8178

@@ -686,7 +683,7 @@ namespace librapid {
686683
sizeof...(Indices),
687684
m_shape.ndim());
688685

689-
int dim = 0;
686+
int dim = 0;
690687
int64_t index = 0;
691688
for (int64_t i : {indices...}) {
692689
LIBRAPID_ASSERT(

librapid/include/librapid/array/assignOps.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,18 @@ namespace librapid {
2424
using Function = detail::Function<descriptor::Trivial, Functor_, Args...>;
2525
using Scalar =
2626
typename array::ArrayContainer<ShapeType_, Storage<StorageScalar>>::Scalar;
27-
constexpr int64_t packetWidth = typetraits::TypeInfo<Scalar>::packetWidth;
2827
constexpr bool allowVectorisation =
2928
typetraits::TypeInfo<
3029
detail::Function<descriptor::Trivial, Functor_, Args...>>::allowVectorisation &&
3130
Function::argsAreSameType;
31+
// constexpr int64_t packetWidth = typetraits::TypeInfo<Scalar>::packetWidth;
32+
constexpr int64_t packetWidth = []() {
33+
if constexpr (allowVectorisation) {
34+
return typetraits::TypeInfo<Scalar>::packetWidth;
35+
} else {
36+
return 1;
37+
}
38+
}();
3239

3340
const int64_t size = function.shape().size();
3441
const int64_t vectorSize = size - (size % packetWidth);

librapid/include/librapid/array/function.hpp

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,60 @@ namespace librapid {
3535
}
3636
}
3737

38+
// Normally, we want to use the scalar type of the input. This said, there are a few edge
39+
// cases where it is necessary to use the type itself.
40+
41+
// Default
42+
template<typename T>
43+
struct ScalarTypeHelper {
44+
using Type = typename TypeInfo<std::decay_t<T>>::Scalar;
45+
};
46+
47+
// Vectors
48+
template<typename T, uint64_t NumDims>
49+
struct ScalarTypeHelper<Vector<T, NumDims>> {
50+
using Type = Vector<T, NumDims>;
51+
};
52+
53+
// Once we have the correct scalar types, we need to check if the result is a lazy-evaluated
54+
// function. If so, we need to extract the actual return type from the function.
55+
56+
// Default
57+
template<typename T>
58+
struct ReturnTypeHelper {
59+
using Type = T;
60+
};
61+
62+
// Binary vector operation
63+
template<typename LHS, typename RHS, typename Op>
64+
struct ReturnTypeHelper<vectorDetail::BinaryVecOp<LHS, RHS, Op>> {
65+
using IntermediateType = vectorDetail::BinaryVecOp<LHS, RHS, Op>;
66+
using Type = decltype(std::declval<IntermediateType>().eval());
67+
};
68+
69+
// Unary vector operation
70+
template<typename Val, typename Op>
71+
struct ReturnTypeHelper<vectorDetail::UnaryVecOp<Val, Op>> {
72+
using IntermediateType = vectorDetail::UnaryVecOp<Val, Op>;
73+
using Type = decltype(std::declval<IntermediateType>().eval());
74+
};
75+
3876
template<typename desc, typename Functor_, typename... Args>
3977
struct TypeInfo<::librapid::detail::Function<desc, Functor_, Args...>> {
4078
static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction;
41-
using Scalar = decltype(std::declval<Functor_>()(
42-
std::declval<typename TypeInfo<std::decay_t<Args>>::Scalar>()...));
43-
using Packet = typename TypeInfo<Scalar>::Packet;
44-
using Backend = decltype(commonBackend<Args...>());
79+
// using Scalar = decltype(std::declval<Functor_>()(
80+
// std::declval<typename TypeInfo<std::decay_t<Args>>::Scalar>()...));
81+
82+
// using Scalar = decltype(std::declval<Functor_>()(
83+
// std::declval<typename ScalarTypeHelper<Args>::Type>()...));
84+
85+
using TempScalar = decltype(std::declval<Functor_>()(
86+
std::declval<typename ScalarTypeHelper<Args>::Type>()...));
87+
88+
using Scalar = typename ReturnTypeHelper<TempScalar>::Type;
89+
90+
using Packet = typename TypeInfo<Scalar>::Packet;
91+
using Backend = decltype(commonBackend<Args...>());
4592
using ShapeType =
4693
typename detail::ShapeTypeHelper<typename TypeInfo<Args>::ShapeType...>::Type;
4794

@@ -88,9 +135,11 @@ namespace librapid {
88135
return obj.scalar(index);
89136
}
90137

91-
template<typename T, typename std::enable_if_t<typetraits::TypeInfo<T>::type ==
92-
::librapid::detail::LibRapidType::Scalar,
93-
int> = 0>
138+
template<typename T,
139+
typename std::enable_if_t<
140+
typetraits::TypeInfo<T>::type == ::librapid::detail::LibRapidType::Scalar ||
141+
typetraits::TypeInfo<T>::type == ::librapid::detail::LibRapidType::Vector,
142+
int> = 0>
94143
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t) {
95144
return obj;
96145
}

librapid/include/librapid/array/operations.hpp

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,12 @@ namespace librapid {
523523
} // namespace typetraits
524524

525525
namespace detail {
526+
template<typename T, LibRapidType... ValidTypes>
527+
constexpr bool isType() {
528+
constexpr LibRapidType t = typetraits::TypeInfo<std::decay_t<T>>::type;
529+
return ((t == ValidTypes) || ...);
530+
}
531+
526532
template<typename VAL>
527533
constexpr bool isArrayOp() {
528534
return (typetraits::IsArrayContainer<std::decay_t<VAL>>::value ||
@@ -531,18 +537,63 @@ namespace librapid {
531537

532538
template<typename LHS, typename RHS>
533539
constexpr bool isArrayOpArray() {
534-
return (typetraits::TypeInfo<std::decay_t<LHS>>::type != LibRapidType::Scalar) &&
535-
(typetraits::TypeInfo<std::decay_t<RHS>>::type != LibRapidType::Scalar) &&
536-
typetraits::IsLibRapidType<std::decay_t<LHS>>::value &&
537-
typetraits::IsLibRapidType<std::decay_t<RHS>>::value;
540+
// return (typetraits::TypeInfo<std::decay_t<LHS>>::type != LibRapidType::Scalar) &&
541+
// (typetraits::TypeInfo<std::decay_t<RHS>>::type != LibRapidType::Scalar) &&
542+
// typetraits::IsLibRapidType<std::decay_t<LHS>>::value &&
543+
// typetraits::IsLibRapidType<std::decay_t<RHS>>::value;
544+
545+
// ArrayContainer,
546+
// ArrayFunction,
547+
// GeneralArrayView
548+
549+
constexpr bool lhsIsValid = isType<LHS,
550+
LibRapidType::ArrayContainer,
551+
LibRapidType::ArrayFunction,
552+
LibRapidType::GeneralArrayView>();
553+
554+
constexpr bool rhsIsValid = isType<RHS,
555+
LibRapidType::ArrayContainer,
556+
LibRapidType::ArrayFunction,
557+
LibRapidType::GeneralArrayView>();
558+
559+
return lhsIsValid && rhsIsValid;
538560
}
539561

540562
template<typename LHS, typename RHS>
541563
constexpr bool isArrayOpWithScalar() {
542-
return (typetraits::IsLibRapidType<std::decay_t<LHS>>::value &&
543-
typetraits::TypeInfo<std::decay_t<RHS>>::type == LibRapidType::Scalar) ||
544-
(typetraits::TypeInfo<std::decay_t<LHS>>::type == LibRapidType::Scalar &&
545-
typetraits::IsLibRapidType<std::decay_t<RHS>>::value);
564+
// // We allow operations with vectors here
565+
// return (typetraits::IsLibRapidType<std::decay_t<LHS>>::value &&
566+
// (
567+
// (typetraits::TypeInfo<std::decay_t<RHS>>::type == LibRapidType::Scalar) ||
568+
// (typetraits::TypeInfo<std::decay_t<RHS>>::type == LibRapidType::Vector)
569+
// )
570+
// ) ||
571+
// (
572+
// (
573+
// (typetraits::TypeInfo<std::decay_t<LHS>>::type == LibRapidType::Scalar) ||
574+
// (typetraits::TypeInfo<std::decay_t<LHS>>::type == LibRapidType::Vector)
575+
// ) &&
576+
// typetraits::IsLibRapidType<std::decay_t<RHS>>::value);
577+
578+
constexpr bool lhsIsArray = isType<LHS,
579+
LibRapidType::ArrayContainer,
580+
LibRapidType::ArrayFunction,
581+
LibRapidType::GeneralArrayView>();
582+
583+
constexpr bool rhsIsArray = isType<RHS,
584+
LibRapidType::ArrayContainer,
585+
LibRapidType::ArrayFunction,
586+
LibRapidType::GeneralArrayView>();
587+
588+
constexpr bool lhsIsScalar = isType<LHS,
589+
LibRapidType::Scalar,
590+
LibRapidType::Vector>();
591+
592+
constexpr bool rhsIsScalar = isType<RHS,
593+
LibRapidType::Scalar,
594+
LibRapidType::Vector>();
595+
596+
return (lhsIsArray ^ rhsIsArray) && (lhsIsScalar ^ rhsIsScalar);
546597
}
547598
} // namespace detail
548599

librapid/include/librapid/array/shape.hpp

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -776,34 +776,19 @@ namespace librapid {
776776
using Type = VectorShape;
777777
};
778778

779-
template<>
780-
struct ShapeTypeHelperImpl<Shape, std::false_type> {
781-
using Type = Shape;
779+
template<typename NonFalseType>
780+
struct ShapeTypeHelperImpl<NonFalseType, std::false_type> {
781+
using Type = NonFalseType;
782782
};
783783

784-
template<>
785-
struct ShapeTypeHelperImpl<std::false_type, Shape> {
784+
template<typename NonFalseType>
785+
struct ShapeTypeHelperImpl<std::false_type, NonFalseType> {
786786
using Type = Shape;
787787
};
788788

789789
template<>
790-
struct ShapeTypeHelperImpl<MatrixShape, std::false_type> {
791-
using Type = MatrixShape;
792-
};
793-
794-
template<>
795-
struct ShapeTypeHelperImpl<std::false_type, MatrixShape> {
796-
using Type = MatrixShape;
797-
};
798-
799-
template<>
800-
struct ShapeTypeHelperImpl<VectorShape, std::false_type> {
801-
using Type = VectorShape;
802-
};
803-
804-
template<>
805-
struct ShapeTypeHelperImpl<std::false_type, VectorShape> {
806-
using Type = VectorShape;
790+
struct ShapeTypeHelperImpl<std::false_type, std::false_type> {
791+
using Type = VectorShape; // Fastest
807792
};
808793

809794
template<typename... Args>

librapid/include/librapid/core/forward.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ namespace librapid {
2828
class ArrayContainer;
2929
}
3030

31+
namespace typetraits {
32+
/// Evaluates as true if the input type is an ArrayContainer instance
33+
/// \tparam T Input type
34+
template<typename T>
35+
struct IsArrayContainer : std::false_type {};
36+
37+
template<typename ShapeType, typename StorageScalar>
38+
struct IsArrayContainer<array::ArrayContainer<ShapeType, StorageScalar>> : std::true_type {
39+
};
40+
}
41+
3142
namespace detail {
3243
/// \brief Identifies which type of function is being used
3344
namespace descriptor {

librapid/include/librapid/core/typetraits.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ namespace librapid::typetraits {
4646
std::true_type testCast(int);
4747
template<typename From, typename To>
4848
std::false_type testCast(float);
49+
50+
// Test for T::allowVectorisation (static constexpr bool)
51+
template<typename T, typename = decltype(T::allowVectorisation)>
52+
std::true_type testAllowVectorisation(int);
53+
template<typename T>
54+
std::false_type testAllowVectorisation(float);
4955
} // namespace impl
5056

5157
template<typename T, typename Index = int64_t>
@@ -63,6 +69,10 @@ namespace librapid::typetraits {
6369
// Detect whether a class can be default constructed
6470
template<class T>
6571
using TriviallyDefaultConstructible = std::is_trivially_default_constructible<T>;
72+
73+
// Detect whether a class has a static constexpr bool member called allowVectorization
74+
template<typename T>
75+
struct HasAllowVectorisation : public decltype(impl::testAllowVectorisation<T>(1)) {};
6676
} // namespace librapid::typetraits
6777

6878
#endif // LIBRAPID_CORE_TYPETRAITS_HPP

librapid/include/librapid/math/vectorForward.hpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@ namespace librapid {
1717

1818
template<typename Storage0, typename Storage1>
1919
auto vectorStorageTypeMerger() {
20-
using Scalar0 = typename typetraits::TypeInfo<Storage0>::Scalar;
21-
using Scalar1 = typename typetraits::TypeInfo<Storage1>::Scalar;
20+
using Scalar0 = typename typetraits::TypeInfo<Storage0>::Scalar;
21+
using Scalar1 = typename typetraits::TypeInfo<Storage1>::Scalar;
2222
static constexpr uint64_t packetWidth0 = typetraits::TypeInfo<Scalar0>::packetWidth;
2323
static constexpr uint64_t packetWidth1 = typetraits::TypeInfo<Scalar1>::packetWidth;
24-
if constexpr (packetWidth0 > 1 && packetWidth1 > 1) {
24+
if constexpr (typetraits::TypeInfo<Storage0>::type == detail::LibRapidType::Scalar) {
25+
return Storage1 {};
26+
} else if constexpr (typetraits::TypeInfo<Storage1>::type ==
27+
detail::LibRapidType::Scalar) {
28+
return Storage0 {};
29+
} else if constexpr (packetWidth0 > 1 && packetWidth1 > 1) {
2530
return SimdVectorStorage<typename Storage0::Scalar, Storage0::dims> {};
2631
} else {
2732
return GenericVectorStorage<typename Storage0::Scalar, Storage0::dims> {};
@@ -50,7 +55,9 @@ namespace librapid {
5055
return static_cast<Derived &>(*this);
5156
}
5257

53-
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const { return derived(); }
58+
// LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE virtual Derived eval() const {
59+
// return derived();
60+
// }
5461

5562
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE virtual IndexTypeConst
5663
operator[](int64_t index) const {
@@ -119,8 +126,16 @@ namespace librapid {
119126
LIBRAPID_ALWAYS_INLINE void assign(Vector<Scalar, N> &dst, const UnaryVecOp<Val, Op> &src);
120127
} // namespace vectorDetail
121128

122-
template<typename ScalarType, uint64_t NumDims>
123-
class Vector;
129+
namespace typetraits {
130+
LIBRAPID_DEFINE_AS_TYPE(typename ScalarType COMMA uint64_t NumDims,
131+
Vector<ScalarType COMMA NumDims>);
132+
133+
LIBRAPID_DEFINE_AS_TYPE(typename LHS COMMA typename RHS COMMA typename Op,
134+
vectorDetail::BinaryVecOp<LHS COMMA RHS COMMA Op>);
135+
136+
LIBRAPID_DEFINE_AS_TYPE(typename Val COMMA typename Op,
137+
vectorDetail::UnaryVecOp<Val COMMA Op>);
138+
} // namespace typetraits
124139
} // namespace librapid
125140

126141
#endif // LIBRAPID_MATH_VECTOR_FORWARD_HPP

0 commit comments

Comments
 (0)