Skip to content

Commit

Permalink
Add: AVX-512 implementations for substring search
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Dec 12, 2023
1 parent 66cd1cf commit 632d226
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/avx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ SZ_PUBLIC sz_cptr_t sz_find_3byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_
n_parts.u8s[1] = n[1];
n_parts.u8s[2] = n[2];

// This implementation is more complex than the `sz_find_4byte_avx2`, as we are going to
// match only 3 bytes within each 4-byte word.
// This implementation is more complex than the `sz_find_4byte_avx2`,
// as we are going to match only 3 bytes within each 4-byte word.
sz_u64_parts_t mask_parts;
mask_parts.u64 = 0;
mask_parts.u8s[0] = mask_parts.u8s[1] = mask_parts.u8s[2] = 0xFF, mask_parts.u8s[3] = 0;
Expand Down
189 changes: 159 additions & 30 deletions src/avx512.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr
SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
__m512i a_vec, b_vec;
__mmask64 mask;
sz_size_t loaded_length;

sz_equal_avx512_cycle:
if (length < 64) {
Expand All @@ -93,47 +92,177 @@ SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length)
}
}

SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t haystack_length, sz_cptr_t needle) {
SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
__m512i h_vec, n_vec = _mm512_set1_epi8(n[0]);
__mmask64 mask;

sz_find_byte_avx512_cycle:
if (h_length < 64) {
mask = mask_up_to(h_length);
h_vec = _mm512_maskz_loadu_epi8(mask, h);
// Reuse the same `mask` variable to find the bit that doesn't match
mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec, n_vec);
if (mask) return h + sz_ctz64(mask);
}
else {
h_vec = _mm512_loadu_epi8(h);
mask = _mm512_cmpeq_epi8_mask(h_vec, n_vec);
if (mask) return h + sz_ctz64(mask);
h += 64, h_length -= 64;
if (h_length) goto sz_find_byte_avx512_cycle;
}
return NULL;
}

__m512i needle_vec = _mm512_set1_epi8(*needle);
__m512i haystack_vec;
SZ_PUBLIC sz_cptr_t sz_find_2byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {

// Calculate alignment offset
sz_size_t unaligned_prefix_length = 64ul - ((sz_size_t)haystack & 63ul);
sz_u64_parts_t n_parts;
n_parts.u64 = 0;
n_parts.u8s[0] = n[0];
n_parts.u8s[1] = n[1];

// Handle unaligned prefix
if (unaligned_prefix_length > 0 && haystack_length >= unaligned_prefix_length) {
haystack_vec = _mm512_maskz_loadu_epi8(mask_up_to(unaligned_prefix_length), haystack);
__mmask64 matches = _mm512_cmpeq_epu8_mask(haystack_vec, needle_vec);
if (matches != 0) return haystack + sz_ctz64(matches);
__m512i h0_vec, h1_vec, n_vec = _mm512_set1_epi16(n_parts.u16s[0]);
__mmask64 mask;
__mmask32 matches0, matches1;

haystack += unaligned_prefix_length;
haystack_length -= unaligned_prefix_length;
sz_find_2byte_avx512_cycle:
if (h_length < 2) { return NULL; }
else if (h_length < 66) {
mask = mask_up_to(h_length);
h0_vec = _mm512_maskz_loadu_epi8(mask, h);
h1_vec = _mm512_maskz_loadu_epi8(mask, h + 1);
matches0 = _mm512_mask_cmpeq_epi16_mask(mask, h0_vec, n_vec);
matches1 = _mm512_mask_cmpeq_epi16_mask(mask, h1_vec, n_vec);
if (matches0 | matches1)
return h + sz_ctz64(_pdep_u64(matches0, 0x5555555555555555) | //
_pdep_u64(matches1, 0xAAAAAAAAAAAAAAAA));
return NULL;
}
else {
h0_vec = _mm512_loadu_epi8(h);
h1_vec = _mm512_loadu_epi8(h + 1);
matches0 = _mm512_cmpeq_epi16_mask(h0_vec, n_vec);
matches1 = _mm512_cmpeq_epi16_mask(h1_vec, n_vec);
// https://lemire.me/blog/2018/01/08/how-fast-can-you-bit-interleave-32-bit-integers/
if (matches0 | matches1)
return h + sz_ctz64(_pdep_u64(matches0, 0x5555555555555555) | //
_pdep_u64(matches1, 0xAAAAAAAAAAAAAAAA));
h += 64, h_length -= 64;
goto sz_find_2byte_avx512_cycle;
}
}

// Main aligned loop
while (haystack_length >= 64) {
haystack_vec = _mm512_load_epi32(haystack);
__mmask64 matches = _mm512_cmpeq_epu8_mask(haystack_vec, needle_vec);
if (matches != 0) return haystack + sz_ctz64(matches);
SZ_PUBLIC sz_cptr_t sz_find_4byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {

haystack += 64;
haystack_length -= 64;
}
sz_u64_parts_t n_parts;
n_parts.u64 = 0;
n_parts.u8s[0] = n[0];
n_parts.u8s[1] = n[1];
n_parts.u8s[2] = n[2];
n_parts.u8s[3] = n[3];

// Handle remaining bytes
if (haystack_length > 0) {
haystack_vec = _mm512_maskz_loadu_epi8(mask_up_to(haystack_length), haystack);
__mmask64 matches = _mm512_cmpeq_epu8_mask(haystack_vec, needle_vec);
if (matches != 0) return haystack + sz_ctz64(matches);
__m512i h0_vec, h1_vec, h2_vec, h3_vec, n_vec = _mm512_set1_epi32(n_parts.u32s[0]);
__mmask64 mask;
__mmask16 matches0, matches1, matches2, matches3;

sz_find_4byte_avx512_cycle:
if (h_length < 4) { return NULL; }
else if (h_length < 68) {
mask = mask_up_to(h_length);
h0_vec = _mm512_maskz_loadu_epi8(mask, h);
h1_vec = _mm512_maskz_loadu_epi8(mask, h + 1);
h2_vec = _mm512_maskz_loadu_epi8(mask, h + 2);
h3_vec = _mm512_maskz_loadu_epi8(mask, h + 3);
matches0 = _mm512_mask_cmpeq_epi32_mask(mask, h0_vec, n_vec);
matches1 = _mm512_mask_cmpeq_epi32_mask(mask, h1_vec, n_vec);
matches2 = _mm512_mask_cmpeq_epi32_mask(mask, h2_vec, n_vec);
matches3 = _mm512_mask_cmpeq_epi32_mask(mask, h3_vec, n_vec);
if (matches0 | matches1 | matches2 | matches3)
return h + sz_ctz64(_pdep_u64(matches0, 0x1111111111111111) | //
_pdep_u64(matches1, 0x2222222222222222) | //
_pdep_u64(matches2, 0x4444444444444444) | //
_pdep_u64(matches3, 0x8888888888888888));
return NULL;
}
else {
h0_vec = _mm512_loadu_epi8(h);
h1_vec = _mm512_loadu_epi8(h + 1);
h2_vec = _mm512_loadu_epi8(h + 2);
h3_vec = _mm512_loadu_epi8(h + 3);
matches0 = _mm512_cmpeq_epi32_mask(h0_vec, n_vec);
matches1 = _mm512_cmpeq_epi32_mask(h1_vec, n_vec);
matches2 = _mm512_cmpeq_epi32_mask(h2_vec, n_vec);
matches3 = _mm512_cmpeq_epi32_mask(h3_vec, n_vec);
if (matches0 | matches1 | matches2 | matches3)
return h + sz_ctz64(_pdep_u64(matches0, 0x1111111111111111) | //
_pdep_u64(matches1, 0x2222222222222222) | //
_pdep_u64(matches2, 0x4444444444444444) | //
_pdep_u64(matches3, 0x8888888888888888));
h += 64, h_length -= 64;
goto sz_find_4byte_avx512_cycle;
}
}

return NULL;
SZ_PUBLIC sz_cptr_t sz_find_3byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {

sz_u64_parts_t n_parts;
n_parts.u64 = 0;
n_parts.u8s[0] = n[0];
n_parts.u8s[1] = n[1];
n_parts.u8s[2] = n[2];

__m512i h0_vec, h1_vec, h2_vec, h3_vec, n_vec = _mm512_set1_epi32(n_parts.u32s[0]);
__mmask64 mask;
__mmask16 matches0, matches1, matches2, matches3;

sz_find_3byte_avx512_cycle:
if (h_length < 3) { return NULL; }
else if (h_length < 67) {
mask = mask_up_to(h_length);
// This implementation is more complex than the `sz_find_4byte_avx512`,
// as we are going to match only 3 bytes within each 4-byte word.
h0_vec = _mm512_maskz_loadu_epi8(mask & 0x7777777777777777, h);
h1_vec = _mm512_maskz_loadu_epi8(mask & 0x7777777777777777, h + 1);
h2_vec = _mm512_maskz_loadu_epi8(mask & 0x7777777777777777, h + 2);
h3_vec = _mm512_maskz_loadu_epi8(mask & 0x7777777777777777, h + 3);
matches0 = _mm512_mask_cmpeq_epi32_mask(mask, h0_vec, n_vec);
matches1 = _mm512_mask_cmpeq_epi32_mask(mask, h1_vec, n_vec);
matches2 = _mm512_mask_cmpeq_epi32_mask(mask, h2_vec, n_vec);
matches3 = _mm512_mask_cmpeq_epi32_mask(mask, h3_vec, n_vec);
if (matches0 | matches1 | matches2 | matches3)
return h + sz_ctz64(_pdep_u64(matches0, 0x1111111111111111) | //
_pdep_u64(matches1, 0x2222222222222222) | //
_pdep_u64(matches2, 0x4444444444444444) | //
_pdep_u64(matches3, 0x8888888888888888));
return NULL;
}
else {
h0_vec = _mm512_maskz_loadu_epi8(0x7777777777777777, h);
h1_vec = _mm512_maskz_loadu_epi8(0x7777777777777777, h + 1);
h2_vec = _mm512_maskz_loadu_epi8(0x7777777777777777, h + 2);
h3_vec = _mm512_maskz_loadu_epi8(0x7777777777777777, h + 3);
matches0 = _mm512_cmpeq_epi32_mask(h0_vec, n_vec);
matches1 = _mm512_cmpeq_epi32_mask(h1_vec, n_vec);
matches2 = _mm512_cmpeq_epi32_mask(h2_vec, n_vec);
matches3 = _mm512_cmpeq_epi32_mask(h3_vec, n_vec);
if (matches0 | matches1 | matches2 | matches3)
return h + sz_ctz64(_pdep_u64(matches0, 0x1111111111111111) | //
_pdep_u64(matches1, 0x2222222222222222) | //
_pdep_u64(matches2, 0x4444444444444444) | //
_pdep_u64(matches3, 0x8888888888888888));
h += 64, h_length -= 64;
goto sz_find_3byte_avx512_cycle;
}
}

SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t haystack_length, sz_cptr_t needle,
sz_size_t needle_length) {
return sz_find_serial(haystack, haystack_length, needle, needle_length);
SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
switch (n_length) {
case 1: return sz_find_byte_avx512(h, h_length, n);
case 2: return sz_find_2byte_avx512(h, h_length, n);
case 3: return sz_find_3byte_avx512(h, h_length, n);
case 4: return sz_find_4byte_avx512(h, h_length, n);
default: return sz_find_serial(h, h_length, n, n_length);
}
}

/**
Expand Down

0 comments on commit 632d226

Please sign in to comment.