From f10fe09c3454aea8469c80111170500beba70d88 Mon Sep 17 00:00:00 2001 From: Yang Hau Date: Sat, 27 Jan 2024 01:43:23 +0800 Subject: [PATCH] feat: Add _mm_[ceil|floor]_[pd|ps|sd|ss] --- sse2rvv.h | 97 ++++++++++++++++++++++++--- tests/impl.cpp | 178 ++++++++++++++++++++++++------------------------- 2 files changed, 178 insertions(+), 97 deletions(-) diff --git a/sse2rvv.h b/sse2rvv.h index 10ba885..be3e0dc 100644 --- a/sse2rvv.h +++ b/sse2rvv.h @@ -45,6 +45,7 @@ #define _sse2rvv_const const #endif +#include #include #include #include @@ -560,13 +561,53 @@ FORCE_INLINE __m128 _mm_castsi128_ps(__m128i a) { return __riscv_vreinterpret_v_i32m1_f32m1(a); } -// FORCE_INLINE __m128d _mm_ceil_pd (__m128d a) {} +FORCE_INLINE __m128d _mm_ceil_pd(__m128d a) { + // FIXME riscv round doesn't work + vfloat64m1_t _a = vreinterpretq_m128d_f64(a); + double arr[2]; + const int len = 2; + __riscv_vse64_v_f64m1(arr, _a, len); + for (int i = 0; i < len; i++) { + arr[i] = ceil(arr[i]); + } + return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(arr, len)); +} -// FORCE_INLINE __m128 _mm_ceil_ps (__m128 a) {} +FORCE_INLINE __m128 _mm_ceil_ps(__m128 a) { + // FIXME riscv round doesn't work + vfloat32m1_t _a = vreinterpretq_m128_f32(a); + float arr[4]; + const int len = 4; + __riscv_vse32_v_f32m1(arr, _a, len); + for (int i = 0; i < len; i++) { + arr[i] = ceil(arr[i]); + } + return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(arr, len)); +} -// FORCE_INLINE __m128d _mm_ceil_sd (__m128d a, __m128d b) {} +FORCE_INLINE __m128d _mm_ceil_sd(__m128d a, __m128d b) { + // FIXME riscv round doesn't work + vfloat64m1_t _a = vreinterpretq_m128d_f64(a); + vfloat64m1_t _b = vreinterpretq_m128d_f64(b); + double arr[2]; + const int len = 2; + __riscv_vse64_v_f64m1(arr, _b, len); + arr[0] = ceil(arr[0]); + vfloat64m1_t _arr = __riscv_vle64_v_f64m1(arr, 1); + return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, _arr, 0, 1)); +} -// FORCE_INLINE __m128 _mm_ceil_ss (__m128 a, __m128 b) {} +FORCE_INLINE __m128 _mm_ceil_ss(__m128 a, __m128 b) { + // FIXME riscv round doesn't work + vfloat32m1_t _a = vreinterpretq_m128_f32(a); + vfloat32m1_t _b = vreinterpretq_m128_f32(b); + float arr[4]; + const int len = 4; + __riscv_vse32_v_f32m1(arr, _b, len); + arr[0] = ceil(arr[0]); + vfloat32m1_t _arr = __riscv_vle32_v_f32m1(arr, 1); + return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, _arr, 0, 1)); +} // FORCE_INLINE void _mm_clflush (void const* p) {} @@ -1460,13 +1501,53 @@ FORCE_INLINE int _mm_extract_ps(__m128 a, const int imm8) { return (int)__riscv_vmv_x_s_i32m1_i32(a_s); } -// FORCE_INLINE __m128d _mm_floor_pd (__m128d a) {} +FORCE_INLINE __m128d _mm_floor_pd(__m128d a) { + // FIXME riscv round doesn't work + vfloat64m1_t _a = vreinterpretq_m128d_f64(a); + double arr[2]; + const int len = 2; + __riscv_vse64_v_f64m1(arr, _a, len); + for (int i = 0; i < len; i++) { + arr[i] = floor(arr[i]); + } + return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(arr, len)); +} -// FORCE_INLINE __m128 _mm_floor_ps (__m128 a) {} +FORCE_INLINE __m128 _mm_floor_ps(__m128 a) { + // FIXME riscv round doesn't work + vfloat32m1_t _a = vreinterpretq_m128_f32(a); + float arr[4]; + const int len = 4; + __riscv_vse32_v_f32m1(arr, _a, len); + for (int i = 0; i < len; i++) { + arr[i] = floor(arr[i]); + } + return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(arr, len)); +} -// FORCE_INLINE __m128d _mm_floor_sd (__m128d a, __m128d b) {} +FORCE_INLINE __m128d _mm_floor_sd(__m128d a, __m128d b) { + // FIXME riscv round doesn't work + vfloat64m1_t _a = vreinterpretq_m128d_f64(a); + vfloat64m1_t _b = vreinterpretq_m128d_f64(b); + double arr[2]; + const int len = 2; + __riscv_vse64_v_f64m1(arr, _b, len); + arr[0] = floor(arr[0]); + vfloat64m1_t _arr = __riscv_vle64_v_f64m1(arr, 1); + return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, _arr, 0, 1)); +} -// FORCE_INLINE __m128 _mm_floor_ss (__m128 a, __m128 b) {} +FORCE_INLINE __m128 _mm_floor_ss(__m128 a, __m128 b) { + // FIXME riscv round doesn't work + vfloat32m1_t _a = vreinterpretq_m128_f32(a); + vfloat32m1_t _b = vreinterpretq_m128_f32(b); + float arr[4]; + const int len = 4; + __riscv_vse32_v_f32m1(arr, _b, len); + arr[0] = floor(arr[0]); + vfloat32m1_t _arr = __riscv_vle32_v_f32m1(arr, 1); + return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, _arr, 0, 1)); +} FORCE_INLINE void _mm_free(void *mem_addr) { free(mem_addr); } diff --git a/tests/impl.cpp b/tests/impl.cpp index 2e4f5f2..cae4460 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -9438,53 +9438,53 @@ result_t test_mm_blendv_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { } result_t test_mm_ceil_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const double *_a = (const double *)impl.test_cases_float_pointer1; - // - // double dx = ceil(_a[0]); - // double dy = ceil(_a[1]); - // - // __m128d a = load_m128d(_a); - // __m128d ret = _mm_ceil_pd(a); - // - // return validate_double(ret, dx, dy); - // #else +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + + double dx = ceil(_a[0]); + double dy = ceil(_a[1]); + + __m128d a = load_m128d(_a); + __m128d ret = _mm_ceil_pd(a); + + return validate_double(ret, dx, dy); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_ceil_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const float *_a = impl.test_cases_float_pointer1; - // float dx = ceilf(_a[0]); - // float dy = ceilf(_a[1]); - // float dz = ceilf(_a[2]); - // float dw = ceilf(_a[3]); - // - // __m128 a = _mm_load_ps(_a); - // __m128 c = _mm_ceil_ps(a); - // return validate_float(c, dx, dy, dz, dw); - // #else +#ifdef ENABLE_TEST_ALL + const float *_a = impl.test_cases_float_pointer1; + float dx = ceilf(_a[0]); + float dy = ceilf(_a[1]); + float dz = ceilf(_a[2]); + float dw = ceilf(_a[3]); + + __m128 a = _mm_load_ps(_a); + __m128 c = _mm_ceil_ps(a); + return validate_float(c, dx, dy, dz, dw); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_ceil_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const double *_a = (const double *)impl.test_cases_float_pointer1; - // const double *_b = (const double *)impl.test_cases_float_pointer2; - // - // double dx = ceil(_b[0]); - // double dy = _a[1]; - // - // __m128d a = load_m128d(_a); - // __m128d b = load_m128d(_b); - // __m128d ret = _mm_ceil_sd(a, b); - // - // return validate_double(ret, dx, dy); - // #else +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + const double *_b = (const double *)impl.test_cases_float_pointer2; + + double dx = ceil(_b[0]); + double dy = _a[1]; + + __m128d a = load_m128d(_a); + __m128d b = load_m128d(_b); + __m128d ret = _mm_ceil_sd(a, b); + + return validate_double(ret, dx, dy); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_ceil_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { @@ -9896,70 +9896,70 @@ result_t test_mm_extract_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { } result_t test_mm_floor_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const double *_a = (const double *)impl.test_cases_float_pointer1; - // - // double dx = floor(_a[0]); - // double dy = floor(_a[1]); - // - // __m128d a = load_m128d(_a); - // __m128d ret = _mm_floor_pd(a); - // - // return validate_double(ret, dx, dy); - // #else +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + + double dx = floor(_a[0]); + double dy = floor(_a[1]); + + __m128d a = load_m128d(_a); + __m128d ret = _mm_floor_pd(a); + + return validate_double(ret, dx, dy); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_floor_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const float *_a = impl.test_cases_float_pointer1; - // float dx = floorf(_a[0]); - // float dy = floorf(_a[1]); - // float dz = floorf(_a[2]); - // float dw = floorf(_a[3]); - // - // __m128 a = load_m128(_a); - // __m128 c = _mm_floor_ps(a); - // return validate_float(c, dx, dy, dz, dw); - // #else +#ifdef ENABLE_TEST_ALL + const float *_a = impl.test_cases_float_pointer1; + float dx = floorf(_a[0]); + float dy = floorf(_a[1]); + float dz = floorf(_a[2]); + float dw = floorf(_a[3]); + + __m128 a = load_m128(_a); + __m128 c = _mm_floor_ps(a); + return validate_float(c, dx, dy, dz, dw); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_floor_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const double *_a = (const double *)impl.test_cases_float_pointer1; - // const double *_b = (const double *)impl.test_cases_float_pointer2; - // - // double dx = floor(_b[0]); - // double dy = _a[1]; - // - // __m128d a = load_m128d(_a); - // __m128d b = load_m128d(_b); - // __m128d ret = _mm_floor_sd(a, b); - // - // return validate_double(ret, dx, dy); - // #else +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + const double *_b = (const double *)impl.test_cases_float_pointer2; + + double dx = floor(_b[0]); + double dy = _a[1]; + + __m128d a = load_m128d(_a); + __m128d b = load_m128d(_b); + __m128d ret = _mm_floor_sd(a, b); + + return validate_double(ret, dx, dy); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_floor_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const float *_a = impl.test_cases_float_pointer1; - // const float *_b = impl.test_cases_float_pointer1; - // - // float f0 = floorf(_b[0]); - // - // __m128 a = load_m128(_a); - // __m128 b = load_m128(_b); - // __m128 c = _mm_floor_ss(a, b); - // - // return validate_float(c, f0, _a[1], _a[2], _a[3]); - // #else +#ifdef ENABLE_TEST_ALL + const float *_a = impl.test_cases_float_pointer1; + const float *_b = impl.test_cases_float_pointer1; + + float f0 = floorf(_b[0]); + + __m128 a = load_m128(_a); + __m128 b = load_m128(_b); + __m128 c = _mm_floor_ss(a, b); + + return validate_float(c, f0, _a[1], _a[2], _a[3]); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_insert_epi32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {