Skip to content

arm sve #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
66 changes: 66 additions & 0 deletions primitive_data/extensions/simd/arm/sve.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
---
description: "Definition of the SIMD TargetExtension sve."
vendor: "arm"
extension_name: "sve"
lscpu_flags: ["sve"]
needs_arch_flags: true
arch_flags: {sve: "arch=armv8-a+sve"}
includes: ['<arm_sve.h>']
simdT_name: "sve"
simdT_mask_type: "svbool_t"
simdT_register_type: |-
TSL_DEP_TYPE(
std::is_integral_v< BaseType >,
TSL_DEP_TYPE(
std::is_unsigned_v< BaseType >,
TSL_DEP_TYPE(
(sizeof( BaseType ) == 1),
svuint8_t,
TSL_DEP_TYPE(
(sizeof( BaseType ) == 2),
svuint16_t,
TSL_DEP_TYPE(
(sizeof( BaseType ) == 4),
svuint32_t,
svuint64_t
)
)
),
TSL_DEP_TYPE(
(sizeof( BaseType ) == 1),
svint8_t,
TSL_DEP_TYPE(
(sizeof( BaseType ) == 2),
svint16_t,
TSL_DEP_TYPE(
(sizeof( BaseType ) == 4),
svint32_t,
svint64_t
)
)
)
),
TSL_DEP_TYPE(
(sizeof( BaseType ) == 4),
svfloat32_t,
svfloat64_t
)
)
simdT_integral_mask_type: |-
TSL_DEP_TYPE(
(VectorSizeInBits / (sizeof(BaseType)*8) == 64),
uint64_t,
TSL_DEP_TYPE(
VectorSizeInBits / (sizeof(BaseType)*8) == 32,
uint32_t,
TSL_DEP_TYPE(
VectorSizeInBits / (sizeof(BaseType)*8) == 16,
uint16_t,
uint8_t
)
)
)
intrin_tp: {uint8_t: ["u", 8], uint16_t: ["u", 16], uint32_t: ["u", 32], uint64_t: ["u", 64], int8_t: ["s", 8], int16_t: ["s", 16], int32_t: ["s", 32], int64_t: ["s", 64], float: ["f", 32], double: ["f", 64]}
intrin_tp_full: {uint8_t: "u8", uint16_t: "u16", uint32_t: "u32", uint64_t: "u64", int8_t: "s8", int16_t: "s16", int32_t: "s32", int64_t: "s64", float: "f32", double: "f64"}
simdT_default_size_in_bits: 512 #can be adapted
simdT_register_type_attributes: "__attribute__((arm_sve_vector_bits(VectorSizeInBits)))"
121 changes: 115 additions & 6 deletions primitive_data/primitives/binary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,18 @@ definitions:
lscpu_flags: ['neon']
note: "is it a good idea to support bitmanipulation for floats and doubles?"
implementation: "return vreinterpretq_{{ intrin_tp_full[ctype] }}_u{{ intrin_tp[ctype][1] }}(vandq_u{{ intrin_tp[ctype][1] }}( vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(a),vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(b)));"
#ARM - SVE
- target_extension: "sve"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: [ 'sve' ]
implementation: "return svand_{{ intrin_tp_full[ctype] }}_z(svptrue_b{{intrin_tp[ctype][1]}}(), a, b );"
- target_extension: "sve"
ctype: ["float", "double"]
lscpu_flags: [ 'sve' ]
implementation: |
auto t1 = svreinterpret_u{{intrin_tp[ctype][1]}}_{{intrin_tp_full[ctype]}}(a);
auto t2 = svreinterpret_u{{intrin_tp[ctype][1]}}_{{intrin_tp_full[ctype]}}(b);
return svreinterpret_{{intrin_tp_full[ctype]}}_u{{intrin_tp[ctype][1]}}(svand_u{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), t1, t2));
#SCALAR
- target_extension: "scalar"
ctype: [ "uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
Expand Down Expand Up @@ -229,6 +241,18 @@ definitions:
ctype: ["float", "double"]
lscpu_flags: ['neon']
implementation: "return vreinterpretq_{{ intrin_tp_full[ctype] }}_u{{ intrin_tp[ctype][1] }}(vorq_u{{ intrin_tp[ctype][1] }}( vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(a),vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(b)));"
#ARM - SVE
- target_extension: "sve"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: [ 'sve' ]
implementation: "return svorr_{{ intrin_tp_full[ctype] }}_z(svptrue_b{{intrin_tp[ctype][1]}}(), a, b );"
- target_extension: "sve"
ctype: ["float", "double"]
lscpu_flags: [ 'sve' ]
implementation: |
auto t1 = svreinterpret_u{{intrin_tp[ctype][1]}}_{{intrin_tp_full[ctype]}}(a);
auto t2 = svreinterpret_u{{intrin_tp[ctype][1]}}_{{intrin_tp_full[ctype]}}(b);
return svreinterpret_{{intrin_tp_full[ctype]}}_u{{intrin_tp[ctype][1]}}(svorr_u{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), t1, t2));
#SCALAR
- target_extension: "scalar"
ctype: [ "uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
Expand Down Expand Up @@ -357,6 +381,18 @@ definitions:
veorq_u{{ intrin_tp[ctype][1] }}(
vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(a),
vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(b)));
#ARM - SVE
- target_extension: "sve"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: [ 'sve' ]
implementation: "return sveor_{{ intrin_tp_full[ctype] }}_z(svptrue_b{{intrin_tp[ctype][1]}}(), a, b );"
- target_extension: "sve"
ctype: ["float", "double"]
lscpu_flags: [ 'sve' ]
implementation: |
auto t1 = svreinterpret_u{{intrin_tp[ctype][1]}}_{{intrin_tp_full[ctype]}}(a);
auto t2 = svreinterpret_u{{intrin_tp[ctype][1]}}_{{intrin_tp_full[ctype]}}(b);
return svreinterpret_{{intrin_tp_full[ctype]}}_u{{intrin_tp[ctype][1]}}(sveor_u{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), t1, t2));
#SCALAR
- target_extension: "scalar"
ctype: [ "uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
Expand Down Expand Up @@ -457,6 +493,11 @@ definitions:
vshlq_n_{{ intrin_tp[ctype][1] }}(
vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(a),
vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(b)));
#ARM - SVE
- target_extension: "sve"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: [ 'sve' ]
implementation: "return svlsl_n_{{ intrin_tp_full[ctype] }}_z(svptrue_b{{intrin_tp[ctype][1]}}(), data, shift);"
#SCALAR
- target_extension: "scalar"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
Expand Down Expand Up @@ -550,6 +591,11 @@ definitions:
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: [ 'neon' ]
implementation: "return vshlq_{{ intrin_tp_full[ctype] }}(data, shift);"
#ARM - SVE
- target_extension: "sve"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: [ 'sve' ]
implementation: "return svlsl_{{intrin_tp_full[ctype]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), data, shift);"
#SCALAR
- target_extension: "scalar"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
Expand Down Expand Up @@ -745,6 +791,20 @@ definitions:
# vshrq_n_{{ intrin_tp[ctype][1] }}(
# vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(a),
# vreinterpretq_u{{ intrin_tp[ctype][1] }}_{{ intrin_tp_full[ctype] }}(b)));
#ARM - SVE
- target_extension: "sve"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: ["sve"]
implementation: |
if constexpr ((std::is_signed_v<typename Vec::base_type>)) {
if constexpr (PreserveSign){
return svasr_n_s{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), data, shift);
}else{
return svlsr_n_u{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), svreinterpret_u{{intrin_tp[ctype][1]}}_s{{intrin_tp[ctype][1]}}(data), shift);
}
}else {
return svlsr_n_u{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), data, shift);
}
#SCALAR
- target_extension: "scalar"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
Expand Down Expand Up @@ -985,6 +1045,20 @@ definitions:
tmp[i] >>= shift;
}
return vld1q_{{ intrin_tp_full[ctype] }}(tmp.data());
#ARM - SVE
- target_extension: "sve"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: ["sve"]
implementation: |
if constexpr ((std::is_signed_v<typename Vec::base_type>)) {
if constexpr (PreserveSign){
return svasr_s{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), data, shift);
}else{
return svlsr_u{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), svreinterpret_u{{intrin_tp[ctype][1]}}_s{{intrin_tp[ctype][1]}}(data), shift);
}
}else {
return svlsr_u{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), data, shift);
}
#SCALAR
- target_extension: "scalar"
ctype: ["int8_t", "int16_t", "int32_t", "int64_t", "uint8_t", "uint16_t", "uint32_t", "uint64_t"]
Expand Down Expand Up @@ -1179,6 +1253,11 @@ definitions:
lscpu_flags: ["neon"]
implementation: |
return details::clz<typename Vec::base_type, typename Vec::offset_base_type>(data);
- target_extension: ["sve"]
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: ["sve"]
implementation: |
return details::clz<typename Vec::base_type, typename Vec::offset_base_type>(data);
- target_extension: ["scalar"]
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: []
Expand Down Expand Up @@ -1643,6 +1722,11 @@ definitions:
lscpu_flags: [ ]
implementation: "return vec;"
#ARM - NEON
#ARM - SVE
- target_extension: "sve"
ctype: ["uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
lscpu_flags: [ 'sve' ]
implementation: "return svorv_{{ intrin_tp_full[ctype] }}(svptrue_b{{intrin_tp[ctype][1]}}(), vec);"
#INTEL - FPGA
- target_extension: ["oneAPIfpga", "oneAPIfpgaRTL"]
ctype: ["uint8_t", "int8_t", "uint16_t", "int16_t", "uint32_t", "int32_t", "float", "uint64_t", "int64_t", "double"]
Expand Down Expand Up @@ -1695,8 +1779,12 @@ testing:
allOk &= (u.i_val == v.i_val);
}
return allOk;
}
else{
}else if constexpr(std::is_same_v<typename Vec::target_extension, tsl::sve>){
storeu<Vec>(reference_result_ptr, set1<Vec>(~0));
storeu<Vec>(test_result_ptr, inv<Vec>(vec));
test_helper.synchronize();
return test_helper.validate();
}else{
storeu<Vec>(reference_result_ptr, ~vec);
storeu<Vec>(test_result_ptr, inv<Vec>(vec));
test_helper.synchronize();
Expand Down Expand Up @@ -1733,8 +1821,12 @@ testing:
allOk &= (u.i_val == v.i_val);
}
return allOk;
}
else{
}else if constexpr(std::is_same_v<typename Vec::target_extension, tsl::sve>){
storeu<Vec>(reference_result_ptr, set1<Vec>(0));
storeu<Vec>(test_result_ptr, inv<Vec>(vec));
test_helper.synchronize();
return test_helper.validate();
}else{
storeu<Vec>(reference_result_ptr, ~vec);
storeu<Vec>(test_result_ptr, inv<Vec>(vec));
test_helper.synchronize();
Expand Down Expand Up @@ -1772,8 +1864,16 @@ testing:
u.val = data[i];
allOk &= (~u.i_val == test_i_val);
}
}
else{
}else if constexpr(std::is_same_v<typename Vec::target_extension, tsl::sve>){
T data[Vec::vector_element_count()];
for(int j = i; j < i + Vec::vector_element_count(); j++){
data[j-i] = ~test_data_ptr[j];
}
storeu<Vec>(reference_result_ptr, loadu<Vec>(data));
storeu<Vec>(test_result_ptr, inv<Vec>(vec));
test_helper.synchronize();
allOk &= test_helper.validate();
}else{
storeu<Vec>(reference_result_ptr, ~vec);
storeu<Vec>(test_result_ptr, inv<Vec>(vec));
test_helper.synchronize();
Expand Down Expand Up @@ -1824,6 +1924,15 @@ definitions:
__m128i all_ones = _mm_set1_epi32(-1);
__m128i as_int = _mm_cast{{intrin_tp_full[ctype]}}_si128(vec);
return _mm_castsi128_{{intrin_tp_full[ctype]}}(_mm_xor_si128(as_int, all_ones));
#ARM - SVE
- target_extension: "sve"
ctype: ["uint16_t", "int16_t", "uint32_t", "int32_t", "uint64_t", "int64_t", "uint8_t", "int8_t"]
lscpu_flags: [ 'sve' ]
implementation: return svnot_{{intrin_tp_full[ctype]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), vec);
- target_extension: "sve"
ctype: ["float", "double"]
lscpu_flags: [ 'sve' ]
implementation: return svreinterpret_{{intrin_tp_full[ctype]}}_u{{intrin_tp[ctype][1]}}(svnot_u{{intrin_tp[ctype][1]}}_z(svptrue_b{{intrin_tp[ctype][1]}}(), svreinterpret_u{{intrin_tp[ctype][1]}}_{{intrin_tp_full[ctype]}}(vec)));
#SCALAR
- target_extension: "scalar"
ctype: [ "uint8_t", "uint16_t", "uint32_t", "uint64_t", "int8_t", "int16_t", "int32_t", "int64_t"]
Expand Down
Loading