From 632d226b4b094268e19345ce489daeb2921eed6f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 12 Dec 2023 03:01:45 +0000 Subject: [PATCH] Add: AVX-512 implementations for substring search --- src/avx2.c | 4 +- src/avx512.c | 189 +++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 161 insertions(+), 32 deletions(-) diff --git a/src/avx2.c b/src/avx2.c index 72fdee3a..5bd111ee 100644 --- a/src/avx2.c +++ b/src/avx2.c @@ -98,8 +98,8 @@ SZ_PUBLIC sz_cptr_t sz_find_3byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_ n_parts.u8s[1] = n[1]; n_parts.u8s[2] = n[2]; - // This implementation is more complex than the `sz_find_4byte_avx2`, as we are going to - // match only 3 bytes within each 4-byte word. + // This implementation is more complex than the `sz_find_4byte_avx2`, + // as we are going to match only 3 bytes within each 4-byte word. sz_u64_parts_t mask_parts; mask_parts.u64 = 0; mask_parts.u8s[0] = mask_parts.u8s[1] = mask_parts.u8s[2] = 0xFF, mask_parts.u8s[3] = 0; diff --git a/src/avx512.c b/src/avx512.c index d6f38040..67b048e7 100644 --- a/src/avx512.c +++ b/src/avx512.c @@ -71,7 +71,6 @@ SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { __m512i a_vec, b_vec; __mmask64 mask; - sz_size_t loaded_length; sz_equal_avx512_cycle: if (length < 64) { @@ -93,47 +92,177 @@ SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) } } -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t haystack_length, sz_cptr_t needle) { +SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + __m512i h_vec, n_vec = _mm512_set1_epi8(n[0]); + __mmask64 mask; + +sz_find_byte_avx512_cycle: + if (h_length < 64) { + mask = mask_up_to(h_length); + h_vec = _mm512_maskz_loadu_epi8(mask, h); + // Reuse the same `mask` variable to find the bit that doesn't match + mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec, n_vec); + if (mask) return h + sz_ctz64(mask); + } + else { + h_vec = _mm512_loadu_epi8(h); + mask = _mm512_cmpeq_epi8_mask(h_vec, n_vec); + if (mask) return h + sz_ctz64(mask); + h += 64, h_length -= 64; + if (h_length) goto sz_find_byte_avx512_cycle; + } + return NULL; +} - __m512i needle_vec = _mm512_set1_epi8(*needle); - __m512i haystack_vec; +SZ_PUBLIC sz_cptr_t sz_find_2byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - // Calculate alignment offset - sz_size_t unaligned_prefix_length = 64ul - ((sz_size_t)haystack & 63ul); + sz_u64_parts_t n_parts; + n_parts.u64 = 0; + n_parts.u8s[0] = n[0]; + n_parts.u8s[1] = n[1]; - // Handle unaligned prefix - if (unaligned_prefix_length > 0 && haystack_length >= unaligned_prefix_length) { - haystack_vec = _mm512_maskz_loadu_epi8(mask_up_to(unaligned_prefix_length), haystack); - __mmask64 matches = _mm512_cmpeq_epu8_mask(haystack_vec, needle_vec); - if (matches != 0) return haystack + sz_ctz64(matches); + __m512i h0_vec, h1_vec, n_vec = _mm512_set1_epi16(n_parts.u16s[0]); + __mmask64 mask; + __mmask32 matches0, matches1; - haystack += unaligned_prefix_length; - haystack_length -= unaligned_prefix_length; +sz_find_2byte_avx512_cycle: + if (h_length < 2) { return NULL; } + else if (h_length < 66) { + mask = mask_up_to(h_length); + h0_vec = _mm512_maskz_loadu_epi8(mask, h); + h1_vec = _mm512_maskz_loadu_epi8(mask, h + 1); + matches0 = _mm512_mask_cmpeq_epi16_mask(mask, h0_vec, n_vec); + matches1 = _mm512_mask_cmpeq_epi16_mask(mask, h1_vec, n_vec); + if (matches0 | matches1) + return h + sz_ctz64(_pdep_u64(matches0, 0x5555555555555555) | // + _pdep_u64(matches1, 0xAAAAAAAAAAAAAAAA)); + return NULL; + } + else { + h0_vec = _mm512_loadu_epi8(h); + h1_vec = _mm512_loadu_epi8(h + 1); + matches0 = _mm512_cmpeq_epi16_mask(h0_vec, n_vec); + matches1 = _mm512_cmpeq_epi16_mask(h1_vec, n_vec); + // https://lemire.me/blog/2018/01/08/how-fast-can-you-bit-interleave-32-bit-integers/ + if (matches0 | matches1) + return h + sz_ctz64(_pdep_u64(matches0, 0x5555555555555555) | // + _pdep_u64(matches1, 0xAAAAAAAAAAAAAAAA)); + h += 64, h_length -= 64; + goto sz_find_2byte_avx512_cycle; } +} - // Main aligned loop - while (haystack_length >= 64) { - haystack_vec = _mm512_load_epi32(haystack); - __mmask64 matches = _mm512_cmpeq_epu8_mask(haystack_vec, needle_vec); - if (matches != 0) return haystack + sz_ctz64(matches); +SZ_PUBLIC sz_cptr_t sz_find_4byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - haystack += 64; - haystack_length -= 64; - } + sz_u64_parts_t n_parts; + n_parts.u64 = 0; + n_parts.u8s[0] = n[0]; + n_parts.u8s[1] = n[1]; + n_parts.u8s[2] = n[2]; + n_parts.u8s[3] = n[3]; - // Handle remaining bytes - if (haystack_length > 0) { - haystack_vec = _mm512_maskz_loadu_epi8(mask_up_to(haystack_length), haystack); - __mmask64 matches = _mm512_cmpeq_epu8_mask(haystack_vec, needle_vec); - if (matches != 0) return haystack + sz_ctz64(matches); + __m512i h0_vec, h1_vec, h2_vec, h3_vec, n_vec = _mm512_set1_epi32(n_parts.u32s[0]); + __mmask64 mask; + __mmask16 matches0, matches1, matches2, matches3; + +sz_find_4byte_avx512_cycle: + if (h_length < 4) { return NULL; } + else if (h_length < 68) { + mask = mask_up_to(h_length); + h0_vec = _mm512_maskz_loadu_epi8(mask, h); + h1_vec = _mm512_maskz_loadu_epi8(mask, h + 1); + h2_vec = _mm512_maskz_loadu_epi8(mask, h + 2); + h3_vec = _mm512_maskz_loadu_epi8(mask, h + 3); + matches0 = _mm512_mask_cmpeq_epi32_mask(mask, h0_vec, n_vec); + matches1 = _mm512_mask_cmpeq_epi32_mask(mask, h1_vec, n_vec); + matches2 = _mm512_mask_cmpeq_epi32_mask(mask, h2_vec, n_vec); + matches3 = _mm512_mask_cmpeq_epi32_mask(mask, h3_vec, n_vec); + if (matches0 | matches1 | matches2 | matches3) + return h + sz_ctz64(_pdep_u64(matches0, 0x1111111111111111) | // + _pdep_u64(matches1, 0x2222222222222222) | // + _pdep_u64(matches2, 0x4444444444444444) | // + _pdep_u64(matches3, 0x8888888888888888)); + return NULL; + } + else { + h0_vec = _mm512_loadu_epi8(h); + h1_vec = _mm512_loadu_epi8(h + 1); + h2_vec = _mm512_loadu_epi8(h + 2); + h3_vec = _mm512_loadu_epi8(h + 3); + matches0 = _mm512_cmpeq_epi32_mask(h0_vec, n_vec); + matches1 = _mm512_cmpeq_epi32_mask(h1_vec, n_vec); + matches2 = _mm512_cmpeq_epi32_mask(h2_vec, n_vec); + matches3 = _mm512_cmpeq_epi32_mask(h3_vec, n_vec); + if (matches0 | matches1 | matches2 | matches3) + return h + sz_ctz64(_pdep_u64(matches0, 0x1111111111111111) | // + _pdep_u64(matches1, 0x2222222222222222) | // + _pdep_u64(matches2, 0x4444444444444444) | // + _pdep_u64(matches3, 0x8888888888888888)); + h += 64, h_length -= 64; + goto sz_find_4byte_avx512_cycle; } +} - return NULL; +SZ_PUBLIC sz_cptr_t sz_find_3byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + + sz_u64_parts_t n_parts; + n_parts.u64 = 0; + n_parts.u8s[0] = n[0]; + n_parts.u8s[1] = n[1]; + n_parts.u8s[2] = n[2]; + + __m512i h0_vec, h1_vec, h2_vec, h3_vec, n_vec = _mm512_set1_epi32(n_parts.u32s[0]); + __mmask64 mask; + __mmask16 matches0, matches1, matches2, matches3; + +sz_find_3byte_avx512_cycle: + if (h_length < 3) { return NULL; } + else if (h_length < 67) { + mask = mask_up_to(h_length); + // This implementation is more complex than the `sz_find_4byte_avx512`, + // as we are going to match only 3 bytes within each 4-byte word. + h0_vec = _mm512_maskz_loadu_epi8(mask & 0x7777777777777777, h); + h1_vec = _mm512_maskz_loadu_epi8(mask & 0x7777777777777777, h + 1); + h2_vec = _mm512_maskz_loadu_epi8(mask & 0x7777777777777777, h + 2); + h3_vec = _mm512_maskz_loadu_epi8(mask & 0x7777777777777777, h + 3); + matches0 = _mm512_mask_cmpeq_epi32_mask(mask, h0_vec, n_vec); + matches1 = _mm512_mask_cmpeq_epi32_mask(mask, h1_vec, n_vec); + matches2 = _mm512_mask_cmpeq_epi32_mask(mask, h2_vec, n_vec); + matches3 = _mm512_mask_cmpeq_epi32_mask(mask, h3_vec, n_vec); + if (matches0 | matches1 | matches2 | matches3) + return h + sz_ctz64(_pdep_u64(matches0, 0x1111111111111111) | // + _pdep_u64(matches1, 0x2222222222222222) | // + _pdep_u64(matches2, 0x4444444444444444) | // + _pdep_u64(matches3, 0x8888888888888888)); + return NULL; + } + else { + h0_vec = _mm512_maskz_loadu_epi8(0x7777777777777777, h); + h1_vec = _mm512_maskz_loadu_epi8(0x7777777777777777, h + 1); + h2_vec = _mm512_maskz_loadu_epi8(0x7777777777777777, h + 2); + h3_vec = _mm512_maskz_loadu_epi8(0x7777777777777777, h + 3); + matches0 = _mm512_cmpeq_epi32_mask(h0_vec, n_vec); + matches1 = _mm512_cmpeq_epi32_mask(h1_vec, n_vec); + matches2 = _mm512_cmpeq_epi32_mask(h2_vec, n_vec); + matches3 = _mm512_cmpeq_epi32_mask(h3_vec, n_vec); + if (matches0 | matches1 | matches2 | matches3) + return h + sz_ctz64(_pdep_u64(matches0, 0x1111111111111111) | // + _pdep_u64(matches1, 0x2222222222222222) | // + _pdep_u64(matches2, 0x4444444444444444) | // + _pdep_u64(matches3, 0x8888888888888888)); + h += 64, h_length -= 64; + goto sz_find_3byte_avx512_cycle; + } } -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t haystack_length, sz_cptr_t needle, - sz_size_t needle_length) { - return sz_find_serial(haystack, haystack_length, needle, needle_length); +SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + switch (n_length) { + case 1: return sz_find_byte_avx512(h, h_length, n); + case 2: return sz_find_2byte_avx512(h, h_length, n); + case 3: return sz_find_3byte_avx512(h, h_length, n); + case 4: return sz_find_4byte_avx512(h, h_length, n); + default: return sz_find_serial(h, h_length, n, n_length); + } } /**