From 5ca462fc6e898825828a9c03c732f5af7b1eb45e Mon Sep 17 00:00:00 2001 From: Yang Hau Date: Fri, 12 Jan 2024 23:56:47 +0800 Subject: [PATCH] feat: Add _mm_sqrt* --- sse2rvv.h | 26 ++++++++++-- tests/impl.cpp | 107 ++++++++++++++++++++++++------------------------- 2 files changed, 75 insertions(+), 58 deletions(-) diff --git a/sse2rvv.h b/sse2rvv.h index d51c0d1..4728f0b 100644 --- a/sse2rvv.h +++ b/sse2rvv.h @@ -2412,13 +2412,31 @@ FORCE_INLINE __m128i _mm_slli_epi64(__m128i a, int imm8) { // FORCE_INLINE __m128i _mm_slli_si128 (__m128i a, int imm8) {} -// FORCE_INLINE __m128d _mm_sqrt_pd (__m128d a) {} +FORCE_INLINE __m128d _mm_sqrt_pd(__m128d a) { + vfloat64m1_t _a = vreinterpretq_m128d_f64(a); + return vreinterpretq_f64_m128d( + __riscv_vfrec7_v_f64m1(__riscv_vfrsqrt7_v_f64m1(_a, 2), 2)); +} -// FORCE_INLINE __m128 _mm_sqrt_ps (__m128 a) {} +FORCE_INLINE __m128 _mm_sqrt_ps(__m128 a) { + vfloat32m1_t _a = vreinterpretq_m128_f32(a); + return vreinterpretq_f32_m128( + __riscv_vfrec7_v_f32m1(__riscv_vfrsqrt7_v_f32m1(_a, 4), 4)); +} -// FORCE_INLINE __m128d _mm_sqrt_sd (__m128d a, __m128d b) {} +FORCE_INLINE __m128d _mm_sqrt_sd(__m128d a, __m128d b) { + vfloat64m1_t _a = vreinterpretq_m128d_f64(a); + vfloat64m1_t _b = vreinterpretq_m128d_f64(b); + vfloat64m1_t b_rnd = + __riscv_vfrec7_v_f64m1(__riscv_vfrsqrt7_v_f64m1(_b, 2), 2); + return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, b_rnd, 0, 1)); +} -// FORCE_INLINE __m128 _mm_sqrt_ss (__m128 a) {} +FORCE_INLINE __m128 _mm_sqrt_ss(__m128 a) { + vfloat32m1_t _a = vreinterpretq_m128_f32(a); + vfloat32m1_t rnd = __riscv_vfrec7_v_f32m1(__riscv_vfrsqrt7_v_f32m1(_a, 4), 4); + return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, rnd, 0, 1)); +} // FORCE_INLINE __m128i _mm_sra_epi16 (__m128i a, __m128i count) {} diff --git a/tests/impl.cpp b/tests/impl.cpp index 29f7c38..06ab316 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -3215,39 +3215,39 @@ result_t test_mm_shuffle_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { } result_t test_mm_sqrt_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const float *_a = (const float *)impl.test_cases_float_pointer1; - // - // float f0 = sqrt(_a[0]); - // float f1 = sqrt(_a[1]); - // float f2 = sqrt(_a[2]); - // float f3 = sqrt(_a[3]); - // - // __m128 a = load_m128(_a); - // __m128 c = _mm_sqrt_ps(a); - // - // return validate_float_error(c, f0, f1, f2, f3, 0.000001f); - // #else +#ifdef ENABLE_TEST_ALL + const float *_a = (const float *)impl.test_cases_float_pointer1; + + float f0 = sqrt(_a[0]); + float f1 = sqrt(_a[1]); + float f2 = sqrt(_a[2]); + float f3 = sqrt(_a[3]); + + __m128 a = load_m128(_a); + __m128 c = _mm_sqrt_ps(a); + + return validate_float_error(c, f0, f1, f2, f3, 0.1f); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_sqrt_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const float *_a = (const float *)impl.test_cases_float_pointer1; - // - // float f0 = sqrt(_a[0]); - // float f1 = _a[1]; - // float f2 = _a[2]; - // float f3 = _a[3]; - // - // __m128 a = load_m128(_a); - // __m128 c = _mm_sqrt_ss(a); - // - // return validate_float_error(c, f0, f1, f2, f3, 0.000001f); - // #else +#ifdef ENABLE_TEST_ALL + const float *_a = (const float *)impl.test_cases_float_pointer1; + + float f0 = sqrt(_a[0]); + float f1 = _a[1]; + float f2 = _a[2]; + float f3 = _a[3]; + + __m128 a = load_m128(_a); + __m128 c = _mm_sqrt_ss(a); + + return validate_float_error(c, f0, f1, f2, f3, 0.1f); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_store_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { @@ -6989,37 +6989,36 @@ result_t test_mm_slli_si128(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { } result_t test_mm_sqrt_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const double *_a = (const double *)impl.test_cases_float_pointer1; - // - // double f0 = sqrt(_a[0]); - // double f1 = sqrt(_a[1]); - // - // __m128d a = load_m128d(_a); - // __m128d c = _mm_sqrt_pd(a); - // - // return validate_double_error(c, f0, f1, 1.0e-15); - // #else +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + double _c[2]; + _c[0] = sqrt(_a[0]); + _c[1] = sqrt(_a[1]); + + __m128d a = load_m128d(_a); + __m128d c = _mm_sqrt_pd(a); + return validate_double_error(c, _c[0], _c[1], 0.1f); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_sqrt_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const double *_a = (const double *)impl.test_cases_float_pointer1; - // const double *_b = (const double *)impl.test_cases_float_pointer2; - // - // double f0 = sqrt(_b[0]); - // double f1 = _a[1]; - // - // __m128d a = load_m128d(_a); - // __m128d b = load_m128d(_b); - // __m128d c = _mm_sqrt_sd(a, b); - // - // return validate_double_error(c, f0, f1, 1.0e-15); - // #else +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + const double *_b = (const double *)impl.test_cases_float_pointer2; + + double f0 = sqrt(_b[0]); + double f1 = _a[1]; + + __m128d a = load_m128d(_a); + __m128d b = load_m128d(_b); + __m128d c = _mm_sqrt_sd(a, b); + + return validate_double_error(c, f0, f1, 0.1f); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_sra_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {