Skip to content

Commit 3898481

Browse files
committed
Add: sz_look_up_transform_neon
1 parent be6c93b commit 3898481

File tree

3 files changed

+85
-19
lines changed

3 files changed

+85
-19
lines changed

c/lib.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ static void sz_dispatch_table_init(void) {
219219
impl->copy = sz_copy_neon;
220220
impl->move = sz_move_neon;
221221
impl->fill = sz_fill_neon;
222+
impl->look_up_transform = sz_look_up_transform_neon;
222223

223224
impl->find = sz_find_neon;
224225
impl->rfind = sz_rfind_neon;

include/stringzilla/stringzilla.h

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,8 @@ SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length)
12511251
SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
12521252
/** @copydoc sz_fill */
12531253
SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value);
1254+
/** @copydoc sz_look_up_transform */
1255+
SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target);
12541256
/** @copydoc sz_find_byte */
12551257
SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
12561258
/** @copydoc sz_rfind_byte */
@@ -5780,13 +5782,11 @@ SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t
57805782

57815783
SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
57825784
sz_u128_vec_t a_vec, b_vec;
5783-
5784-
while (length >= 16) {
5785+
for (; length >= 16; a += 16, b += 16, length -= 16) {
57855786
a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a);
57865787
b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b);
57875788
uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16);
57885789
if (vmaxvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match
5789-
a += 16, b += 16, length -= 16;
57905790
}
57915791

57925792
// Handle remaining bytes
@@ -5795,19 +5795,27 @@ SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
57955795
}
57965796

57975797
SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
5798-
sz_u128_vec_t src_vec;
5799-
5800-
while (length >= 16) {
5801-
src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source);
5802-
vst1q_u8((sz_u8_t *)target, src_vec.u8x16);
5803-
target += 16, source += 16, length -= 16;
5804-
}
5805-
5806-
// Handle remaining bytes
5798+
// In most cases the `source` and the `target` are not aligned, but we should
5799+
// at least make sure that writes don't touch many cache lines.
5800+
// NEON has an instruction to load and write 64 bytes at once.
5801+
//
5802+
// sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less.
5803+
// sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less.
5804+
// for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source;
5805+
// length -= head_length;
5806+
// for (; length >= 64; target += 64, source += 64, length -= 64)
5807+
// vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source));
5808+
// for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source;
5809+
//
5810+
// Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time:
5811+
for (; length >= 16; target += 16, source += 16, length -= 16)
5812+
vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source));
58075813
if (length) sz_copy_serial(target, source, length);
58085814
}
58095815

58105816
SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
5817+
// When moving small buffers, using a small buffer on stack as a temporary storage is faster.
5818+
58115819
if (target < source || target >= source + length) {
58125820
// Non-overlapping, proceed forward
58135821
sz_copy_neon(target, source, length);
@@ -5843,6 +5851,56 @@ SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) {
58435851
if (length) sz_fill_serial(target, length, value);
58445852
}
58455853

5854+
SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) {
5855+
5856+
// If the input is tiny (especially smaller than the look-up table itself), we may end up paying
5857+
// more for organizing the SIMD registers and changing the CPU state, than for the actual computation.
5858+
if (length <= 128) {
5859+
sz_look_up_transform_serial(source, length, lut, target);
5860+
return;
5861+
}
5862+
5863+
sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less.
5864+
sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less.
5865+
5866+
// We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers.
5867+
// According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput.
5868+
uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec;
5869+
lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0));
5870+
lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64));
5871+
lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128));
5872+
lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192));
5873+
5874+
sz_u128_vec_t source_vec;
5875+
// If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or
5876+
// `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`.
5877+
sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec;
5878+
sz_u128_vec_t blended_0_to_255_vec;
5879+
5880+
// Process the head with serial code
5881+
for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source];
5882+
5883+
// Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction
5884+
// to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position
5885+
// within each 64-byte range of the table.
5886+
// Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/
5887+
length -= head_length;
5888+
length -= tail_length;
5889+
for (; length >= 16; source += 16, target += 16, length -= 16) {
5890+
source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source);
5891+
lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16);
5892+
lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40)));
5893+
lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80)));
5894+
lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0)));
5895+
blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16),
5896+
vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16));
5897+
vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16);
5898+
}
5899+
5900+
// Process the tail with serial code
5901+
for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source];
5902+
}
5903+
58465904
SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
58475905
sz_u64_t matches;
58485906
sz_u128_vec_t h_vec, n_vec, matches_vec;
@@ -6276,6 +6334,8 @@ SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr
62766334
sz_look_up_transform_avx512(source, length, lut, target);
62776335
#elif SZ_USE_X86_AVX2
62786336
sz_look_up_transform_avx2(source, length, lut, target);
6337+
#elif SZ_USE_ARM_NEON
6338+
sz_look_up_transform_neon(source, length, lut, target);
62796339
#else
62806340
sz_look_up_transform_serial(source, length, lut, target);
62816341
#endif

scripts/bench_memory.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ tracked_unary_functions_t copy_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o
7676
{"sz_copy_avx2" + suffix, wrap_sz(sz_copy_avx2)},
7777
#endif
7878
#if SZ_USE_ARM_SVE
79-
{"sz_copy_sve" + suffix, wrap_sz(sz_copy_sve), true},
79+
{"sz_copy_sve" + suffix, wrap_sz(sz_copy_sve)},
8080
#endif
8181
#if SZ_USE_ARM_NEON
8282
{"sz_copy_neon" + suffix, wrap_sz(sz_copy_neon)},
@@ -116,7 +116,7 @@ tracked_unary_functions_t fill_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o
116116
{"sz_fill_avx2", wrap_sz(sz_fill_avx2)},
117117
#endif
118118
#if SZ_USE_ARM_SVE
119-
{"sz_fill_sve", wrap_sz(sz_fill_sve), true},
119+
{"sz_fill_sve", wrap_sz(sz_fill_sve)},
120120
#endif
121121
#if SZ_USE_ARM_NEON
122122
{"sz_fill_neon", wrap_sz(sz_fill_neon)},
@@ -197,6 +197,9 @@ tracked_unary_functions_t transform_functions() {
197197
#endif
198198
#if SZ_USE_X86_AVX2
199199
{"sz_look_up_transform_avx2", wrap_sz(sz_look_up_transform_avx2)},
200+
#endif
201+
#if SZ_USE_ARM_NEON
202+
{"sz_look_up_transform_neon", wrap_sz(sz_look_up_transform_neon)},
200203
#endif
201204
};
202205
return result;
@@ -223,15 +226,17 @@ void bench_memory(std::vector<std::string_view> const &slices, sz_cptr_t dataset
223226
sz_ptr_t output_buffer_ptr) {
224227

225228
if (slices.size() == 0) return;
229+
(void)dataset_start_ptr;
230+
(void)output_buffer_ptr;
226231

227232
bench_memory(slices, copy_functions<true>(dataset_start_ptr, output_buffer_ptr));
228233
bench_memory(slices, copy_functions<false>(dataset_start_ptr, output_buffer_ptr));
229234
bench_memory(slices, fill_functions(dataset_start_ptr, output_buffer_ptr));
230-
bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 1));
231-
bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 8));
232-
bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, SZ_CACHE_LINE_WIDTH));
233-
bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, max_shift_length));
234-
bench_memory(slices, transform_functions());
235+
// bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 1));
236+
// bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, 8));
237+
// bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, SZ_CACHE_LINE_WIDTH));
238+
// bench_memory(slices, move_functions(dataset_start_ptr, output_buffer_ptr, max_shift_length));
239+
// bench_memory(slices, transform_functions());
235240
}
236241

237242
int main(int argc, char const **argv) {

0 commit comments

Comments
 (0)