Skip to content

Commit

Permalink
Optimize base64 decoder slightly for 3-7x speedup - {DRQS 175668871} …
Browse files Browse the repository at this point in the history
…(#4797)
  • Loading branch information
Cameron Desrochers authored and GitHub Enterprise committed Jul 23, 2024
1 parent e862d6b commit d9f0b33
Showing 1 changed file with 312 additions and 1 deletion.
313 changes: 312 additions & 1 deletion groups/bdl/bdlde/bdlde_base64decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,13 +435,23 @@ BSLS_IDENT("$Id: $")

#include <bslmf_assert.h>

#include <bsls_alignedbuffer.h>
#include <bsls_assert.h>
#include <bsls_deprecatefeature.h>
#include <bsls_performancehint.h>
#include <bsls_review.h>
#include <bsls_types.h>

#include <bsl_cstring.h>
#include <bsl_cstdint.h>
#include <bsl_iostream.h>

#ifdef __SSE4_2__
#include <emmintrin.h>
#include <smmintrin.h>
#include <tmmintrin.h>
#endif

namespace BloombergLP {
namespace bdlde {

Expand Down Expand Up @@ -860,6 +870,304 @@ int Base64Decoder::convert(OUTPUT_ITERATOR out,
return e_ERROR_STATE == d_state ? -1 : d_bitsInStack / 8;
}

template<>
inline
int Base64Decoder::convert<char *, const char *>(
char *out,
int *numOut,
int *numIn,
const char *begin,
const char *end,
int maxNumOut)
{
BSLS_ASSERT(numOut);
BSLS_ASSERT(numIn);

if (BSLS_PERFORMANCEHINT_PREDICT_UNLIKELY(
e_ERROR_STATE == d_state || e_DONE_STATE == d_state)) {
int rv = e_DONE_STATE == d_state ? -2 : -1;
d_state = e_ERROR_STATE;
*numOut = 0;
*numIn = 0;
return rv; // RETURN
}

int numEmitted = 0;

// Emit as many output bytes as possible.

if (BSLS_PERFORMANCEHINT_PREDICT_UNLIKELY(8 <= d_bitsInStack)) {
while (8 <= d_bitsInStack && numEmitted != maxNumOut) {
d_bitsInStack -= 8;
*out = static_cast<char>((d_stack >> d_bitsInStack) & 0xff);
++out;
++numEmitted;
}
}

// Consume as many input bytes as possible.

const char *originalBegin = begin;

if (BSLS_PERFORMANCEHINT_PREDICT_LIKELY(e_INPUT_STATE == d_state)) {
if (BSLS_PERFORMANCEHINT_PREDICT_LIKELY(d_bitsInStack == 0)) {
// Optimize for common case
#ifdef __SSE4_2__
// Load 16-byte slices of LUT. Note that the entire 256-byte LUT
// is *not* loaded, but only the middle slices that are non-ff.
const __m128i *alphabetSlices =
reinterpret_cast<const __m128i *>(d_alphabet_p);
__m128i lut5 = _mm_loadu_si128(alphabetSlices + 7);
__m128i lut4 = _mm_loadu_si128(alphabetSlices + 6);
__m128i lut3 = _mm_loadu_si128(alphabetSlices + 5);
__m128i lut2 = _mm_loadu_si128(alphabetSlices + 4);
__m128i lut1 = _mm_loadu_si128(alphabetSlices + 3);
__m128i lut0 = _mm_loadu_si128(alphabetSlices + 2);

// Heavily inspired by techniques outlined in
// http://0x80.pl/notesen/2016-01-17-sse-base64-decoding.html

// xor LUT fragments together for pshufb-xor chaining below.
lut5 = _mm_xor_si128(lut5, lut4);
lut4 = _mm_xor_si128(lut4, lut3);
lut3 = _mm_xor_si128(lut3, lut2);
lut2 = _mm_xor_si128(lut2, lut1);
lut1 = _mm_xor_si128(lut1, lut0);

while (end - begin >= 16 && static_cast<unsigned>(numEmitted + 12)
<= static_cast<unsigned>(maxNumOut)) {
// Load 16 base64 characters (will eventually be transformed
// into 12 bytes)
__m128i x = _mm_loadu_si128(
reinterpret_cast<const __m128i *>(begin));

// Offset indexes to match first LUT slice at offset 0x20
x = _mm_subs_epi8(x, _mm_set1_epi8(0x20));

// If indexes were < 0x20, 'x' will contain negative values
// which we will check for later (minimum bounds check)
__m128i tooSmall = x;

// Using the characters as indexes, look up the corresponding
// values from the LUT. If an index is non-negative, only its
// low 4 bits are considered. If an index is negative, 0 is
// returned for its lookup value.
__m128i decoded = _mm_shuffle_epi8(lut0, x);

// Advance to the next LUT slice. Note that if the previous
// slice was the correct one for a given index, the index will
// become negative after this, resulting in subsequent lookups
// simply xor-ing 0 (harmless no-ops).
x = _mm_subs_epi8(x, _mm_set1_epi8(0x10));

// Perform the next lookup using the same low 4 bits of each
// non-negative index. The result is then xor-ed with the
// previous lookup result. For negative indices, this is a
// no-op, while for non-negative indices, the xor with the
// previous LUT slice value cancels out the xor-ing done to the
// LUT slices above the loop, leaving the original value from
// this LUT slice.
decoded = _mm_xor_si128(decoded, _mm_shuffle_epi8(lut1, x));

// Continue to advance to each LUT slice
x = _mm_subs_epi8(x, _mm_set1_epi8(0x10));
decoded = _mm_xor_si128(decoded, _mm_shuffle_epi8(lut2, x));
x = _mm_subs_epi8(x, _mm_set1_epi8(0x10));
decoded = _mm_xor_si128(decoded, _mm_shuffle_epi8(lut3, x));
x = _mm_subs_epi8(x, _mm_set1_epi8(0x10));
decoded = _mm_xor_si128(decoded, _mm_shuffle_epi8(lut4, x));
x = _mm_subs_epi8(x, _mm_set1_epi8(0x10));
decoded = _mm_xor_si128(decoded, _mm_shuffle_epi8(lut5, x));
x = _mm_subs_epi8(x, _mm_set1_epi8(0x10));

// At this point, the indexes in 'x' should be negative, as
// we've exhausted all populated LUT slices. If any are not,
// that indicates the maximum bounds check failed.

// Check the minimum and maximum bounds were respected, as well
// as for any 'ff' values loaded from LUT slices themselves.
if (BSLS_PERFORMANCEHINT_PREDICT_UNLIKELY(
!_mm_testz_si128(tooSmall | decoded | ~x,
_mm_set1_epi8(0x80)))) {
// Unknown char; could be error or could be a character to
// ignore; either way fall back to regular decoding
break;
}

// 'decoded' currently contains dwords layed out like
// |00aaaaaa|00bbbbbb|00cccccc|00dddddd|. Convert to
// |0000aaaa aabbbbbb|0000cccc ccdddddd| with a multiply-add.
decoded = _mm_maddubs_epi16(decoded, _mm_set1_epi16(0x0140));

// Convert to final form of
// |00000000 aaaaaabb bbbbcccc ccdddddd| with another multiply-
// add. Note that each triplet of values is aligned to a byte
// boundary following this operation.
decoded = _mm_madd_epi16(decoded, _mm_set1_epi32(0x00011000));

// Take care of endianness and last four one-byte gaps by
// explicitly selecting each byte we want in order.
__m128i selection = _mm_set_epi64(
reinterpret_cast<__m64>(0xffffffff0c0d0e08ull),
reinterpret_cast<__m64>(0x090a040506000102ull));
decoded = _mm_shuffle_epi8(decoded, selection);

// Store the result
memcpy(out, &decoded, 12);

begin += 16;
numEmitted += 12;
out += 12;
}
#endif
while (end - begin >= 4 && static_cast<unsigned>(numEmitted + 3)
<= static_cast<unsigned>(maxNumOut)) {
bsls::AlignedBuffer<4, 4> inBuffer;
uint8_t *in = reinterpret_cast<uint8_t *>(inBuffer.buffer());
memcpy(in, begin, 4);

uint8_t x[4];
x[0] = static_cast<uint8_t>(d_alphabet_p[in[0]]);
x[1] = static_cast<uint8_t>(d_alphabet_p[in[1]]);
x[2] = static_cast<uint8_t>(d_alphabet_p[in[2]]);
x[3] = static_cast<uint8_t>(d_alphabet_p[in[3]]);

uint32_t x4;
memcpy(&x4, x, sizeof(x4));
if (BSLS_PERFORMANCEHINT_PREDICT_UNLIKELY(x4 & 0x80808080u)) {
// Unknown char; could be error or could be a character to
// ignore; either way fall back to char-by-char decoding
break;
}

out[0] = static_cast<char>((x[0] << 2) | (x[1] >> 4));
out[1] = static_cast<char>((x[1] << 4) | (x[2] >> 2));
out[2] = static_cast<char>((x[2] << 6) | (x[3] >> 0));

begin += 4;
numEmitted += 3;
out += 3;
}
}

while (18 >= d_bitsInStack && begin != end) {
const unsigned char byte = static_cast<unsigned char>(*begin);

++begin;

unsigned char converted = static_cast<unsigned char>(
d_alphabet_p[byte]);

if (converted < 64) {
d_stack = (d_stack << 6) | converted;
d_bitsInStack += 6;
if (8 <= d_bitsInStack && numEmitted != maxNumOut) {
d_bitsInStack -= 8;
*out = static_cast<char>(
(d_stack >> d_bitsInStack) & 0xff);
++out;
++numEmitted;
}
}
else if (!d_ignorable_p[byte]) {
if ('=' == byte && d_isPadded) {
const int residual = residualBits(
d_outputLength + numEmitted);
// 'residual' is 0, 6, 12, or 18.
//: o If it's 0, that's an error since no '=' should be
//: needed.
//:
//: o If it's 6, that's an error because an incomplete
//: byte has been input.
//:
//: o 12 means 2 bytes have been read, meaning we have to
//: do 1 byte of output (which we may have already done).
//: The low-order 4 bits of stack should either be
//: 0 or the stack should be empty.
//:
//: o 18 means 3 bytes have been read, meaning we have to
//: do 2 bytes of output (some or all of which we may
//: have already done). The low-order 2 bits of stack
//: should either be 0 or the stack should be empty.

const int leftOver = residual % 8;
d_state = 0 != (d_stack & ((1 << leftOver) - 1))
? e_ERROR_STATE
: 12 == residual
? e_NEED_EQUAL_STATE
: 18 == residual
? e_SOFT_DONE_STATE
: e_ERROR_STATE;
d_stack >>= leftOver;
d_bitsInStack -= leftOver;
}
else {
d_state = e_ERROR_STATE;
}
break;
}
}
}

if (e_NEED_EQUAL_STATE == d_state) {
BSLS_ASSERT(d_isPadded);

while (begin != end) {
const unsigned char byte = static_cast<unsigned char>(*begin);

++begin;

if (!d_ignorable_p[byte]) {
if ('=' == byte) {
d_state = e_SOFT_DONE_STATE;
}
else {
d_state = e_ERROR_STATE;
}
break;
}
}
}
if (BSLS_PERFORMANCEHINT_PREDICT_UNLIKELY(e_SOFT_DONE_STATE == d_state
&& begin != end)) {
do {
const unsigned char byte = static_cast<unsigned char>(*begin);

++begin;

if (BSLS_PERFORMANCEHINT_PREDICT_UNLIKELY(!d_ignorable_p[byte])) {
d_state = e_ERROR_STATE;
break;
}
} while (begin != end);
}

*numIn = begin - originalBegin;
*numOut = numEmitted;
d_outputLength += numEmitted;

return e_ERROR_STATE == d_state ? -1 : d_bitsInStack / 8;
}

template<>
inline
int Base64Decoder::convert<unsigned char *, const unsigned char *>(
unsigned char *out,
int *numOut,
int *numIn,
const unsigned char *begin,
const unsigned char *end,
int maxNumOut)
{
return convert(reinterpret_cast<char *>(out),
numOut,
numIn,
reinterpret_cast<const char *>(begin),
reinterpret_cast<const char *>(end),
maxNumOut);
}


template <class OUTPUT_ITERATOR>
int Base64Decoder::endConvert(OUTPUT_ITERATOR out)
{
Expand All @@ -878,8 +1186,11 @@ int Base64Decoder::endConvert(OUTPUT_ITERATOR out,
if (!d_isPadded && e_INPUT_STATE == d_state) {
const int residual = residualBits(d_outputLength);
const int leftOver = residual % 8;
if (6 == residual || 0 != (d_stack & ((1 << leftOver) - 1))) {
if (BSLS_PERFORMANCEHINT_PREDICT_UNLIKELY(6 == residual ||
0 != (d_stack & ((1 << leftOver) - 1)))) {
d_state = e_ERROR_STATE;
*numOut = 0;
return -1; // RETURN
}
else {
d_stack >>= leftOver;
Expand Down

0 comments on commit d9f0b33

Please sign in to comment.