Skip to content

Commit

Permalink
feat: Add _mm_load[h|l|r]*
Browse files Browse the repository at this point in the history
  • Loading branch information
howjmay committed Jan 20, 2024
1 parent d9f15b6 commit 57bfa56
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 80 deletions.
62 changes: 46 additions & 16 deletions sse2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ typedef union ALIGN_STRUCT(16) SIMDVec {
__riscv_vreinterpret_v_i32m1_f32m1(__riscv_vreinterpret_v_i16m1_i32m1(x))
#define vreinterpretq_i32_m128(x) __riscv_vreinterpret_v_i32m1_f32m1(x)
#define vreinterpretq_i64_m128(x) \
__riscv_vreinterpret_v_f64m1_f32m1(__riscv_vreinterpret_v_i64m1_f64m1(x))
__riscv_vreinterpret_v_i32m1_f32m1(__riscv_vreinterpret_v_i64m1_i32m1(x))
#define vreinterpretq_f32_m128(x) (x)
#define vreinterpretq_f64_m128(x) \
__riscv_vreinterpret_v_u32m1_f32m1(__riscv_vreinterpret_v_u64m1_u32m1( \
Expand Down Expand Up @@ -1685,22 +1685,21 @@ FORCE_INLINE __m128d _mm_load_pd(double const *mem_addr) {
}

FORCE_INLINE __m128d _mm_load_pd1(double const *mem_addr) {
double p[2] = {mem_addr[0], mem_addr[0]};
return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(p, 2));
return vreinterpretq_f64_m128d(__riscv_vfmv_v_f_f64m1(mem_addr[0], 2));
}

FORCE_INLINE __m128 _mm_load_ps(float const *mem_addr) {
return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(mem_addr, 4));
}

FORCE_INLINE __m128 _mm_load_ps1(float const *mem_addr) {
float p[4] = {mem_addr[0], mem_addr[0], mem_addr[0], mem_addr[0]};
return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(p, 4));
return vreinterpretq_f32_m128(__riscv_vfmv_v_f_f32m1(mem_addr[0], 4));
}

FORCE_INLINE __m128d _mm_load_sd(double const *mem_addr) {
double p[2] = {mem_addr[0], 0};
return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(p, 2));
vfloat64m1_t addr = __riscv_vle64_v_f64m1(mem_addr, 1);
vfloat64m1_t zeros = __riscv_vfmv_v_f_f64m1(0, 2);
return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(zeros, addr, 0, 1));
}

FORCE_INLINE __m128i _mm_load_si128(__m128i const *mem_addr) {
Expand All @@ -1709,8 +1708,9 @@ FORCE_INLINE __m128i _mm_load_si128(__m128i const *mem_addr) {
}

FORCE_INLINE __m128 _mm_load_ss(float const *mem_addr) {
float p[4] = {mem_addr[0], 0, 0, 0};
return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(p, 4));
vfloat32m1_t addr = __riscv_vle32_v_f32m1(mem_addr, 1);
vfloat32m1_t zeros = __riscv_vfmv_v_f_f32m1(0, 4);
return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(zeros, addr, 0, 1));
}

FORCE_INLINE __m128d _mm_load1_pd(double const *mem_addr) {
Expand All @@ -1723,19 +1723,49 @@ FORCE_INLINE __m128 _mm_load1_ps(float const *mem_addr) {

// FORCE_INLINE __m128d _mm_loaddup_pd (double const* mem_addr) {}

// FORCE_INLINE __m128d _mm_loadh_pd (__m128d a, double const* mem_addr) {}
FORCE_INLINE __m128d _mm_loadh_pd(__m128d a, double const *mem_addr) {
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
vfloat64m1_t addr = __riscv_vle64_v_f64m1(mem_addr, 1);
return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, addr, 1, 2));
}

// FORCE_INLINE __m128 _mm_loadh_pi (__m128 a, __m64 const* mem_addr) {}
FORCE_INLINE __m128 _mm_loadh_pi(__m128 a, __m64 const *mem_addr) {
vint64m1_t _a = vreinterpretq_m128_i64(a);
vint64m1_t addr = vreinterpretq_m64_i64(*mem_addr);
return vreinterpretq_i64_m128(__riscv_vslideup_vx_i64m1(_a, addr, 1, 2));
}

// FORCE_INLINE __m128i _mm_loadl_epi64 (__m128i const* mem_addr) {}
FORCE_INLINE __m128i _mm_loadl_epi64(__m128i const *mem_addr) {
vint64m1_t addr = vreinterpretq_m128i_i64(*mem_addr);
vint64m1_t zeros = __riscv_vmv_v_x_i64m1(0, 2);
return vreinterpretq_i64_m128i(__riscv_vslideup_vx_i64m1(addr, zeros, 1, 2));
}

// FORCE_INLINE __m128d _mm_loadl_pd (__m128d a, double const* mem_addr) {}
FORCE_INLINE __m128d _mm_loadl_pd(__m128d a, double const *mem_addr) {
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
vfloat64m1_t addr = __riscv_vle64_v_f64m1(mem_addr, 1);
return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, addr, 0, 1));
}

// FORCE_INLINE __m128 _mm_loadl_pi (__m128 a, __m64 const* mem_addr) {}
FORCE_INLINE __m128 _mm_loadl_pi(__m128 a, __m64 const *mem_addr) {
vint64m1_t _a = vreinterpretq_m128_i64(a);
vint64m1_t addr = vreinterpretq_m64_i64(*mem_addr);
return vreinterpretq_i64_m128(__riscv_vslideup_vx_i64m1(_a, addr, 0, 1));
}

// FORCE_INLINE __m128d _mm_loadr_pd (double const* mem_addr) {}
FORCE_INLINE __m128d _mm_loadr_pd(double const *mem_addr) {
vfloat64m1_t addr = __riscv_vle64_v_f64m1(mem_addr, 2);
vfloat64m1_t addr_high = __riscv_vslidedown_vx_f64m1(addr, 1, 2);
return vreinterpretq_f64_m128d(
__riscv_vslideup_vx_f64m1(addr_high, addr, 1, 2));
}

// FORCE_INLINE __m128 _mm_loadr_ps (float const* mem_addr) {}
FORCE_INLINE __m128 _mm_loadr_ps(float const *mem_addr) {
vuint32m1_t addr = __riscv_vle32_v_u32m1((uint32_t const *)mem_addr, 4);
vuint32m1_t vid = __riscv_vid_v_u32m1(4);
vuint32m1_t vid_rev = __riscv_vrsub_vx_u32m1(vid, 3, 4);
return vreinterpretq_u32_m128(__riscv_vrgather_vv_u32m1(addr, vid_rev, 4));
}

FORCE_INLINE __m128d _mm_loadu_pd(double const *mem_addr) {
return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(mem_addr, 2));
Expand Down
128 changes: 64 additions & 64 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2194,43 +2194,43 @@ result_t test_mm_load1_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_loadh_pi(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *p1 = impl.test_cases_float_pointer1;
// const float *p2 = impl.test_cases_float_pointer2;
// const __m64 *b = (const __m64 *)p2;
// __m128 a = _mm_load_ps(p1);
// __m128 c = _mm_loadh_pi(a, b);
//
// return validate_float(c, p1[0], p1[1], p2[0], p2[1]);
// #else
#ifdef ENABLE_TEST_ALL
const float *p1 = impl.test_cases_float_pointer1;
const float *p2 = impl.test_cases_float_pointer2;
const __m64 *b = (const __m64 *)p2;
__m128 a = _mm_load_ps(p1);
__m128 c = _mm_loadh_pi(a, b);

return validate_float(c, p1[0], p1[1], p2[0], p2[1]);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_loadl_pi(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *p1 = impl.test_cases_float_pointer1;
// const float *p2 = impl.test_cases_float_pointer2;
// __m128 a = _mm_load_ps(p1);
// const __m64 *b = (const __m64 *)p2;
// __m128 c = _mm_loadl_pi(a, b);
//
// return validate_float(c, p2[0], p2[1], p1[2], p1[3]);
// #else
#ifdef ENABLE_TEST_ALL
const float *p1 = impl.test_cases_float_pointer1;
const float *p2 = impl.test_cases_float_pointer2;
__m128 a = _mm_load_ps(p1);
const __m64 *b = (const __m64 *)p2;
__m128 c = _mm_loadl_pi(a, b);

return validate_float(c, p2[0], p2[1], p1[2], p1[3]);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_loadr_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *addr = impl.test_cases_float_pointer1;
//
// __m128 ret = _mm_loadr_ps(addr);
//
// return validate_float(ret, addr[3], addr[2], addr[1], addr[0]);
// #else
#ifdef ENABLE_TEST_ALL
const float *addr = impl.test_cases_float_pointer1;

__m128 ret = _mm_loadr_ps(addr);

return validate_float(ret, addr[3], addr[2], addr[1], addr[0]);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_loadu_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down Expand Up @@ -5664,55 +5664,55 @@ result_t test_mm_load1_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_loadh_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
// const double *addr = (const double *)impl.test_cases_float_pointer2;
//
// __m128d a = load_m128d(_a);
// __m128d ret = _mm_loadh_pd(a, addr);
//
// return validate_double(ret, _a[0], addr[0]);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;
const double *addr = (const double *)impl.test_cases_float_pointer2;

__m128d a = load_m128d(_a);
__m128d ret = _mm_loadh_pd(a, addr);

return validate_double(ret, _a[0], addr[0]);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_loadl_epi64(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const int64_t *addr = (const int64_t *)impl.test_cases_int_pointer1;
//
// __m128i ret = _mm_loadl_epi64((const __m128i *)addr);
//
// return validate_int64(ret, addr[0], 0);
// #else
#ifdef ENABLE_TEST_ALL
const int64_t *addr = (const int64_t *)impl.test_cases_int_pointer1;

__m128i ret = _mm_loadl_epi64((const __m128i *)addr);

return validate_int64(ret, addr[0], 0);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_loadl_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
// const double *addr = (const double *)impl.test_cases_float_pointer2;
//
// __m128d a = load_m128d(_a);
// __m128d ret = _mm_loadl_pd(a, addr);
//
// return validate_double(ret, addr[0], _a[1]);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;
const double *addr = (const double *)impl.test_cases_float_pointer2;

__m128d a = load_m128d(_a);
__m128d ret = _mm_loadl_pd(a, addr);

return validate_double(ret, addr[0], _a[1]);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_loadr_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *addr = (const double *)impl.test_cases_float_pointer1;
//
// __m128d ret = _mm_loadr_pd(addr);
//
// return validate_double(ret, addr[1], addr[0]);
// #else
#ifdef ENABLE_TEST_ALL
const double *addr = (const double *)impl.test_cases_float_pointer1;

__m128d ret = _mm_loadr_pd(addr);

return validate_double(ret, addr[1], addr[0]);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_loadu_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down

0 comments on commit 57bfa56

Please sign in to comment.