diff --git a/CMakeLists.txt b/CMakeLists.txt index af27542..1fdf26b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,7 +88,7 @@ if(INTGEMM_DONT_BUILD_TESTS) return() endif() -foreach(exe benchmark biasmultiply benchmark_quantizer) +foreach(exe benchmark biasmultiply benchmark_quantizer non_mult_8) add_executable(${exe} benchmarks/${exe}.cc) target_link_libraries(${exe} intgemm) endforeach() diff --git a/benchmarks/non_mult_8.cc b/benchmarks/non_mult_8.cc new file mode 100644 index 0000000..6531f4d --- /dev/null +++ b/benchmarks/non_mult_8.cc @@ -0,0 +1,149 @@ +#include "../intgemm/aligned.h" +#include "intgemm/intgemm_config.h" +#include "../intgemm/avx512_gemm.h" +#include "../intgemm/sse2_gemm.h" +#include "../intgemm/avx2_gemm.h" +#include "../intgemm/ssse3_gemm.h" +#include "../intgemm/intgemm.h" +#include "../intgemm/stats.h" +#include "../intgemm/callbacks.h" +#include +#include + +/************************************************************************************ util ************************************************************************************/ +template +int numDigits(T number) { + int digits = 0; + if (number <= 0) { + digits = 1; // count the minus and take care of the zero case + } + while (number) { + number /= 10; + digits++; + } + return digits; +} + +template +void printMat(intType * a, size_t rows, size_t cols, std::string name, int digits = 0) { + std::cerr << name << std::endl; + for (size_t i = 0; i < rows; i++) { + for (size_t j = 0; j < cols; j++) { + int numbah = (int)a[i*cols + j]; + // Pad for nice printing + int mydigits = digits - numDigits(numbah); + for (int t = 0; t < mydigits; t++) { + std::cerr << ' '; + } + std::cerr << numbah << " "; + } + std::cerr << std::endl; + } + std::cerr << std::endl; +} + +template +void toColMajor(intType *in, intType * out, size_t rows, size_t cols) { + for (size_t i = 0; i < rows; i++) { + for (size_t j = 0; j < cols; j++) { + out[j*rows + i] = in[i*cols + j]; + } + } +} + +namespace intgemm { +template +void prepBtst(Index width, Index B_cols, float * in = nullptr) { + AlignedVector B(width * B_cols); + + //std::mt19937 gen; + //std::uniform_real_distribution dist(-1.0f, 1.0f); + + if (in != 0) { + for (Index i = 0; i B_prep(B.size()); + //AlignedVector B_prep_print(B.size()); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + printMat(B_prep.begin(), B_cols, width, "Prep Mat", 3); + + + //toColMajor(B_prep.begin(), B_prep_print.begin(), B_cols, width); + //printMat(B_prep_print.begin(), B_cols, width, "Prep Mat trans", 3); + +} + +void padMatrixTst(Index width, Index B_cols) { + AlignedVector B(width * B_cols); + std::div_t results = std::div(B_cols, 8); + + for (Index i = 0; i(width, 8, padded.begin()); +} + + +template +void smallMultTst(Index A_rows, Index width, Index B_cols) { + AlignedVector A(A_rows* width); + AlignedVector B(width * B_cols); + AlignedVector C(A_rows * B_cols); + + + for (Index i = 0; i A_prep(A.size()); + AlignedVector B_prep(B.size()); + + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); // A is strictly positive here + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + printMat(B_prep.begin(), B_cols, width, "Prep Mat B", 3); + + Routine::Multiply8Shift((uint8_t*)A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, C.begin())); + printMat(C.begin(), A_rows, B_cols, "Prep Mat C", 5); + +} + +} // namespace intgemm; +int main() { + using namespace intgemm; + //prepBtst(32, 35); + //prepBtst(64, 9); + //padMatrixTst(32, 35); + smallMultTst(2, 64, 9); +} diff --git a/intgemm/aligned.h b/intgemm/aligned.h index 6fda369..17ff014 100644 --- a/intgemm/aligned.h +++ b/intgemm/aligned.h @@ -39,7 +39,9 @@ template class AlignedVector { } AlignedVector(const AlignedVector&) = delete; + AlignedVector(AlignedVector&) = delete; AlignedVector& operator=(const AlignedVector&) = delete; + AlignedVector& operator=(AlignedVector&) = delete; ~AlignedVector() { #ifdef _MSC_VER diff --git a/intgemm/avx512_gemm.h b/intgemm/avx512_gemm.h index 90f67ee..d7594c8 100644 --- a/intgemm/avx512_gemm.h +++ b/intgemm/avx512_gemm.h @@ -254,12 +254,12 @@ struct Kernels8 { /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ INTGEMM_AVX512BW static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) { - assert(size % 16 == 0); + std::div_t result = std::div(size, 16); assert(reinterpret_cast(input) % 64 == 0); const __m512i pos127 = _mm512_set1_epi32(127); const __m512i zero = _mm512_setzero_si512(); const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult); - const float *end = input + size; + const float *end = input + result.quot*16; // Do the majority using AVX512 for (; input < end; input += 16, output += 16) { __m512i asint = QuantizerGrab(input, quant_mult_reg); asint = _mm512_min_epi32(asint, pos127); @@ -267,6 +267,9 @@ struct Kernels8 { asint = _mm512_max_epi32(asint, zero); _mm512_mask_cvtusepi32_storeu_epi8(output, 0xffff, asint); } + for (int i = 0; i < result.rem; i++) { // Fill in the gaps linearly + output[i] = static_cast(std::max(roundf(std::max(input[i]*quant_mult, 0.0f)), 255.0f)); + } } // Tile size for B; B must be a multiple of this block size. diff --git a/intgemm/avx512vnni_gemm.h b/intgemm/avx512vnni_gemm.h index 28e8c14..c31bbdd 100644 --- a/intgemm/avx512vnni_gemm.h +++ b/intgemm/avx512vnni_gemm.h @@ -83,15 +83,21 @@ struct Kernels8 : public AVX512BW::Kernels8 { template INTGEMM_AVX512VNNI static void Multiply8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { assert(width % sizeof(Register) == 0); - assert(B_cols % 8 == 0); + std::div_t results = std::div(B_cols, 8); + Index B_cols_trimmed = B_cols; + if (results.rem != 0) { + B_cols_trimmed = results.quot*8; + } + assert(B_cols_trimmed % 8 == 0); assert(reinterpret_cast(A) % sizeof(Register) == 0); assert(reinterpret_cast(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl(callback); const Index simd_width = width / sizeof(Register); Register zeros = setzero_si(); // Go over 8 columns of B at a time. + Index B0_colidx = 0; // OMP can't deal with this variable being asigned outside of the loop, hence we declare it once and asign to 0 twice #pragma omp for - for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + for (B0_colidx = 0; B0_colidx < B_cols_trimmed; B0_colidx += 8) { const Register *B0_col = reinterpret_cast(B) + B0_colidx * simd_width; // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { @@ -119,20 +125,51 @@ struct Kernels8 : public AVX512BW::Kernels8 { callback_impl.Run(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); } } + // Final bit, if we have a non-mult-of-eight matrix + if (results.rem != 0) { + const Register *B0_col = reinterpret_cast(B) + (B_cols_trimmed * width)/(sizeof(Register)); + // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { + // Iterate over shared (inner) dimension. + const Register *A_live = reinterpret_cast(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; + // TODO: separate first step. + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + Register * sums[8] = {&sum0, &sum1, &sum2, &sum3, &sum4, &sum5, &sum6, &sum7}; + for (; A_live != A_end; ++A_live, B_live += results.rem) { + Register a = *A_live; + //MultiplyAdd + for (int i = 0; i < results.rem; i++) { + VNNI8(*sums[i], a,*(B_live + i)); + } + } + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); + auto total = PermuteSummer(pack0123, pack4567); + callback_impl.RunPartial(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols), (Index)results.rem); + } + } } template INTGEMM_AVX512VNNI static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { assert(width % sizeof(Register) == 0); - assert(B_cols % 8 == 0); + std::div_t results = std::div(B_cols, 8); + Index B_cols_trimmed = B_cols; + if (results.rem != 0) { + B_cols_trimmed = results.quot*8; + } + assert(B_cols_trimmed % 8 == 0); assert(reinterpret_cast(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl(callback); Index simd_width = width / sizeof(Register); Register zeros = setzero_si(); const Register a = set1_epi8(1); // Go over 8 columns of B at a time. + Index B0_colidx = 0; // OMP can't deal with this variable being asigned outside of the loop, hence we declare it once and asign to 0 twice #pragma omp for - for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + for (B0_colidx = 0; B0_colidx < B_cols_trimmed; B0_colidx += 8) { const Register *B0_col = reinterpret_cast(B) + B0_colidx * simd_width; const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function const Register *B_end = B_live + simd_width*8; @@ -155,6 +192,25 @@ struct Kernels8 : public AVX512BW::Kernels8 { auto total = PermuteSummer(pack0123, pack4567); callback_impl.Run(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols)); } + // Final bit, if we have a non-mult-of-eight matrix + if (results.rem != 0) { + const Register *B0_col = reinterpret_cast(B) + (B_cols_trimmed * width)/(sizeof(Register)); + const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function + const Register *B_end = B_live + simd_width*results.rem; + + // TODO: separate first step. + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + Register * sums[8] = {&sum0, &sum1, &sum2, &sum3, &sum4, &sum5, &sum6, &sum7}; + for (; B_live != B_end; B_live += results.rem) { + for (int i = 0; i < results.rem; i++) { + VNNI8(*sums[i], a,*(B_live + i)); + } + } + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); + auto total = PermuteSummer(pack0123, pack4567); + callback_impl.RunPartial(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols), (Index)results.rem); + } } constexpr static const char *const kName = "8-bit AVX512VNNI"; diff --git a/intgemm/callbacks/implementations.inl b/intgemm/callbacks/implementations.inl index 126701d..45e263d 100644 --- a/intgemm/callbacks/implementations.inl +++ b/intgemm/callbacks/implementations.inl @@ -147,6 +147,18 @@ public: kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); } + INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::unquantize(input, mult_reg); + kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial); + } + private: vf unquant_mult; UnquantizeAndWrite config; @@ -172,6 +184,17 @@ public: auto result = kernels::relu(kernels::unquantize(input, mult_reg)); kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); } + INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::relu(kernels::unquantize(input, mult_reg)); + kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial); + } private: vf unquant_mult; @@ -191,6 +214,11 @@ public: kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); } + INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) { + auto result = kernels::add_bias_partial(input, config.bias_addr, info.col_idx, partial); + kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial); + } + private: AddBiasAndWrite config; }; @@ -216,6 +244,18 @@ public: result = kernels::add_bias(result, config.bias_addr, info.col_idx); kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); } + INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::unquantize(input, mult_reg); + result = kernels::add_bias_partial(result, config.bias_addr, info.col_idx, partial); + kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial); + } private: vf unquant_mult; UnquantizeAndAddBiasAndWrite config; @@ -243,6 +283,19 @@ public: result = kernels::relu(result); kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); } + INTGEMM_TARGET void RunPartial(vi input, const OutputBufferInfo& info, Index partial) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::unquantize(input, mult_reg); + result = kernels::add_bias_partial(result, config.bias_addr, info.col_idx, partial); + result = kernels::relu(result); + kernels::write_partial(result, config.output_addr, info.row_idx * info.cols + info.col_idx, partial); + } private: vf unquant_mult; UnquantizeAndAddBiasAndWriteRelu config; diff --git a/intgemm/interleave.h b/intgemm/interleave.h index 95f05ce..ca1a25e 100644 --- a/intgemm/interleave.h +++ b/intgemm/interleave.h @@ -1,11 +1,15 @@ #pragma once #include "intgemm/intgemm_config.h" +#include "aligned.h" #include "intrinsics.h" #include "types.h" #include #include +#include +#include +#include namespace intgemm { @@ -150,6 +154,56 @@ template static inline void Transpose8InLane( r11 = tmp; } +/** + * @brief Pads the last bit of a row major matrix so that it is a multiple of 8 + * After preparing, we can discard the final bits that are written by writing + * into the output matrix at the approriate index. Memory management is done by the + * aligned vector class. + * @param in matrix ptr + * @param rows + * @param cols + * @return AlignedVector Returns a float * containing rows x 8 padded matrix + */ +static inline AlignedVector padMatrix(const float * in, Index rows, Index cols) { + std::div_t results = std::div(cols, 8); + + // Create a padded "right" small matrix that will contain the extra part + // that is non-multiple of 8. It is basically a rows*8 matrix + AlignedVector padded_matrix(8*rows); + for (unsigned int i = 0; i< 8*rows; i++) { + padded_matrix[i] = std::nanf("1"); + } + + // Copy the remainder of the big matrix onto the new small matrix + for (Index i = 0; i < rows; i++) { + for (int j = 0; j Returns a float * containing a shrinked version of the original matrix. + */ +static inline AlignedVector shrinkMat(const float * in, Index rows, Index cols) { + std::div_t results = std::div(cols, 8); + Index stride = results.quot*8; + AlignedVector shrunk_matrix(rows*stride); + Index consecutive = 0; + for (Index i = 0; i < rows; i++) { + for (Index j = 0; j static inline void Transpose8InLane( #define INTGEMM_PREPARE_B_8(target, QuantClass) \ target static inline void PrepareB(const float *input, int8_t *output_shadow, float quant_mult, Index rows, Index cols) { \ FRegister q = set1_ps(quant_mult); \ + /* Check if have padding to do */ \ + std::div_t results = std::div(cols, 8); \ + AlignedVector padded_matrix; \ + AlignedVector shrunk_matrix; \ + if (results.rem != 0) { \ + padded_matrix = padMatrix(input, rows, cols); \ + shrunk_matrix = shrinkMat(input, rows, cols); \ + input = shrunk_matrix.begin(); \ + cols = results.quot*8; \ + } \ /* Currently all multipliers have a stride of 8 columns.*/ \ const Index kColStride = 8; \ assert(cols % kColStride == 0); \ @@ -215,6 +279,22 @@ target static inline void PrepareB(const float *input, int8_t *output_shadow, fl Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); \ } \ } \ + if (results.rem != 0) { \ + /*Prepare the remaider matrix*/ \ + AlignedVector padded_matrix_int8(8*rows); \ + PrepareB(padded_matrix.begin(), padded_matrix_int8.begin(), quant_mult, rows, 8); \ + /* Copy non-NAN at the back of the current matrix \ + * That means we write every i%8 < rem width elements and skip the rest. \ + For example we have 3 extra columns, we write 3 widths, skip 5 widths, write 3 widths*/ \ + Index consecutive = rows*cols; \ + for (unsigned int i = 0; i < 8*rows; i++) { \ + int consecutive_width = std::div((int)i, (int)sizeof(Register)).quot; \ + if (consecutive_width%8 < results.rem) { \ + output_shadow[consecutive] = padded_matrix_int8[i]; \ + consecutive++; \ + } \ + } \ + } \ } \ #define INTGEMM_PREPARE_B_16(target, QuantClass) \ diff --git a/intgemm/kernels/implementations.inl b/intgemm/kernels/implementations.inl index 4f1b39f..5d641e9 100644 --- a/intgemm/kernels/implementations.inl +++ b/intgemm/kernels/implementations.inl @@ -24,26 +24,64 @@ namespace intgemm { namespace kernels { /* - * Write + * Write. Potentially unaligned memory, hence use storeu */ CPU_ATTR static inline void write(vi input, int8_t* output, Index offset) { - *reinterpret_cast(output + offset) = input; + storeu_ps(reinterpret_cast(output + offset), *reinterpret_cast(&input)); + //*reinterpret_cast(output + offset) = input; } CPU_ATTR static inline void write(vi input, int16_t* output, Index offset) { - *reinterpret_cast(output + offset) = input; + storeu_ps(reinterpret_cast(output + offset), *reinterpret_cast(&input)); + //*reinterpret_cast(output + offset) = input; } CPU_ATTR static inline void write(vi input, int* output, Index offset) { - *reinterpret_cast(output + offset) = input; + storeu_ps(reinterpret_cast(output + offset), *reinterpret_cast(&input)); + //*reinterpret_cast(output + offset) = input; } CPU_ATTR static inline void write(vf input, float* output, Index offset) { - *reinterpret_cast(output + offset) = input; + storeu_ps(reinterpret_cast(output + offset), input); + //*reinterpret_cast(output + offset) = input; } CPU_ATTR static inline void write(vd input, double* output, Index offset) { - *reinterpret_cast(output + offset) = input; + storeu_ps(reinterpret_cast(output + offset), *reinterpret_cast(&input)); + //*reinterpret_cast(output + offset) = input; +} + +/* + * Non-Vector write + */ +CPU_ATTR static inline void write_partial(vi input, int8_t* output, Index offset, Index partial) { + for (Index i = 0; i < partial; i++) { + *(output + offset + i) = reinterpret_cast(&input)[i]; + } +} + +CPU_ATTR static inline void write_partial(vi input, int16_t* output, Index offset, Index partial) { + for (Index i = 0; i < partial; i++) { + *(output + offset + i) = reinterpret_cast(&input)[i]; + } +} + +CPU_ATTR static inline void write_partial(vi input, int* output, Index offset, Index partial) { + for (Index i = 0; i < partial; i++) { + *(output + offset + i) = reinterpret_cast(&input)[i]; + } +} + +CPU_ATTR static inline void write_partial(vf input, float* output, Index offset, Index partial) { + for (Index i = 0; i < partial; i++) { + *(output + offset + i) = reinterpret_cast(&input)[i]; + } +} + +CPU_ATTR static inline void write_partial(vd input, double* output, Index offset, Index partial) { + for (Index i = 0; i < partial; i++) { + *(output + offset + i - 1) = reinterpret_cast(&input)[i]; + } } /* @@ -88,6 +126,50 @@ CPU_ATTR static inline vd add_bias(vd input, const double* bias_addr, Index bias return add_pd(input, bias_term); } +/* + * Non vector bias add + */ + +CPU_ATTR static inline vi add_bias_partial(vi input, const int8_t* bias_addr, Index bias_offset, Index partial) { + vi bias_term = set1_epi8(0); + for (Index i = 0; i(&bias_term)[i] = *(bias_addr + bias_offset + i); + } + return add_epi8(input, bias_term); +} + +CPU_ATTR static inline vi add_bias_partial(vi input, const int16_t* bias_addr, Index bias_offset, Index partial) { + vi bias_term = set1_epi16(0); + for (Index i = 0; i(&bias_term)[i] = *(bias_addr + bias_offset + i); + } + return add_epi16(input, bias_term); +} + +CPU_ATTR static inline vi add_bias_partial(vi input, const int* bias_addr, Index bias_offset, Index partial) { + vi bias_term = set1_epi32(0); + for (Index i = 0; i(&bias_term)[i] = *(bias_addr + bias_offset + i); + } + return add_epi32(input, bias_term); +} + +CPU_ATTR static inline vf add_bias_partial(vf input, const float* bias_addr, Index bias_offset, Index partial) { + vf bias_term = set1_ps(0); + for (Index i = 0; i(&bias_term)[i] = *(bias_addr + bias_offset + i); + } + return add_ps(input, bias_term); +} + +CPU_ATTR static inline vd add_bias_partial(vd input, const double* bias_addr, Index bias_offset, Index partial) { + vd bias_term = set1_pd(0); + for (Index i = 0; i(&bias_term)[i] = *(bias_addr + bias_offset + i); + } + return add_pd(input, bias_term); +} + /* * ReLU */ diff --git a/intgemm/stats.inl b/intgemm/stats.inl index 68a5b8e..90199be 100644 --- a/intgemm/stats.inl +++ b/intgemm/stats.inl @@ -11,6 +11,7 @@ #else #error Included with unexpected architecture #endif +#include namespace intgemm { namespace INTGEMM_ARCH { @@ -63,12 +64,17 @@ INTGEMM_TARGET static inline float MaxAbsolute(const float *begin_float, const f /* Computes the euclidean norm and returns the mean and the standard deviation. Optionally it can be the mean and standard deviation in absolute terms. */ INTGEMM_TARGET static inline MeanStd VectorMeanStd(const float *begin_float, const float *end_float, bool absolute) { assert(end_float > begin_float); - assert((end_float - begin_float) % (sizeof(FRegister) / sizeof(float)) == 0); + // Make sure we deal with any number of elements size_t num_items = end_float - begin_float; + const size_t constexpr width = sizeof(FRegister) / sizeof(float); + std::ldiv_t result = std::ldiv((long)num_items, (long)width); + const FRegister *begin = reinterpret_cast(begin_float); - const FRegister *end = reinterpret_cast(end_float); + const FRegister *end = reinterpret_cast(begin_float + result.quot*width); FRegister squares = set1_ps(0); FRegister sums = set1_ps(0); + float squares_sum = 0; + float normal_sums = 0; if (absolute) { const FRegister abs_mask = cast_ps(set1_epi32(kFloatAbsoluteMask)); for (; begin != end; begin++) { @@ -76,15 +82,25 @@ INTGEMM_TARGET static inline MeanStd VectorMeanStd(const float *begin_float, con squares = add_ps(squares, mul_ps(vec, vec)); sums = add_ps(sums, vec); } + for (long i = 0; i < result.rem; i++) { + size_t index = result.quot*width + i; + squares_sum += begin_float[index]*begin_float[index]; + normal_sums += std::fabs(begin_float[index]); + } } else { for (; begin != end; begin++) { FRegister vec = *begin; squares = add_ps(squares, mul_ps(vec, vec)); sums = add_ps(sums, vec); } + for (long i = 0; i < result.rem; i++) { + size_t index = result.quot*width + i; + squares_sum += begin_float[index]*begin_float[index]; + normal_sums += begin_float[index]; + } } - float squares_sum = AddFloat32(squares); - float normal_sums = AddFloat32(sums); + squares_sum += AddFloat32(squares); + normal_sums += AddFloat32(sums); MeanStd ret; ret.mean = normal_sums/num_items; ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean)); diff --git a/test/add127_test.cc b/test/add127_test.cc index c31732c..42de41d 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -478,9 +478,17 @@ TEST_CASE ("Multiply AVX512F 8bit Shift vs Int", "[Add127]") { #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs Int", "[Add127]") { if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyShiftInt(1, 64, 3, 0.0001f, 0.052f, 0.036f, 0.0001f); TestMultiplyShiftInt(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.0001f); + TestMultiplyShiftInt(1, 64, 9, 0.0001f, 0.05f, 0.03f, 0.0001f); + TestMultiplyShiftInt(1, 64, 13, 0.0001f, 0.05f, 0.03f, 0.0001f); + TestMultiplyShiftInt(1, 64, 83, 0.0001f, 0.077f, 0.032f, 0.0001f); + TestMultiplyShiftInt(1, 256, 270, 0.0001f, 0.61f, 0.17f, 0.0001f); + TestMultiplyShiftInt(8, 256, 233, 0.0001f, 0.23f, 0.06f, 0.0001f); TestMultiplyShiftInt(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.0001f); + TestMultiplyShiftInt(8, 256, 270, 0.0001f, 0.23f, 0.06f, 0.0001f); TestMultiplyShiftInt(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f); + TestMultiplyShiftInt(320, 256, 211, 0.0001f, 0.27f, 0.06f, 0.0001f); TestMultiplyShiftInt(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f); TestMultiplyShiftInt(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.0001f); TestMultiplyShiftInt(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f); diff --git a/test/quantize_test.cc b/test/quantize_test.cc index 622ff71..0f8b361 100644 --- a/test/quantize_test.cc +++ b/test/quantize_test.cc @@ -147,10 +147,14 @@ TEST_CASE("QuantizeStd SSSE3", "[VectorMeanStd]") { if (kCPU < CPUType::SSSE3) return; testVectorMeanStd(64); testVectorMeanStd(64, true); + testVectorMeanStd(133); + testVectorMeanStd(133, true); testVectorMeanStd(256); testVectorMeanStd(256, true); testVectorMeanStd(2048); testVectorMeanStd(2048, true); + testVectorMeanStd(2931); + testVectorMeanStd(2931, true); testVectorMeanStd(65536); testVectorMeanStd(65536, true); testVectorMeanStd(81920); @@ -164,10 +168,14 @@ TEST_CASE("QuantizeStd AVX2", "[VectorMeanStd]") { if (kCPU < CPUType::AVX2) return; testVectorMeanStd(64); testVectorMeanStd(64, true); + testVectorMeanStd(133); + testVectorMeanStd(133, true); testVectorMeanStd(256); testVectorMeanStd(256, true); testVectorMeanStd(2048); testVectorMeanStd(2048, true); + testVectorMeanStd(2931); + testVectorMeanStd(2931, true); testVectorMeanStd(65536); testVectorMeanStd(65536, true); testVectorMeanStd(81920); @@ -182,10 +190,14 @@ TEST_CASE("QuantizeStd AVX512BW", "[VectorMeanStd]") { if (kCPU < CPUType::AVX512BW) return; testVectorMeanStd(64); testVectorMeanStd(64, true); + testVectorMeanStd(133); + testVectorMeanStd(133, true); testVectorMeanStd(256); testVectorMeanStd(256, true); testVectorMeanStd(2048); testVectorMeanStd(2048, true); + testVectorMeanStd(2931); + testVectorMeanStd(2931, true); testVectorMeanStd(65536); testVectorMeanStd(65536, true); testVectorMeanStd(81920);