Skip to content

Commit 5af4b92

Browse files
Extend #1172 approach to avx512
1 parent b1fa370 commit 5af4b92

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

include/xsimd/arch/xsimd_avx512bw.hpp

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,18 +316,66 @@ namespace xsimd
316316
}
317317

318318
// load
319-
template <class A, class T, class = typename std::enable_if<batch_bool<T, A>::size == 64, void>::type>
319+
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
320320
XSIMD_INLINE batch_bool<T, A> load_unaligned(bool const* mem, batch_bool<T, A>, requires_arch<avx512bw>) noexcept
321321
{
322-
__m512i bool_val = _mm512_loadu_si512((__m512i const*)mem);
323-
return _mm512_cmpgt_epu8_mask(bool_val, _mm512_setzero_si512());
322+
using mask_type = typename batch_bool<T, A>::register_type;
323+
XSIMD_IF_CONSTEXPR(batch_bool<T, A>::size == 64)
324+
{
325+
__m512i bool_val = _mm512_loadu_si512((__m512i const*)mem);
326+
return (mask_type)_mm512_cmpgt_epu8_mask(bool_val, _mm512_setzero_si512());
327+
}
328+
else XSIMD_IF_CONSTEXPR(batch_bool<T, A>::size == 32)
329+
{
330+
__m256i bpack = _mm256_loadu_si256((__m256i const*)mem);
331+
return (mask_type)_mm512_cmpgt_epu16_mask(_mm512_cvtepu8_epi16(bpack), _mm512_setzero_si512());
332+
}
333+
else XSIMD_IF_CONSTEXPR(batch_bool<T, A>::size == 16)
334+
{
335+
__m128i bpack = _mm_loadu_si128((__m128i const*)mem);
336+
return (mask_type)_mm512_cmpgt_epu32_mask(_mm512_cvtepu8_epi32(bpack), _mm512_setzero_si512());
337+
}
338+
else XSIMD_IF_CONSTEXPR(batch_bool<T, A>::size == 8)
339+
{
340+
__m128i bpack = _mm_loadl_epi64((__m128i const*)mem);
341+
return (mask_type)_mm512_cmpgt_epu64_mask(_mm512_cvtepu8_epi64(bpack), _mm512_setzero_si512());
342+
}
343+
else
344+
{
345+
assert(false && "unexpected batch size");
346+
return {};
347+
}
324348
}
325349

326-
template <class A, class T, class = typename std::enable_if<batch_bool<T, A>::size == 64, void>::type>
350+
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
327351
XSIMD_INLINE batch_bool<T, A> load_aligned(bool const* mem, batch_bool<T, A>, requires_arch<avx512bw>) noexcept
328352
{
329-
__m512i bool_val = _mm512_load_si512((__m512i const*)mem);
330-
return _mm512_cmpgt_epu8_mask(bool_val, _mm512_setzero_si512());
353+
using mask_type = typename batch_bool<T, A>::register_type;
354+
XSIMD_IF_CONSTEXPR(batch_bool<T, A>::size == 64)
355+
{
356+
__m512i bool_val = _mm512_load_si512((__m512i const*)mem);
357+
return (mask_type)_mm512_cmpgt_epu8_mask(bool_val, _mm512_setzero_si512());
358+
}
359+
else XSIMD_IF_CONSTEXPR(batch_bool<T, A>::size == 32)
360+
{
361+
__m256i bpack = _mm256_load_si256((__m256i const*)mem);
362+
return (mask_type)_mm512_cmpgt_epu16_mask(_mm512_cvtepu8_epi16(bpack), _mm512_setzero_si512());
363+
}
364+
else XSIMD_IF_CONSTEXPR(batch_bool<T, A>::size == 16)
365+
{
366+
__m128i bpack = _mm_load_si128((__m128i const*)mem);
367+
return (mask_type)_mm512_cmpgt_epu32_mask(_mm512_cvtepu8_epi32(bpack), _mm512_setzero_si512());
368+
}
369+
else XSIMD_IF_CONSTEXPR(batch_bool<T, A>::size == 8)
370+
{
371+
__m128i bpack = _mm_loadl_epi64((__m128i const*)mem);
372+
return (mask_type)_mm512_cmpgt_epu64_mask(_mm512_cvtepu8_epi64(bpack), _mm512_setzero_si512());
373+
}
374+
else
375+
{
376+
assert(false && "unexpected batch size");
377+
return {};
378+
}
331379
}
332380

333381
// max

0 commit comments

Comments
 (0)