Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add _mm_sqrt* #43

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions sse2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -2412,13 +2412,31 @@ FORCE_INLINE __m128i _mm_slli_epi64(__m128i a, int imm8) {

// FORCE_INLINE __m128i _mm_slli_si128 (__m128i a, int imm8) {}

// FORCE_INLINE __m128d _mm_sqrt_pd (__m128d a) {}
FORCE_INLINE __m128d _mm_sqrt_pd(__m128d a) {
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
return vreinterpretq_f64_m128d(
__riscv_vfrec7_v_f64m1(__riscv_vfrsqrt7_v_f64m1(_a, 2), 2));
}

// FORCE_INLINE __m128 _mm_sqrt_ps (__m128 a) {}
FORCE_INLINE __m128 _mm_sqrt_ps(__m128 a) {
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
return vreinterpretq_f32_m128(
__riscv_vfrec7_v_f32m1(__riscv_vfrsqrt7_v_f32m1(_a, 4), 4));
}

// FORCE_INLINE __m128d _mm_sqrt_sd (__m128d a, __m128d b) {}
FORCE_INLINE __m128d _mm_sqrt_sd(__m128d a, __m128d b) {
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
vfloat64m1_t _b = vreinterpretq_m128d_f64(b);
vfloat64m1_t b_rnd =
__riscv_vfrec7_v_f64m1(__riscv_vfrsqrt7_v_f64m1(_b, 2), 2);
return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, b_rnd, 0, 1));
}

// FORCE_INLINE __m128 _mm_sqrt_ss (__m128 a) {}
FORCE_INLINE __m128 _mm_sqrt_ss(__m128 a) {
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
vfloat32m1_t rnd = __riscv_vfrec7_v_f32m1(__riscv_vfrsqrt7_v_f32m1(_a, 4), 4);
return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, rnd, 0, 1));
}

// FORCE_INLINE __m128i _mm_sra_epi16 (__m128i a, __m128i count) {}

Expand Down
107 changes: 53 additions & 54 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3215,39 +3215,39 @@ result_t test_mm_shuffle_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_sqrt_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = (const float *)impl.test_cases_float_pointer1;
//
// float f0 = sqrt(_a[0]);
// float f1 = sqrt(_a[1]);
// float f2 = sqrt(_a[2]);
// float f3 = sqrt(_a[3]);
//
// __m128 a = load_m128(_a);
// __m128 c = _mm_sqrt_ps(a);
//
// return validate_float_error(c, f0, f1, f2, f3, 0.000001f);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = (const float *)impl.test_cases_float_pointer1;

float f0 = sqrt(_a[0]);
float f1 = sqrt(_a[1]);
float f2 = sqrt(_a[2]);
float f3 = sqrt(_a[3]);

__m128 a = load_m128(_a);
__m128 c = _mm_sqrt_ps(a);

return validate_float_error(c, f0, f1, f2, f3, 0.1f);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_sqrt_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = (const float *)impl.test_cases_float_pointer1;
//
// float f0 = sqrt(_a[0]);
// float f1 = _a[1];
// float f2 = _a[2];
// float f3 = _a[3];
//
// __m128 a = load_m128(_a);
// __m128 c = _mm_sqrt_ss(a);
//
// return validate_float_error(c, f0, f1, f2, f3, 0.000001f);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = (const float *)impl.test_cases_float_pointer1;

float f0 = sqrt(_a[0]);
float f1 = _a[1];
float f2 = _a[2];
float f3 = _a[3];

__m128 a = load_m128(_a);
__m128 c = _mm_sqrt_ss(a);

return validate_float_error(c, f0, f1, f2, f3, 0.1f);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_store_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down Expand Up @@ -6989,37 +6989,36 @@ result_t test_mm_slli_si128(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_sqrt_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
//
// double f0 = sqrt(_a[0]);
// double f1 = sqrt(_a[1]);
//
// __m128d a = load_m128d(_a);
// __m128d c = _mm_sqrt_pd(a);
//
// return validate_double_error(c, f0, f1, 1.0e-15);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;
double _c[2];
_c[0] = sqrt(_a[0]);
_c[1] = sqrt(_a[1]);

__m128d a = load_m128d(_a);
__m128d c = _mm_sqrt_pd(a);
return validate_double_error(c, _c[0], _c[1], 0.1f);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_sqrt_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
// const double *_b = (const double *)impl.test_cases_float_pointer2;
//
// double f0 = sqrt(_b[0]);
// double f1 = _a[1];
//
// __m128d a = load_m128d(_a);
// __m128d b = load_m128d(_b);
// __m128d c = _mm_sqrt_sd(a, b);
//
// return validate_double_error(c, f0, f1, 1.0e-15);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;
const double *_b = (const double *)impl.test_cases_float_pointer2;

double f0 = sqrt(_b[0]);
double f1 = _a[1];

__m128d a = load_m128d(_a);
__m128d b = load_m128d(_b);
__m128d c = _mm_sqrt_sd(a, b);

return validate_double_error(c, f0, f1, 0.1f);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_sra_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down