diff --git a/librapid/include/librapid/array/arrayContainer.hpp b/librapid/include/librapid/array/arrayContainer.hpp index 0f03eccb..cfd9e6a8 100644 --- a/librapid/include/librapid/array/arrayContainer.hpp +++ b/librapid/include/librapid/array/arrayContainer.hpp @@ -943,6 +943,8 @@ namespace librapid { } } // namespace array + // template + namespace detail { template struct IsArrayType { diff --git a/librapid/include/librapid/array/linalg/transpose.hpp b/librapid/include/librapid/array/linalg/transpose.hpp index 6a303068..db4f5291 100644 --- a/librapid/include/librapid/array/linalg/transpose.hpp +++ b/librapid/include/librapid/array/linalg/transpose.hpp @@ -168,8 +168,59 @@ namespace librapid { _mm_storeu_pd(out + 1 * cols, _mm_mul_pd(tmp1Unpck, alphaVec)); } -# endif // LIBRAPID_MSVC -#endif // LIBRAPID_NATIVE_ARCH +# elif defined(LIBRAPID_NEON) +# define LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE 2 +# define LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE 4 + + template + LIBRAPID_ALWAYS_INLINE void transposeFloatKernel(float *__restrict out, + float *__restrict in, Alpha alpha, + int64_t cols) { + float32x4_t r0, r1, r2, r3; + float32x4_t t0, t1, t2, t3; + + r0 = vld1q_f32(&in[0 * cols]); + r1 = vld1q_f32(&in[1 * cols]); + r2 = vld1q_f32(&in[2 * cols]); + r3 = vld1q_f32(&in[3 * cols]); + + t0 = vzip1q_f32(r0, r1); + t1 = vzip2q_f32(r0, r1); + t2 = vzip1q_f32(r2, r3); + t3 = vzip2q_f32(r2, r3); + + r0 = vcombine_f32(vget_low_f32(t0), vget_low_f32(t2)); + r1 = vcombine_f32(vget_high_f32(t0), vget_high_f32(t2)); + r2 = vcombine_f32(vget_low_f32(t1), vget_low_f32(t3)); + r3 = vcombine_f32(vget_high_f32(t1), vget_high_f32(t3)); + + float32x4_t alphaVec = vdupq_n_f32(alpha); + + vst1q_f32(&out[0 * cols], vmulq_f32(r0, alphaVec)); + vst1q_f32(&out[1 * cols], vmulq_f32(r1, alphaVec)); + vst1q_f32(&out[2 * cols], vmulq_f32(r2, alphaVec)); + vst1q_f32(&out[3 * cols], vmulq_f32(r3, alphaVec)); + } + + template + LIBRAPID_ALWAYS_INLINE void transposeDoubleKernel(double *__restrict out, + double *__restrict in, Alpha alpha, + int64_t cols) { + float64x2_t r0, r1; + + r0 = vld1q_f64(&in[0 * cols]); + r1 = vld1q_f64(&in[1 * cols]); + + float64x2_t t0 = vzip1q_f64(r0, r1); + float64x2_t t1 = vzip2q_f64(r0, r1); + + float64x2_t alphaVec = vdupq_n_f64(alpha); + + vst1q_f64(&out[0 * cols], vmulq_f64(t0, alphaVec)); + vst1q_f64(&out[1 * cols], vmulq_f64(t1, alphaVec)); + } +# endif +#endif // LIBRAPID_NATIVE_ARCH // Ensure the kernel size is always defined, even if the above code doesn't define it #ifndef LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE @@ -310,7 +361,7 @@ namespace librapid { } } } -#endif // LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE > 0 +#endif // LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE > 0 } // namespace cpu #if defined(LIBRAPID_HAS_OPENCL) @@ -431,8 +482,8 @@ namespace librapid { rows)); } } // namespace cuda -#endif // LIBRAPID_HAS_CUDA - } // namespace detail +#endif // LIBRAPID_HAS_CUDA + } // namespace detail namespace array { template diff --git a/librapid/include/librapid/core/config.hpp b/librapid/include/librapid/core/config.hpp index 3ed090fc..d8cb419c 100644 --- a/librapid/include/librapid/core/config.hpp +++ b/librapid/include/librapid/core/config.hpp @@ -146,16 +146,17 @@ #endif // Instruction sets -#define ARCH_AVX512_2 10 -#define ARCH_AVX512 9 -#define ARCH_AVX2 8 -#define ARCH_AVX 7 -#define ARCH_SSE4_2 6 -#define ARCH_SSE4_1 5 -#define ARCH_SSSE3 4 -#define ARCH_SSE3 3 -#define ARCH_SSE2 2 -#define ARCH_SSE 1 +#define ARCH_AVX512_2 11 +#define ARCH_AVX512 10 +#define ARCH_AVX2 9 +#define ARCH_AVX 8 +#define ARCH_SSE4_2 7 +#define ARCH_SSE4_1 6 +#define ARCH_SSSE3 5 +#define ARCH_SSE3 4 +#define ARCH_SSE2 3 +#define ARCH_SSE 2 +#define ARCH_NEON 1 #define ARCH_NONE 0 // Instruction set detection @@ -485,4 +486,4 @@ namespace librapid::backend { // Code to be run *before* main() #include "preMain.hpp" -#endif // LIBRAPID_CORE_CONFIG_HPP \ No newline at end of file +#endif // LIBRAPID_CORE_CONFIG_HPP diff --git a/librapid/include/librapid/core/log.hpp b/librapid/include/librapid/core/log.hpp index 2d73eea2..4a327b92 100644 --- a/librapid/include/librapid/core/log.hpp +++ b/librapid/include/librapid/core/log.hpp @@ -20,7 +20,26 @@ namespace librapid::assert { (int)signature.length(), (int)strlen("ASSERTION FAILED")); - std::string formatted = fmt::format( + // std::string formatted = fmt::format( + // "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " + // "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " + // "{4:>{10}}]\n{5}\n", + // "ASSERTION FAILED", + // filename, + // signature, + // line, + // conditionString, + // formattedMessage, + // maxLen + 14, + // maxLen + 9, + // maxLen + 5, + // maxLen + 9, + // maxLen + 4); + + // fmt::print(fmt::fg(fmt::color::red), formatted); + // fmt::vprint(fmt::fg(fmt::color::red), formatted); + + fmt::print(fmt::fg(fmt::color::red), "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " "{4:>{10}}]\n{5}\n", @@ -35,8 +54,6 @@ namespace librapid::assert { maxLen + 5, maxLen + 9, maxLen + 4); - - fmt::print(fmt::fg(fmt::color::red), formatted); } throw RaiseType(formattedMessage); diff --git a/librapid/src/global.cpp b/librapid/src/global.cpp index 2bfc7acd..5f3caff8 100644 --- a/librapid/src/global.cpp +++ b/librapid/src/global.cpp @@ -55,7 +55,10 @@ namespace librapid { // OpenBLAS threading #if defined(LIBRAPID_BLAS_OPENBLAS) openblas_set_num_threads((int)numThreads); + + #if defined(_OPENMP) omp_set_num_threads((int)numThreads); + #endif // _OPENMP goto_set_num_threads((int)numThreads); setOpenBLASThreadsEnv((int)numThreads); diff --git a/librapid/src/openclConfigure.cpp b/librapid/src/openclConfigure.cpp index 17955463..3d551530 100644 --- a/librapid/src/openclConfigure.cpp +++ b/librapid/src/openclConfigure.cpp @@ -29,15 +29,6 @@ __kernel void testAddition(__global const float *a, __global const float *b, __g cl::Program program(context, sources); err = program.build(); - // if (err != CL_SUCCESS) { - // auto format = fmt::fg(fmt::color::red) | fmt::emphasis::bold; - // fmt::print(format, - // "Error compiling test program: {}\n", - // program.getBuildInfo(device)); - // fmt::print(format, "Error Code [{}]: {}\n", err, opencl::getOpenCLErrorString(err)); - // return false; - // } - // Check the build status cl_build_status buildStatus = program.getBuildInfo(device);