Skip to content

Commit

Permalink
Merge pull request #458 from howjmay/vqdmlal_high_lane
Browse files Browse the repository at this point in the history
feat: Add vqdmlal_high_lane_[s16|s32]
  • Loading branch information
howjmay authored Jul 25, 2024
2 parents b5f1a82 + 7de1c40 commit 0cfa50b
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 10 deletions.
24 changes: 18 additions & 6 deletions neon2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -8993,15 +8993,15 @@ FORCE_INLINE uint64x2_t vmlal_high_laneq_u32(uint64x2_t a, uint32x4_t b, uint32x
__riscv_vwmaccu_vv_u64m2(__riscv_vlmul_ext_v_u64m1_u64m2(a), b_high, c_dup, 2));
}

FORCE_INLINE int32x4_t vqdmlal_lane_s16(int32x4_t a, int16x4_t b, int16x4_t c, const int __d) {
vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, __d, 4);
FORCE_INLINE int32x4_t vqdmlal_lane_s16(int32x4_t a, int16x4_t b, int16x4_t c, const int lane) {
vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, lane, 4);
vint32m1_t bc_mul = __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmul_vv_i32m2(b, c_dup, 4));
vint32m1_t bc_mulx2 = __riscv_vmul_vx_i32m1(bc_mul, 2, 4);
return __riscv_vadd_vv_i32m1(a, bc_mulx2, 4);
}

FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, const int __d) {
vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, __d, 2);
FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, const int lane) {
vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, lane, 2);
vint64m1_t bc_mul = __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmul_vv_i64m2(b, c_dup, 2));
vint64m1_t bc_mulx2 = __riscv_vmul_vx_i64m1(bc_mul, 2, 2);
return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2);
Expand All @@ -9011,9 +9011,21 @@ FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, c

// FORCE_INLINE int64_t vqdmlals_lane_s32(int64_t a, int32_t b, int32x2_t v, const int lane);

// FORCE_INLINE int32x4_t vqdmlal_high_lane_s16(int32x4_t a, int16x8_t b, int16x4_t v, const int lane);
FORCE_INLINE int32x4_t vqdmlal_high_lane_s16(int32x4_t a, int16x8_t b, int16x4_t c, const int lane) {
vint16m1_t b_high = __riscv_vslidedown_vx_i16m1(b, 4, 8);
vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, lane, 4);
vint32m1_t bc_mul = __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmul_vv_i32m2(b_high, c_dup, 4));
vint32m1_t bc_mulx2 = __riscv_vmul_vx_i32m1(bc_mul, 2, 4);
return __riscv_vadd_vv_i32m1(a, bc_mulx2, 4);
}

// FORCE_INLINE int64x2_t vqdmlal_high_lane_s32(int64x2_t a, int32x4_t b, int32x2_t v, const int lane);
FORCE_INLINE int64x2_t vqdmlal_high_lane_s32(int64x2_t a, int32x4_t b, int32x2_t c, const int lane) {
vint32m1_t b_high = __riscv_vslidedown_vx_i32m1(b, 2, 4);
vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, lane, 2);
vint64m1_t bc_mul = __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmul_vv_i64m2(b_high, c_dup, 2));
vint64m1_t bc_mulx2 = __riscv_vmul_vx_i64m1(bc_mul, 2, 2);
return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2);
}

// FORCE_INLINE int32x4_t vqdmlal_laneq_s16(int32x4_t a, int16x4_t b, int16x8_t v, const int lane);

Expand Down
54 changes: 52 additions & 2 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31892,9 +31892,59 @@ result_t test_vqdmlalh_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {

result_t test_vqdmlals_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }

result_t test_vqdmlal_high_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
result_t test_vqdmlal_high_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1;
const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2;
const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3;
int32x4_t a = vld1q_s32(_a);
int16x8_t b = vld1q_s16(_b);
int16x4_t c = vld1_s16(_c);
;
int32x4_t d;
int32_t _d[4];
#define TEST_IMPL(IDX) \
for (int i = 0; i < 4; i++) { \
_d[i] = sat_add(_a[i], sat_dmull(_b[i + 4], _c[IDX])); \
} \
d = vqdmlal_high_lane_s16(a, b, c, IDX); \
CHECK_RESULT(validate_int32(d, _d[0], _d[1], _d[2], _d[3]))

IMM_4_ITER
#undef TEST_IMPL

result_t test_vqdmlal_high_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
return TEST_SUCCESS;
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vqdmlal_high_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1;
const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2;
const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3;
int64x2_t a = vld1q_s64(_a);
int32x4_t b = vld1q_s32(_b);
int32x2_t c = vld1_s32(_c);
int64x2_t d;
int64_t _d[2];

#define TEST_IMPL(IDX) \
for (int i = 0; i < 2; i++) { \
_d[i] = sat_add(_a[i], sat_dmull(_b[i + 2], _c[IDX])); \
} \
d = vqdmlal_high_lane_s32(a, b, c, IDX); \
CHECK_RESULT(validate_int64(d, _d[0], _d[1]))

IMM_2_ITER
#undef TEST_IMPL

return TEST_SUCCESS;
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vqdmlal_laneq_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }

Expand Down
4 changes: 2 additions & 2 deletions tests/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1950,8 +1950,8 @@
_(vqdmlal_lane_s32) \
/*_(vqdmlalh_lane_s16) */ \
/*_(vqdmlals_lane_s32) */ \
/*_(vqdmlal_high_lane_s16) */ \
/*_(vqdmlal_high_lane_s32) */ \
_(vqdmlal_high_lane_s16) \
_(vqdmlal_high_lane_s32) \
/*_(vqdmlal_laneq_s16) */ \
/*_(vqdmlal_laneq_s32) */ \
/*_(vqdmlalh_laneq_s16) */ \
Expand Down

0 comments on commit 0cfa50b

Please sign in to comment.