From 5484560d5fe44a058652ad3523ae252c9b58dc30 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Wed, 29 Jan 2025 19:11:36 -0800 Subject: [PATCH 01/33] init code structure for matmul 2 bits Signed-off-by: Liqun Fu --- cmake/onnxruntime_mlas.cmake | 2 + .../cpu/quantization/matmul_nbits.cc | 39 +- .../cpu/quantization/matmul_nbits_impl.cc | 35 +- .../cpu/quantization/matmul_nbits_impl.h | 2 +- onnxruntime/core/mlas/inc/mlas_q4.h | 2 +- onnxruntime/core/mlas/lib/q4_dq.cpp | 363 ++++++++++++------ onnxruntime/core/mlas/lib/qnbitgemm.cpp | 124 ++++-- onnxruntime/core/mlas/lib/qnbitgemm.h | 10 + .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 86 +++++ .../mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h | 52 +++ .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 10 + .../test/contrib_ops/matmul_4bits_test.cc | 261 +++++++------ .../test/mlas/bench/bench_qnbitgemm.cpp | 4 +- .../test/mlas/unittest/test_blockq4.cpp | 2 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 37 +- .../test/optimizer/graph_transform_test.cc | 2 +- 16 files changed, 720 insertions(+), 311 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index ed3ad89247975..90667e488ffe8 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -182,6 +182,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_bitnet_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp @@ -586,6 +587,7 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_bitnet_kernel_avx2.cpp ) if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) set(mlas_platform_srcs_avx2 diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index c3e43f897c509..f9e0795b6dbfe 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -111,8 +111,8 @@ class MatMulNBits final : public OpKernel { has_unquantized_zero_point_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; } - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(nbits_ == 2 || nbits_ == 4, + "Only 2 and 4b quantization is supported for MatMulNBits op, additional bits support is planned."); const Tensor* tensor_zero_point = nullptr; has_zp_input_ = info.TryGetConstantInput(InputIndex::zero_points, &tensor_zero_point); } @@ -436,17 +436,30 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_data, // quantization scales - static_cast(zero_points_data), // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); + // dequantize b, only 2 and 4b quantization is supported for now + if (this->nbits_ == 2) { + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else if (this->nbits_ == 4) { + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } } else { ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index 6a19a741c3028..dd3d1fd9ac2cc 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -16,6 +16,15 @@ namespace onnxruntime { namespace contrib { +template +void Dequantize2BitsKernelReOrder( + T* /*output*/, const uint8_t* /*quant_data*/, const T* /*scale_data*/, + const zeroT* /*zero_points*/, const int32_t* /*reorder_idx*/, int /*block_size*/, + int /*groups_per_threadblock*/, int /*total_groups*/, int /*out_rows*/, int /*out_cols*/, + int /*blockIdx_x*/, int /*threadIdx_x*/) { + assert(false); +} + template void Dequantize4BitsKernelReOrder( T* output, const uint8_t* quant_data, const T* scale_data, @@ -73,7 +82,7 @@ void Dequantize4BitsKernelReOrder( } } -template +template void DequantizeBlockwise( inputT* output, // dequantized output const uint8_t* quant_data, // quantized input @@ -95,24 +104,36 @@ void DequantizeBlockwise( pool, static_cast(blocks_per_grid), [&](std::ptrdiff_t block_id) { for (int j = 0; j < 256; j++) { - Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, - reorder_idx, block_size, groups_per_threadblock, - total_groups, N, K, static_cast(block_id), j); + if constexpr (qbits == 2) { + Dequantize2BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } else { + Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } } }); } -template void DequantizeBlockwise( +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + + +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const float* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h index 5061ac5c800a6..b875048cbc585 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -6,7 +6,7 @@ namespace onnxruntime { namespace contrib { -template +template void DequantizeBlockwise( inputT* output, // dequantized output const uint8_t* quant_data, // quantized input diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index aec14070ffd55..80db68750799b 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -277,9 +277,9 @@ MlasBlockwiseQuantizedShape( * * If the qbits or block_size values are unsupported the output sizes will be zero. */ +template void MLASCALL MlasBlockwiseQuantizedBufferSizes( - int qbits, int block_size, bool columnwise, int rows, diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 015d69de68766..acc3cdd651751 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -402,7 +402,8 @@ template < struct BlockwiseQuantizer { // To support other qbits, need to add bit packing code for // storing to dst and zero points - static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(qbits == 4 || qbits == 2, "Only 4b block quantization is supported!"); + //static_assert(qbits != 2 || Columnwise, "Only support Columnwise in qbits == 2 case."); using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; @@ -480,7 +481,7 @@ struct BlockwiseQuantizer { thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)(BitsTraits::kMid)); const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); @@ -521,40 +522,68 @@ struct BlockwiseQuantizer { } } - // !! 4b specific code as we need to pack 2 4b numbers into one byte + // !! qbits specific code as we need to pack 2 4b numbers into one byte if (zero_points != nullptr) { - const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; + const int32_t meta_idx = meta_col * ((row_blks + 1) / BitsTraits::kPackSize) + meta_row / BitsTraits::kPackSize; + if constexpr (qbits == 4) { zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } else if constexpr (qbits == 2) { + zero_points[meta_idx] = (zp_bytes[0] & 0x3) | ((zp_bytes[1] & 0x3) << 2) | + ((zp_bytes[2] & 0x3) << 4) | ((zp_bytes[3] & 0x3) << 6); + } else { + static_assert(false && "only support qbits of 4 and 2"); + } } - for (int32_t j = c; j < c_end; ++j) { + for (int32_t j = c; j < c_end; ++j) { // this does not work if j runs more then 1 because zp_bytes is indexed by i. const int32_t meta_c = j / QuantBlk::kColumn; - for (int32_t i = r; i < r_end; i += 2) { + for (int32_t i = r; i < r_end; i += BitsTraits::kPackSize) { const int32_t meta_r = i / QuantBlk::kRow; const float scale = static_cast(scales[meta_c * row_blks + meta_r]); const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - const int8_t zp = zp_bytes[meta_r & 1]; - const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; - - const float v0 = static_cast(src[i * leadingDimension + j]); - const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), - 0.0f, BitsTraits::kMaxFp); - - uint8_t vi1 = (uint8_t)zp; - if (i + 1 < r_end) { - float reciprocal_scale1 = reciprocal_scale; - if constexpr (QuantBlk::kRow == 1) { - const float scale1 = - static_cast(scales[meta_c * row_blks + meta_r + 1]); - reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + if constexpr (qbits == 4) { + const int8_t zp = zp_bytes[meta_r & 1]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), 0.0f, BitsTraits::kMaxFp); + + uint8_t vi1 = (uint8_t)zp1; + if (i + 1 < r_end) { + float reciprocal_scale1 = reciprocal_scale; + if constexpr (QuantBlk::kRow == 1) { + const float scale1 = + static_cast(scales[meta_c * row_blks + meta_r + 1]); + reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + } + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, + BitsTraits::kMaxFp); + } + dst[j * q_rows + i / BitsTraits::kPackSize] = (vi0 & 0xf) | (vi1 << 4); + } else { + const int8_t zp0 = zp_bytes[(i / QuantBlk::kRow) & 3]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 3]; + const int8_t zp2 = zp_bytes[((i + 2) / QuantBlk::kRow) & 3]; + const int8_t zp3 = zp_bytes[((i + 3) / QuantBlk::kRow) & 3]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp0), 0.0f, BitsTraits::kMaxFp); + uint8_t vi1 = 0, vi2 = 0, vi3 = 0; + if (i + 1 < r_end) { + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale + zp1), 0.0f, BitsTraits::kMaxFp); + } + if (i + 2 < r_end) { + const float v2 = static_cast(src[(i + 2) * leadingDimension + j]); + vi2 = (uint8_t)std::clamp(roundf(v2 * reciprocal_scale + zp2), 0.0f, BitsTraits::kMaxFp); } - const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); - vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, - BitsTraits::kMaxFp); + if (i + 3 < r_end) { + const float v3 = static_cast(src[(i + 3) * leadingDimension + j]); + vi3 = (uint8_t)std::clamp(roundf(v3 * reciprocal_scale + zp3), 0.0f, BitsTraits::kMaxFp); + } + dst[j * q_rows + i / BitsTraits::kPackSize] = (vi0 & 0x03) | ((vi1 & 0x03) << 2) | ((vi2 & 0x03) << 4) | ((vi3 & 0x03) << 6); } - - // !! 4b specific code - dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); } } }); @@ -587,6 +616,8 @@ struct BlockwiseQuantizer { const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + constexpr int pack_size = BitsTraits::kPackSize; + int q_rows, q_cols; quantizedShape(rows, columns, q_rows, q_cols); @@ -605,37 +636,78 @@ struct BlockwiseQuantizer { for (int32_t j = c; j < c_end; ++j) { const int32_t meta_col = j / QuantBlk::kColumn; - // !! 4b specific code + // !! 2 and 4b specific code // the whole loop is 4b specific due to sub 8 bit packing // and unpacking. We can potentially make this qbits generic // by wraping the packing/unpacking code like cutlass::Array - for (int32_t i = r; i < r_end; i += 2) { + for (int32_t i = r; i < r_end; i += pack_size) { const int32_t meta_row = i / QuantBlk::kRow; const float scale0 = static_cast(scales[meta_col * row_blks + meta_row]); - const int zp_pair = - (zero_points == nullptr) - ? 0x88 - : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; - const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); - - const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; - const float v0 = (static_cast(vi0) - zp0) * scale0; - - dst[j * rows + i] = static_cast(v0); - if ((i + 1) < r_end) { - float scale1 = scale0; - int zp1 = zp0; - if constexpr (QuantBlk::kRow == 1) { - scale1 = - static_cast(scales[meta_col * row_blks + meta_row + 1]); - zp1 = (zp_pair >> 4) & 0xf; + if constexpr (qbits == 4) { + const int zp_pair = + (zero_points == nullptr) + ? 0x88 + : zero_points[meta_col * ((row_blks + 1) / pack_size) + meta_row / pack_size]; + const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; + const float v0 = (static_cast(vi0) - zp0) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + float scale1 = scale0; + int zp1 = zp0; + if constexpr (QuantBlk::kRow == 1) { + scale1 = + static_cast(scales[meta_col * row_blks + meta_row + 1]); + zp1 = (zp_pair >> 4) & 0xf; + } + const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; + const float v1 = (static_cast(vi1) - zp1) * scale1; + dst[j * rows + (i + 1)] = static_cast(v1); + } + } else { + const int zp_quad = zero_points[meta_col * ((row_blks + 3) / pack_size) + meta_row / pack_size]; + int zp = 0; + const int meta_row_mod = meta_row % 4; + switch (meta_row_mod) { + case 0: + zp = zp_quad & 0x3; + break; + case 1: + zp = (zp_quad >> 2) & 0x3; + break; + case 2: + zp = (zp_quad >> 4) & 0x3; + break; + case 3: + zp = (zp_quad >> 6) & 0x3; + break; + } + + const uint8_t& weight = weights[j * q_rows + i / pack_size]; + const uint8_t vi0 = weight & 0x3; + const float v0 = (static_cast(vi0) - zp) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + const uint8_t vi1 = (weight >> 2) & 0x3; + const float v1 = (static_cast(vi1) - zp) * scale0; + dst[j * rows + (i + 1)] = static_cast(v1); + } + if ((i + 2) < r_end) { + const uint8_t vi2 = (weight >> 4) & 0x3; + const float v2 = (static_cast(vi2) - zp) * scale0; + dst[j * rows + (i + 2)] = static_cast(v2); + } + if ((i + 3) < r_end) { + const uint8_t vi3 = (weight >> 6) & 0x3; + const float v3 = (static_cast(vi3) - zp) * scale0; + dst[j * rows + (i + 3)] = static_cast(v3); } - const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; - const float v1 = (static_cast(vi1) - zp1) * scale1; - dst[j * rows + (i + 1)] = static_cast(v1); } } } @@ -1450,8 +1522,17 @@ MlasBlockwiseQuantizedShape( int& q_cols ); -template -void +template void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols +); + +template void MlasBlockwiseQuantizedShape( int block_size, bool columnwise, @@ -1461,9 +1542,9 @@ MlasBlockwiseQuantizedShape( int& q_cols ); +template void MLASCALL MlasBlockwiseQuantizedBufferSizes( - int qbits, int block_size, bool columnwise, int rows, @@ -1478,72 +1559,70 @@ MlasBlockwiseQuantizedBufferSizes( *q_zero_point_size_in_bytes = 0; } - if (qbits == 4) { - switch (block_size) { - case 16: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 32: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 64: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 128: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 256: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; - default: - // Only block size 16, 32, 64, 128, 256 are supported. - break; - } + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; } } @@ -1620,8 +1699,29 @@ MlasQuantizeBlockwise( } } -template -void +template void MLASCALL +MlasBlockwiseQuantizedBufferSizes<2>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + +template void MLASCALL +MlasBlockwiseQuantizedBufferSizes<4>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + +template void MlasQuantizeBlockwise( uint8_t* dst, float* scales, @@ -1635,6 +1735,20 @@ MlasQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template void +MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool +); + template void MlasQuantizeBlockwise( @@ -1730,6 +1844,19 @@ MlasDequantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + template bool MlasQDQQuantizeBlockwise( diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index f064a8e1d6a78..7e7baf137f604 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -33,6 +33,7 @@ enum QNBitGemmVariant { HQNBitGemmVariant_BitWidth4_CompFp16, HQNBitGemmVariant_BitWidth4_CompInt8, + SQNBitGemmVariant_BitWidth2_CompInt8, // End of valid variants // Keep this element last and ensure that its value is the number of valid QNBitGemmVariant values. @@ -47,16 +48,24 @@ GetQNBitGemmVariant( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - if (BlkBitWidth == 4 && - (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == SQNBIT_CompFp32) { - return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == HQNBIT_CompFp16) { - return HQNBitGemmVariant_BitWidth4_CompFp16; - } else if (ComputeType == SQNBIT_CompInt8) { - return SQNBitGemmVariant_BitWidth4_CompInt8; - } else if (ComputeType == HQNBIT_CompInt8) { - return HQNBitGemmVariant_BitWidth4_CompInt8; + if (BlkBitWidth == 4) { + if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256) { + if (ComputeType == SQNBIT_CompFp32) { + return SQNBitGemmVariant_BitWidth4_CompFp32; + } else if (ComputeType == HQNBIT_CompFp16) { + return HQNBitGemmVariant_BitWidth4_CompFp16; + } else if (ComputeType == SQNBIT_CompInt8) { + return SQNBitGemmVariant_BitWidth4_CompInt8; + } else if (ComputeType == HQNBIT_CompInt8) { + return HQNBitGemmVariant_BitWidth4_CompInt8; + } + } + } else if (BlkBitWidth == 2) { + if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256) { + if (ComputeType == SQNBIT_CompInt8) + { + return SQNBitGemmVariant_BitWidth2_CompInt8; + } } } @@ -89,11 +98,14 @@ MlasIsQNBitGemmAvailable( Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr && Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr; } - case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 + case SQNBitGemmVariant_BitWidth4_CompInt8: { return (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); } + case SQNBitGemmVariant_BitWidth2_CompInt8: { + return (Dispatch->SQ2BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr); + } default: { return false; } @@ -120,14 +132,17 @@ QNBitGemmPerGemmWorkspaceSize( if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceSize != nullptr) { return Dispatch->Q4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); + } else if (BlkBitWidth == 2 && Dispatch->Q2BitGemmPerGemmWorkspaceSize != nullptr) { + return Dispatch->Q2BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); } + return 0; } size_t QNBitGemmPerGemmWorkspaceAlignment( - size_t BlkBitWidth, + size_t /*BlkBitWidth*/, size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) @@ -137,7 +152,8 @@ QNBitGemmPerGemmWorkspaceAlignment( return 1; } - if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { + // alignment is the same w.r.t. BlkBitWidth. + if (/*BlkBitWidth == 4 && */Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { return Dispatch->Q4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); } @@ -204,6 +220,12 @@ MlasQNBitGemmPackQuantBDataSize( ); } + if (BlkBitWidth == 2 && Dispatch->Q2BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q2BitGemmPackQuantBDataSize( + N, K, BlkLen, ComputeType + ); + } + return 0; } @@ -269,9 +291,9 @@ MlasQNBitGemmPackQuantBData( ThreadPool ); } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { - // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. - //assert(QuantBScale == nullptr); - //assert(QuantBZeroPoint == nullptr); + // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. + // assert(QuantBScale == nullptr); + // assert(QuantBZeroPoint == nullptr); Dispatch->SQ4BitGemmPackQuantBData( N, K, @@ -283,6 +305,19 @@ MlasQNBitGemmPackQuantBData( ); return; } + } else if (BlkBitWidth == 2) { + if (Dispatch->SQ2BitGemmPackQuantBData != nullptr) { + Dispatch->SQ2BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); + return; + } } } @@ -507,6 +542,20 @@ HQ4BitGemm_CompFp16( } } +void +SQ2BitGemm_CompInt8( + const size_t /*BlkLen*/, + const size_t /*K*/, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const /*DataParams*/, + void* const /*PerGemmWorkspace*/, + const size_t /*RangeStartM*/, + const size_t /*RangeCountM*/, + const size_t /*RangeStartN*/, + const size_t /*RangeCountN*/ +) +{ +} + void SQ4BitGemm_CompInt8( const size_t BlkLen, @@ -639,6 +688,7 @@ SQ4BitGemm_CompInt8( template void InitializeWorkspace_CompInt8( + size_t BlkBitWidth, size_t M, size_t N, size_t K, @@ -653,6 +703,7 @@ InitializeWorkspace_CompInt8( template <> void InitializeWorkspace_CompInt8( + size_t BlkBitWidth, size_t M, size_t N, size_t K, @@ -667,26 +718,14 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(N); const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; + // TODO: THIS is temporary: in case of BlkBitWidth == 2 we want to force use QuantizeARow even if + // QuantizeARowComputeBlkSum_CompInt8 is available. + const auto QuantizeARow2 = BlkBitWidth == 2 ? nullptr : GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - // TODO: try parallel on BatchN * M threads because BatchN is usually 1. - if (QuantizeARow) { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); - } else { + if (QuantizeARow2) { MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; const float* ARowPtr = data.A; @@ -704,12 +743,26 @@ InitializeWorkspace_CompInt8( QuantARowBlkSum += BlockCountK; } }); + } else { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); } } template <> void InitializeWorkspace_CompInt8( + size_t BlkBitWidth, size_t M, size_t N, size_t K, @@ -720,6 +773,7 @@ InitializeWorkspace_CompInt8( size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool ) { + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); MLAS_UNREFERENCED_PARAMETER(M); MLAS_UNREFERENCED_PARAMETER(N); MLAS_UNREFERENCED_PARAMETER(K); @@ -733,6 +787,7 @@ InitializeWorkspace_CompInt8( template using InitializeWorkspaceFn = std::function; default: return nullptr; @@ -797,6 +853,8 @@ GetQNBitGemm(QNBitGemmVariant variant) return SQ4BitGemm_CompFp32; case SQNBitGemmVariant_BitWidth4_CompInt8: return SQ4BitGemm_CompInt8; + case SQNBitGemmVariant_BitWidth2_CompInt8: + return SQ2BitGemm_CompInt8; default: return nullptr; } @@ -849,7 +907,7 @@ MlasQNBitGemmBatch( if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( - M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool + BlkBitWidth, M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool ); } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index eb3d0b44ae3de..c0dd11e2444ee 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -100,6 +100,13 @@ struct MLAS_QNBIT_GEMM_DISPATCH { Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; + // TODO: rename Q4BitGemmPackQuantBDataSize_Fn to QNBitGemmPackQuantBDataSize_Fn + // because its signature shall be the same regardness of bit width. + // or has bit width as an argument so we only need one function. + // this same applied to Q4BitGemmPackQuantBData_Fn, Q4BitGemmPerGemmWorkspaceSize_Fn, + // SQ2BitGemmKernel_CompInt8. + Q4BitGemmPackQuantBDataSize_Fn* Q2BitGemmPackQuantBDataSize = nullptr; + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ typedef void(Q4BitGemmPackQuantBData_Fn)( size_t N, @@ -113,6 +120,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* SQ2BitGemmPackQuantBData = nullptr; typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( size_t N, @@ -152,6 +160,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { ); Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; + Q4BitGemmPerGemmWorkspaceSize_Fn* Q2BitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -342,6 +351,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { ); SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; + SQ4BitGemmKernel_CompInt8_Fn* SQ2BitGemmKernel_CompInt8 = nullptr; /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp new file mode 100644 index 0000000000000..6c1a133609f69 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -0,0 +1,86 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_avx2.cpp.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx2. + +--*/ + +#include +#include +#include + +#include "qnbitgemm.h" + +size_t +Q2BitGemmPackQuantBDataSize( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +void SQ2BitGemmPackQuantBData( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* /*QuantBDataBegin*/, + std::byte* /*PackedQuantBDataBegin*/, + MLAS_THREADPOOL* /*ThreadPool*/ +) +{ +} + +size_t +Q2BitGemmPerGemmWorkspaceSize( + size_t /*M*/, + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +size_t +SQ2BitGemmKernel_CompInt8_avx2( + size_t /*BlkLen*/, + const std::byte* /*QuantA*/, + const std::byte* /*QuantBData*/, + const float* /*QuantBScale*/, + const std::byte* /*QuantBZeroPoint*/, + float* /*C*/, + size_t /*CountM*/, + size_t /*CountN*/, + size_t /*CountK*/, + size_t /*BlockCountK*/, + size_t /*ldc*/, + const float* /*Bias*/ +) +{ + return 0; +} + +void +QuantizeARow_CompInt8( + size_t /*BlkLen*/, + const float* /*A*/, + size_t /*CountK*/, + std::byte* /*QuantA*/ +) +{ +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h new file mode 100644 index 0000000000000..5e8aefb792265 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h @@ -0,0 +1,52 @@ +#pragma once +#include "qnbitgemm.h" + +size_t Q2BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +void +SQ2BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +); + +size_t +Q2BitGemmPerGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +size_t +SQ2BitGemmKernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias +); + +void QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 81615da46aa2e..fe9720fd7e383 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -29,6 +29,8 @@ Module Name: #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" +#include "sqnbitgemm_bitnet_kernel_avx2.h" + void MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) { @@ -1346,6 +1348,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 9bf08c6350833..d6940dc2cf367 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -32,8 +32,10 @@ namespace test { namespace { -constexpr int QBits = 4; +constexpr int Q2Bits = 2; +constexpr int Q4Bits = 4; +template void QuantizeDequantize(std::vector& raw_vals, std::vector& quant_vals, std::vector& scales, @@ -44,7 +46,7 @@ void QuantizeDequantize(std::vector& raw_vals, auto& ortenv = **ort_env.get(); onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); - MlasQuantizeBlockwise( + MlasQuantizeBlockwise( quant_vals.data(), scales.data(), zp != nullptr ? zp->data() : nullptr, @@ -57,7 +59,7 @@ void QuantizeDequantize(std::vector& raw_vals, tp); // Note that raw_vals is NxK after dequant - MlasDequantizeBlockwise( + MlasDequantizeBlockwise( raw_vals.data(), // dequantized output quant_vals.data(), // quantized input scales.data(), // quantization scales @@ -95,7 +97,7 @@ std::ostream& operator<<(std::ostream& os, const TestOptions& opts) { << ", has_bias:" << opts.has_bias; } -template +template void RunTest(const TestOptions& opts, std::vector>&& explicit_eps = {}) { SCOPED_TRACE(opts); @@ -121,12 +123,12 @@ void RunTest(const TestOptions& opts, #endif int q_rows, q_cols; - MlasBlockwiseQuantizedShape(static_cast(opts.block_size), /* columnwise */ true, + MlasBlockwiseQuantizedShape(static_cast(opts.block_size), /* columnwise */ true, static_cast(K), static_cast(N), q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(QBits, static_cast(opts.block_size), /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes(static_cast(opts.block_size), /* columnwise */ true, static_cast(K), static_cast(N), q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); @@ -134,7 +136,7 @@ void RunTest(const TestOptions& opts, std::vector scales(q_scale_size); std::vector zp(q_zp_size_in_bytes); - QuantizeDequantize(input1_f_vals, + QuantizeDequantize(input1_f_vals, input1_vals, scales, opts.has_zero_point ? &zp : nullptr, @@ -175,7 +177,7 @@ void RunTest(const TestOptions& opts, test.AddAttribute("K", K); test.AddAttribute("N", N); test.AddAttribute("block_size", opts.block_size); - test.AddAttribute("bits", QBits); + test.AddAttribute("bits", qbits); test.AddAttribute("accuracy_level", opts.accuracy_level); if constexpr (use_float16) { @@ -267,7 +269,7 @@ void RunTest(const TestOptions& opts, } // namespace -template +template void TestMatMulNBitsTyped() { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; @@ -282,24 +284,27 @@ void TestMatMulNBitsTyped() { base_opts.output_rel_error = 0.02f; } + if constexpr (qbits == 4) { TestOptions opts = base_opts; - RunTest(opts); + RunTest(opts); } { TestOptions opts = base_opts; opts.has_zero_point = true; - RunTest(opts); + RunTest(opts); } #if !defined(USE_DML) && !defined(USE_WEBGPU) + if constexpr (qbits == 4) { TestOptions opts = base_opts; opts.has_g_idx = true; - RunTest(opts); + RunTest(opts); } + if constexpr (qbits == 4) { TestOptions opts = base_opts; opts.has_g_idx = true; @@ -316,80 +321,84 @@ void TestMatMulNBitsTyped() { // only enabled for CPU EP for now std::vector> explicit_eps; explicit_eps.emplace_back(DefaultCpuExecutionProvider()); - RunTest(opts, std::move(explicit_eps)); + RunTest(opts, std::move(explicit_eps)); } { TestOptions opts = base_opts; opts.has_zero_point = true, opts.zp_is_4bit = false; - RunTest(opts); + RunTest(opts); } #endif // !defined(USE_DML) && !defined(USE_WEBGPU) } TEST(MatMulNBits, Float32_Accuracy0) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float32_Accuracy1) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float32_Accuracy4) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); +} + +TEST(MatMulNBits, DISABLED_Float32_Accuracy4_Q2) { + TestMatMulNBitsTyped(); } #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64) @@ -397,68 +406,68 @@ TEST(MatMulNBits, Float32_Accuracy4) { // Actual and expected difference is over 0.01 with DmlExecutionProvider. // Skip the tests instead of raising the tolerance to make is pass. TEST(MatMulNBits, Float16_Accuracy2) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float16_Accuracy0) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float16_Accuracy4) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } #endif #endif diff --git a/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp index 64d229889214b..a511664407af0 100644 --- a/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp @@ -31,8 +31,8 @@ void RunQNBitGemmBenchmark(size_t BlkLen, } size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes( - BlkBitWidth, static_cast(BlkLen), /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes( + static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index f75002f715154..11e5cec1f1e69 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -53,7 +53,7 @@ class MlasBlockwiseQdqTest : public MlasTestBase { MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise, rows, columns, + MlasBlockwiseQuantizedBufferSizes<4>(block_size, columnwise, rows, columns, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index e22018ae2877f..365137d466256 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -142,20 +142,37 @@ class MlasSQNBitGemmTest : public MlasTestBase { const float b_scale = QuantBScale[n * BlockCountK + k_blk]; - static_assert(BlkBitWidth == 4, "only implemented for 4-bit quantized B"); + uint8_t b_zp = 0; + if constexpr (BlkBitWidth == 4) { + b_zp = 8; + } else if constexpr (BlkBitWidth == 2) { + assert(QuantBZeroPoint && "zero point input is needed for BlkBitWidth == 2"); + } else { + static_assert(false, "only implemented for 2- and 4-bit quantized B"); + } - uint8_t b_zp = 8; + int pack_size = 8 / BlkBitWidth; if (QuantBZeroPoint != nullptr) { - const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / 2) + k_blk / 2]; - b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F); + const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / pack_size) + k_blk / pack_size]; + if constexpr (BlkBitWidth == 4) { + b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F); + } else if constexpr (BlkBitWidth == 2) { + int shift = (k_blk & 3) * 2; + b_zp = (b_zp_byte >> shift) & 0x03; + } } int32_t qsum = 0; for (size_t kk = 0; kk < k_blk_len; ++kk) { const int8_t qa = QuantAData[m * BlockCountK * BlkLen + k + kk]; - const uint8_t qb_byte = QuantBData[(n * BlockCountK * BlkLen + k + kk) / 2]; - const int8_t qb = ((kk & 1) == 1 ? (qb_byte >> 4) : (qb_byte & 0x0F)) - b_zp; + const uint8_t qb_byte = QuantBData[(n * BlockCountK * BlkLen + k + kk) / pack_size]; + int8_t qb = 0; + if constexpr (BlkBitWidth == 4) { + qb = ((kk & 1) == 1 ? (qb_byte >> 4) : (qb_byte & 0x0F)) - b_zp; + } else if constexpr (BlkBitWidth == 2) { + qb = ((qb_byte >> ((kk & 3) * 2)) & 0x03) - b_zp; + } qsum += qa * qb; } @@ -246,7 +263,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { uint8_t* QuantBZeroPoint = nullptr; { size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes(BlkBitWidth, BlkLen, /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes(BlkLen, /* columnwise */ true, static_cast(K), static_cast(N), QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); @@ -422,13 +439,17 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 64>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 128>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 256>::RegisterShortExecuteTests(); - return count; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e069f6ef2432a..f097521ddc21a 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -7930,7 +7930,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsBiasFusion) { q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(qbits, block_size, /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes(block_size, /* columnwise */ true, K, N, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); From 8c1cfe11d3cc150db5427242ae6c27a1e5748cd4 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 30 Jan 2025 16:36:41 -0800 Subject: [PATCH 02/33] add and pass q4dq tests for q2bit - rename file and test name later Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/q4_dq.cpp | 99 ++++- .../test/mlas/unittest/test_blockq4.cpp | 387 +++++++++++++----- 2 files changed, 361 insertions(+), 125 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index acc3cdd651751..39d921a76fac4 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -481,7 +481,11 @@ struct BlockwiseQuantizer { thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)(BitsTraits::kMid)); + if constexpr (qbits == 2) + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)2); + if constexpr (qbits == 4) + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); @@ -524,14 +528,13 @@ struct BlockwiseQuantizer { // !! qbits specific code as we need to pack 2 4b numbers into one byte if (zero_points != nullptr) { - const int32_t meta_idx = meta_col * ((row_blks + 1) / BitsTraits::kPackSize) + meta_row / BitsTraits::kPackSize; if constexpr (qbits == 4) { + const int32_t meta_idx = meta_col * ((row_blks + 1) / BitsTraits::kPackSize) + meta_row / BitsTraits::kPackSize; zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); } else if constexpr (qbits == 2) { + const int32_t meta_idx = meta_col * ((row_blks + 3) / BitsTraits::kPackSize) + meta_row / BitsTraits::kPackSize; zero_points[meta_idx] = (zp_bytes[0] & 0x3) | ((zp_bytes[1] & 0x3) << 2) | ((zp_bytes[2] & 0x3) << 4) | ((zp_bytes[3] & 0x3) << 6); - } else { - static_assert(false && "only support qbits of 4 and 2"); } } @@ -670,7 +673,8 @@ struct BlockwiseQuantizer { dst[j * rows + (i + 1)] = static_cast(v1); } } else { - const int zp_quad = zero_points[meta_col * ((row_blks + 3) / pack_size) + meta_row / pack_size]; + const int zp_quad = (zero_points == nullptr) ? + 0xAA : zero_points[meta_col * ((row_blks + 3) / pack_size) + meta_row / pack_size]; int zp = 0; const int meta_row_mod = meta_row % 4; switch (meta_row_mod) { @@ -730,19 +734,35 @@ struct BlockwiseQuantizer { * @tparam signed_quant quantized type is signed */ template -struct BlockwiseQDQQuantizer; - -template -struct BlockwiseQDQQuantizer { +struct BlockwiseQDQQuantizer { static MLAS_FORCEINLINE uint8_t GetElem(uint8_t val, int32_t idx) { - return (val >> (idx << 2)) & 0xF; + if constexpr (qbits == 2) { + return (val >> (idx << 1)) & 0x3; + } else if constexpr (qbits == 4) { + return (val >> (idx << 2)) & 0xF; + } } static MLAS_FORCEINLINE uint8_t SetElem(uint8_t val, int32_t idx, uint8_t dst) { - auto shift = idx << 2; - return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + if constexpr (qbits == 2) { + auto shift = idx << 1; + return ((val & 0x3) << shift) | (dst & (~(0x3 << shift))); + } else if constexpr (qbits == 4) { + auto shift = idx << 2; + return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + } + } + + template + static MLAS_FORCEINLINE uint8_t Pack(uint8_t v0, uint8_t v1, uint8_t v2, uint8_t v3) + { + if constexpr (add2) { + return ((v0 & 0x3) ^ 2) | (((v1 & 0x3) ^ 2) << 2) | (((v2 & 0x3) ^ 2) << 4) | (((v3 & 0x3) ^ 2) << 6); + } else { + return (v0 & 0x3) | ((v1 & 0x3) << 2) | ((v2 & 0x3) << 4) | ((v3 & 0x3) << 6); + } } template @@ -1491,7 +1511,7 @@ MlasBlockwiseQuantizedShape( template void -MlasBlockwiseQuantMetaShape( +MlasBlockwiseQuantMetaShape( int block_size, bool columnwise, int rows, @@ -1500,6 +1520,16 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); +template void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols +); + template void MlasBlockwiseQuantMetaShape( @@ -1901,6 +1931,19 @@ MlasQDQQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template bool +MlasQDQQuantizeBlockwise( + const float* src, + float* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template bool MlasQDQQuantizeBlockwise( const MLAS_FP16* src, @@ -1940,6 +1983,36 @@ MlasQDQTransposeBlockwiseQuantized( } } +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template void MlasQDQTransposeBlockwiseQuantized( const uint8_t* src_weights, diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index 11e5cec1f1e69..fbe9e8b5f0d98 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -19,6 +19,9 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" +constexpr int Q2Bits = 2; +constexpr int Q4Bits = 4; + class MlasBlockwiseQdqTest : public MlasTestBase { private: MatrixGuardBuffer FpBuf; @@ -36,6 +39,7 @@ class MlasBlockwiseQdqTest : public MlasTestBase { MatrixGuardBuffer QDQTransposedOutputScales; MatrixGuardBuffer QDQTransposedOutputOffsets; + template void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { float* dequant_buf = FpBuf.GetBuffer(rows * columns, true); float* transposed = FpBuf2.GetBuffer(rows * columns, true); @@ -46,41 +50,79 @@ class MlasBlockwiseQdqTest : public MlasTestBase { int meta_rows; int meta_cols; - MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); int q_rows; int q_cols; - MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes<4>(block_size, columnwise, rows, columns, + MlasBlockwiseQuantizedBufferSizes(block_size, columnwise, rows, columns, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); uint8_t* qdq_weights = QDQOutputElements.GetBuffer((rows * columns + 1) / 2, true); uint8_t* qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); - int v = 7; - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; + int pack_size = 8 / qbits; + int v; + if constexpr (qbits == 2) { + v = 1; + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += pack_size) { + int idx = c * q_rows + r / pack_size; + uint8_t v0 = static_cast(v); + v = (v + 1) % 4; + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 1) % 4; + if (v == 3) { + v = (v + 1) % 4; + } + } + uint8_t v2 = 0; + if (r + 2 < rows) { + v2 = static_cast(v); + v = (v + 1) % 4; + if (v == 3) { + v = (v + 1) % 4; + } + } + uint8_t v3 = 0; + if (r + 3 < rows) { + v3 = static_cast(v); + v = (v + 1) % 4; + if (v == 3) { + v = (v + 1) % 4; + } + } + elements[idx] = (v3 << 6) | (v2 << 4) | (v1 << 2) | v0; } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); + } + } else if constexpr(qbits == 4) { + v = 7; + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += 2) { + int idx = c * q_rows + r / 2; + uint8_t v0 = static_cast(v); v = (v + 5) % 16; if (v == 11 || v == 7 || v == 3) { // making the cycle 13 instead of 16, avoiding same values in a row v = (v + 5) % 16; } - } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } - elements[idx] = (v1 << 4) | v0; + elements[idx] = (v1 << 4) | v0; + } } } @@ -91,30 +133,57 @@ class MlasBlockwiseQdqTest : public MlasTestBase { uint8_t* qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); uint8_t* qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); if (zp) { - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; + if constexpr (qbits == 2) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += pack_size) { + int idx = c * ((meta_rows + 3) / pack_size) + r / pack_size; + uint8_t v0 = static_cast(v); + v = (v + 1) % 4; + uint8_t v1 = 0; + if (r + 1 < meta_rows) { + v1 = static_cast(v); + v = (v + 1) % 4; + } + uint8_t v2 = 0; + if (r + 2 < meta_rows) { + v2 = static_cast(v); + v = (v + 1) % 4; + } + uint8_t v3 = 0; + if (r + 3 < meta_rows) { + v3 = static_cast(v); + v = (v + 1) % 4; + } + zp[idx] = (v3 << 6) | (v2 << 4) | (v1 << 2) | v0; } - uint8_t v1 = 0; - if (r + 1 < meta_rows) { - v1 = static_cast(v); + } + } + else if constexpr (qbits == 4) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += 2) { + int idx = c * ((meta_rows + 1) / 2) + r / 2; + uint8_t v0 = static_cast(v); v = (v + 5) % 16; if (v == 11 || v == 7 || v == 3) { // making the cycle 13 instead of 16, avoiding same values in a row v = (v + 5) % 16; } + uint8_t v1 = 0; + if (r + 1 < meta_rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + zp[idx] = (v1 << 4) | v0; } - zp[idx] = (v1 << 4) | v0; } } } - MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, + MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, columnwise, rows, columns, threadpool_ptr); MlasTranspose(dequant_buf, transposed, columns, rows); @@ -123,48 +192,79 @@ class MlasBlockwiseQdqTest : public MlasTestBase { float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); - MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, + MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, columnwise, rows, columns, columns, threadpool_ptr); - if (columnwise) { - bool signed_quant = MlasQDQQuantizeBlockwise( - transposed, qdq_scales, qdq_zp, qdq_weights, - true, rows, columns, block_size, threadpool_ptr); + if constexpr (qbits == 4) { + if (columnwise) { + bool signed_quant = MlasQDQQuantizeBlockwise( + transposed, qdq_scales, qdq_zp, qdq_weights, + true, rows, columns, block_size, threadpool_ptr); - ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; + ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; - if (symmetric) { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, - true, rows, columns, block_size, threadpool_ptr); + if (symmetric) { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); - } else { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, - true, rows, columns, block_size, threadpool_ptr); + } else { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + } } } - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) - << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] & 0xf, elements[idx] & 0xf) + if constexpr (qbits == 2) { + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += pack_size) { + int idx = c * q_rows + r / pack_size; + ASSERT_EQ(o_elements[idx] & 0x3, elements[idx] & 0x3) << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (r + 1 < rows) { + ASSERT_EQ((o_elements[idx] >> 2) & 0x3, (elements[idx] >> 2) & 0x3) + << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 2 < rows) { + ASSERT_EQ((o_elements[idx] >> 4) & 0x3, (elements[idx] >> 4) & 0x3) + << ", index=[" << r + 2 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 3 < rows) { + ASSERT_EQ((o_elements[idx] >> 6) & 0x3, (elements[idx] >> 6) & 0x3) + << ", index=[" << r + 3 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } - - if (r + 1 < rows) { - ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + } + } else if constexpr (qbits == 4) { + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += 2) { + int idx = c * q_rows + r / 2; + ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) + << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] >> 4, elements[idx] >> 4) + if constexpr (qbits == 4) { + if (columnwise) { + ASSERT_EQ(qdq_weights_T[idx] & 0xf, elements[idx] & 0xf) + << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } + if (r + 1 < rows) { + ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if constexpr (qbits == 4) { + if (columnwise) { + ASSERT_EQ(qdq_weights_T[idx] >> 4, elements[idx] >> 4) + << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } } } } @@ -177,34 +277,63 @@ class MlasBlockwiseQdqTest : public MlasTestBase { << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_scales_T[idx], scales[idx]) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if constexpr (qbits == 4) { + if (columnwise) { + ASSERT_EQ(qdq_scales_T[idx], scales[idx]) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } } } if (symmetric) return; - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] & 0xf, zp[idx] & 0xf) + + if constexpr (qbits == 2) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += pack_size) { + int idx = c * ((meta_rows + 3) / pack_size) + r / pack_size; + ASSERT_EQ(o_zp[idx] & 0x3, zp[idx] & 0x3) << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (r + 1 < meta_rows) { + ASSERT_EQ((o_zp[idx] >> 2) & 0x3, (zp[idx] >> 2) & 0x3) + << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 2 < meta_rows) { + ASSERT_EQ((o_zp[idx] >> 4) & 0x3, (zp[idx] >> 4) & 0x3) + << ", index=" << r + 2 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 3 < meta_rows) { + ASSERT_EQ((o_zp[idx] >> 6) & 0x3, (zp[idx] >> 6) & 0x3) + << ", index=" << r + 3 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } - if (r + 1 < meta_rows) { - ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + } + } else if constexpr (qbits == 4) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += 2) { + int idx = c * ((meta_rows + 1) / 2) + r / 2; + ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] >> 4, zp[idx] >> 4) + ASSERT_EQ(qdq_zp_T[idx] & 0xf, zp[idx] & 0xf) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 1 < meta_rows) { + ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (columnwise) { + ASSERT_EQ(qdq_zp_T[idx] >> 4, zp[idx] >> 4) + << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } } } @@ -217,44 +346,78 @@ class MlasBlockwiseQdqTest : public MlasTestBase { return suite_name.c_str(); } - void ExecuteShort(void) override { - Test(20, 1, 32, true, false); - Test(20, 1, 32, true, true); - Test(1, 20, 32, false, false); - Test(1, 20, 32, false, true); - Test(52, 1, 32, true, false); - Test(52, 1, 32, true, true); - Test(1, 52, 32, false, false); - Test(1, 52, 32, false, true); - Test(20, 3, 32, true, false); - Test(20, 3, 32, true, true); - Test(3, 20, 32, false, false); - Test(3, 20, 32, false, true); - Test(52, 3, 32, true, false); - Test(52, 3, 32, true, true); - Test(3, 52, 32, false, false); - Test(3, 52, 32, false, true); - Test(52, 3, 64, true, false); - Test(52, 3, 64, true, true); - Test(3, 52, 64, false, false); - Test(3, 52, 64, false, true); - Test(32 * 9 + 17, 41, 32, true, false); - Test(32 * 9 + 17, 41, 32, true, true); - Test(41, 32 * 9 + 17, 32, false, false); - Test(41, 32 * 9 + 17, 32, false, true); - Test(32 * 9 + 17, 41, 64, true, false); - Test(32 * 9 + 17, 41, 64, true, true); - Test(41, 32 * 9 + 17, 64, false, false); - Test(41, 32 * 9 + 17, 64, false, true); - Test(32 * 15 + 17, 63, 128, true, false); - Test(32 * 15 + 17, 63, 128, true, true); - Test(63, 32 * 15 + 17, 128, false, false); - Test(63, 32 * 15 + 17, 128, false, true); - - Test(256, 256, 32, true, false); - Test(256, 256, 32, true, true); - Test(256, 256, 32, false, false); - Test(256, 256, 32, false, true); + void ExecuteShort(void) { + // only support columnwise = true with qbits=2 + Test(20, 1, 32, true, false); + Test(20, 1, 32, true, true); + //Test(1, 20, 32, false, false); + //Test(1, 20, 32, false, true); + Test(52, 1, 32, true, false); + Test(52, 1, 32, true, true); + //Test(1, 52, 32, false, false); + //Test(1, 52, 32, false, true); + Test(20, 3, 32, true, false); + Test(20, 3, 32, true, true); + //Test(3, 20, 32, false, false); + //Test(3, 20, 32, false, true); + Test(52, 3, 32, true, false); + Test(52, 3, 32, true, true); + //Test(3, 52, 32, false, false); + //Test(3, 52, 32, false, true); + Test(52, 3, 64, true, false); + Test(52, 3, 64, true, true); + //Test(3, 52, 64, false, false); + //Test(3, 52, 64, false, true); + Test(32 * 9 + 17, 41, 32, true, false); + Test(32 * 9 + 17, 41, 32, true, true); + //Test(41, 32 * 9 + 17, 32, false, false); + //Test(41, 32 * 9 + 17, 32, false, true); + Test(32 * 9 + 17, 41, 64, true, false); + Test(32 * 9 + 17, 41, 64, true, true); + //Test(41, 32 * 9 + 17, 64, false, false); + //Test(41, 32 * 9 + 17, 64, false, true); + Test(32 * 15 + 17, 63, 128, true, false); + Test(32 * 15 + 17, 63, 128, true, true); + //Test(63, 32 * 15 + 17, 128, false, false); + //Test(63, 32 * 15 + 17, 128, false, true); + + Test(20, 1, 32, true, false); + Test(20, 1, 32, true, true); + Test(1, 20, 32, false, false); + Test(1, 20, 32, false, true); + Test(52, 1, 32, true, false); + Test(52, 1, 32, true, true); + Test(1, 52, 32, false, false); + Test(1, 52, 32, false, true); + Test(20, 3, 32, true, false); + Test(20, 3, 32, true, true); + Test(3, 20, 32, false, false); + Test(3, 20, 32, false, true); + Test(52, 3, 32, true, false); + Test(52, 3, 32, true, true); + Test(3, 52, 32, false, false); + Test(3, 52, 32, false, true); + Test(52, 3, 64, true, false); + Test(52, 3, 64, true, true); + Test(3, 52, 64, false, false); + Test(3, 52, 64, false, true); + Test(32 * 9 + 17, 41, 32, true, false); + Test(32 * 9 + 17, 41, 32, true, true); + Test(41, 32 * 9 + 17, 32, false, false); + Test(41, 32 * 9 + 17, 32, false, true); + Test(32 * 9 + 17, 41, 64, true, false); + Test(32 * 9 + 17, 41, 64, true, true); + Test(41, 32 * 9 + 17, 64, false, false); + Test(41, 32 * 9 + 17, 64, false, true); + Test(32 * 15 + 17, 63, 128, true, false); + Test(32 * 15 + 17, 63, 128, true, true); + Test(63, 32 * 15 + 17, 128, false, false); + Test(63, 32 * 15 + 17, 128, false, true); + + Test(256, 256, 32, true, false); + Test(256, 256, 32, true, true); + Test(256, 256, 32, false, false); + Test(256, 256, 32, false, true); } MlasBlockwiseQdqTest() = default; From f6f22e30d5e777ccc196957e8870a64de9f476ec Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 30 Jan 2025 22:41:16 -0800 Subject: [PATCH 03/33] some fixes Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 5 +- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 50 +++++++++++++------ .../test/contrib_ops/matmul_4bits_test.cc | 11 ++-- .../test/mlas/unittest/test_sqnbitgemm.cpp | 12 +++-- 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 7e7baf137f604..096c795b4e1c5 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -554,6 +554,7 @@ SQ2BitGemm_CompInt8( const size_t /*RangeCountN*/ ) { + // TODO: implement this to call 2bit t-mac kernel } void @@ -920,7 +921,7 @@ MlasQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + if (BlkBitWidth == 4 && ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; @@ -991,7 +992,7 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + if (BlkBitWidth == 4 && ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 6c1a133609f69..1d7a1ce73e1d9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -15,45 +15,63 @@ Module Name: --*/ -#include -#include -#include - #include "qnbitgemm.h" +#include "sqnbitgemm_q8_block.h" size_t Q2BitGemmPackQuantBDataSize( - size_t /*N*/, - size_t /*K*/, - size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - return 0; + // TODO: This code shall change according to T-Mac. + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 2; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; } void SQ2BitGemmPackQuantBData( size_t /*N*/, size_t /*K*/, size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/, const std::byte* /*QuantBDataBegin*/, std::byte* /*PackedQuantBDataBegin*/, MLAS_THREADPOOL* /*ThreadPool*/ ) { + // TODO: need implementation } size_t Q2BitGemmPerGemmWorkspaceSize( - size_t /*M*/, - size_t /*N*/, - size_t /*K*/, - size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - return 0; + MLAS_UNREFERENCED_PARAMETER(N); + + switch (ComputeType) { + case SQNBIT_CompInt8: { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + // QuantData + Scale + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; + } + default: { + return 0; + } + } } size_t diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index d6940dc2cf367..bfd682ae3918f 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -97,7 +97,7 @@ std::ostream& operator<<(std::ostream& os, const TestOptions& opts) { << ", has_bias:" << opts.has_bias; } -template +template void RunTest(const TestOptions& opts, std::vector>&& explicit_eps = {}) { SCOPED_TRACE(opts); @@ -284,8 +284,7 @@ void TestMatMulNBitsTyped() { base_opts.output_rel_error = 0.02f; } - if constexpr (qbits == 4) - { + if constexpr (qbits == 4) { TestOptions opts = base_opts; RunTest(opts); } @@ -297,15 +296,13 @@ void TestMatMulNBitsTyped() { } #if !defined(USE_DML) && !defined(USE_WEBGPU) - if constexpr (qbits == 4) - { + if constexpr (qbits == 4) { TestOptions opts = base_opts; opts.has_g_idx = true; RunTest(opts); } - if constexpr (qbits == 4) - { + if constexpr (qbits == 4) { TestOptions opts = base_opts; opts.has_g_idx = true; opts.has_bias = true; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 365137d466256..26f02466be450 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -146,17 +146,18 @@ class MlasSQNBitGemmTest : public MlasTestBase { if constexpr (BlkBitWidth == 4) { b_zp = 8; } else if constexpr (BlkBitWidth == 2) { - assert(QuantBZeroPoint && "zero point input is needed for BlkBitWidth == 2"); + b_zp = 2; } else { static_assert(false, "only implemented for 2- and 4-bit quantized B"); } int pack_size = 8 / BlkBitWidth; if (QuantBZeroPoint != nullptr) { - const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / pack_size) + k_blk / pack_size]; if constexpr (BlkBitWidth == 4) { + const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / pack_size) + k_blk / pack_size]; b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F); } else if constexpr (BlkBitWidth == 2) { + const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 3) / pack_size) + k_blk / pack_size]; int shift = (k_blk & 3) * 2; b_zp = (b_zp_byte >> shift) & 0x03; } @@ -396,6 +397,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); - //count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); From 3e1a951448fb37664a4f8d41e994b4142ea98978 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Mon, 3 Feb 2025 12:24:40 -0800 Subject: [PATCH 04/33] add apis to neon and other avxs Signed-off-by: Liqun Fu --- .../core/mlas/lib/qnbitgemm_kernel_neon.cpp | 62 +++++++++++++++++++ .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 2 + .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 9 +++ .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 10 +++ .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 9 +++ .../test/mlas/unittest/test_sqnbitgemm.cpp | 2 - 6 files changed, 92 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index d05de64e68ec8..b12e2358d77bd 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -167,6 +167,61 @@ Q4BitGemmPerGemmWorkspaceAlignment( } } +size_t +Q2BitGemmPackQuantBDataSize( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +void +SQ2BitGemmPackQuantBData( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* /*QuantBDataBegin*/, + std::byte* /*PackedQuantBDataBegin*/, + MLAS_THREADPOOL* /*ThreadPool*/ +) +{ +} + +size_t +Q2BitGemmPerGemmWorkspaceSize( + size_t /*M*/, + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +size_t +SQ2BitGemmKernel_CompInt8_avx2( + size_t /*BlkLen*/, + const std::byte* /*QuantA*/, + const std::byte* /*QuantBData*/, + const float* /*QuantBScale*/, + const std::byte* /*QuantBZeroPoint*/, + float* /*C*/, + size_t /*CountM*/, + size_t /*CountN*/, + size_t /*CountK*/, + size_t /*BlockCountK*/, + size_t /*ldc*/, + const float* /*Bias*/ +) +{ + return 0; +} + } // namespace } // namespace sqnbitgemm_neon @@ -197,5 +252,12 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 1d7a1ce73e1d9..d6d104967e3a7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -90,6 +90,7 @@ SQ2BitGemmKernel_CompInt8_avx2( const float* /*Bias*/ ) { + // reference SQ4BitGemmKernel_CompInt8_avx2 return 0; } @@ -101,4 +102,5 @@ QuantizeARow_CompInt8( std::byte* /*QuantA*/ ) { + // shall be similar to QuantizeARow_CompInt8_avx2 without blksum related code. } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index fe9720fd7e383..56c54cf9befb4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1375,5 +1375,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + // change funcions if implementation are different from avx2 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index b4e25d4e4040a..d07ba72d1ed8b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -32,6 +32,7 @@ Module Name: // #include "sqnbitgemm_kernel_avx_common_fp32.h" +#include "sqnbitgemm_bitnet_kernel_avx2.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32_avx512( @@ -368,5 +369,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; + // change funcions if implementation are different from avx2 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index a4468bb906bbc..83fba19c1702d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -27,6 +27,7 @@ Module Name: #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" #include "sqnbitgemm_kernel_avx512_int8_blklen128.h" +#include "sqnbitgemm_bitnet_kernel_avx2.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32( @@ -353,5 +354,13 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; + // change funcions if implementation are different from avx2 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 26f02466be450..d849118aae7ef 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -147,8 +147,6 @@ class MlasSQNBitGemmTest : public MlasTestBase { b_zp = 8; } else if constexpr (BlkBitWidth == 2) { b_zp = 2; - } else { - static_assert(false, "only implemented for 2- and 4-bit quantized B"); } int pack_size = 8 / BlkBitWidth; From 013006100158dd5ef8f0ec662716d67008c1ecf5 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Mon, 3 Feb 2025 12:50:04 -0800 Subject: [PATCH 05/33] fix neon build Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index b12e2358d77bd..6fcc530ff11a8 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -252,12 +252,12 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + d.Q2BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = sqnbitgemm_neon::SQ2BitGemmPackQuantBData; - d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + d.Q2BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q2BitGemmPerGemmWorkspaceSize; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + d.SQ2BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; return d; }(); From b4aad0134c3d1cb7f2e43e05fa299abfa14eb3c5 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Mon, 3 Feb 2025 14:05:51 -0800 Subject: [PATCH 06/33] disable 2bit test Signed-off-by: Liqun Fu --- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 1 + onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index bfd682ae3918f..468243791e25a 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -394,6 +394,7 @@ TEST(MatMulNBits, Float32_Accuracy4) { TestMatMulNBitsTyped(); } +// TODO: enable and add more tests for 2bit development. TEST(MatMulNBits, DISABLED_Float32_Accuracy4_Q2) { TestMatMulNBitsTyped(); } diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index d849118aae7ef..fee0eacc246dd 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -443,8 +443,9 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); From ff531cbeec0135a201e9cc4de1c0e60231e185b4 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Fri, 7 Mar 2025 14:03:57 -0800 Subject: [PATCH 07/33] 2 bit quantize to support model builder Signed-off-by: Liqun Fu --- .../python/onnxruntime_pybind_quant.cc | 28 +++- .../quantization/matmul_4bits_quantizer.py | 16 +- .../python/tools/quantization/quantize.py | 4 +- .../models/llama/convert_to_onnx.py | 3 + .../models/phi2/convert_to_onnx.py | 1 + .../test/python/quantization/op_test_utils.py | 3 +- .../quantization/test_op_matmul_4bits.py | 140 +++++++++++------- .../test_quantizeblockwise_4bits.py | 103 +++++++++---- 8 files changed, 204 insertions(+), 94 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 51a52af1b151e..f582a58d0734a 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -35,9 +35,10 @@ namespace py = pybind11; using namespace onnxruntime; template -void QuantizeMatMul4BitsBlockwise( +void QuantizeMatMulNBitsBlockwise( py::array_t dst, // shape: [ N, block_per_K, block_blob_size ] py::array_t src, // shape: [K, N] + int32_t bits, py::array_t scale, // shape: [N, block_per_K] py::array_t zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2] int32_t block_size, @@ -53,7 +54,23 @@ void QuantizeMatMul4BitsBlockwise( py::buffer_info scale_buf = scale.request(); py::buffer_info zp_buf = zero_points.request(); - MlasQuantizeBlockwise( + if (bits == 2) { + if constexpr (std::is_same::value) { + assert(false); + } + MlasQuantizeBlockwise( + reinterpret_cast(dst_buf.ptr), + reinterpret_cast(scale_buf.ptr), + is_symmetric ? nullptr : reinterpret_cast(zp_buf.ptr), + reinterpret_cast(src_buf.ptr), + block_size, + true, + K, + N, + N, + tp.get()); + } else if (bits == 4) { + MlasQuantizeBlockwise( reinterpret_cast(dst_buf.ptr), reinterpret_cast(scale_buf.ptr), is_symmetric ? nullptr : reinterpret_cast(zp_buf.ptr), @@ -64,6 +81,9 @@ void QuantizeMatMul4BitsBlockwise( N, N, tp.get()); + } else { + assert(false); + } } template @@ -126,8 +146,8 @@ void QuantizeMatMulBnb4Blockwise( } void CreateQuantPybindModule(py::module& m) { - m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); - m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); + m.def("quantize_matmul_nbits", &QuantizeMatMulNBitsBlockwise); + m.def("quantize_matmul_nbits", &QuantizeMatMulNBitsBlockwise); m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); m.def("quantize_qdq_matmul_4bits", &QuantizeQDQMatMul4BitsBlockwise); diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index b4ee5074754dc..b4d09d73fdd43 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -18,7 +18,7 @@ from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto from packaging import version -from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_qdq_matmul_4bits +from onnxruntime.capi._pybind_state import quantize_matmul_nbits, quantize_qdq_matmul_4bits from .calibrate import CalibrationDataReader from .onnx_model import ONNXModel @@ -185,6 +185,7 @@ def __init__( class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, + bits: int = 4, block_size: int = 128, is_symmetric: bool = False, accuracy_level: int | None = None, @@ -221,7 +222,7 @@ def __init__( ) self.block_size = block_size self.is_symmetric = is_symmetric - self.bits = 4 + self.bits = bits self.accuracy_level = accuracy_level @@ -742,8 +743,8 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.nd packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) - quantize_matmul_4bits( - packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric + quantize_matmul_nbits( + packed, fp32weight, self.config.bits, scales, zero_point, block_size, cols, rows, self.config.is_symmetric ) else: packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") @@ -800,7 +801,7 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis rows, cols = b_ndarray.shape kwargs["K"] = rows kwargs["N"] = cols - kwargs["bits"] = 4 + kwargs["bits"] = self.config.bits kwargs["block_size"] = self.config.block_size if self.config.accuracy_level is not None: kwargs["accuracy_level"] = self.config.accuracy_level @@ -1090,6 +1091,7 @@ class MatMul4BitsQuantizer: def __init__( self, model: ModelProto | str, + bits: int, block_size: int = 128, is_symmetric: bool = False, accuracy_level: int | None = None, @@ -1104,6 +1106,7 @@ def __init__( nodes_to_exclude = [] self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) self.model_path = model if isinstance(model, str) else None + self.bits = bits self.block_size = block_size self.is_symmetric = is_symmetric self.accuracy_level = accuracy_level @@ -1113,6 +1116,7 @@ def __init__( if algo_config is None: algo_config = DefaultWeightOnlyQuantConfig( + bits=bits, block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level, @@ -1433,6 +1437,7 @@ def parse_args(): ) elif args.quant_method == "default": quant_config = DefaultWeightOnlyQuantConfig( + bits=args.bits, block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level, @@ -1469,6 +1474,7 @@ def parse_args(): quant = MatMul4BitsQuantizer( model=model, + bits=args.bits, accuracy_level=args.accuracy_level, nodes_to_exclude=args.nodes_to_exclude, nodes_to_include=args.nodes_to_include, diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 27221f9445c30..91ff848ba3c09 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -907,12 +907,12 @@ def quantize( extra_options=quant_config.extra_options, ) else: - # training package doesn't has quantize_matmul_4bits, avoid global import + # training package doesn't has quantize_matmul_nbits, avoid global import from .matmul_4bits_quantizer import MatMul4BitsQuantizer, WeightOnlyQuantConfig if isinstance(quant_config, WeightOnlyQuantConfig): model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load(model_input) - quant = MatMul4BitsQuantizer(model, algo_config=quant_config) + quant = MatMul4BitsQuantizer(model, bits=4, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_output, True) else: diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 89fd613ecbbc2..b3f0aa4e4c766 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -636,6 +636,8 @@ def get_args(): blockwise_group = parser.add_argument_group("blockwise (4-bit quantization)") + parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight") + blockwise_group.add_argument( "--block_size", required=False, @@ -957,6 +959,7 @@ def main(): model = onnx.load_model(fp_path, load_external_data=True) quant = MatMul4BitsQuantizer( model=model, + bits=args.bits, block_size=args.block_size, is_symmetric=True, accuracy_level=args.int4_accuracy_level, diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index 8083778423241..1e0a083bdd744 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -162,6 +162,7 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): assert self.precision == Precision.INT4 quant = MatMul4BitsQuantizer( model=optimizer.model, + bits=4, block_size=self.block_size, is_symmetric=True, accuracy_level=self.accuracy_level, diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index e329b4da38f67..2aa2c2f3b3101 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -465,6 +465,7 @@ def check_model_correctness( inputs, rtol=1e-2, atol=0.05, + skip_onnx_reference_evaluator=False, providers=None, dynamic=False, is_gemm=False, @@ -481,7 +482,7 @@ def check_model_correctness( with open(model_path_origin, "rb") as f: model_onnx = onnx.load(f) ops_set = {node.op_type for node in model_onnx.graph.node} - check_reference_evaluator = not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"}) + check_reference_evaluator = not skip_onnx_reference_evaluator and not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"}) check_target_evaluator = False with open(model_path_to_check, "rb") as f: diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index ed0c65cba78ac..7e212db612c38 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -16,7 +16,7 @@ from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type from onnxruntime.quantization import quant_utils - +from parameterized import parameterized class TestOpMatMul4Bits(unittest.TestCase): @classmethod @@ -27,26 +27,42 @@ def setUpClass(cls): def tearDownClass(cls): cls._tmp_model_dir.cleanup() - def fill_int4_data(self, shape: int | tuple[int, ...], symmetric: bool) -> np.ndarray: + def fill_nbits_data(self, shape: int | tuple[int, ...], bits: int, symmetric: bool) -> np.ndarray: line = np.zeros(shape) line = line.reshape(-1) - if symmetric: - v = -2.0 - for i in range(line.shape[0]): - if v == 0 or v == -3 or v == 3: + if bits == 2: + if symmetric: + v = -1.0 + for i in range(line.shape[0]): + line[i] = v v += 1 - line[i] = v - v += 1 - if v >= 8: - v = -8 - else: - v = 0.0 - for i in range(line.shape[0]): - line[i] = v - v += 1 - if v >= 16: - v = 0 + if v >= 2: + v = -2 + else: + v = 0.0 + for i in range(line.shape[0]): + line[i] = v + v += 1 + if v >= 4: + v = 0 + elif bits == 4: + if symmetric: + v = -2.0 + for i in range(line.shape[0]): + if v == 0 or v == -3 or v == 3: + v += 1 + line[i] = v + v += 1 + if v >= 8: + v = -8 + else: + v = 0.0 + for i in range(line.shape[0]): + line[i] = v + v += 1 + if v >= 16: + v = 0 return line.reshape(shape) @@ -67,7 +83,7 @@ def input_feeds( dr = TestDataFeeds(input_data_list) return dr - def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> None: + def construct_model_matmul(self, output_model_path: str, nbits: int, symmetric: bool) -> None: # (input) # | # MatMul @@ -80,7 +96,7 @@ def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> Non def make_matmul( input_name, weight_shape: int | tuple[int, ...], weight_name: str, output_name: str, node_name: str ): - weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) + weight_data = self.fill_nbits_data(weight_shape, nbits, symmetric).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) return onnx.helper.make_node( "MatMul", @@ -120,6 +136,7 @@ def make_matmul( def construct_model_gather( self, output_model_path: str, + nbits: int, symmetric: bool, tdata: TensorProto.DataType, tind: TensorProto.DataType, @@ -138,7 +155,7 @@ def construct_model_gather( def make_gather( indices_name, data_shape: int | tuple[int, ...], data_name: str, output_name: str, node_name: str ): - weight_data = self.fill_int4_data(data_shape, symmetric).astype( + weight_data = self.fill_nbits_data(data_shape, nbits, symmetric).astype( np.float32 if tdata == TensorProto.FLOAT else np.float16 ) initializers.append(onnx.numpy_helper.from_array(weight_data, name=data_name)) @@ -180,6 +197,7 @@ def quant_test( self, model_fp32_path: str, data_reader: TestDataFeeds, + bits: int, block_size: int, is_symmetric: bool, quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator, @@ -199,13 +217,14 @@ def quant_test( model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( + bits=bits, block_size=block_size, is_symmetric=is_symmetric, quant_format=quant_format, op_types_to_quantize=op_types_to_quantize, quant_axes=quant_axes, ) - quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, bits=bits, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) @@ -239,7 +258,9 @@ def quant_test( data_reader.rewind() try: - check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next(), rtol, atol) + skip_onnx_reference_evaluator = True if bits==2 else False + check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next(), rtol, atol, + skip_onnx_reference_evaluator=skip_onnx_reference_evaluator) except Exception as exception: if "4b quantization not yet supported on this hardware platform!" in exception.args[0]: # Currently we don't have int4 quantization support on all platforms, has to tolerate this exception @@ -252,6 +273,7 @@ def quant_test_with_algo( algorithm: str, model_fp32_path: str, data_reader: TestDataFeeds, + bits: int, block_size: int, is_symmetric: bool, ): @@ -274,7 +296,7 @@ def quant_test_with_algo( algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size) model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) - quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, bits, block_size, is_symmetric, algo_config=algo_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) @@ -292,100 +314,106 @@ def quant_test_with_algo( else: raise exception + @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) - def test_quantize_matmul_int4_symmetric(self): + def test_quantize_matmul_nbits_symmetric(self, bits): np.random.seed(13) model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=True) + self.construct_model_matmul(model_fp32_path, 2, symmetric=True) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test(model_fp32_path, data_reader, 32, True) + self.quant_test(model_fp32_path, data_reader, bits, 32, True) + @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) - def test_quantize_matmul_int4_offsets(self): + def test_quantize_matmul_nbits_offsets(self, bits): model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=False) + self.construct_model_matmul(model_fp32_path, bits, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test(model_fp32_path, data_reader, 32, False) + self.quant_test(model_fp32_path, data_reader, bits, 32, False) + @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) - def test_quantize_gather_int4_symmetric(self): + def test_quantize_gather_nbits_symmetric(self, bits): np.random.seed(13) model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("gather_fp32_symmetric.onnx").absolute()) - self.construct_model_gather(model_fp32_path, True, TensorProto.FLOAT, TensorProto.INT32) + self.construct_model_gather(model_fp32_path, bits, True, TensorProto.FLOAT, TensorProto.INT32) data_reader = self.input_feeds(1, {"input": (100, 1000)}, -545, 535, np.int32) # cover rounding error - self.quant_test(model_fp32_path, data_reader, 32, True, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5) + self.quant_test(model_fp32_path, data_reader, bits, 32, True, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5) + @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) - def test_quantize_gather_int4_offsets(self): + def test_quantize_gather_nbits_offsets(self, bits): model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("gather_fp32_offset.onnx").absolute()) - self.construct_model_gather(model_fp32_path, False, TensorProto.FLOAT16, TensorProto.INT64) + self.construct_model_gather(model_fp32_path, bits, False, TensorProto.FLOAT16, TensorProto.INT64) data_reader = self.input_feeds(1, {"input": (100, 1000)}, -545, 535, np.int64) # cover rounding error - self.quant_test(model_fp32_path, data_reader, 32, False, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5) + self.quant_test(model_fp32_path, data_reader, bits, 32, False, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5) + @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) - def test_quantize_matmul_int4_symmetric_qdq(self): + def test_quantize_matmul_nbits_symmetric_qdq(self, bits): np.random.seed(13) model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=True) + self.construct_model_matmul(model_fp32_path, bits, symmetric=True) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test(model_fp32_path, data_reader, 32, True, quant_utils.QuantFormat.QDQ) + self.quant_test(model_fp32_path, data_reader, bits, 32, True, quant_utils.QuantFormat.QDQ) + @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) - def test_quantize_matmul_int4_offsets_qdq(self): + def test_quantize_matmul_nbits_offsets_qdq(self, bits): model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=False) + self.construct_model_matmul(model_fp32_path, bits, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test(model_fp32_path, data_reader, 32, False, quant_utils.QuantFormat.QDQ) + self.quant_test(model_fp32_path, data_reader, bits, 32, False, quant_utils.QuantFormat.QDQ) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) def test_quantize_matmul_int4_using_rtn_algo(self): if not find_spec("neural_compressor"): self.skipTest("skip test_smooth_quant since neural_compressor is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=False) + self.construct_model_matmul(model_fp32_path, 4, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 32, False) + self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 4, 32, False) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) def test_quantize_matmul_int4_using_gptq_algo(self): if not find_spec("neural_compressor"): self.skipTest("skip test_smooth_quant since neural_compressor is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=False) + self.construct_model_matmul(model_fp32_path, 4, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False) + self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 4, 32, False) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) def test_quantize_matmul_int4_using_hqq_algo(self): if not find_spec("torch"): self.skipTest("skip test_hqq_quant since torch is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=False) + self.construct_model_matmul(model_fp32_path, 4, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False) + self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 4, 32, False) if __name__ == "__main__": diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py index 97931acf03f42..30068f6026884 100644 --- a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py @@ -24,12 +24,15 @@ def dequantize_blockwise_4bits(quant_values, scale, zero_point, valid_len): return quant_float -def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): +def quantize_blockwise_nbits_ref(matrix_float: npt.ArrayLike, bits: int, block_size: int, is_symmetric: bool): if len(matrix_float.shape) != 2: raise ValueError("Current int4 block quantization only supports 2D tensors!") rows, cols = matrix_float.shape - blob_size = block_size // 2 + pack_size = 8 // bits + default_zp = 1 << (bits - 1) + quant_max = ((1 << bits) - 1) if is_symmetric else ((1 << (bits -1)) - 1) + blob_size = block_size // pack_size k_blocks = (rows + block_size - 1) // block_size padded_rows = k_blocks * block_size pad_len = padded_rows - rows @@ -39,7 +42,11 @@ def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, i packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") scales = np.zeros((cols, k_blocks), dtype=matrix_float_padded.dtype) - zero_point = np.full((cols, (k_blocks + 1) // 2), 136, dtype="uint8") + + if bits == 2: + zero_point = np.full((cols, (k_blocks + 1) // pack_size), 0b10101010, dtype="uint8") + elif bits == 4: + zero_point = np.full((cols, (k_blocks + 1) // pack_size), 0b10001000, dtype="uint8") matrix_float_padded = np.transpose(matrix_float_padded) for n in range(cols): @@ -47,38 +54,82 @@ def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, i if is_symmetric: amax_idx = np.argmax(np.abs(matrix_float_padded[n, k_id : k_id + block_size])) bmax = np.float32(matrix_float_padded[n, k_id + amax_idx]) - scale = bmax / (-8.0) - zp = 8 + scale = bmax / default_zp + zp = default_zp else: vmin = np.min(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) vmax = np.max(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) vmin = min(vmin, 0.0) vmax = max(vmax, 0.0) - scale = (vmax - vmin) / ((1 << 4) - 1) + scale = (vmax - vmin) / ((1 << bits) - 1) zero_point_fp = vmin if scale != 0.0: zero_point_fp = 0.0 - vmin / scale - zp = min(15, max(0, round(zero_point_fp))) + zp = min(quant_max, max(0, round(zero_point_fp))) reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 block_idx = k_id // block_size scales[n, block_idx] = scale - zp_pair = zero_point[n, block_idx // 2] - zero_point[n, block_idx // 2] = ( - ((zp_pair & 0x0F) | (zp << 4)) if (block_idx & 1) else ((zp_pair & 0xF0) | zp) - ) - - blk_int0 = np.clip( - np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 2] * reciprocal_scale + zp)), - 0, - 15, - ).astype("uint8") - blk_int1 = np.clip( - np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 2] * reciprocal_scale + zp)), - 0, - 15, - ).astype("uint8") - packed[n, block_idx] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) + if bits == 2: + zero_point_index = block_idx // 4 + zp_quad = zero_point[n, zero_point_index] + quad_index = block_idx & 3 + if quad_index == 0: + zero_point[n, zero_point_index] = (zp_quad & 0b11111100) | zp + elif quad_index == 1: + zero_point[n, zero_point_index] = ((zp_quad & 0b11110011) | (zp << 2)) + elif quad_index == 3: + zero_point[n, zero_point_index] = ((zp_quad & 0b11001111) | (zp << 4)) + elif quad_index == 3: + zero_point[n, zero_point_index] = ((zp_quad & 0b00111111) | (zp << 6)) + else: + raise ValueError("Unsupported bits for blockwise quantization!") + + blk_int0 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + blk_int1 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + blk_int2 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 2 : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + blk_int3 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 3 : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + + packed[n, block_idx] = np.bitwise_or( + np.bitwise_or(blk_int0, np.left_shift(blk_int1, 2)), + np.left_shift(np.bitwise_or(blk_int2, np.left_shift(blk_int3, 2)), 4)) + elif bits == 4: + zero_point_index = block_idx // 2 + zp_pair = zero_point[n, zero_point_index] + if (block_idx & 1) == 0: + zero_point[n, zero_point_index] = (zp_pair & 0xF0) | zp + else: + zero_point[n, zero_point_index] = (zp_pair & 0x0F) | (zp << 4) + + blk_int0 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") + blk_int1 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") + packed[n, block_idx] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) + else: + raise ValueError("Unsupported bits for blockwise quantization!") return (packed, scales, zero_point) @@ -92,15 +143,15 @@ def quantize_blockwise_4bits_target(matrix_float: npt.ArrayLike, block_size: int packed = np.zeros((cols, k_blocks, block_size // 2), dtype="uint8") scales = np.zeros((cols, k_blocks), dtype=matrix_float.dtype) zero_point = np.full((cols, (k_blocks + 1) // 2), 136, dtype="uint8") - from onnxruntime.capi._pybind_state import quantize_matmul_4bits + from onnxruntime.capi._pybind_state import quantize_matmul_nbits - quantize_matmul_4bits(packed, matrix_float, scales, zero_point, block_size, cols, rows, is_symmetric) + quantize_matmul_nbits(packed, matrix_float, 4, scales, zero_point, block_size, cols, rows, is_symmetric) return (packed, scales, zero_point) class TestQuantizeBlockwise4Bits(unittest.TestCase): @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" ) def test_quantize_blockwise_4bits(self): for rows, cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: From e85431e5546a39c604a02e7e76ed6978cde886aa Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 17 Jul 2025 14:56:38 -0700 Subject: [PATCH 08/33] fix compile errors --- .../cpu/quantization/matmul_nbits.cc | 1 - onnxruntime/core/mlas/lib/q4_dq.cpp | 2 - onnxruntime/core/mlas/lib/qnbitgemm.cpp | 40 +++++-------------- onnxruntime/core/mlas/lib/qnbitgemm.h | 2 + .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 17 +++----- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 9 ++--- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 9 ++--- .../python/onnxruntime_pybind_quant.cc | 19 --------- .../test/contrib_ops/matmul_4bits_test.cc | 1 - .../test/mlas/unittest/test_blockq4.cpp | 5 ++- 10 files changed, 26 insertions(+), 79 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index cd4b811983a4b..1fea7ee416f12 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -471,7 +471,6 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, static_cast(N_), // number of columns in quantized input thread_pool); } else if (this->nbits_ == 4) { - if (nbits_ == 4) { MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 67671b8d51626..b127671dcb517 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -600,8 +600,6 @@ struct BlockwiseQuantizer { const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; - constexpr int pack_size = BitsTraits::kPackSize; - int q_rows, q_cols; quantizedShape(rows, columns, q_rows, q_cols); constexpr int32_t kPackSize = BitsTraits::kPackSize; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 4db58451001e2..bc803657aae48 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -108,11 +108,15 @@ MlasIsQNBitGemmAvailable( } case SQNBitGemmVariant_BitWidth2_CompInt8: { return (Dispatch->SQ2BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr); + } + case SQ8BitGemmVariant_CompInt8: { return Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr && Dispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr && + Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr; } default: { return false; + } } } @@ -145,7 +149,7 @@ QNBitGemmPerGemmWorkspaceSize( size_t QNBitGemmPerGemmWorkspaceAlignment( - size_t /*BlkBitWidth*/, + size_t BlkBitWidth, size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) @@ -232,7 +236,7 @@ MlasQNBitGemmPackQuantBDataSize( if (BlkBitWidth == 2 && Dispatch->Q2BitGemmPackQuantBDataSize != nullptr) { return Dispatch->Q2BitGemmPackQuantBDataSize( - N, K, BlkLen, ComputeType + N, K, BlkLen, HasZeroPoint, ComputeType ); } @@ -327,6 +331,7 @@ MlasQNBitGemmPackQuantBData( ThreadPool ); return; + } } else if (BlkBitWidth == 8) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -906,19 +911,6 @@ InitializeWorkspace_CompInt8( QuantARowBlkSum += BlockCountK; } }); - } else { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); } } @@ -1088,14 +1080,7 @@ MlasQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (BlkBitWidth == 4 && ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; - PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); - ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); - } else if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { + if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; @@ -1173,14 +1158,7 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (BlkBitWidth == 4 && ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; - - PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); - } else if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { + if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index 2ca20b45bf5b9..5214ea61127b5 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -102,6 +102,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH { Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; + Q4BitGemmPackQuantBDataSize_Fn* Q2BitGemmPackQuantBDataSize = nullptr; + /** Gets size of packed quantized B data containing 8-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ typedef size_t(Q8BitGemmPackQuantBDataSize_Fn)( size_t N, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index c361f9da34533..144beda003328 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1446,6 +1446,9 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; + d.Q2BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<2>; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; @@ -1460,11 +1463,6 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; - - d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; @@ -1480,6 +1478,9 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; + d.Q2BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<2>; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; @@ -1490,12 +1491,6 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; - // change funcions if implementation are different from avx2 - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; - - d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 221bee3c3d17c..7d0c0fbd8ee0a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -478,6 +478,9 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; + d.Q2BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<2>; + + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512; @@ -492,12 +495,6 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; - // change funcions if implementation are different from avx2 - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; - - d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 226e3e62f05b1..e4832621442eb 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -463,6 +463,9 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; + d.Q2BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<2>; + + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512vnni; @@ -477,12 +480,6 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; - // change funcions if implementation are different from avx2 - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; - - d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 016c68c960095..b83d7cc4cb6e8 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -53,22 +53,6 @@ void QuantizeMatMulNBitsBlockwise( py::buffer_info scale_buf = scale.request(); py::buffer_info zp_buf = zero_points.request(); - if (qbits == 2) { - if constexpr (std::is_same::value) { - assert(false); - } - MlasQuantizeBlockwise( - reinterpret_cast(dst_buf.ptr), - reinterpret_cast(scale_buf.ptr), - is_symmetric ? nullptr : reinterpret_cast(zp_buf.ptr), - reinterpret_cast(src_buf.ptr), - block_size, - true, - K, - N, - N, - tp.get()); - } else if (qbits == 4 || qbits == 8) { MlasQuantizeBlockwise( reinterpret_cast(dst_buf.ptr), reinterpret_cast(scale_buf.ptr), @@ -80,9 +64,6 @@ void QuantizeMatMulNBitsBlockwise( N, N, tp.get()); - } else { - assert(false); - } } template diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index fe712979d31bd..0bd1ab63c2d77 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -493,7 +493,6 @@ TEST(MatMulNBits, LegacyShape) { TestMatMulNBitsTyped(); } -#endif #endif #endif diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index 20f9779798479..a6d8bcd4a34e6 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -212,10 +212,10 @@ class MlasBlockwiseQdqTest : public MlasTestBase { } } - MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, + MlasDequantizeBlockwise(input, elements, scales, zp, block_size, columnwise, rows, columns, threadpool_ptr); - MlasTranspose(dequant_buf, transposed, columns, rows); + MlasTranspose(input, transposed, columns, rows); uint8_t* o_elements = OutputElements.GetBuffer(q_rows * q_cols, true); float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); @@ -304,6 +304,7 @@ class MlasBlockwiseQdqTest : public MlasTestBase { } } } + } for (int c = 0; c < meta_cols; c++) { for (int r = 0; r < meta_rows; r++) { From 96427403722bf6c9f520c1973577fe3ca85baadc Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 18 Jul 2025 14:27:56 -0700 Subject: [PATCH 09/33] resolve build failure update --- .../test/contrib_ops/matmul_4bits_test.cc | 297 +++++++-------- .../test/mlas/unittest/test_blockq4.cpp | 360 ++++-------------- 2 files changed, 218 insertions(+), 439 deletions(-) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 0bd1ab63c2d77..f58d0b8a51ba7 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -34,10 +34,8 @@ namespace test { namespace { -constexpr int Q2Bits = 2; -constexpr int Q4Bits = 4; +constexpr int QBits = 4; -template void QuantizeDequantize(std::vector& raw_vals, std::vector& quant_vals, std::vector& scales, @@ -48,7 +46,7 @@ void QuantizeDequantize(std::vector& raw_vals, auto& ortenv = **ort_env.get(); onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); - MlasQuantizeBlockwise( + MlasQuantizeBlockwise( quant_vals.data(), scales.data(), zp != nullptr ? zp->data() : nullptr, @@ -61,7 +59,7 @@ void QuantizeDequantize(std::vector& raw_vals, tp); // Note that raw_vals is NxK after dequant - MlasDequantizeBlockwise( + MlasDequantizeBlockwise( raw_vals.data(), // dequantized output quant_vals.data(), // quantized input scales.data(), // quantization scales @@ -101,7 +99,7 @@ struct TestOptions { << ", has_bias:" << opts.has_bias; } -template +template void RunTest(const TestOptions& opts, std::vector>&& explicit_eps = {}) { SCOPED_TRACE(opts); @@ -125,29 +123,18 @@ void RunTest(const TestOptions& opts, std::vector input0_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); std::vector input1_f_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); -#if 0 // for Debugging - std::vector input1_f_vals_trans(N * K); - MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); -#endif - - int q_rows, q_cols; - MlasBlockwiseQuantizedShape(static_cast(opts.block_size), /* columnwise */ true, - static_cast(K), static_cast(N), - q_rows, q_cols); - - size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(static_cast(opts.block_size), /* columnwise */ true, - static_cast(K), static_cast(N), - q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); int64_t k_blocks = (K + opts.block_size - 1) / opts.block_size; int64_t blob_size = (opts.block_size * QBits + 7) / 8; + size_t q_scale_size = static_cast(N * k_blocks); + size_t q_data_size_in_bytes = static_cast(N * k_blocks * blob_size); // packed as UInt4x2 const int64_t zero_point_blob_size = (k_blocks * QBits + 7) / 8; + size_t q_zp_size_in_bytes = static_cast(N * zero_point_blob_size); // packed as UInt4x2 std::vector input1_vals(q_data_size_in_bytes); std::vector scales(q_scale_size); std::vector zp(q_zp_size_in_bytes); - QuantizeDequantize(input1_f_vals, + QuantizeDequantize(input1_f_vals, input1_vals, scales, opts.has_zero_point ? &zp : nullptr, @@ -178,7 +165,7 @@ void RunTest(const TestOptions& opts, test.AddAttribute("K", K); test.AddAttribute("N", N); test.AddAttribute("block_size", opts.block_size); - test.AddAttribute("bits", qbits); + test.AddAttribute("bits", QBits); test.AddAttribute("accuracy_level", opts.accuracy_level); if constexpr (std::is_same_v) { @@ -285,8 +272,9 @@ void RunTest(const TestOptions& opts, } // namespace -template -void TestMatMulNBitsTyped() { +template +void TestMatMulNBitsTyped(std::optional abs_error = std::nullopt, + std::optional rel_error = std::nullopt) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; base_opts.block_size = block_size; @@ -309,25 +297,25 @@ void TestMatMulNBitsTyped() { base_opts.output_rel_error = 0.02f; } - if constexpr (qbits == 4) { + { TestOptions opts = base_opts; - RunTest(opts); + RunTest(opts); } { TestOptions opts = base_opts; opts.has_zero_point = true; - RunTest(opts); + RunTest(opts); } #if !defined(USE_DML) && !defined(USE_WEBGPU) - if constexpr (qbits == 4) { + { TestOptions opts = base_opts; opts.has_g_idx = true; - RunTest(opts); + RunTest(opts); } - if constexpr (qbits == 4) { + { TestOptions opts = base_opts; opts.has_g_idx = true; opts.has_bias = true; @@ -343,13 +331,14 @@ void TestMatMulNBitsTyped() { // only enabled for CPU EP for now std::vector> explicit_eps; explicit_eps.emplace_back(DefaultCpuExecutionProvider()); - RunTest(opts, std::move(explicit_eps)); + RunTest(opts, std::move(explicit_eps)); } { TestOptions opts = base_opts; - opts.has_zero_point = true, opts.zp_is_4bit = false; - RunTest(opts); + opts.has_zero_point = true; + opts.zp_is_4bit = false; + RunTest(opts); } #endif // !defined(USE_DML) && !defined(USE_WEBGPU) } @@ -357,134 +346,141 @@ void TestMatMulNBitsTyped() { #if !defined(USE_OPENVINO) TEST(MatMulNBits, Float32_Accuracy0) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float32_Accuracy1) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float32_Accuracy4) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); } -// TODO: enable and add more tests for 2bit development. -TEST(MatMulNBits, DISABLED_Float32_Accuracy4_Q2) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64) +#if !defined(USE_DML) +// Actual and expected difference is over 0.01 with DmlExecutionProvider. +// Skip the tests instead of raising the tolerance to make is pass. +TEST(MatMulNBits, Float16_Accuracy2) { + TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float16_Accuracy0) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float16_Accuracy4) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); + // TestMatMulNBitsTyped(); } TEST(MatMulNBits, LegacyShape) { @@ -493,6 +489,7 @@ TEST(MatMulNBits, LegacyShape) { TestMatMulNBitsTyped(); } +#endif #endif #endif @@ -530,18 +527,20 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (std::is_same_v) { #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_WEBGPU execution_providers.push_back(DefaultWebGpuExecutionProvider()); -#endif - RunTest(opts, std::move(execution_providers)); +#endif } else { #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index a6d8bcd4a34e6..4302e36db75cd 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -20,9 +20,6 @@ Module Name: #include "mlas_q4.h" #include "core/mlas/lib/mlasi.h" -constexpr int Q2Bits = 2; -constexpr int Q4Bits = 4; - template int GetElem(int v, int idx) { return (v >> (qbits * idx)) & ((1 << qbits) - 1); @@ -65,7 +62,6 @@ class MlasBlockwiseQdqTest : public MlasTestBase { return std::abs(va - vb) < err + std::abs(va) * rel; } - template void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { constexpr int packSize = 8 / qbits; T* input = FpBuf.GetFilledBuffer(rows * columns, [this](T* start, size_t size) { @@ -81,153 +77,49 @@ class MlasBlockwiseQdqTest : public MlasTestBase { int meta_rows; int meta_cols; - MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); int q_rows; int q_cols; - MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; MlasBlockwiseQuantizedBufferSizes(block_size, columnwise, rows, columns, - q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); - - uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); - uint8_t* qdq_weights = QDQOutputElements.GetBuffer((rows * columns + 1) / 2, true); - uint8_t* qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); - - int pack_size = 8 / qbits; - int v; - if constexpr (qbits == 2) { - v = 1; - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += pack_size) { - int idx = c * q_rows + r / pack_size; - uint8_t v0 = static_cast(v); - v = (v + 1) % 4; - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 1) % 4; - if (v == 3) { - v = (v + 1) % 4; - } - } - uint8_t v2 = 0; - if (r + 2 < rows) { - v2 = static_cast(v); - v = (v + 1) % 4; - if (v == 3) { - v = (v + 1) % 4; - } - } - uint8_t v3 = 0; - if (r + 3 < rows) { - v3 = static_cast(v); - v = (v + 1) % 4; - if (v == 3) { - v = (v + 1) % 4; - } - } - elements[idx] = (v3 << 6) | (v2 << 4) | (v1 << 2) | v0; - } - } - } else if constexpr(qbits == 4) { - v = 7; - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); - elements[idx] = (v1 << 4) | v0; - } - } + uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); // after quantize + uint8_t* qdq_weights; + uint8_t* qdq_weights_T; + if constexpr (qbits == 4) { + qdq_weights = QDQOutputElements.GetBuffer((rows * columns + packSize - 1) / packSize, true); + qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); } T* scales = InputScales.GetBuffer(q_scale_size, true); uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(q_zp_size_in_bytes, true); - uint8_t* qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); - uint8_t* qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); - if (zp) { - if constexpr (qbits == 2) { - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += pack_size) { - int idx = c * ((meta_rows + 3) / pack_size) + r / pack_size; - uint8_t v0 = static_cast(v); - v = (v + 1) % 4; - uint8_t v1 = 0; - if (r + 1 < meta_rows) { - v1 = static_cast(v); - v = (v + 1) % 4; - } - uint8_t v2 = 0; - if (r + 2 < meta_rows) { - v2 = static_cast(v); - v = (v + 1) % 4; - } - uint8_t v3 = 0; - if (r + 3 < meta_rows) { - v3 = static_cast(v); - v = (v + 1) % 4; - } - zp[idx] = (v3 << 6) | (v2 << 4) | (v1 << 2) | v0; - } - } - } - else if constexpr (qbits == 4) { - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < meta_rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - zp[idx] = (v1 << 4) | v0; - } - } - } + T* qdq_scales; + T* qdq_scales_T; + uint8_t* qdq_zp; + uint8_t* qdq_zp_T; + if constexpr (qbits == 4) { + qdq_scales = QDQOutputScales.GetBuffer(scale_size, true); + qdq_scales_T = QDQTransposedOutputScales.GetBuffer(q_scale_size, true); + qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); + qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); } - MlasDequantizeBlockwise(input, elements, scales, zp, block_size, - columnwise, rows, columns, threadpool_ptr); - - MlasTranspose(input, transposed, columns, rows); + MlasQuantizeBlockwise(elements, scales, zp, input, block_size, + columnwise, rows, columns, columns, threadpool_ptr); - uint8_t* o_elements = OutputElements.GetBuffer(q_rows * q_cols, true); - float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); - uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); + MlasDequantizeBlockwise(dequant, elements, scales, zp, block_size, + columnwise, rows, columns, threadpool_ptr); - MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, - columnwise, rows, columns, columns, threadpool_ptr); + MlasTranspose(dequant, transposed, columns, rows, threadpool_ptr); if constexpr (qbits == 4) { if (columnwise) { - bool signed_quant = MlasQDQQuantizeBlockwise( - transposed, qdq_scales, qdq_zp, qdq_weights, + bool signed_quant = MlasQDQQuantizeBlockwise( + input, qdq_scales, qdq_zp, qdq_weights, true, rows, columns, block_size, threadpool_ptr); ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; @@ -238,7 +130,7 @@ class MlasBlockwiseQdqTest : public MlasTestBase { true, rows, columns, block_size, threadpool_ptr); } else { - MlasQDQTransposeBlockwiseQuantized( + MlasQDQTransposeBlockwiseQuantized( qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, true, rows, columns, block_size, threadpool_ptr); } @@ -264,47 +156,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase { << ", qdq index=[" << r + l << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; } - if (r + 2 < rows) { - ASSERT_EQ((o_elements[idx] >> 4) & 0x3, (elements[idx] >> 4) & 0x3) - << ", index=[" << r + 2 << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - if (r + 3 < rows) { - ASSERT_EQ((o_elements[idx] >> 6) & 0x3, (elements[idx] >> 6) & 0x3) - << ", index=[" << r + 3 << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } } } - } else if constexpr (qbits == 4) { - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) - << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if constexpr (qbits == 4) { - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] & 0xf, elements[idx] & 0xf) - << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - } - if (r + 1 < rows) { - ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if constexpr (qbits == 4) { - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - } - } - } - } - } for (int c = 0; c < meta_cols; c++) { for (int r = 0; r < meta_rows; r++) { @@ -315,53 +168,14 @@ class MlasBlockwiseQdqTest : public MlasTestBase { } } - if (symmetric) return; - - if constexpr (qbits == 2) { + if (symmetric) return; for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += pack_size) { - int idx = c * ((meta_rows + 3) / pack_size) + r / pack_size; - ASSERT_EQ(o_zp[idx] & 0x3, zp[idx] & 0x3) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (r + 1 < meta_rows) { - ASSERT_EQ((o_zp[idx] >> 2) & 0x3, (zp[idx] >> 2) & 0x3) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - if (r + 2 < meta_rows) { - ASSERT_EQ((o_zp[idx] >> 4) & 0x3, (zp[idx] >> 4) & 0x3) - << ", index=" << r + 2 << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - if (r + 3 < meta_rows) { - ASSERT_EQ((o_zp[idx] >> 6) & 0x3, (zp[idx] >> 6) & 0x3) - << ", index=" << r + 3 << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - } - } - } else if constexpr (qbits == 4) { - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - if (r + 1 < meta_rows) { - ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + for (int r = 0; r < meta_rows; r += packSize) { + int idx = c * ((meta_rows + packSize - 1) / packSize) + r / packSize; + for (int l = 0; l < packSize && r + l < meta_rows; ++l) { + ASSERT_EQ(GetElem(qdq_zp_T[idx], l), GetElem(zp[idx], l)) + << ", qdq index=" << r + l << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } } } } @@ -374,78 +188,44 @@ class MlasBlockwiseQdqTest : public MlasTestBase { return suite_name.c_str(); } - void ExecuteShort(void) { - // only support columnwise = true with qbits=2 - Test(20, 1, 32, true, false); - Test(20, 1, 32, true, true); - //Test(1, 20, 32, false, false); - //Test(1, 20, 32, false, true); - Test(52, 1, 32, true, false); - Test(52, 1, 32, true, true); - //Test(1, 52, 32, false, false); - //Test(1, 52, 32, false, true); - Test(20, 3, 32, true, false); - Test(20, 3, 32, true, true); - //Test(3, 20, 32, false, false); - //Test(3, 20, 32, false, true); - Test(52, 3, 32, true, false); - Test(52, 3, 32, true, true); - //Test(3, 52, 32, false, false); - //Test(3, 52, 32, false, true); - Test(52, 3, 64, true, false); - Test(52, 3, 64, true, true); - //Test(3, 52, 64, false, false); - //Test(3, 52, 64, false, true); - Test(32 * 9 + 17, 41, 32, true, false); - Test(32 * 9 + 17, 41, 32, true, true); - //Test(41, 32 * 9 + 17, 32, false, false); - //Test(41, 32 * 9 + 17, 32, false, true); - Test(32 * 9 + 17, 41, 64, true, false); - Test(32 * 9 + 17, 41, 64, true, true); - //Test(41, 32 * 9 + 17, 64, false, false); - //Test(41, 32 * 9 + 17, 64, false, true); - Test(32 * 15 + 17, 63, 128, true, false); - Test(32 * 15 + 17, 63, 128, true, true); - //Test(63, 32 * 15 + 17, 128, false, false); - //Test(63, 32 * 15 + 17, 128, false, true); - - Test(20, 1, 32, true, false); - Test(20, 1, 32, true, true); - Test(1, 20, 32, false, false); - Test(1, 20, 32, false, true); - Test(52, 1, 32, true, false); - Test(52, 1, 32, true, true); - Test(1, 52, 32, false, false); - Test(1, 52, 32, false, true); - Test(20, 3, 32, true, false); - Test(20, 3, 32, true, true); - Test(3, 20, 32, false, false); - Test(3, 20, 32, false, true); - Test(52, 3, 32, true, false); - Test(52, 3, 32, true, true); - Test(3, 52, 32, false, false); - Test(3, 52, 32, false, true); - Test(52, 3, 64, true, false); - Test(52, 3, 64, true, true); - Test(3, 52, 64, false, false); - Test(3, 52, 64, false, true); - Test(32 * 9 + 17, 41, 32, true, false); - Test(32 * 9 + 17, 41, 32, true, true); - Test(41, 32 * 9 + 17, 32, false, false); - Test(41, 32 * 9 + 17, 32, false, true); - Test(32 * 9 + 17, 41, 64, true, false); - Test(32 * 9 + 17, 41, 64, true, true); - Test(41, 32 * 9 + 17, 64, false, false); - Test(41, 32 * 9 + 17, 64, false, true); - Test(32 * 15 + 17, 63, 128, true, false); - Test(32 * 15 + 17, 63, 128, true, true); - Test(63, 32 * 15 + 17, 128, false, false); - Test(63, 32 * 15 + 17, 128, false, true); - - Test(256, 256, 32, true, false); - Test(256, 256, 32, true, true); - Test(256, 256, 32, false, false); - Test(256, 256, 32, false, true); + void ExecuteShort(void) override { + Test(20, 1, 32, true, false); + Test(20, 1, 32, true, true); + Test(1, 20, 32, false, false); + Test(1, 20, 32, false, true); + Test(52, 1, 32, true, false); + Test(52, 1, 32, true, true); + Test(1, 52, 32, false, false); + Test(1, 52, 32, false, true); + Test(20, 3, 32, true, false); + Test(20, 3, 32, true, true); + Test(3, 20, 32, false, false); + Test(3, 20, 32, false, true); + Test(52, 3, 32, true, false); + Test(52, 3, 32, true, true); + Test(3, 52, 32, false, false); + Test(3, 52, 32, false, true); + Test(52, 3, 64, true, false); + Test(52, 3, 64, true, true); + Test(3, 52, 64, false, false); + Test(3, 52, 64, false, true); + Test(32 * 9 + 17, 41, 32, true, false); + Test(32 * 9 + 17, 41, 32, true, true); + Test(41, 32 * 9 + 17, 32, false, false); + Test(41, 32 * 9 + 17, 32, false, true); + Test(32 * 9 + 17, 41, 64, true, false); + Test(32 * 9 + 17, 41, 64, true, true); + Test(41, 32 * 9 + 17, 64, false, false); + Test(41, 32 * 9 + 17, 64, false, true); + Test(32 * 15 + 17, 63, 128, true, false); + Test(32 * 15 + 17, 63, 128, true, true); + Test(63, 32 * 15 + 17, 128, false, false); + Test(63, 32 * 15 + 17, 128, false, true); + + Test(256, 256, 32, true, false); + Test(256, 256, 32, true, true); + Test(256, 256, 32, false, false); + Test(256, 256, 32, false, true); } MlasBlockwiseQdqTest() = default; From 892222a45f1078dcfbbd9370f3454b636998b030 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 23 Jul 2025 15:10:04 -0700 Subject: [PATCH 10/33] 2 bits check --- onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h index 80a360ebb1b29..9a5f449989dd9 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h @@ -31,7 +31,7 @@ Status CheckInputs(const T* /*activation*/, // group_index : (K) or (k_blocks * block_size), or null // bias : (N), or null // Note that scales and zero_points can be 1D for backward compatibility. - if (bits != 4 && bits != 8) { + if (bits != 2 && bits != 4 && bits != 8) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bits should be 4 or 8, got ", bits); } From 07b7f3fcfcdf3bf88e1dc11f3d55d4c5161d2380 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 25 Jul 2025 14:24:22 -0700 Subject: [PATCH 11/33] fixed bug causing int8 tests to fail --- onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index e4832621442eb..d4fe05c157c0e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -481,6 +481,5 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; }(); From 493ebd152cfbbb99a368b789572aef2848ebcbf5 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 7 Aug 2025 11:21:08 -0700 Subject: [PATCH 12/33] lintrunner --- .../cpu/quantization/matmul_nbits_impl.cc | 9 +- .../cpu/quantization/matmul_nbits_impl.h | 2 +- .../quantization/matmul_nbits_quantizer.py | 2 - .../test/mlas/unittest/test_sqnbitgemm.cpp | 10 +- .../test/python/quantization/op_test_utils.py | 4 +- .../quantization/test_op_matmul_4bits.py | 138 +++++++----------- .../test_quantizeblockwise_4bits.py | 103 ++++--------- 7 files changed, 94 insertions(+), 174 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index dd3d1fd9ac2cc..50bcbe4a177c5 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -106,12 +106,12 @@ void DequantizeBlockwise( for (int j = 0; j < 256; j++) { if constexpr (qbits == 2) { Dequantize2BitsKernelReOrder(output, quant_data, scales_data, zero_points, - reorder_idx, block_size, groups_per_threadblock, - total_groups, N, K, static_cast(block_id), j); + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); } else { Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, - reorder_idx, block_size, groups_per_threadblock, - total_groups, N, K, static_cast(block_id), j); + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); } } }); @@ -122,7 +122,6 @@ template void DequantizeBlockwise( const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); - template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h index b875048cbc585..be77ec03d006b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -6,7 +6,7 @@ namespace onnxruntime { namespace contrib { -template +template void DequantizeBlockwise( inputT* output, // dequantized output const uint8_t* quant_data, // quantized input diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index cfc2d9ea08327..e8540b4567f3e 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -1259,7 +1259,6 @@ def __init__( if algo_config is None: algo_config = DefaultWeightOnlyQuantConfig( - bits=bits, block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level, @@ -1582,7 +1581,6 @@ def parse_args(): ) elif args.quant_method == "default": quant_config = DefaultWeightOnlyQuantConfig( - bits=args.bits, block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level, diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 5bf9f6c064fe1..47002dd7eea72 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -444,11 +444,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); - //count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); - //count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); - //count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); - //count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); + // count += SQNBitGemmShortExecuteTest<2, 16>::RegisterShortExecuteTests(); + // count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + // count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + // count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + // count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests(); diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index ecfe03e2eccc2..03b03e25d1923 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -482,7 +482,9 @@ def check_model_correctness( with open(model_path_origin, "rb") as f: model_onnx = onnx.load(f) ops_set = {node.op_type for node in model_onnx.graph.node} - check_reference_evaluator = not skip_onnx_reference_evaluator and not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"}) + check_reference_evaluator = not skip_onnx_reference_evaluator and not ( + ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"} + ) check_target_evaluator = False with open(model_path_to_check, "rb") as f: diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index e54e26cfcce21..a019260a11670 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -16,7 +16,7 @@ from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type from onnxruntime.quantization import quant_utils -from parameterized import parameterized + class TestOpMatMul4Bits(unittest.TestCase): @classmethod @@ -27,42 +27,26 @@ def setUpClass(cls): def tearDownClass(cls): cls._tmp_model_dir.cleanup() - def fill_nbits_data(self, shape: int | tuple[int, ...], bits: int, symmetric: bool) -> np.ndarray: + def fill_int4_data(self, shape: int | tuple[int, ...], symmetric: bool) -> np.ndarray: line = np.zeros(shape) line = line.reshape(-1) - if bits == 2: - if symmetric: - v = -1.0 - for i in range(line.shape[0]): - line[i] = v - v += 1 - if v >= 2: - v = -2 - else: - v = 0.0 - for i in range(line.shape[0]): - line[i] = v - v += 1 - if v >= 4: - v = 0 - elif bits == 4: - if symmetric: - v = -2.0 - for i in range(line.shape[0]): - if v == 0 or v == -3 or v == 3: - v += 1 - line[i] = v + if symmetric: + v = -2.0 + for i in range(line.shape[0]): + if v == 0 or v == -3 or v == 3: v += 1 - if v >= 8: - v = -8 - else: - v = 0.0 - for i in range(line.shape[0]): - line[i] = v - v += 1 - if v >= 16: - v = 0 + line[i] = v + v += 1 + if v >= 8: + v = -8 + else: + v = 0.0 + for i in range(line.shape[0]): + line[i] = v + v += 1 + if v >= 16: + v = 0 return line.reshape(shape) @@ -83,7 +67,7 @@ def input_feeds( dr = TestDataFeeds(input_data_list) return dr - def construct_model_matmul(self, output_model_path: str, nbits: int, symmetric: bool) -> None: + def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> None: # (input) # | # MatMul @@ -96,7 +80,7 @@ def construct_model_matmul(self, output_model_path: str, nbits: int, symmetric: def make_matmul( input_name, weight_shape: int | tuple[int, ...], weight_name: str, output_name: str, node_name: str ): - weight_data = self.fill_nbits_data(weight_shape, nbits, symmetric).astype(np.float32) + weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) return onnx.helper.make_node( "MatMul", @@ -136,7 +120,6 @@ def make_matmul( def construct_model_gather( self, output_model_path: str, - nbits: int, symmetric: bool, tdata: TensorProto.DataType, tind: TensorProto.DataType, @@ -155,7 +138,7 @@ def construct_model_gather( def make_gather( indices_name, data_shape: int | tuple[int, ...], data_name: str, output_name: str, node_name: str ): - weight_data = self.fill_nbits_data(data_shape, nbits, symmetric).astype( + weight_data = self.fill_int4_data(data_shape, symmetric).astype( np.float32 if tdata == TensorProto.FLOAT else np.float16 ) initializers.append(onnx.numpy_helper.from_array(weight_data, name=data_name)) @@ -197,7 +180,6 @@ def quant_test( self, model_fp32_path: str, data_reader: TestDataFeeds, - bits: int, block_size: int, is_symmetric: bool, quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator, @@ -216,8 +198,7 @@ def quant_test( from onnxruntime.quantization import matmul_nbits_quantizer # noqa: PLC0415 model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) - quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( - bits=bits, + quant_config = matmul_nbits_quantizer.DefaultWeightOnlyQuantConfig( block_size=block_size, is_symmetric=is_symmetric, quant_format=quant_format, @@ -258,9 +239,7 @@ def quant_test( data_reader.rewind() try: - skip_onnx_reference_evaluator = True if bits==2 else False - check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next(), rtol, atol, - skip_onnx_reference_evaluator=skip_onnx_reference_evaluator) + check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next(), rtol, atol) except Exception as exception: if "4b quantization not yet supported on this hardware platform!" in exception.args[0]: # Currently we don't have int4 quantization support on all platforms, has to tolerate this exception @@ -273,7 +252,6 @@ def quant_test_with_algo( algorithm: str, model_fp32_path: str, data_reader: TestDataFeeds, - bits: int, block_size: int, is_symmetric: bool, ): @@ -317,76 +295,70 @@ def quant_test_with_algo( else: raise exception - @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) - def test_quantize_matmul_nbits_symmetric(self, bits): + def test_quantize_matmul_int4_symmetric(self): np.random.seed(13) model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, 2, symmetric=True) + self.construct_model_matmul(model_fp32_path, symmetric=True) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test(model_fp32_path, data_reader, bits, 32, True) + self.quant_test(model_fp32_path, data_reader, 32, True) - @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) - def test_quantize_matmul_nbits_offsets(self, bits): + def test_quantize_matmul_int4_offsets(self): model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, bits, symmetric=False) + self.construct_model_matmul(model_fp32_path, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test(model_fp32_path, data_reader, bits, 32, False) + self.quant_test(model_fp32_path, data_reader, 32, False) - @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) - def test_quantize_gather_nbits_symmetric(self, bits): + def test_quantize_gather_int4_symmetric(self): np.random.seed(13) model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("gather_fp32_symmetric.onnx").absolute()) - self.construct_model_gather(model_fp32_path, bits, True, TensorProto.FLOAT, TensorProto.INT32) + self.construct_model_gather(model_fp32_path, True, TensorProto.FLOAT, TensorProto.INT32) data_reader = self.input_feeds(1, {"input": (100, 1000)}, -545, 535, np.int32) # cover rounding error - self.quant_test(model_fp32_path, data_reader, bits, 32, True, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5) + self.quant_test(model_fp32_path, data_reader, 32, True, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5) - @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) - def test_quantize_gather_nbits_offsets(self, bits): + def test_quantize_gather_int4_offsets(self): model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("gather_fp32_offset.onnx").absolute()) - self.construct_model_gather(model_fp32_path, bits, False, TensorProto.FLOAT16, TensorProto.INT64) + self.construct_model_gather(model_fp32_path, False, TensorProto.FLOAT16, TensorProto.INT64) data_reader = self.input_feeds(1, {"input": (100, 1000)}, -545, 535, np.int64) # cover rounding error - self.quant_test(model_fp32_path, data_reader, bits, 32, False, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5) + self.quant_test(model_fp32_path, data_reader, 32, False, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5) - @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) - def test_quantize_matmul_nbits_symmetric_qdq(self, bits): + def test_quantize_matmul_int4_symmetric_qdq(self): np.random.seed(13) model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, bits, symmetric=True) + self.construct_model_matmul(model_fp32_path, symmetric=True) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test(model_fp32_path, data_reader, bits, 32, True, quant_utils.QuantFormat.QDQ) + self.quant_test(model_fp32_path, data_reader, 32, True, quant_utils.QuantFormat.QDQ) - @parameterized.expand([(2,), (4,)]) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) - def test_quantize_matmul_nbits_offsets_qdq(self, bits): + def test_quantize_matmul_int4_offsets_qdq(self): model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, bits, symmetric=False) + self.construct_model_matmul(model_fp32_path, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test(model_fp32_path, data_reader, bits, 32, False, quant_utils.QuantFormat.QDQ) + self.quant_test(model_fp32_path, data_reader, 32, False, quant_utils.QuantFormat.QDQ) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) def test_quantize_matmul_int4_using_rtn_algo(self): if not find_spec("neural_compressor"): @@ -394,12 +366,12 @@ def test_quantize_matmul_int4_using_rtn_algo(self): if not find_spec("torch"): self.skipTest("skip test_quantize_matmul_int4_using_rtn_algo since torch is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, 4, symmetric=False) + self.construct_model_matmul(model_fp32_path, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 4, 32, False) + self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 32, False) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) def test_quantize_matmul_int4_using_gptq_algo(self): if not find_spec("neural_compressor"): @@ -407,20 +379,20 @@ def test_quantize_matmul_int4_using_gptq_algo(self): if not find_spec("torch"): self.skipTest("skip test_quantize_matmul_int4_using_gptq_algo since torch is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, 4, symmetric=False) + self.construct_model_matmul(model_fp32_path, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 4, 32, False) + self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False) @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) def test_quantize_matmul_int4_using_hqq_algo(self): if not find_spec("torch"): self.skipTest("skip test_hqq_quant since torch is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, 4, symmetric=False) + self.construct_model_matmul(model_fp32_path, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) - self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 4, 32, False) + self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False) if __name__ == "__main__": diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py index e92758c2694ab..2a78b7bd1900a 100644 --- a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py @@ -24,15 +24,12 @@ def dequantize_blockwise_4bits(quant_values, scale, zero_point, valid_len): return quant_float -def quantize_blockwise_nbits_ref(matrix_float: npt.ArrayLike, bits: int, block_size: int, is_symmetric: bool): +def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): if len(matrix_float.shape) != 2: raise ValueError("Current int4 block quantization only supports 2D tensors!") rows, cols = matrix_float.shape - pack_size = 8 // bits - default_zp = 1 << (bits - 1) - quant_max = ((1 << bits) - 1) if is_symmetric else ((1 << (bits -1)) - 1) - blob_size = block_size // pack_size + blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size padded_rows = k_blocks * block_size pad_len = padded_rows - rows @@ -42,11 +39,7 @@ def quantize_blockwise_nbits_ref(matrix_float: npt.ArrayLike, bits: int, block_s packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") scales = np.zeros((cols, k_blocks), dtype=matrix_float_padded.dtype) - - if bits == 2: - zero_point = np.full((cols, (k_blocks + 1) // pack_size), 0b10101010, dtype="uint8") - elif bits == 4: - zero_point = np.full((cols, (k_blocks + 1) // pack_size), 0b10001000, dtype="uint8") + zero_point = np.full((cols, (k_blocks + 1) // 2), 136, dtype="uint8") matrix_float_padded = np.transpose(matrix_float_padded) for n in range(cols): @@ -54,82 +47,38 @@ def quantize_blockwise_nbits_ref(matrix_float: npt.ArrayLike, bits: int, block_s if is_symmetric: amax_idx = np.argmax(np.abs(matrix_float_padded[n, k_id : k_id + block_size])) bmax = np.float32(matrix_float_padded[n, k_id + amax_idx]) - scale = bmax / default_zp - zp = default_zp + scale = bmax / (-8.0) + zp = 8 else: vmin = np.min(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) vmax = np.max(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) vmin = min(vmin, 0.0) vmax = max(vmax, 0.0) - scale = (vmax - vmin) / ((1 << bits) - 1) + scale = (vmax - vmin) / ((1 << 4) - 1) zero_point_fp = vmin if scale != 0.0: zero_point_fp = 0.0 - vmin / scale - zp = min(quant_max, max(0, round(zero_point_fp))) + zp = min(15, max(0, round(zero_point_fp))) reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 block_idx = k_id // block_size scales[n, block_idx] = scale - if bits == 2: - zero_point_index = block_idx // 4 - zp_quad = zero_point[n, zero_point_index] - quad_index = block_idx & 3 - if quad_index == 0: - zero_point[n, zero_point_index] = (zp_quad & 0b11111100) | zp - elif quad_index == 1: - zero_point[n, zero_point_index] = ((zp_quad & 0b11110011) | (zp << 2)) - elif quad_index == 3: - zero_point[n, zero_point_index] = ((zp_quad & 0b11001111) | (zp << 4)) - elif quad_index == 3: - zero_point[n, zero_point_index] = ((zp_quad & 0b00111111) | (zp << 6)) - else: - raise ValueError("Unsupported bits for blockwise quantization!") - - blk_int0 = np.clip( - np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 4] * reciprocal_scale + zp)), - 0, - 3, - ).astype("uint8") - blk_int1 = np.clip( - np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 4] * reciprocal_scale + zp)), - 0, - 3, - ).astype("uint8") - blk_int2 = np.clip( - np.round(np.float32(matrix_float_padded[n, k_id + 2 : k_id + block_size : 4] * reciprocal_scale + zp)), - 0, - 3, - ).astype("uint8") - blk_int3 = np.clip( - np.round(np.float32(matrix_float_padded[n, k_id + 3 : k_id + block_size : 4] * reciprocal_scale + zp)), - 0, - 3, - ).astype("uint8") - - packed[n, block_idx] = np.bitwise_or( - np.bitwise_or(blk_int0, np.left_shift(blk_int1, 2)), - np.left_shift(np.bitwise_or(blk_int2, np.left_shift(blk_int3, 2)), 4)) - elif bits == 4: - zero_point_index = block_idx // 2 - zp_pair = zero_point[n, zero_point_index] - if (block_idx & 1) == 0: - zero_point[n, zero_point_index] = (zp_pair & 0xF0) | zp - else: - zero_point[n, zero_point_index] = (zp_pair & 0x0F) | (zp << 4) - - blk_int0 = np.clip( - np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 2] * reciprocal_scale + zp)), - 0, - 15, - ).astype("uint8") - blk_int1 = np.clip( - np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 2] * reciprocal_scale + zp)), - 0, - 15, - ).astype("uint8") - packed[n, block_idx] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) - else: - raise ValueError("Unsupported bits for blockwise quantization!") + zp_pair = zero_point[n, block_idx // 2] + zero_point[n, block_idx // 2] = ( + ((zp_pair & 0x0F) | (zp << 4)) if (block_idx & 1) else ((zp_pair & 0xF0) | zp) + ) + + blk_int0 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") + blk_int1 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") + packed[n, block_idx] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) return (packed, scales, zero_point) @@ -143,15 +92,15 @@ def quantize_blockwise_4bits_target(matrix_float: npt.ArrayLike, block_size: int packed = np.zeros((cols, k_blocks, block_size // 2), dtype="uint8") scales = np.zeros((cols, k_blocks), dtype=matrix_float.dtype) zero_point = np.full((cols, (k_blocks + 1) // 2), 136, dtype="uint8") - from onnxruntime.capi._pybind_state import quantize_matmul_nbits, quantize_matmul_4bits # noqa: PLC0415 + from onnxruntime.capi._pybind_state import quantize_matmul_4bits # noqa: PLC0415 - quantize_matmul_nbits(packed, matrix_float, 4, scales, zero_point, block_size, cols, rows, is_symmetric) + quantize_matmul_4bits(packed, matrix_float, scales, zero_point, block_size, cols, rows, is_symmetric) return (packed, scales, zero_point) class TestQuantizeBlockwise4Bits(unittest.TestCase): @unittest.skipIf( - find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_nbits" + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) def test_quantize_blockwise_4bits(self): for rows, cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: From b4b143fc6a56103f6d99204fffe624b80dc9d4ec Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 13 Aug 2025 13:12:34 -0700 Subject: [PATCH 13/33] prepack wip -- not prepacking b data because dispatch to check for mlas kernel not implemented for fp32. Also, I need to write the packing logic for the scales as well. --- .../cpu/quantization/matmul_nbits.cc | 1 + onnxruntime/core/mlas/lib/qnbitgemm.cpp | 3 + .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 151 ++++++++++++++++-- 3 files changed, 146 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index e371e89776cc1..871e8e8e4bf46 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -39,6 +39,7 @@ typedef enum { Level2, /*!< input fp16, accumulator fp16 */ Level3, /*!< input bf16, accumulator fp32 */ Level4, /*!< input int8, accumulator int32 */ + Level5, /*!< input uint8, use TMAC LUT approach TODO: fix this comment*/ } ACCURACY_LEVEL; // T: A data type. diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index bc803657aae48..7fd33fc849772 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -35,6 +35,7 @@ enum QNBitGemmVariant { SQ8BitGemmVariant_CompInt8, SQNBitGemmVariant_BitWidth2_CompInt8, + SQ2BitGemmVariant_CompFp32, // TODO: determine if this makes sense // End of valid variants // Keep this element last and ensure that its value is the number of valid QNBitGemmVariant values. @@ -53,6 +54,8 @@ GetQNBitGemmVariant( if (BlkBitWidth == 2) { if (ComputeType == SQNBIT_CompInt8) { return SQNBitGemmVariant_BitWidth2_CompInt8; + } else if (ComputeType == SQNBIT_CompFp32) { + return SQ2BitGemmVariant_CompFp32; } } else if (BlkBitWidth == 4) { if (ComputeType == SQNBIT_CompFp32) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index d6d104967e3a7..856695b970be7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -17,6 +17,7 @@ Module Name: #include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" +#include size_t Q2BitGemmPackQuantBDataSize( @@ -28,7 +29,6 @@ Q2BitGemmPackQuantBDataSize( { // TODO: This code shall change according to T-Mac. MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - constexpr size_t BlkBitWidth = 2; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -37,16 +37,148 @@ Q2BitGemmPackQuantBDataSize( } void SQ2BitGemmPackQuantBData( - size_t /*N*/, - size_t /*K*/, - size_t /*BlkLen*/, + size_t N, + size_t K, + size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/, - const std::byte* /*QuantBDataBegin*/, - std::byte* /*PackedQuantBDataBegin*/, - MLAS_THREADPOOL* /*ThreadPool*/ -) + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) { - // TODO: need implementation + // T-MAC like configuration (approved): + // bits=2, g=4, ngroups_per_elem=8/g=2, simd_n_in=16, simd_n_out=8, bm=512, kfactor=16 + constexpr int bits = 2; + constexpr int g = 4; + constexpr int ngroups_per_elem = 8 / g; // 2 + constexpr int simd_n_in = 16; + constexpr int simd_n_out = 8; + constexpr int bm = 512; // tune as needed; must be multiple of bits and mgroup + constexpr int kfactor = 16; // tune as needed; must divide K/g per block + + // Basic checks + MLAS_UNREFERENCED_PARAMETER(K); + assert(BlkLen % g == 0); + assert((BlkLen / g) % kfactor == 0); + const int mgroup = ngroups_per_elem * simd_n_in; // 32 + assert(bm % mgroup == 0); + assert(bm % bits == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(bits, BlkLen); // BlkLen/4 bytes + + const int m_block = bm / bits; // number of original rows (columns of B) per tile + assert(N % static_cast(m_block) == 0); + const size_t tiles_in_m = N / static_cast(m_block); + + const int K_over_g = static_cast(BlkLen / g); + + // We write destination in block-major layout: for each k-block, its N columns packed contiguously. + // Per (k_blk, tile) we produce a chunk of size m_block * BlkDataSize bytes. + const size_t tile_chunk_bytes = static_cast(m_block) * BlkDataSize; // = m_block * BlkLen/4 + + const size_t Iterations = BlockCountK * tiles_in_m; + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t k_blk = static_cast(tid) / tiles_in_m; + const size_t tile_idx = static_cast(tid) % tiles_in_m; + + // Temporary buffers per tile + // buf2: size = (m_block * bits) * (BlkLen/g) + // tilechunk: size = m_block * BlkLen/4 bytes + std::vector buf2(static_cast(m_block) * bits * K_over_g, 0); + std::vector tilechunk(tile_chunk_bytes, 0); + + // Stage 1: build buf2 (bit-planes grouped along K by g) + for (int im = 0; im < m_block; ++im) { + const size_t n_col = tile_idx * static_cast(m_block) + static_cast(im); + const size_t src_block_offset = n_col * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* src_block = QuantBDataBegin + src_block_offset; + + for (int ik = 0; ik < static_cast(BlkLen); ++ik) { + const int byte_idx = ik >> 2; // ik/4 + const int lane = ik & 3; // ik%4 + const uint8_t src_byte = static_cast(src_block[byte_idx]); + const uint8_t v = static_cast((src_byte >> (lane * bits)) & 0x3u); + + const int ik_g = ik / g; + const int shft_left = ik % g; // 0..3 + for (int ib = 0; ib < bits; ++ib) { + const size_t idx = static_cast(im) * bits * K_over_g + static_cast(ib) * K_over_g + static_cast(ik_g); + buf2[idx] = static_cast(buf2[idx] + (((v >> ib) & 0x1u) << shft_left)); + } + } + } + + // Precompute reshape/transpose factors (use K' = BlkLen) + const int c0_fac2 = K_over_g; + const int c0_fac1 = simd_n_out * c0_fac2; + const int c0_fac0 = bits * c0_fac1; + + const int c1_nb2 = K_over_g; + const int c1_nb1 = simd_n_in * c1_nb2; + const int c1_nb0 = ngroups_per_elem * c1_nb1; + const int c1_fac2 = K_over_g; + const int c1_fac1 = ngroups_per_elem * c1_fac2; + const int c1_fac0 = simd_n_in * c1_fac1; + + const int c2_nb4 = kfactor; + const int c2_nb3 = (K_over_g / kfactor) * c2_nb4; + const int c2_nb2 = ngroups_per_elem * c2_nb3; + const int c2_nb1 = simd_n_in * c2_nb2; + const int c2_nb0 = (bm / mgroup) * c2_nb1; + const int c2_fac3 = simd_n_in * ngroups_per_elem; + const int c2_fac2 = kfactor * c2_fac3; + const int c2_fac1 = (bm / mgroup) * c2_fac2; + const int c2_fac0 = (K_over_g / kfactor) * c2_fac1; + + // Stage 2: multi-reshape/transpose into tilechunk + for (int im = 0; im < m_block; ++im) { + for (int ib = 0; ib < bits; ++ib) { + for (int ik = 0; ik < K_over_g; ++ik) { + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + int new_im = im / simd_n_out; + int new_isno = im % simd_n_out; + int new_ib = ib; + int new_ik = ik; + int new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; + + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + new_im = new_idx / c1_nb0; + int new_ing = (new_idx % c1_nb0) / c1_nb1; + int new_isni = (new_idx % c1_nb1) / c1_nb2; + new_ik = (new_idx % c1_nb2); + new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; + + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + new_im = new_idx / c2_nb0; + int new_ibm = (new_idx % c2_nb0) / c2_nb1; + new_isni = (new_idx % c2_nb1) / c2_nb2; + new_ing = (new_idx % c2_nb2) / c2_nb3; + new_ik = (new_idx % c2_nb3) / c2_nb4; + int new_ikf = (new_idx % c2_nb4); + new_idx = new_im * c2_fac0 + new_ik * c2_fac1 + new_ibm * c2_fac2 + new_ikf * c2_fac3 + new_isni * ngroups_per_elem + new_ing; + + // Collapse ngroups into byte by left-shifting lanes of g + const size_t src_idx = static_cast(im) * bits * K_over_g + static_cast(ib) * K_over_g + static_cast(ik); + const uint8_t v = buf2[src_idx]; + const size_t dst_idx = static_cast(new_idx / ngroups_per_elem); + tilechunk[dst_idx] = static_cast(tilechunk[dst_idx] + (v << (new_ing * g))); + } + } + } + + // Store the tile chunk into destination + std::byte* dst_block_base = PackedQuantBDataBegin + k_blk * (N * BlkDataSize); + std::byte* tile_dest = dst_block_base + tile_idx * tile_chunk_bytes; + // copy bytes + for (size_t i = 0; i < tile_chunk_bytes; ++i) { + tile_dest[i] = static_cast(tilechunk[i]); + } + } + ); } size_t @@ -94,6 +226,7 @@ SQ2BitGemmKernel_CompInt8_avx2( return 0; } +// TODO: do we need this..? void QuantizeARow_CompInt8( size_t /*BlkLen*/, From 534b8e6d18099b4a49fd96a911696e64335c583f Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 15 Aug 2025 15:29:11 -0700 Subject: [PATCH 14/33] fixed dispatch issue, added acc level 4 tests, and now running into assert issue with the data shuffling in prepack --- .../cpu/quantization/matmul_nbits.cc | 4 ++++ onnxruntime/core/mlas/lib/qnbitgemm.cpp | 5 +---- .../test/contrib_ops/matmul_2bits_test.cc | 17 +++++++++++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 871e8e8e4bf46..f8930e03a46c4 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -317,6 +317,10 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); auto* y_data = y->MutableData(); + // TODO: add the logic for generating lookup table here -- for now we can assume that + // 2 bits + acc level 4 = use look up table but in the future adapt so that we use a mamtulnbits attr to decide + // if we want to do lut generation + const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); const size_t N = static_cast(helper.N()); diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 7fd33fc849772..c710681539c4c 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -35,7 +35,6 @@ enum QNBitGemmVariant { SQ8BitGemmVariant_CompInt8, SQNBitGemmVariant_BitWidth2_CompInt8, - SQ2BitGemmVariant_CompFp32, // TODO: determine if this makes sense // End of valid variants // Keep this element last and ensure that its value is the number of valid QNBitGemmVariant values. @@ -54,8 +53,6 @@ GetQNBitGemmVariant( if (BlkBitWidth == 2) { if (ComputeType == SQNBIT_CompInt8) { return SQNBitGemmVariant_BitWidth2_CompInt8; - } else if (ComputeType == SQNBIT_CompFp32) { - return SQ2BitGemmVariant_CompFp32; } } else if (BlkBitWidth == 4) { if (ComputeType == SQNBIT_CompFp32) { @@ -110,7 +107,7 @@ MlasIsQNBitGemmAvailable( (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); } case SQNBitGemmVariant_BitWidth2_CompInt8: { - return (Dispatch->SQ2BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr); + return (Dispatch->SQ2BitGemmKernel_CompInt8 != nullptr); // TODO: originally also checked for the existence of Dispatch->QuantizeARow_CompInt8 which for some reason dispatched as null } case SQ8BitGemmVariant_CompInt8: { return Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr && diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 884922ec5c098..de267a29803b1 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -260,6 +260,23 @@ TEST(MatMulNBits, Float32_2Bits_Accuracy0) { TestMatMul2BitsTyped(); TestMatMul2BitsTyped(); } + +TEST(MatMulNBits, Float32_2Bits_Accuracy4) { + // Currently, only fallback option enabled for 2bit datatypes + // where the 2bits are dequantized to fp32 + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); +} } // namespace test } // namespace onnxruntime From 70d658898eaf4808ea13f84059b26d0809f44730 Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Tue, 2 Sep 2025 20:58:23 +0000 Subject: [PATCH 15/33] deep sigh --- cmake/onnxruntime_mlas.cmake | 2 + .../cpu/quantization/matmul_nbits.cc | 8 + onnxruntime/core/mlas/inc/mlas_qnbit.h | 23 ++ onnxruntime/core/mlas/lib/mlasi.h | 5 + onnxruntime/core/mlas/lib/platform.cpp | 2 + onnxruntime/core/mlas/lib/q4_dq.cpp | 35 --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 36 +++ onnxruntime/core/mlas/lib/qlutgemm.h | 29 ++ .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 258 +++++++++++++++++- .../test/contrib_ops/matmul_2bits_test.cc | 24 +- 10 files changed, 361 insertions(+), 61 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/qlutgemm.cpp create mode 100644 onnxruntime/core/mlas/lib/qlutgemm.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index d2dad6bdb06f9..34ed6901f8e4e 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp ${MLAS_SRC_DIR}/qnbitgemm.h ${MLAS_SRC_DIR}/qnbitgemm.cpp + ${MLAS_SRC_DIR}/qlutgemm.h + ${MLAS_SRC_DIR}/qlutgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/cast.cpp diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index f8930e03a46c4..9246644fabf57 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -55,6 +55,10 @@ GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { return SQNBIT_CompInt8; } + if (accuracy_level_attr == static_cast(Level5) && MlasIsTMACAvailable(nbits, block_size, SQNBIT_CompInt8)) { + return TMAC; + } + return SQNBIT_CompFp32; } @@ -320,6 +324,10 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, // TODO: add the logic for generating lookup table here -- for now we can assume that // 2 bits + acc level 4 = use look up table but in the future adapt so that we use a mamtulnbits attr to decide // if we want to do lut generation + if (compute_type_ == TMAC) { + // call lut gen somehow + MlasTmacInitializeTable(); + } const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 3627989609737..165e425cbf4a7 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -32,6 +32,7 @@ typedef enum { BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ + TMAC } MLAS_QNBIT_GEMM_COMPUTE_TYPE; /** @@ -221,3 +222,25 @@ MlasQNBitGemmScalesPacked( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool HasZeroPoint ); + +/** + * @brief Determines whether the TMAC LUT optimization path is available on the current platform + * for the provided quantization parameters. + * + * This API is used by higher-level ops to choose the TMAC path. It complements + * MlasIsQNBitGemmAvailable by querying availability of the LUT-based strategy. + */ +bool MLASCALL +MlasIsTMACAvailable( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +/** + * @brief Initializes any global tables required by TMAC LUT kernels. + * + * Returns true if initialization succeeded or was unnecessary. + */ +bool MLASCALL +MlasTmacInitializeTable(); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index a099bcf8438fe..403a301b82eac 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1211,6 +1211,10 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; +struct MLAS_QNBIT_LUT_GEMM_DISPATCH; + +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLUTGenKernelAvx2; + // // Rotary embedding dispatch structure. // @@ -1400,6 +1404,7 @@ struct MLAS_PLATFORM { const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr}; + const MLAS_QNBIT_LUT_GEMM_DISPATCH* LUTGenKernel{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 3256dadb856d3..2413144919cdb 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -411,6 +411,8 @@ Return Value: this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; this->RopeDispatch = &MlasRopeDispatchAvx2; + // TODO: check if this really goes here or if there are other platform reqs that we need to fulfill + this->LUTGenKernel = &MlasLUTGenKernelAvx2; // // Check if the processor supports Hybrid core architecture. diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index b127671dcb517..820a3ca762319 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -1452,17 +1452,6 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); -template -void -MlasBlockwiseQuantMetaShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& meta_rows, - int& meta_cols - ); - template void MlasBlockwiseQuantMetaShape( int block_size, @@ -1539,16 +1528,6 @@ MlasBlockwiseQuantizedShape( int& q_cols ); -template void -MlasBlockwiseQuantizedShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& q_rows, - int& q_cols -); - template void MlasBlockwiseQuantizedShape( int block_size, @@ -1818,20 +1797,6 @@ MlasQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); -template void -MlasQuantizeBlockwise( - uint8_t* dst, - float* scales, - uint8_t* zero_points, - const float* src, - int block_size, - bool columnwise, - int rows, - int columns, - int leading_dimension, - MLAS_THREADPOOL* thread_pool -); - template void MlasQuantizeBlockwise( diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp new file mode 100644 index 0000000000000..1b283c6c67539 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -0,0 +1,36 @@ +/*++ + +// TODO: finish filling this out + +module includes kernel functions for generating LUT for T-MAC GEMM optimization strategy. +*/ + +#include "qlutgemm.h" + +bool MLASCALL MlasIsTMACAvailable( + size_t /*BlkBitWidth*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + const auto* Dispatch = GetMlasPlatform().LUTGenKernel; + return Dispatch != nullptr; + // TODO: once you add the kernel for lut matmul itself, add switch case that handles the variant + // and checks that the variant exists +} + +bool MLASCALL MlasTmacInitializeTable( + size_t BlkLen, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* qlut, + size_t CountN, + size_t countK, + size_t BlockStrideQuantB, + const float* Bias +) { + const auto* Dispatch = GetMlasPlatform().LUTGenKernel; + + return false; +} diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h new file mode 100644 index 0000000000000..86c2e1cb3812e --- /dev/null +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -0,0 +1,29 @@ +// TODO: fill out abstract for this file +// Base off of qnbitgemm.h + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +typedef +void(MLAS_QNBIT_GEMM_LUT_GEN)( + int32_t group_size, + int8_t* lut, + onnxruntime::MLFloat16* b, + onnxruntime::MLFloat16* scales, + onnxruntime::MLFloat16* biases +); + + +// +// Kernel dispatch structure. +// +// NOTE: This name must match the forward declaration in mlasi.h: +// struct MLAS_QNBIT_LUT_GEMM_DISPATCH; +// Keep it minimal for now; extend with function pointers as kernels are added. +struct MLAS_QNBIT_LUT_GEMM_DISPATCH { + // Intentionally empty placeholder; add members as needed. + MLAS_QNBIT_GEMM_LUT_GEN* GenerateLUT = nullptr; + +}; \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 856695b970be7..3e433908924da 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -16,8 +16,11 @@ Module Name: --*/ #include "qnbitgemm.h" +#include "qlutgemm.h" #include "sqnbitgemm_q8_block.h" #include +// AVX2 intrinsics +#include size_t Q2BitGemmPackQuantBDataSize( @@ -206,24 +209,135 @@ Q2BitGemmPerGemmWorkspaceSize( } } +// pass in LUT for size_t SQ2BitGemmKernel_CompInt8_avx2( - size_t /*BlkLen*/, - const std::byte* /*QuantA*/, - const std::byte* /*QuantBData*/, - const float* /*QuantBScale*/, - const std::byte* /*QuantBZeroPoint*/, - float* /*C*/, - size_t /*CountM*/, - size_t /*CountN*/, - size_t /*CountK*/, - size_t /*BlockCountK*/, - size_t /*ldc*/, - const float* /*Bias*/ + size_t BlkLen, // group + const std::byte* QuantA, + const std::byte* QuantBData, // we pass in the LUT here + const float* QuantBScale, // LUT scales + const std::byte* QuantBZeroPoint, // LUT zero points + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t /*BlockCountK*/, // number of k blocks of length blklen?? + size_t /*ldc*/, // leading dimension for c (unused for CountN==1 path) + const float* /*Bias*/ // bias per output col for c ) { - // reference SQ4BitGemmKernel_CompInt8_avx2 - return 0; + // Implement qgemm_lut_int8_g4 (AVX2 path) for Bits=2, g=4, ActK=16, CountN == 1, K % 16 == 0. + // Notes: + // - This uses the same A/LUT/scales/biases layout assumptions as tmac's tbl.cpp AVX2 path. + // - C is updated in the same lane order as tmac (tile-local contiguous), which is fine for CountN==1. + + constexpr int Bits = 2; + constexpr int ActK = 16; + MLAS_UNREFERENCED_PARAMETER(BlkLen); + + // Preconditions we support in this initial implementation. + if (CountN != 1 || (CountK % ActK) != 0) { + return 0; // not handled + } + + const uint8_t* a = reinterpret_cast(QuantA); + const int8_t* lut = reinterpret_cast(QuantBData); + const float* lut_scales = QuantBScale; // one per kk-chunk (ActK) + const float* lut_biases = reinterpret_cast(QuantBZeroPoint); // one per kk-chunk (ActK) + float* c = C; + + // Process rows in groups of 32 as in tmac AVX2 path (i iterates 16 over m/2). + size_t rows_handled = (CountM / 32) * 32; + if (rows_handled == 0) { + return 0; + } + + const __m128i vec_mask = _mm_set1_epi8(0x0f); + + for (size_t i = 0; i < rows_handled / 2; i += 16) { + __m256 vec_c0{}, vec_c1{}, vec_c2{}, vec_c3{}; + bool c_initialized = false; + float partial_sum = -0.0f; + + for (size_t kk = 0; kk < CountK; kk += ActK) { + // Accumulators for this kk-chunk: sum 16 int8 lookups across ActK into 4x8 lanes + __m128i acc_lo_low = _mm_setzero_si128(); + __m128i acc_lo_high = _mm_setzero_si128(); + __m128i acc_hi_low = _mm_setzero_si128(); + __m128i acc_hi_high = _mm_setzero_si128(); + + for (int k = 0; k < ActK; ++k) { + // Load 16 LUT entries for this k (indices 0..15) + const __m128i vec_lut_k = _mm_loadu_si128(reinterpret_cast(lut + (kk + k) * 16)); + // Load 16 selector bytes for bottom/top nibbles from A for this (i-block, k) + const __m128i vec_as = _mm_loadu_si128(reinterpret_cast(a + i * CountK + (kk + k) * 16)); + const __m128i vec_a_bot = _mm_and_si128(vec_as, vec_mask); + const __m128i vec_a_top = _mm_and_si128(_mm_srli_epi16(vec_as, 4), vec_mask); + + // Shuffle-gather from LUT using bottom and top nibble indices + const __m256i vec_lut_dup = _mm256_set_m128i(vec_lut_k, vec_lut_k); + const __m256i vec_a_bt = _mm256_set_m128i(vec_a_top, vec_a_bot); + const __m256i vec_v = _mm256_shuffle_epi8(vec_lut_dup, vec_a_bt); // 32 int8 results + + // Split to 2x16 and sign-extend to int16 + const __m128i v_bot8 = _mm256_castsi256_si128(vec_v); + const __m128i v_top8 = _mm256_extracti128_si256(vec_v, 1); + + const __m256i vb16 = _mm256_cvtepi8_epi16(v_bot8); + const __m256i vt16 = _mm256_cvtepi8_epi16(v_top8); + + const __m128i vb16_low = _mm256_castsi256_si128(vb16); + const __m128i vb16_high = _mm256_extracti128_si256(vb16, 1); + const __m128i vt16_low = _mm256_castsi256_si128(vt16); + const __m128i vt16_high = _mm256_extracti128_si256(vt16, 1); + + acc_lo_low = _mm_add_epi16(acc_lo_low, vb16_low); + acc_lo_high = _mm_add_epi16(acc_lo_high, vb16_high); + acc_hi_low = _mm_add_epi16(acc_hi_low, vt16_low); + acc_hi_high = _mm_add_epi16(acc_hi_high, vt16_high); + } + + // Convert to float vectors (4 groups of 8) + const __m256 vec_v_low_low = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(acc_lo_low)); + const __m256 vec_v_low_high = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(acc_lo_high)); + const __m256 vec_v_high_low = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(acc_hi_low)); + const __m256 vec_v_high_high = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(acc_hi_high)); + + float lut_s = lut_scales[kk / ActK]; + float lut_b = lut_biases ? lut_biases[kk / ActK] : 0.0f; + partial_sum += lut_b; + + // Apply per-bit-group bias pattern: add bias only when (ib % Bits == 0) + auto fma_with_bias = [&](const __m256& vs, size_t ib) { + if ((ib % Bits) == 0) { + return _mm256_fmadd_ps(vs, _mm256_set1_ps(lut_s), _mm256_set1_ps(lut_b)); + } else { + return _mm256_mul_ps(vs, _mm256_set1_ps(lut_s)); + } + }; + + if (!c_initialized) { + vec_c0 = fma_with_bias(vec_v_low_low, (i / 4)); + vec_c1 = fma_with_bias(vec_v_low_high, (i / 4 + 1)); + vec_c2 = fma_with_bias(vec_v_high_low, (i / 4 + 2)); + vec_c3 = fma_with_bias(vec_v_high_high, (i / 4 + 3)); + c_initialized = true; + } else { + vec_c0 = _mm256_add_ps(vec_c0, fma_with_bias(vec_v_low_low, (i / 4))); + vec_c1 = _mm256_add_ps(vec_c1, fma_with_bias(vec_v_low_high, (i / 4 + 1))); + vec_c2 = _mm256_add_ps(vec_c2, fma_with_bias(vec_v_high_low, (i / 4 + 2))); + vec_c3 = _mm256_add_ps(vec_c3, fma_with_bias(vec_v_high_high, (i / 4 + 3))); + } + } // kk + + // Store back to C in tmac lane order: 8 floats x 4 groups + _mm256_storeu_ps(c + i * 2, vec_c0); + _mm256_storeu_ps(c + i * 2 + 8, vec_c1); + _mm256_storeu_ps(c + i * 2 + 16, vec_c2); + _mm256_storeu_ps(c + i * 2 + 24, vec_c3); + } + + return rows_handled; } // TODO: do we need this..? @@ -236,4 +350,120 @@ QuantizeARow_CompInt8( ) { // shall be similar to QuantizeARow_CompInt8_avx2 without blksum related code. + // we don't need this function -- remove from dispatch? } + +// based on lut_ctor_g4_int8_impl +void +GenerateLUT_avx2( + int32_t group_size, + int8_t* lut, + onnxruntime::MLFloat16* b, + onnxruntime::MLFloat16* scales, + onnxruntime::MLFloat16* biases +) { + // Helper to horizontally add all 8 lanes of a __m256 + auto addv_ps = [](const __m256 v) -> float { + __m128 res = _mm256_extractf128_ps(v, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(v)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); + }; + + // Read scale (already computed elsewhere) and prepare its reciprocal. + const float scale_f = static_cast(scales[0]); + const float t_scale = scale_f != 0.0f ? (1.0f / scale_f) : 0.0f; + + // Accumulate bias across blocks of 32 (matches tmac layout: 4 interleaved streams of 8) + float bias_acc = 0.0f; + + // Temporary buffers for converted floats + float tmp[32]; + float b0[8], b1[8], b2[8], b3[8]; + + // We produce 16 vectors per 32-wide chunk, then pack to int8 and store + // Each block of 32 half values contributes 32 int8 entries per LUT row (16 entries x 2 halves) arranged like tmac + for (int kblk = 0; kblk < group_size / 32; ++kblk) { + // Convert 32 halfs to float + const onnxruntime::MLFloat16* base = b + kblk * 32; + for (int i = 0; i < 32; ++i) tmp[i] = static_cast(base[i]); + + // De-interleave to 4 streams of 8 + for (int i = 0; i < 8; ++i) { + b0[i] = tmp[i * 4 + 0]; + b1[i] = tmp[i * 4 + 1]; + b2[i] = tmp[i * 4 + 2]; + b3[i] = tmp[i * 4 + 3]; + } + + __m256 vec_b0 = _mm256_loadu_ps(b0); + __m256 vec_b1 = _mm256_loadu_ps(b1); + __m256 vec_b2 = _mm256_loadu_ps(b2); + __m256 vec_b3 = _mm256_loadu_ps(b3); + + __m256 vec_lut[16]; + + // Build odd indices 1..15: b0 +/- b1 +/- b2 +/- b3 depending on bits of g + for (int g = 1; g < 16; g += 2) { + __m256 v = vec_b0; + v = (g & 0b0010) ? _mm256_add_ps(v, vec_b1) : _mm256_sub_ps(v, vec_b1); + v = (g & 0b0100) ? _mm256_add_ps(v, vec_b2) : _mm256_sub_ps(v, vec_b2); + v = (g & 0b1000) ? _mm256_add_ps(v, vec_b3) : _mm256_sub_ps(v, vec_b3); + vec_lut[g] = v; + } + + // Even indices are negatives of mirrored odd indices + for (int g = 0; g < 16; g += 2) { + vec_lut[g] = _mm256_sub_ps(_mm256_setzero_ps(), vec_lut[15 - g]); + } + + // Accumulate bias from entry 0 (before scaling) + bias_acc += addv_ps(vec_lut[0]); + + // Apply inverse scale + const __m256 vs = _mm256_set1_ps(t_scale); + for (int g = 0; g < 16; ++g) { + vec_lut[g] = _mm256_mul_ps(vec_lut[g], vs); + } + + // Round to nearest, pack to int8 with saturate, and shuffle into the final lane order + __m256i vec_qlut[4]; + const __m256i shuf = _mm256_setr_epi8( + 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, + 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + + for (int g = 0; g < 4; ++g) { + __m256i i0 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 0], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i1 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 1], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i2 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 2], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i3 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 3], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + __m256i i8 = _mm256_packs_epi16(i0, i2); + vec_qlut[g] = _mm256_shuffle_epi8(i8, shuf); + } + + // Store 8 lanes x 4 rows for this 32-wide block + int32_t* qlut_i32 = reinterpret_cast(lut); + for (int lane = 0; lane < 8; ++lane) { + for (int g = 0; g < 4; ++g) { + qlut_i32[kblk * 32 + lane * 4 + g] = _mm256_extract_epi32(vec_qlut[g], lane); + } + } + } + + // Write back bias and leave scale as-is + biases[0] = onnxruntime::MLFloat16(bias_acc); + // scales[0] unchanged + return; +} + +// Kernel dispatch structure definition. + +const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLUTGenKernelAvx2 = []() { + MLAS_QNBIT_LUT_GEMM_DISPATCH d; + d.GenerateLUT = GenerateLUT_avx2; + return d; +}(); \ No newline at end of file diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index de267a29803b1..9010e15cf56e5 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -264,18 +264,18 @@ TEST(MatMulNBits, Float32_2Bits_Accuracy0) { TEST(MatMulNBits, Float32_2Bits_Accuracy4) { // Currently, only fallback option enabled for 2bit datatypes // where the 2bits are dequantized to fp32 - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); } } // namespace test } // namespace onnxruntime From ad2572b42805040ba70079fece20b05452fc182c Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Thu, 4 Sep 2025 23:34:05 +0000 Subject: [PATCH 16/33] builds somehow --- .../cpu/quantization/matmul_nbits.cc | 6 ++-- onnxruntime/core/mlas/inc/mlas_qnbit.h | 9 +++-- onnxruntime/core/mlas/lib/qlutgemm.cpp | 36 +++++++++++-------- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 34 +++++++++++++----- 4 files changed, 58 insertions(+), 27 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 9246644fabf57..7a45707da3078 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -55,7 +55,7 @@ GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { return SQNBIT_CompInt8; } - if (accuracy_level_attr == static_cast(Level5) && MlasIsTMACAvailable(nbits, block_size, SQNBIT_CompInt8)) { + if (accuracy_level_attr == static_cast(Level5) && MlasIsTMACAvailable(nbits, block_size)) { return TMAC; } @@ -321,12 +321,14 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); auto* y_data = y->MutableData(); + IAllocatorUniquePtr lut{}; + // TODO: add the logic for generating lookup table here -- for now we can assume that // 2 bits + acc level 4 = use look up table but in the future adapt so that we use a mamtulnbits attr to decide // if we want to do lut generation if (compute_type_ == TMAC) { // call lut gen somehow - MlasTmacInitializeTable(); + MlasTmacInitializeTable(block_size_, packed_b_.get(), scales_data, lut.get()); } const size_t batch_count = helper.OutputOffsets().size(); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 165e425cbf4a7..b3d81aae73ed3 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -233,8 +233,7 @@ MlasQNBitGemmScalesPacked( bool MLASCALL MlasIsTMACAvailable( size_t BlkBitWidth, - size_t BlkLen, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + size_t BlkLen ); /** @@ -243,4 +242,8 @@ MlasIsTMACAvailable( * Returns true if initialization succeeded or was unnecessary. */ bool MLASCALL -MlasTmacInitializeTable(); +MlasTmacInitializeTable(size_t BlkLen, + const void* QuantBData, // B in MLFloat16 (per your layout) + const float* QuantBScale, // scale(s) in float + void* qlut // destination LUT buffer (int8 data) +); diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index 1b283c6c67539..a9a940b6ff294 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -9,28 +9,36 @@ module includes kernel functions for generating LUT for T-MAC GEMM optimization bool MLASCALL MlasIsTMACAvailable( size_t /*BlkBitWidth*/, - size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ + size_t BlkLen ) { const auto* Dispatch = GetMlasPlatform().LUTGenKernel; - return Dispatch != nullptr; - // TODO: once you add the kernel for lut matmul itself, add switch case that handles the variant - // and checks that the variant exists + return Dispatch != nullptr && BlkLen == 4; // only support group sizes of 4 for now } bool MLASCALL MlasTmacInitializeTable( size_t BlkLen, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* qlut, - size_t CountN, - size_t countK, - size_t BlockStrideQuantB, - const float* Bias + const void* QuantBData, // B in MLFloat16 (per your layout) + const float* QuantBScale, // scale(s) in float + void* qlut ) { const auto* Dispatch = GetMlasPlatform().LUTGenKernel; + if (!Dispatch || !Dispatch->GenerateLUT) return false; - return false; + // Cast target LUT buffer to int8, and prepare half-precision inputs + auto* lut_i8 = reinterpret_cast(qlut); + auto* b_half = const_cast( + reinterpret_cast(QuantBData)); + + // Convert the first float scale to half (adjust if you have more) + onnxruntime::MLFloat16 s16(QuantBScale[0]); + onnxruntime::MLFloat16 b16(0.0f); // output bias goes here // TODO: pass the biases here + + // Call the dispatch + Dispatch->GenerateLUT(static_cast(BlkLen), lut_i8, b_half, &s16, &b16); + + // If you need the bias value elsewhere, read it from b16 + // float bias_f = static_cast(b16); + + return true; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 3e433908924da..e16a731c5a036 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -26,16 +26,14 @@ size_t Q2BitGemmPackQuantBDataSize( size_t N, size_t K, - size_t BlkLen, + size_t /*BlkLen*/, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { // TODO: This code shall change according to T-Mac. MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - constexpr size_t BlkBitWidth = 2; - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PackedQuantBDataSize = N * K / 8; return PackedQuantBDataSize; } @@ -447,10 +445,30 @@ GenerateLUT_avx2( // Store 8 lanes x 4 rows for this 32-wide block int32_t* qlut_i32 = reinterpret_cast(lut); - for (int lane = 0; lane < 8; ++lane) { - for (int g = 0; g < 4; ++g) { - qlut_i32[kblk * 32 + lane * 4 + g] = _mm256_extract_epi32(vec_qlut[g], lane); - } + + for (int g = 0; g < 4; ++g) { + qlut_i32[kblk * 32 + 0 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 0); + } + for (int g = 0; g < 4; ++g) { + qlut_i32[kblk * 32 + 1 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 1); + } + for (int g = 0; g < 4; ++g) { + qlut_i32[kblk * 32 + 2 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 2); + } + for (int g = 0; g < 4; ++g) { + qlut_i32[kblk * 32 + 3 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 3); + } + for (int g = 0; g < 4; ++g) { + qlut_i32[kblk * 32 + 4 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 4); + } + for (int g = 0; g < 4; ++g) { + qlut_i32[kblk * 32 + 5 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 5); + } + for (int g = 0; g < 4; ++g) { + qlut_i32[kblk * 32 + 6 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 6); + } + for (int g = 0; g < 4; ++g) { + qlut_i32[kblk * 32 + 7 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 7); } } From b312815d8a894c9f1222b9d610f07ce388991f86 Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Wed, 10 Sep 2025 22:18:31 +0000 Subject: [PATCH 17/33] update --- onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | 4 +++- onnxruntime/core/mlas/lib/qlutgemm.cpp | 5 +++-- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 7a45707da3078..60de4ff8121da 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -180,7 +180,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); is_packed = false; - if (has_g_idx_ || has_unquantized_zero_point_) { + // if (has_g_idx_ || has_unquantized_zero_point_) { + // TODO: this part modified so i can test matmulnbits + if (has_g_idx_) { return Status::OK(); } diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index a9a940b6ff294..5c048b91017b8 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -9,11 +9,12 @@ module includes kernel functions for generating LUT for T-MAC GEMM optimization bool MLASCALL MlasIsTMACAvailable( size_t /*BlkBitWidth*/, - size_t BlkLen + size_t /*BlkLen*/ ) { const auto* Dispatch = GetMlasPlatform().LUTGenKernel; - return Dispatch != nullptr && BlkLen == 4; // only support group sizes of 4 for now + return Dispatch != nullptr; + // return Dispatch != nullptr && BlkLen == 4; // only support group sizes of 4 for now } bool MLASCALL MlasTmacInitializeTable( diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index c710681539c4c..655fe0e650c28 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -51,8 +51,8 @@ GetQNBitGemmVariant( { if ((BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { if (BlkBitWidth == 2) { - if (ComputeType == SQNBIT_CompInt8) { - return SQNBitGemmVariant_BitWidth2_CompInt8; + if (ComputeType == TMAC) { + return SQNBitGemmVariant_BitWidth2_CompInt8; // TODO: rename this kernel } } else if (BlkBitWidth == 4) { if (ComputeType == SQNBIT_CompFp32) { From bfeac34be56d7adfa269ebd3794364989cf980c0 Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Tue, 16 Sep 2025 16:58:49 +0000 Subject: [PATCH 18/33] udpate --- .../cpu/quantization/matmul_nbits.cc | 47 +++- onnxruntime/core/mlas/inc/mlas_qnbit.h | 5 +- onnxruntime/core/mlas/lib/qlutgemm.cpp | 18 +- onnxruntime/core/mlas/lib/qlutgemm.h | 7 +- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 236 +++++++++++------- 5 files changed, 196 insertions(+), 117 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 60de4ff8121da..9b0118c430f2e 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -181,7 +181,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All ORT_UNUSED_PARAMETER(prepacked_weights); is_packed = false; // if (has_g_idx_ || has_unquantized_zero_point_) { - // TODO: this part modified so i can test matmulnbits + // TODO: this part modified so i can test ek atmulnbits if (has_g_idx_) { return Status::OK(); } @@ -325,18 +325,51 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, IAllocatorUniquePtr lut{}; + const size_t batch_count = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + // TODO: add the logic for generating lookup table here -- for now we can assume that // 2 bits + acc level 4 = use look up table but in the future adapt so that we use a mamtulnbits attr to decide // if we want to do lut generation if (compute_type_ == TMAC) { // call lut gen somehow - MlasTmacInitializeTable(block_size_, packed_b_.get(), scales_data, lut.get()); + // Create a mutable copy of scales since MlasTmacInitializeTable modifies the scales + float* scales_float = nullptr; + IAllocatorUniquePtr scales_copy; + + if (std::is_same::value) { + // For float scales, create a copy + const auto* float_scales = reinterpret_cast(scales_data); + size_t scales_size = static_cast(scales->Shape().Size()); + scales_copy = IAllocator::MakeUniquePtr(allocator, scales_size, false); + std::copy(float_scales, float_scales + scales_size, scales_copy.get()); + scales_float = scales_copy.get(); + } else { + // For MLFloat16, use pre-converted scales if available, otherwise convert + if (scales_fp32_) { + // Create a copy of the pre-converted scales + size_t scales_size = static_cast(scales->Shape().Size()); + scales_copy = IAllocator::MakeUniquePtr(allocator, scales_size, false); + std::copy(scales_fp32_.get(), scales_fp32_.get() + scales_size, scales_copy.get()); + scales_float = scales_copy.get(); + } else { + // Convert MLFloat16 to float + const auto* half_scales = reinterpret_cast(scales_data); + size_t scales_size = static_cast(scales->Shape().Size()); + scales_copy = IAllocator::MakeUniquePtr(allocator, scales_size, false); + MlasConvertHalfToFloatBuffer(half_scales, scales_copy.get(), scales_size); + scales_float = scales_copy.get(); + } + } + + MlasTmacInitializeTable(block_size_, + packed_b_.get(), + scales_float, + static_cast(K), + lut.get()); } - - const size_t batch_count = helper.OutputOffsets().size(); - const size_t M = static_cast(helper.M()); - const size_t N = static_cast(helper.N()); - const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); IAllocatorUniquePtr workspace{}; diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index b3d81aae73ed3..764f4d75a87d7 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -243,7 +243,8 @@ MlasIsTMACAvailable( */ bool MLASCALL MlasTmacInitializeTable(size_t BlkLen, - const void* QuantBData, // B in MLFloat16 (per your layout) - const float* QuantBScale, // scale(s) in float + void* QuantBData, // B in MLFloat16 (per your layout) + float* QuantBScale, // scale(s) in float + int K, // K dimension void* qlut // destination LUT buffer (int8 data) ); diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index 5c048b91017b8..df793611f0a64 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -17,26 +17,28 @@ bool MLASCALL MlasIsTMACAvailable( // return Dispatch != nullptr && BlkLen == 4; // only support group sizes of 4 for now } +// TODO: also pass in a biases reference bool MLASCALL MlasTmacInitializeTable( size_t BlkLen, - const void* QuantBData, // B in MLFloat16 (per your layout) - const float* QuantBScale, // scale(s) in float + void* QuantBData, // B in MLFloat16 (per your layout) + float* QuantBScale, // scale(s) in float + int K, void* qlut ) { + // base on lut_ctor_int8_g4 const auto* Dispatch = GetMlasPlatform().LUTGenKernel; if (!Dispatch || !Dispatch->GenerateLUT) return false; // Cast target LUT buffer to int8, and prepare half-precision inputs auto* lut_i8 = reinterpret_cast(qlut); - auto* b_half = const_cast( - reinterpret_cast(QuantBData)); + auto* b_float = reinterpret_cast(QuantBData); - // Convert the first float scale to half (adjust if you have more) - onnxruntime::MLFloat16 s16(QuantBScale[0]); - onnxruntime::MLFloat16 b16(0.0f); // output bias goes here // TODO: pass the biases here + const int num_groups = static_cast(K / BlkLen); + + float* biases = new float[num_groups](); // Call the dispatch - Dispatch->GenerateLUT(static_cast(BlkLen), lut_i8, b_half, &s16, &b16); + Dispatch->GenerateLUT(static_cast(BlkLen), lut_i8, b_float, QuantBScale, biases, K); // If you need the bias value elsewhere, read it from b16 // float bias_f = static_cast(b16); diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h index 86c2e1cb3812e..f3396069c989a 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.h +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -10,9 +10,10 @@ typedef void(MLAS_QNBIT_GEMM_LUT_GEN)( int32_t group_size, int8_t* lut, - onnxruntime::MLFloat16* b, - onnxruntime::MLFloat16* scales, - onnxruntime::MLFloat16* biases + float* b, + float* scales, + float* biases, + int K ); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index e16a731c5a036..663e4a95c0fd6 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -22,6 +22,21 @@ Module Name: // AVX2 intrinsics #include +static inline float _mm256_addv_ps(const __m256 v) { + __m128 res = _mm256_extractf128_ps(v, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(v)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// Conditional pragma unroll for compiler compatibility +#if defined(__INTEL_COMPILER) || defined(__clang__) +#define PRAGMA_UNROLL _Pragma("unroll") +#else +#define PRAGMA_UNROLL +#endif + size_t Q2BitGemmPackQuantBDataSize( size_t N, @@ -338,144 +353,171 @@ SQ2BitGemmKernel_CompInt8_avx2( return rows_handled; } -// TODO: do we need this..? -void -QuantizeARow_CompInt8( - size_t /*BlkLen*/, - const float* /*A*/, - size_t /*CountK*/, - std::byte* /*QuantA*/ -) -{ - // shall be similar to QuantizeARow_CompInt8_avx2 without blksum related code. - // we don't need this function -- remove from dispatch? +void partial_max_g4_int8_k8(float* lut_scales, float* b) { + const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); + __m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + 2, vec_bi, 1); + __m256 vec_b3 = _mm256_i32gather_ps(b + 3, vec_bi, 1); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + __m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0); + __m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1); + __m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2); + __m256 vec_babs3 = _mm256_andnot_ps(vec_sign, vec_b3); + __m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), _mm256_add_ps(vec_babs2, vec_babs3)); + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + float scales = _mm_cvtss_f32(max4) / 127; + *lut_scales = std::max(*lut_scales, scales); } -// based on lut_ctor_g4_int8_impl -void -GenerateLUT_avx2( - int32_t group_size, - int8_t* lut, - onnxruntime::MLFloat16* b, - onnxruntime::MLFloat16* scales, - onnxruntime::MLFloat16* biases +void lut_ctor_g4_int8_impl( + int32_t group_size, + int8_t* qlut, + float* b, + float* lut_scales, + float* lut_biases ) { - // Helper to horizontally add all 8 lanes of a __m256 - auto addv_ps = [](const __m256 v) -> float { - __m128 res = _mm256_extractf128_ps(v, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(v)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - return _mm_cvtss_f32(res); - }; - - // Read scale (already computed elsewhere) and prepare its reciprocal. - const float scale_f = static_cast(scales[0]); - const float t_scale = scale_f != 0.0f ? (1.0f / scale_f) : 0.0f; - - // Accumulate bias across blocks of 32 (matches tmac layout: 4 interleaved streams of 8) - float bias_acc = 0.0f; - - // Temporary buffers for converted floats - float tmp[32]; - float b0[8], b1[8], b2[8], b3[8]; - - // We produce 16 vectors per 32-wide chunk, then pack to int8 and store - // Each block of 32 half values contributes 32 int8 entries per LUT row (16 entries x 2 halves) arranged like tmac - for (int kblk = 0; kblk < group_size / 32; ++kblk) { - // Convert 32 halfs to float - const onnxruntime::MLFloat16* base = b + kblk * 32; - for (int i = 0; i < 32; ++i) tmp[i] = static_cast(base[i]); - - // De-interleave to 4 streams of 8 - for (int i = 0; i < 8; ++i) { - b0[i] = tmp[i * 4 + 0]; - b1[i] = tmp[i * 4 + 1]; - b2[i] = tmp[i * 4 + 2]; - b3[i] = tmp[i * 4 + 3]; - } + const int act_k = group_size; // we assume K == group_size for now - __m256 vec_b0 = _mm256_loadu_ps(b0); - __m256 vec_b1 = _mm256_loadu_ps(b1); - __m256 vec_b2 = _mm256_loadu_ps(b2); - __m256 vec_b3 = _mm256_loadu_ps(b3); + __m256 vec_lut[16]; + float biases = 0.0; + const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); + float scales = *lut_scales; + float t_scales = scales ? 1.0f / scales : 0.0f; - __m256 vec_lut[16]; + for (int k = 0; k < act_k / 32; ++k) { + __m256 vec_b0 = _mm256_i32gather_ps(b + k * 32 + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + k * 32 + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + k * 32 + 2, vec_bi, 1); + __m256 vec_b3 = _mm256_i32gather_ps(b + k * 32 + 3, vec_bi, 1); - // Build odd indices 1..15: b0 +/- b1 +/- b2 +/- b3 depending on bits of g +PRAGMA_UNROLL for (int g = 1; g < 16; g += 2) { - __m256 v = vec_b0; - v = (g & 0b0010) ? _mm256_add_ps(v, vec_b1) : _mm256_sub_ps(v, vec_b1); - v = (g & 0b0100) ? _mm256_add_ps(v, vec_b2) : _mm256_sub_ps(v, vec_b2); - v = (g & 0b1000) ? _mm256_add_ps(v, vec_b3) : _mm256_sub_ps(v, vec_b3); - vec_lut[g] = v; + vec_lut[g] = vec_b0; + if (g & 0b0010) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b1); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b1); + } + if (g & 0b0100) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b2); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b2); + } + if (g & 0b1000) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b3); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b3); + } } - - // Even indices are negatives of mirrored odd indices +PRAGMA_UNROLL for (int g = 0; g < 16; g += 2) { - vec_lut[g] = _mm256_sub_ps(_mm256_setzero_ps(), vec_lut[15 - g]); + vec_lut[g] = -vec_lut[15 - g]; } - // Accumulate bias from entry 0 (before scaling) - bias_acc += addv_ps(vec_lut[0]); + biases += _mm256_addv_ps(vec_lut[0]); - // Apply inverse scale - const __m256 vs = _mm256_set1_ps(t_scale); +PRAGMA_UNROLL for (int g = 0; g < 16; ++g) { - vec_lut[g] = _mm256_mul_ps(vec_lut[g], vs); + vec_lut[g] = _mm256_mul_ps(vec_lut[g], _mm256_set1_ps(t_scales)); } - // Round to nearest, pack to int8 with saturate, and shuffle into the final lane order __m256i vec_qlut[4]; - const __m256i shuf = _mm256_setr_epi8( - 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, - 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); - - for (int g = 0; g < 4; ++g) { + const __m256i shuf = _mm256_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, + 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); +PRAGMA_UNROLL + for (int g = 0; g < 4; g += 1) { __m256i i0 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 0], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); __m256i i1 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 1], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); __m256i i2 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 2], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); __m256i i3 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 3], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - i0 = _mm256_packs_epi32(i0, i1); - i2 = _mm256_packs_epi32(i2, i3); - __m256i i8 = _mm256_packs_epi16(i0, i2); - vec_qlut[g] = _mm256_shuffle_epi8(i8, shuf); + i0 = _mm256_packs_epi32(i0, i1); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32(i2, i3); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16(i0, i2); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + vec_qlut[g] = _mm256_shuffle_epi8(i0, shuf); // 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 } - // Store 8 lanes x 4 rows for this 32-wide block - int32_t* qlut_i32 = reinterpret_cast(lut); - + int32_t* qlut_i32 = reinterpret_cast(qlut); +PRAGMA_UNROLL for (int g = 0; g < 4; ++g) { - qlut_i32[kblk * 32 + 0 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 0); + qlut_i32[k * 32 + 0 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 0); } +PRAGMA_UNROLL for (int g = 0; g < 4; ++g) { - qlut_i32[kblk * 32 + 1 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 1); + qlut_i32[k * 32 + 1 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 1); } +PRAGMA_UNROLL for (int g = 0; g < 4; ++g) { - qlut_i32[kblk * 32 + 2 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 2); + qlut_i32[k * 32 + 2 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 2); } +PRAGMA_UNROLL for (int g = 0; g < 4; ++g) { - qlut_i32[kblk * 32 + 3 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 3); + qlut_i32[k * 32 + 3 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 3); } +PRAGMA_UNROLL for (int g = 0; g < 4; ++g) { - qlut_i32[kblk * 32 + 4 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 4); + qlut_i32[k * 32 + 4 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 4); } +PRAGMA_UNROLL for (int g = 0; g < 4; ++g) { - qlut_i32[kblk * 32 + 5 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 5); + qlut_i32[k * 32 + 5 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 5); } +PRAGMA_UNROLL for (int g = 0; g < 4; ++g) { - qlut_i32[kblk * 32 + 6 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 6); + qlut_i32[k * 32 + 6 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 6); } +PRAGMA_UNROLL for (int g = 0; g < 4; ++g) { - qlut_i32[kblk * 32 + 7 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 7); + qlut_i32[k * 32 + 7 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 7); } } - // Write back bias and leave scale as-is - biases[0] = onnxruntime::MLFloat16(bias_acc); - // scales[0] unchanged - return; + *lut_scales = scales; + *lut_biases = biases; + +} + + +// based on lut_ctor_g4_int8_impl +void +GenerateLUT_avx2( + int32_t group_size, + int8_t* lut, + float* b, + float* scales, + float* biases, + int K +) { + const int kk_outer_max = K / group_size; + + for (int32_t kk_outer = 0; kk_outer < kk_outer_max; ++kk_outer) { + // compute partial max - directly reset scale to 0.0 + scales[kk_outer] = 0.0f; + for (int32_t k_outer = 0; k_outer < group_size / 32; ++k_outer) { + partial_max_g4_int8_k8(&scales[kk_outer], &b[(kk_outer * group_size) + (k_outer * 32)]); + } + } + + for (int32_t k_outer_1 = 0; k_outer_1 < kk_outer_max; ++k_outer_1) { + lut_ctor_g4_int8_impl(group_size, (&(lut[(k_outer_1 * group_size * 4)])), (&(b[(k_outer_1 * group_size)])), (&(scales[k_outer_1])), (&(biases[k_outer_1]))); + } + +} + +// try adding this back in: + +void +QuantizeARow_CompInt8( + size_t /*BlkLen*/, + const float* /*A*/, + size_t /*CountK*/, + std::byte* /*QuantA*/ +) { + // Not implemented yet. } // Kernel dispatch structure definition. From a5de1080790e1195f6dd509d48fdf34d4bab04da Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 1 Oct 2025 14:44:28 -0700 Subject: [PATCH 19/33] Implement Pre Packing of qweight for tmac --- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 2 +- onnxruntime/core/mlas/lib/qnbitgemm.h | 10 +- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 224 +++++++++--------- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 4 +- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 2 +- 6 files changed, 124 insertions(+), 120 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 655fe0e650c28..474216bdda3dd 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -236,7 +236,7 @@ MlasQNBitGemmPackQuantBDataSize( if (BlkBitWidth == 2 && Dispatch->Q2BitGemmPackQuantBDataSize != nullptr) { return Dispatch->Q2BitGemmPackQuantBDataSize( - N, K, BlkLen, HasZeroPoint, ComputeType + N, K, BlkLen, ComputeType ); } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index 5214ea61127b5..a231255c9fd16 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -100,9 +100,17 @@ struct MLAS_QNBIT_GEMM_DISPATCH { MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); + // TODO:: just use Q4BitGemmPackQuantBDataSize if extra params are not needed in future + typedef size_t(Q2BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; - Q4BitGemmPackQuantBDataSize_Fn* Q2BitGemmPackQuantBDataSize = nullptr; + Q2BitGemmPackQuantBDataSize_Fn* Q2BitGemmPackQuantBDataSize = nullptr; /** Gets size of packed quantized B data containing 8-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ typedef size_t(Q8BitGemmPackQuantBDataSize_Fn)( diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 663e4a95c0fd6..7dc330d250626 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -46,10 +46,15 @@ Q2BitGemmPackQuantBDataSize( ) { // TODO: This code shall change according to T-Mac. + // Modify based on tmac compute type if needed. MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - const size_t PackedQuantBDataSize = N * K / 8; - return PackedQuantBDataSize; + // const size_t PackedQuantBDataSize = N * K / 8; + constexpr size_t BlkBitWidth = 2; + constexpr size_t g = 4; // group size + const size_t ngroups_per_elem = 8 / g; + const size_t PackedQuantBDataSize = (N * BlkBitWidth) * (K / g / ngroups_per_elem); + return PackedQuantBDataSize; // 1048576 } void SQ2BitGemmPackQuantBData( @@ -62,15 +67,17 @@ void SQ2BitGemmPackQuantBData( MLAS_THREADPOOL* ThreadPool ) { + //decompose W into w1,... w_bits create temp buffer buf2 of size N * bits * (K/g) + // T-MAC like configuration (approved): // bits=2, g=4, ngroups_per_elem=8/g=2, simd_n_in=16, simd_n_out=8, bm=512, kfactor=16 - constexpr int bits = 2; - constexpr int g = 4; - constexpr int ngroups_per_elem = 8 / g; // 2 - constexpr int simd_n_in = 16; - constexpr int simd_n_out = 8; - constexpr int bm = 512; // tune as needed; must be multiple of bits and mgroup - constexpr int kfactor = 16; // tune as needed; must divide K/g per block + constexpr size_t bits = 2; + constexpr size_t g = 4; + constexpr size_t ngroups_per_elem = 8 / g; // 2 + constexpr size_t simd_n_in = 16; + constexpr size_t simd_n_out = 8; + constexpr size_t bm = 256; // tune as needed; must be multiple of bits and mgroup + constexpr size_t kfactor = 16; // tune as needed; must divide K/g per block // Basic checks MLAS_UNREFERENCED_PARAMETER(K); @@ -80,119 +87,106 @@ void SQ2BitGemmPackQuantBData( assert(bm % mgroup == 0); assert(bm % bits == 0); + uint8_t * buf = new uint8_t[N * bits * (K / g)]; + memset(buf, 0, N * bits * (K / g)); + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(bits, BlkLen); // BlkLen/4 bytes + const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + for (size_t ik = 0; ik < K; ++ik) { + size_t idx = (im * K + ik); + size_t num_elem_per_byte = 8 / bits; + size_t elem_idx = idx % num_elem_per_byte; + + uint8_t v = ((const uint8_t *)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits); + + for (size_t ib =0; ib < bits; ++ib) { + size_t new_ik = ik / g; + size_t shft_left = ik % g; + buf[im * bits * K / g + ib * K /g + new_ik] += ((v >> ib) & 1) << shft_left; + } + } + } + ); - const int m_block = bm / bits; // number of original rows (columns of B) per tile - assert(N % static_cast(m_block) == 0); - const size_t tiles_in_m = N / static_cast(m_block); + // Now buf contains the bit planes grouped by g along K + // Next, we need to do a multi-reshape/transpose into the final layout - const int K_over_g = static_cast(BlkLen / g); - // We write destination in block-major layout: for each k-block, its N columns packed contiguously. - // Per (k_blk, tile) we produce a chunk of size m_block * BlkDataSize bytes. - const size_t tile_chunk_bytes = static_cast(m_block) * BlkDataSize; // = m_block * BlkLen/4 + const size_t c0_fac2 = K / g; + const size_t c0_fac1 = simd_n_out * c0_fac2; + const size_t c0_fac0 = bits * c0_fac1; - const size_t Iterations = BlockCountK * tiles_in_m; + const size_t c1_nb2 = K / g; + const size_t c1_nb1 = simd_n_in * c1_nb2; + const size_t c1_nb0 = ngroups_per_elem * c1_nb1; + const size_t c1_fac2 = K / g; + const size_t c1_fac1 = ngroups_per_elem * c1_fac2; + const size_t c1_fac0 = simd_n_in * c1_fac1; + + + const size_t c2_nb4 = kfactor; + const size_t c2_nb3 = K / g / kfactor * c2_nb4; + const size_t c2_nb2 = ngroups_per_elem * c2_nb3; + const size_t c2_nb1 = simd_n_in * c2_nb2; + const size_t c2_nb0 = bm / mgroup * c2_nb1; + const size_t c2_fac3 = simd_n_in * ngroups_per_elem; + const size_t c2_fac2 = kfactor * c2_fac3; + const size_t c2_fac1 = bm / mgroup * c2_fac2; + const size_t c2_fac0 = K / g / kfactor * c2_fac1; + + const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); + memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed? MlasTrySimpleParallel( ThreadPool, Iterations, [&](ptrdiff_t tid) { - const size_t k_blk = static_cast(tid) / tiles_in_m; - const size_t tile_idx = static_cast(tid) % tiles_in_m; - - // Temporary buffers per tile - // buf2: size = (m_block * bits) * (BlkLen/g) - // tilechunk: size = m_block * BlkLen/4 bytes - std::vector buf2(static_cast(m_block) * bits * K_over_g, 0); - std::vector tilechunk(tile_chunk_bytes, 0); - - // Stage 1: build buf2 (bit-planes grouped along K by g) - for (int im = 0; im < m_block; ++im) { - const size_t n_col = tile_idx * static_cast(m_block) + static_cast(im); - const size_t src_block_offset = n_col * BlockCountK * BlkDataSize + k_blk * BlkDataSize; - const std::byte* src_block = QuantBDataBegin + src_block_offset; - - for (int ik = 0; ik < static_cast(BlkLen); ++ik) { - const int byte_idx = ik >> 2; // ik/4 - const int lane = ik & 3; // ik%4 - const uint8_t src_byte = static_cast(src_block[byte_idx]); - const uint8_t v = static_cast((src_byte >> (lane * bits)) & 0x3u); - - const int ik_g = ik / g; - const int shft_left = ik % g; // 0..3 - for (int ib = 0; ib < bits; ++ib) { - const size_t idx = static_cast(im) * bits * K_over_g + static_cast(ib) * K_over_g + static_cast(ik_g); - buf2[idx] = static_cast(buf2[idx] + (((v >> ib) & 0x1u) << shft_left)); - } + size_t im = static_cast(tid); + for (size_t ib = 0; ib < bits; ib++) { + for (size_t ik = 0; ik < K / g; ik++) { + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + size_t new_im = im / simd_n_out; + size_t new_isno = im % simd_n_out; + size_t new_ib = ib; + size_t new_ik = ik; + size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; + + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + new_im = new_idx / c1_nb0; + size_t new_ing = (new_idx % c1_nb0) / c1_nb1; + size_t new_isni = (new_idx % c1_nb1) / c1_nb2; + new_ik = (new_idx % c1_nb2); + new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; + + // # 0 1 2 3 4 5 + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + new_im = new_idx / c2_nb0; + size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; + new_isni = (new_idx % c2_nb1) / c2_nb2; + new_ing = (new_idx % c2_nb2) / c2_nb3; + new_ik = (new_idx % c2_nb3) / c2_nb4; + size_t new_ikf = (new_idx % c2_nb4); + new_idx = new_im * c2_fac0 + + new_ik * c2_fac1 + + new_ibm * c2_fac2 + + new_ikf * c2_fac3 + + new_isni * ngroups_per_elem + + new_ing; + new_idx = new_idx / ngroups_per_elem; + size_t buf_idx = im * bits * K / g + ib * K / g + ik; + uint8_t buf_val = buf[buf_idx]; + + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + PackedQuantBDataBegin[new_idx] = static_cast( + static_cast(PackedQuantBDataBegin[new_idx]) + + (buf_val << (new_ing * g))); } } - - // Precompute reshape/transpose factors (use K' = BlkLen) - const int c0_fac2 = K_over_g; - const int c0_fac1 = simd_n_out * c0_fac2; - const int c0_fac0 = bits * c0_fac1; - - const int c1_nb2 = K_over_g; - const int c1_nb1 = simd_n_in * c1_nb2; - const int c1_nb0 = ngroups_per_elem * c1_nb1; - const int c1_fac2 = K_over_g; - const int c1_fac1 = ngroups_per_elem * c1_fac2; - const int c1_fac0 = simd_n_in * c1_fac1; - - const int c2_nb4 = kfactor; - const int c2_nb3 = (K_over_g / kfactor) * c2_nb4; - const int c2_nb2 = ngroups_per_elem * c2_nb3; - const int c2_nb1 = simd_n_in * c2_nb2; - const int c2_nb0 = (bm / mgroup) * c2_nb1; - const int c2_fac3 = simd_n_in * ngroups_per_elem; - const int c2_fac2 = kfactor * c2_fac3; - const int c2_fac1 = (bm / mgroup) * c2_fac2; - const int c2_fac0 = (K_over_g / kfactor) * c2_fac1; - - // Stage 2: multi-reshape/transpose into tilechunk - for (int im = 0; im < m_block; ++im) { - for (int ib = 0; ib < bits; ++ib) { - for (int ik = 0; ik < K_over_g; ++ik) { - // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) - int new_im = im / simd_n_out; - int new_isno = im % simd_n_out; - int new_ib = ib; - int new_ik = ik; - int new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; - - // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) - new_im = new_idx / c1_nb0; - int new_ing = (new_idx % c1_nb0) / c1_nb1; - int new_isni = (new_idx % c1_nb1) / c1_nb2; - new_ik = (new_idx % c1_nb2); - new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; - - // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) - new_im = new_idx / c2_nb0; - int new_ibm = (new_idx % c2_nb0) / c2_nb1; - new_isni = (new_idx % c2_nb1) / c2_nb2; - new_ing = (new_idx % c2_nb2) / c2_nb3; - new_ik = (new_idx % c2_nb3) / c2_nb4; - int new_ikf = (new_idx % c2_nb4); - new_idx = new_im * c2_fac0 + new_ik * c2_fac1 + new_ibm * c2_fac2 + new_ikf * c2_fac3 + new_isni * ngroups_per_elem + new_ing; - - // Collapse ngroups into byte by left-shifting lanes of g - const size_t src_idx = static_cast(im) * bits * K_over_g + static_cast(ib) * K_over_g + static_cast(ik); - const uint8_t v = buf2[src_idx]; - const size_t dst_idx = static_cast(new_idx / ngroups_per_elem); - tilechunk[dst_idx] = static_cast(tilechunk[dst_idx] + (v << (new_ing * g))); - } - } - } - - // Store the tile chunk into destination - std::byte* dst_block_base = PackedQuantBDataBegin + k_blk * (N * BlkDataSize); - std::byte* tile_dest = dst_block_base + tile_idx * tile_chunk_bytes; - // copy bytes - for (size_t i = 0; i < tile_chunk_bytes; ++i) { - tile_dest[i] = static_cast(tilechunk[i]); - } } ); } @@ -222,7 +216,7 @@ Q2BitGemmPerGemmWorkspaceSize( } } -// pass in LUT for +// pass in LUT for size_t SQ2BitGemmKernel_CompInt8_avx2( size_t BlkLen, // group @@ -414,7 +408,9 @@ PRAGMA_UNROLL } PRAGMA_UNROLL for (int g = 0; g < 16; g += 2) { - vec_lut[g] = -vec_lut[15 - g]; + //vec_lut[g] = -vec_lut[15 - g]; + const __m256 neg_mask = _mm256_set1_ps(-0.0f); // all lanes have sign bit set + vec_lut[g] = _mm256_xor_ps(vec_lut[15 - g], neg_mask); } biases += _mm256_addv_ps(vec_lut[0]); @@ -483,7 +479,7 @@ PRAGMA_UNROLL // based on lut_ctor_g4_int8_impl -void +void GenerateLUT_avx2( int32_t group_size, int8_t* lut, @@ -526,4 +522,4 @@ const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLUTGenKernelAvx2 = []() { MLAS_QNBIT_LUT_GEMM_DISPATCH d; d.GenerateLUT = GenerateLUT_avx2; return d; -}(); \ No newline at end of file +}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 144beda003328..6d5d39abd039c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1446,7 +1446,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; - d.Q2BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<2>; + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; @@ -1478,7 +1478,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; - d.Q2BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<2>; + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 7d0c0fbd8ee0a..51f1ef79c4898 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -478,7 +478,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; - d.Q2BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<2>; + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index d4fe05c157c0e..7ff1000d267f9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -463,7 +463,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; - d.Q2BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<2>; + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; From 7ff8218e460344c4d3fcf9de319a2e320628f59a Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 6 Oct 2025 14:34:10 -0700 Subject: [PATCH 20/33] Implement Pre packing for Scales and zero points --- .../cpu/quantization/matmul_nbits.cc | 27 ++++++++- onnxruntime/core/mlas/inc/mlas_qnbit.h | 18 +++++- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 59 ++++++++++++++++++- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 3 +- 4 files changed, 103 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 9b0118c430f2e..53605104f0ecd 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -145,6 +145,8 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; + IAllocatorUniquePtr packed_scales_zp_{}; + size_t packed_scales_zp_size_{0}; IAllocatorUniquePtr scales_fp32_{}; IAllocatorUniquePtr bias_fp32_{}; @@ -223,6 +225,23 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All is_packed = true; } #endif // MLAS_TARGET_ARM64 + } else if (compute_type_ == TMAC) { + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + auto scales_ptr = tensor.Data(); + if (has_zp_input_) { + const Tensor* zero_points = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); + auto zero_points_ptr = zero_points->Data(); + + packed_scales_zp_size_ = N_ * K_ / block_size_ * 2; + packed_scales_zp_ = IAllocator::MakeUniquePtr(alloc, packed_scales_zp_size_, true); + MlasTMACPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, zero_points_ptr); + } else { + packed_scales_zp_size_ = N_ * K_ / block_size_; + packed_scales_zp_ = IAllocator::MakeUniquePtr(alloc, packed_scales_zp_size_, true); + MlasTMACPackScalesAndZeroPoints(N_, K_, nbits_, block_size_,has_zp_input_, packed_scales_zp_.get(), scales_ptr, nullptr); + } + } } return Status::OK(); @@ -289,6 +308,12 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou is_packed = false; } #endif // MLAS_TARGET_AMD64_IX86 + } else if (compute_type_ == TMAC) { + //TODO:: handle fp16 scales + // TMAC packs scales and zero points together by interleaving them + + + } return Status::OK(); @@ -363,7 +388,7 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, scales_float = scales_copy.get(); } } - + MlasTmacInitializeTable(block_size_, packed_b_.get(), scales_float, diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 764f4d75a87d7..55f33648b3ca8 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -223,6 +223,22 @@ MlasQNBitGemmScalesPacked( bool HasZeroPoint ); +/** + * @brief Packs the scales and zero points into a format that the TMAC kernel expects. + */ +void MLASCALL +MlasTMACPackScalesAndZeroPoints( + size_t N, + size_t K, + size_t BitWidth, + size_t BlkLen, + bool HasZeroPoint, + float* PackedQuantBZPBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint +); + + /** * @brief Determines whether the TMAC LUT optimization path is available on the current platform * for the provided quantization parameters. @@ -242,7 +258,7 @@ MlasIsTMACAvailable( * Returns true if initialization succeeded or was unnecessary. */ bool MLASCALL -MlasTmacInitializeTable(size_t BlkLen, +MlasTmacInitializeTable(size_t BlkLen, void* QuantBData, // B in MLFloat16 (per your layout) float* QuantBScale, // scale(s) in float int K, // K dimension diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 474216bdda3dd..c7c6c601791ab 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -319,7 +319,7 @@ MlasQNBitGemmPackQuantBData( ); return; } - } else if (BlkBitWidth == 2) { + } else if (BlkBitWidth == 2) { // TODO:: might switch to for TMAC type if other 2-bit kernels like i2s are added if (Dispatch->SQ2BitGemmPackQuantBData != nullptr) { Dispatch->SQ2BitGemmPackQuantBData( N, @@ -375,6 +375,63 @@ MlasQNBitGemmScalesPacked( return false; } + +void MlasTMACPackScalesAndZeroPoints( + size_t N, + size_t K, + size_t BitWidth, + size_t BlkLen, + bool HasZeroPoint, + float* PackedQuantBZPBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint +) +{ + // TODO: Need tmac config so we don't hardcode here. + constexpr size_t bits = 2; + constexpr size_t g = 4; + constexpr size_t ngroups_per_elem = 8 / g; // 2 + constexpr size_t simd_n_in = 16; + constexpr size_t simd_n_out = 8; + constexpr size_t bm = 256; // tune as needed; must be multiple of bits and mgroup + constexpr size_t kfactor = 16; // tune as needed; must divide K/g per block + constexpr size_t num_elem_per_byte = 8 / bits; + + + for (size_t im = 0; im < N ; im += 1) { + for (size_t ik = 0; ik < K; ik += BlkLen) { + size_t idx = (im * K + ik) / BlkLen; + float scale = QuantBScale[idx]; + float zp; + if (HasZeroPoint) { + // zp are two bit packed + size_t elem_idx = idx % num_elem_per_byte; + uint8_t v = QuantBZeroPoint[idx / num_elem_per_byte] >> (elem_idx * bits) & (1 << bits) - 1; + zp = static_cast(v); + } + + size_t nb1 = K / BlkLen; + size_t nb0 = bm / BitWidth * nb1; + size_t new_im = idx / nb0; + size_t new_ibm = (idx % nb0) / nb1; + size_t new_ik = (idx % nb1); + + if (HasZeroPoint) { + size_t new_isimd = new_ibm % simd_n_out; + size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out; + size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; + + PackedQuantBZPBegin[new_idx_scale] = scale; + PackedQuantBZPBegin[new_idx_zero] = zp; + } else { + size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm; + PackedQuantBZPBegin[new_idx] = scale; + } + } + } +} + namespace { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 7dc330d250626..36f07ac30d2e4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -93,7 +93,7 @@ void SQ2BitGemmPackQuantBData( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(bits, BlkLen); // BlkLen/4 bytes const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed - + MlasTrySimpleParallel( ThreadPool, Iterations, [&](ptrdiff_t tid) { @@ -189,6 +189,7 @@ void SQ2BitGemmPackQuantBData( } } ); + delete[] buf; } size_t From 6d8e8ece749b5003ef844c6e3806ab3540f6bea9 Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 6 Oct 2025 15:08:44 -0700 Subject: [PATCH 21/33] Transform zero points before interleaving --- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index c7c6c601791ab..c04c732e1223c 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -408,6 +408,12 @@ void MlasTMACPackScalesAndZeroPoints( size_t elem_idx = idx % num_elem_per_byte; uint8_t v = QuantBZeroPoint[idx / num_elem_per_byte] >> (elem_idx * bits) & (1 << bits) - 1; zp = static_cast(v); + + // Note: TMAC does this during model conversion. Since, we follow ORT format, we need to do it here. + // This seems gptq quantization specific. + // We should either use different op than matmul_nbits or add attribute to matmul_nbits to indicate this. + zp = zp - (1 << (bits - 1)) - 1; // make it signed + zp = zp * scale; // store scale * zp } size_t nb1 = K / BlkLen; From 5d19daf0111a921038ec7329e058865d93a648d2 Mon Sep 17 00:00:00 2001 From: vraspar Date: Tue, 7 Oct 2025 15:25:23 -0700 Subject: [PATCH 22/33] Initial implementation of tmac kernel config --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 82 +++++++++++++++++++ onnxruntime/core/mlas/lib/qlutgemm.h | 23 +++++- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 17 ++-- 3 files changed, 113 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index df793611f0a64..f167276ee2095 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -4,9 +4,91 @@ module includes kernel functions for generating LUT for T-MAC GEMM optimization strategy. */ +#include +#include +#include #include "qlutgemm.h" +/** T-MAC GEMM kernel Config */ +static std::unordered_map tmac_kernel_configs; + + + + +const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits, size_t block_size) { + std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits); + if (tmac_kernel_configs.count(key)) { + return tmac_kernel_configs[key]; + } + + MlasTMACKernelParams params; + params.g = 4; + params.ngroups_per_elem = 8 / params.g; + params.simd_n_in = 16; + params.simd_n_out = 8; + params.chunk_n = 8; + + params.bits = nbits; + params.q_group_size = block_size; + + if (block_size % 64 == 0) { + params.act_group_size = 64; + } else if (block_size % 32 == 0) { + params.act_group_size = 32; + } else { + // throw error + ORT_THROW("Unsupported activation group size: ", block_size);; + } + params.actk = params.act_group_size / params.g; + + //search space + std::vector bms; + if (nbits == 1 || nbits == 2 || nbits == 4) { + bms = {256, 512, 1024, 2048, 320, 640, 1280}; + } else if (nbits == 3) { + bms = {192, 384, 576, 758}; + } + + std::vector bns = {8, 16, 32, 64}; + std::vector kfactors = {8, 16}; + + double min_time = 1e9; + + // TODO: add profile based policy + int threads = std::thread::hardware_concurrency(); + + float smallest_penalty = 1e9; + for (int bm: bms) { + if (M % (bm/nbits) != 0 || bm % nbits != 0) { + continue; + } + size_t num_tiles = M/ (bm/nbits); + size_t num_groups = (num_tiles + threads - 1) / threads; + float penalty = 0.1 * num_groups + (num_groups - 1.0 * num_tiles / threads) / num_groups; + if (penalty < smallest_penalty) { + smallest_penalty = penalty; + params.bm = bm; + } + } + + size_t largest_kfactor = 0; + for (size_t kfactor: kfactors) { + if ((kfactor < params.actk) || (kfactor * params.g > params.q_group_size)) { + continue; + } + if (kfactor > largest_kfactor) { + largest_kfactor = kfactor; + params.kfactor = kfactor; + } + } + + tmac_kernel_configs[key] = params; + return tmac_kernel_configs[key]; +} + + + bool MLASCALL MlasIsTMACAvailable( size_t /*BlkBitWidth*/, size_t /*BlkLen*/ diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h index f3396069c989a..bc642955dfce3 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.h +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -6,6 +6,27 @@ #include "mlas_qnbit.h" #include "mlasi.h" + +/** + * @brief Parameters for TMAC kernel + */ +struct MlasTMACKernelParams { + size_t g; + size_t ngroups_per_elem; + size_t q_group_size; + size_t act_group_size; + + size_t kfactor; + size_t bits; + size_t actk; + size_t bm; + size_t simd_n_in; + size_t simd_n_out; + size_t chunk_n; +}; + +const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits, size_t block_size); + typedef void(MLAS_QNBIT_GEMM_LUT_GEN)( int32_t group_size, @@ -27,4 +48,4 @@ struct MLAS_QNBIT_LUT_GEMM_DISPATCH { // Intentionally empty placeholder; add members as needed. MLAS_QNBIT_GEMM_LUT_GEN* GenerateLUT = nullptr; -}; \ No newline at end of file +}; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 36f07ac30d2e4..241bd47ebad2a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -70,14 +70,15 @@ void SQ2BitGemmPackQuantBData( //decompose W into w1,... w_bits create temp buffer buf2 of size N * bits * (K/g) // T-MAC like configuration (approved): - // bits=2, g=4, ngroups_per_elem=8/g=2, simd_n_in=16, simd_n_out=8, bm=512, kfactor=16 - constexpr size_t bits = 2; - constexpr size_t g = 4; - constexpr size_t ngroups_per_elem = 8 / g; // 2 - constexpr size_t simd_n_in = 16; - constexpr size_t simd_n_out = 8; - constexpr size_t bm = 256; // tune as needed; must be multiple of bits and mgroup - constexpr size_t kfactor = 16; // tune as needed; must divide K/g per block + // bits=2, g=4, ngroups_per_elem=8/g=2, simd_n_in=16, simd_n_out=8, bm=256, kfactor=16 + const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2, BlkLen); + const size_t bits = 2; + const size_t g = tmac_params.g; + const size_t ngroups_per_elem = tmac_params.ngroups_per_elem; + const size_t simd_n_in = tmac_params.simd_n_in; + const size_t simd_n_out = tmac_params.simd_n_out; + const size_t bm = tmac_params.bm; + const size_t kfactor = tmac_params.kfactor; // Basic checks MLAS_UNREFERENCED_PARAMETER(K); From c6000562e403e44d4ccb0169800e6d71778e4116 Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 8 Oct 2025 11:23:03 -0700 Subject: [PATCH 23/33] Move pre packing scales and zp code to qlutgemm and use tmac_params --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 61 ++++++++++++++++++++++++ onnxruntime/core/mlas/lib/qnbitgemm.cpp | 63 ------------------------- 2 files changed, 61 insertions(+), 63 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index f167276ee2095..ba9fe2c4af7d8 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -87,6 +87,67 @@ const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits return tmac_kernel_configs[key]; } +void MlasTMACPackScalesAndZeroPoints( + size_t N, + size_t K, + size_t BitWidth, + size_t BlkLen, + bool HasZeroPoint, + float* PackedQuantBZPBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint +) +{ + const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2, BlkLen); + const size_t bits = tmac_params.bits; + const size_t g = tmac_params.g; + const size_t ngroups_per_elem = tmac_params.ngroups_per_elem; + const size_t simd_n_in = tmac_params.simd_n_in; + const size_t simd_n_out = tmac_params.simd_n_out; + const size_t bm = tmac_params.bm; + const size_t kfactor = tmac_params.kfactor; + const size_t num_elem_per_byte = 8 / bits; + + + for (size_t im = 0; im < N ; im += 1) { + for (size_t ik = 0; ik < K; ik += BlkLen) { + size_t idx = (im * K + ik) / BlkLen; + float scale = QuantBScale[idx]; + float zp; + if (HasZeroPoint) { + // zp are two bit packed + size_t elem_idx = idx % num_elem_per_byte; + uint8_t v = QuantBZeroPoint[idx / num_elem_per_byte] >> (elem_idx * bits) & (1 << bits) - 1; + zp = static_cast(v); + + // Note: TMAC does this during model conversion. Since, we follow ORT format, we need to do it here. + // This seems gptq quantization specific. + // We should either use different op than matmul_nbits or add attribute to matmul_nbits to indicate this. + zp = zp - (1 << (bits - 1)) - 1; // make it signed + zp = zp * scale; // store scale * zp + } + + size_t nb1 = K / BlkLen; + size_t nb0 = bm / BitWidth * nb1; + size_t new_im = idx / nb0; + size_t new_ibm = (idx % nb0) / nb1; + size_t new_ik = (idx % nb1); + + if (HasZeroPoint) { + size_t new_isimd = new_ibm % simd_n_out; + size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out; + size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; + + PackedQuantBZPBegin[new_idx_scale] = scale; + PackedQuantBZPBegin[new_idx_zero] = zp; + } else { + size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm; + PackedQuantBZPBegin[new_idx] = scale; + } + } + } +} bool MLASCALL MlasIsTMACAvailable( diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index c04c732e1223c..c9a9e1b7b9ba8 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -375,69 +375,6 @@ MlasQNBitGemmScalesPacked( return false; } - -void MlasTMACPackScalesAndZeroPoints( - size_t N, - size_t K, - size_t BitWidth, - size_t BlkLen, - bool HasZeroPoint, - float* PackedQuantBZPBegin, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint -) -{ - // TODO: Need tmac config so we don't hardcode here. - constexpr size_t bits = 2; - constexpr size_t g = 4; - constexpr size_t ngroups_per_elem = 8 / g; // 2 - constexpr size_t simd_n_in = 16; - constexpr size_t simd_n_out = 8; - constexpr size_t bm = 256; // tune as needed; must be multiple of bits and mgroup - constexpr size_t kfactor = 16; // tune as needed; must divide K/g per block - constexpr size_t num_elem_per_byte = 8 / bits; - - - for (size_t im = 0; im < N ; im += 1) { - for (size_t ik = 0; ik < K; ik += BlkLen) { - size_t idx = (im * K + ik) / BlkLen; - float scale = QuantBScale[idx]; - float zp; - if (HasZeroPoint) { - // zp are two bit packed - size_t elem_idx = idx % num_elem_per_byte; - uint8_t v = QuantBZeroPoint[idx / num_elem_per_byte] >> (elem_idx * bits) & (1 << bits) - 1; - zp = static_cast(v); - - // Note: TMAC does this during model conversion. Since, we follow ORT format, we need to do it here. - // This seems gptq quantization specific. - // We should either use different op than matmul_nbits or add attribute to matmul_nbits to indicate this. - zp = zp - (1 << (bits - 1)) - 1; // make it signed - zp = zp * scale; // store scale * zp - } - - size_t nb1 = K / BlkLen; - size_t nb0 = bm / BitWidth * nb1; - size_t new_im = idx / nb0; - size_t new_ibm = (idx % nb0) / nb1; - size_t new_ik = (idx % nb1); - - if (HasZeroPoint) { - size_t new_isimd = new_ibm % simd_n_out; - size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out; - size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; - size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; - - PackedQuantBZPBegin[new_idx_scale] = scale; - PackedQuantBZPBegin[new_idx_zero] = zp; - } else { - size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm; - PackedQuantBZPBegin[new_idx] = scale; - } - } - } -} - namespace { From 5cf99e6be1396ec43180a7846f293704fec6b1de Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Mon, 13 Oct 2025 20:41:02 +0000 Subject: [PATCH 24/33] update --- .../cpu/quantization/matmul_nbits.cc | 39 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 24 +- onnxruntime/core/mlas/lib/qlutgemm.cpp | 185 +++++- onnxruntime/core/mlas/lib/qlutgemm.h | 18 +- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 569 +++++++++++++----- .../mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h | 42 +- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 6 - .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 3 - .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 1 - 9 files changed, 641 insertions(+), 246 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 53605104f0ecd..2eb322abab94f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -348,8 +348,6 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); auto* y_data = y->MutableData(); - IAllocatorUniquePtr lut{}; - const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); const size_t N = static_cast(helper.N()); @@ -359,41 +357,8 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, // 2 bits + acc level 4 = use look up table but in the future adapt so that we use a mamtulnbits attr to decide // if we want to do lut generation if (compute_type_ == TMAC) { - // call lut gen somehow - // Create a mutable copy of scales since MlasTmacInitializeTable modifies the scales - float* scales_float = nullptr; - IAllocatorUniquePtr scales_copy; - - if (std::is_same::value) { - // For float scales, create a copy - const auto* float_scales = reinterpret_cast(scales_data); - size_t scales_size = static_cast(scales->Shape().Size()); - scales_copy = IAllocator::MakeUniquePtr(allocator, scales_size, false); - std::copy(float_scales, float_scales + scales_size, scales_copy.get()); - scales_float = scales_copy.get(); - } else { - // For MLFloat16, use pre-converted scales if available, otherwise convert - if (scales_fp32_) { - // Create a copy of the pre-converted scales - size_t scales_size = static_cast(scales->Shape().Size()); - scales_copy = IAllocator::MakeUniquePtr(allocator, scales_size, false); - std::copy(scales_fp32_.get(), scales_fp32_.get() + scales_size, scales_copy.get()); - scales_float = scales_copy.get(); - } else { - // Convert MLFloat16 to float - const auto* half_scales = reinterpret_cast(scales_data); - size_t scales_size = static_cast(scales->Shape().Size()); - scales_copy = IAllocator::MakeUniquePtr(allocator, scales_size, false); - MlasConvertHalfToFloatBuffer(half_scales, scales_copy.get(), scales_size); - scales_float = scales_copy.get(); - } - } - - MlasTmacInitializeTable(block_size_, - packed_b_.get(), - scales_float, - static_cast(K), - lut.get()); + MlasTmac(a_data, block_size_, packed_b_.get(), scales_data, y_data, K, M, N, thread_pool); + return Status::OK(); } const size_t lda = helper.Lda(false); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 55f33648b3ca8..2dd73f49fc9c8 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -253,14 +253,20 @@ MlasIsTMACAvailable( ); /** - * @brief Initializes any global tables required by TMAC LUT kernels. + * @brief Executes TMAC compute * - * Returns true if initialization succeeded or was unnecessary. + * This function handles generating the look up tables and accumulating the matmul results. + * Results will be stored in C. */ -bool MLASCALL -MlasTmacInitializeTable(size_t BlkLen, - void* QuantBData, // B in MLFloat16 (per your layout) - float* QuantBScale, // scale(s) in float - int K, // K dimension - void* qlut // destination LUT buffer (int8 data) -); +void MLASCALL +MlasTmac( + const void* A, + size_t BlkLen, + const void* QuantBData, + const float* QuantBScale, + void* C, + int K, + int M, + int N, + MLAS_THREADPOOL* threadpool +); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index ba9fe2c4af7d8..068c9fa371327 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -53,8 +53,6 @@ const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits std::vector bns = {8, 16, 32, 64}; std::vector kfactors = {8, 16}; - double min_time = 1e9; - // TODO: add profile based policy int threads = std::thread::hardware_concurrency(); @@ -100,12 +98,8 @@ void MlasTMACPackScalesAndZeroPoints( { const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2, BlkLen); const size_t bits = tmac_params.bits; - const size_t g = tmac_params.g; - const size_t ngroups_per_elem = tmac_params.ngroups_per_elem; - const size_t simd_n_in = tmac_params.simd_n_in; const size_t simd_n_out = tmac_params.simd_n_out; const size_t bm = tmac_params.bm; - const size_t kfactor = tmac_params.kfactor; const size_t num_elem_per_byte = 8 / bits; @@ -117,7 +111,7 @@ void MlasTMACPackScalesAndZeroPoints( if (HasZeroPoint) { // zp are two bit packed size_t elem_idx = idx % num_elem_per_byte; - uint8_t v = QuantBZeroPoint[idx / num_elem_per_byte] >> (elem_idx * bits) & (1 << bits) - 1; + uint8_t v = (QuantBZeroPoint[idx / num_elem_per_byte] >> (elem_idx * bits) & (1 << bits)) - 1; zp = static_cast(v); // Note: TMAC does this during model conversion. Since, we follow ORT format, we need to do it here. @@ -160,31 +154,172 @@ bool MLASCALL MlasIsTMACAvailable( // return Dispatch != nullptr && BlkLen == 4; // only support group sizes of 4 for now } -// TODO: also pass in a biases reference -bool MLASCALL MlasTmacInitializeTable( +size_t CalculateLUTSize(int k, int m, size_t group_size) { + return k * m * group_size; +} + +void MLASCALL MlasTmac( + const void* A, size_t BlkLen, - void* QuantBData, // B in MLFloat16 (per your layout) - float* QuantBScale, // scale(s) in float + const void* QuantBData, // Quantized weights (B matrix) + const float* QuantBScale, // scale(s) for quantized weights + void* C, int K, - void* qlut + int M, // batch size (number of rows in activation) + int N, + MLAS_THREADPOOL* threadpool ) { - // base on lut_ctor_int8_g4 + // adapted from ggml_backend_tmac_mul_mat const auto* Dispatch = GetMlasPlatform().LUTGenKernel; - if (!Dispatch || !Dispatch->GenerateLUT) return false; + if (!Dispatch || !Dispatch->GenerateLUT) { + ORT_THROW("TMAC not supported in this configuration."); + } - // Cast target LUT buffer to int8, and prepare half-precision inputs - auto* lut_i8 = reinterpret_cast(qlut); - auto* b_float = reinterpret_cast(QuantBData); + const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2, BlkLen); + size_t lut_size = CalculateLUTSize(K, M, tmac_params.bits); + auto lut_buffer = std::make_unique(lut_size); - const int num_groups = static_cast(K / BlkLen); + const size_t lut_meta_size = (K / BlkLen) * M; + auto biases_float = std::make_unique(lut_meta_size); + auto scales_float = std::make_unique(lut_meta_size); - float* biases = new float[num_groups](); + const auto* a_float = reinterpret_cast(A); // Activation data - // Call the dispatch - Dispatch->GenerateLUT(static_cast(BlkLen), lut_i8, b_float, QuantBScale, biases, K); + const int num_groups = static_cast(K / BlkLen); - // If you need the bias value elsewhere, read it from b16 - // float bias_f = static_cast(b16); + // Parallelize over M (batch dimension) + // Each iteration processes one row of the activation matrix + MlasTrySimpleParallel( + threadpool, + static_cast(M), + [&](ptrdiff_t ine11) { + const size_t row_offset = static_cast(ine11) * K; + const size_t lut_offset = static_cast(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT + const size_t scale_bias_offset = static_cast(ine11) * num_groups; + + // Call the dispatch function for this row + Dispatch->GenerateLUT( + static_cast(BlkLen), + reinterpret_cast(lut_buffer.get()) + lut_offset, // Output LUT for this row + const_cast(a_float + row_offset), // Input activation for this row + scales_float.get() + scale_bias_offset, // Scales for this row + biases_float.get() + scale_bias_offset, // Biases for this row + K + ); + } + ); + + // all relevant LUT's have been generated + // equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line + const size_t bm = tmac_params.bm; // TODO: hardcoding for now + const size_t bits = tmac_params.bits; + + // TODO: fix the below 4 + // Matrix multiplication: Output[N×M] = QuantBData[N×K] × Weights[K×M] + const size_t OutputRows = M; // Number of output features + const size_t OutputCols = N; // Batch size + const size_t NumTiles = M * bits / bm; + + const size_t ChunkSize0 = M / NumTiles; + const size_t ChunkSize1 = tmac_params.chunk_n; // process one batch item at a time + +// In llama.cpp terminology (note the swap!): +// ne0 = M (output features, called "n" in llama.cpp) +// ne1 = N (batch size, called "m" in llama.cpp) + + // Calculate number of chunks in each dimension + const size_t nchunk0 = (OutputRows + ChunkSize0 - 1) / ChunkSize0; // Should equal NumTiles + const size_t nchunk1 = (OutputCols + ChunkSize1 - 1) / ChunkSize1; + const size_t total_chunks = nchunk0 * nchunk1; + + // Pre-calculate sizes for offset calculations + const size_t w_size = OutputRows * K * bits / 8; + const size_t w_chunk_size = w_size / NumTiles; + + // Determine weight-scale layout. These should be provided by the caller or inferred from the packed weights. + // For now we default to per-group symmetric quantization (no zero-point, not one-scale). + bool one_scale = false; // TODO: expose this as a function parameter if needed + bool has_zero_point = false; // TODO: expose this as a function parameter if needed + + // Total number of scale (float) entries for the whole weight matrix: + // - if one_scale: single global scale (1) + // - otherwise: number of quantization groups = (M * K / BlkLen) + // and if zero-points are present each group stores (scale, zero_point) -> *2 + const size_t groups_total = static_cast(M) * static_cast(K) / BlkLen; + const size_t scales_size_total = one_scale ? 1 : (groups_total * (has_zero_point ? 2 : 1)); + + // n_tile_num == NumTiles (number of M tiles) + const size_t n_tile_num = NumTiles; + + // Per-tile scales size = total scales size divided evenly across tiles. + // If one_scale is true we do not advance the scales pointer per tile, so set per tile size to 0 + size_t scales_size_per_tile = 0; + if (!one_scale) { + if (scales_size_total % n_tile_num != 0) { + // Sanity: scales should partition evenly across tiles. If they don't, choose floor division + // and document that callers must layout scales accordingly. + // Prefer to error loudly in debug builds. + fprintf(stderr, "Warning: scales_size_total=%zu is not divisible by n_tile_num=%zu; using floor division.\n", scales_size_total, n_tile_num); + } + scales_size_per_tile = scales_size_total / n_tile_num; + } - return true; -} + // Note: when one_scale == true, callers should pass a pointer to a single scale value (scales_offset=0 will be used) + + // Cast to appropriate types + const auto* packed_weights = reinterpret_cast(QuantBData); + const int8_t* lut_i8 = reinterpret_cast(lut_buffer.get()); + + // lut_scales_size is the number of scale values per batch item (= K / BlkLen) + const size_t lut_scales_size = static_cast(K) / BlkLen; + + // Parallelize over the 2D chunk grid + MlasTrySimpleParallel( + threadpool, + total_chunks, + [&](ptrdiff_t current_chunk) { + // Decompose linear chunk index into 2D coordinates + const size_t ith0 = current_chunk % nchunk0; // Chunk in dimension 0 (output rows) + const size_t ith1 = current_chunk / nchunk0; // Chunk in dimension 1 (batch) + + // Calculate ranges for this chunk + const size_t ir0_start = ChunkSize0 * ith0; + const size_t ir0_end = std::min(ir0_start + ChunkSize0, OutputRows); + + const size_t ir1_start = ChunkSize1 * ith1; + const size_t ir1_end = std::min(ir1_start + ChunkSize1, OutputCols); + + // Process all tiles in dimension 0 for this chunk + for (size_t ichunk0 = ir0_start / ChunkSize0; ichunk0 < ir0_end / ChunkSize0; ichunk0++) { + // Calculate weight offsets + const size_t w_offset = ichunk0 * w_chunk_size; + const size_t scales_offset = ichunk0 * scales_size_per_tile; + + // Process all batch items in this chunk + for (size_t ine11 = ir1_start; ine11 < ir1_end; ine11++) { + // Calculate LUT offsets for this batch item + const size_t qlut_offset = K * ine11 * 4; + const size_t lut_scales_offset = lut_scales_size * ine11; + + // Calculate output offset + const size_t dst_offset = OutputRows * ine11 + ichunk0 * ChunkSize0; + + // Call the dispatch function to compute this tile + Dispatch->ComputeGemm( + const_cast(reinterpret_cast(packed_weights + w_offset)), // Weight tile + QuantBScale + scales_offset, // Weight scales for this tile + const_cast(reinterpret_cast(lut_i8 + qlut_offset)), // LUT for this batch row + scales_float.get() + lut_scales_offset, // LUT scales + biases_float.get() + lut_scales_offset, // LUT biases + reinterpret_cast(C) + dst_offset, // Output location + static_cast(bm), // bm + static_cast(K), // K dimension + static_cast(M), // K dimension + static_cast(N), // N dimension (batch size) + BlkLen // Weight quantization group size + ); + } + } + } + ); +} \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h index bc642955dfce3..277c56feac9b0 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.h +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -31,12 +31,26 @@ typedef void(MLAS_QNBIT_GEMM_LUT_GEN)( int32_t group_size, int8_t* lut, - float* b, + const float* b, float* scales, float* biases, int K ); +typedef +void(MLAS_QNBIT_LUT_GEMM_COMPUTE)( + const void* A, + const void* a_scales, + const void* LUT, + const void* LUT_Scales, + const void* LUT_Biases, + void* C, + int bm, + int K, + int M, // batch size (number of rows in activation) + int N, + size_t BlkLen +); // // Kernel dispatch structure. @@ -48,4 +62,6 @@ struct MLAS_QNBIT_LUT_GEMM_DISPATCH { // Intentionally empty placeholder; add members as needed. MLAS_QNBIT_GEMM_LUT_GEN* GenerateLUT = nullptr; + MLAS_QNBIT_LUT_GEMM_COMPUTE* ComputeGemm = nullptr; + }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 241bd47ebad2a..8c716c39a2eb6 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -15,13 +15,16 @@ Module Name: --*/ -#include "qnbitgemm.h" -#include "qlutgemm.h" -#include "sqnbitgemm_q8_block.h" +#include +#include #include // AVX2 intrinsics #include +#include "qnbitgemm.h" +#include "qlutgemm.h" +#include "sqnbitgemm_q8_block.h" + static inline float _mm256_addv_ps(const __m256 v) { __m128 res = _mm256_extractf128_ps(v, 1); res = _mm_add_ps(res, _mm256_castps256_ps128(v)); @@ -37,6 +40,135 @@ static inline float _mm256_addv_ps(const __m256 v) { #define PRAGMA_UNROLL #endif +// Helper macros for extracting and widening vectors +#define extract_low_epi8_epi16(v) _mm256_cvtepi8_epi16(_mm256_castsi256_si128(v)) +#define extract_high_epi8_epi16(v) _mm256_cvtepi8_epi16(_mm256_extracti128_si256(v, 1)) +#define extract_low_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_castsi256_si128(v)) +#define extract_high_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_extracti128_si256(v, 1)) + + +// Template classes for accumulation +template +struct SignedHalvingAdder { + SignedHalvingAdder adder; + __m256i lhs = _mm256_setzero_si256(); + + inline void push(__m256i v, int k) { + if (k < N / 2) { + adder.push(v, k); + if (k == N / 2 - 1) { + lhs = adder.get(); + } + } else { + adder.push(v, k - N / 2); + if (k == N - 1) { + lhs = _mm256_avg_epu8(lhs, adder.get()); + } + } + } + + inline __m256i get() { + return lhs; + } + + inline __m256i get_low() { + return extract_low_epi8_epi16(lhs); + } + + inline __m256i get_high() { + return extract_high_epi8_epi16(lhs); + } +}; + +template <> +struct SignedHalvingAdder<2> { + __m256i lhs = _mm256_setzero_si256(); + + inline void push(__m256i v, int k) { + if (k == 0) { + lhs = v; + } else { + lhs = _mm256_avg_epu8(lhs, v); + } + } + + inline __m256i get() { + return lhs; + } + + inline __m256i get_low() { + return extract_low_epi8_epi16(lhs); + } + + inline __m256i get_high() { + return extract_high_epi8_epi16(lhs); + } +}; + +template +struct SignedWideningAdder { + __m256i lhs_low = _mm256_setzero_si256(); + __m256i lhs_high = _mm256_setzero_si256(); + + inline void push(__m256i v, int k) { + if (k == 0) { + lhs_low = extract_low_epi8_epi16(v); + lhs_high = extract_high_epi8_epi16(v); + } else { + lhs_low = _mm256_add_epi16(lhs_low, extract_low_epi8_epi16(v)); + lhs_high = _mm256_add_epi16(lhs_high, extract_high_epi8_epi16(v)); + } + } + + inline __m256i get_low() { + return lhs_low; + } + + inline __m256i get_high() { + return lhs_high; + } +}; + +template +using SignedAdder = typename std::conditional, SignedWideningAdder>::type; + +// Template for computing log2 at compile time +template +struct mylog2 { + enum { + value = 1 + mylog2::value + }; +}; + +template <> +struct mylog2<0> { + enum { + value = -1 + }; +}; + +// Template for computing bias scale at compile time +template +constexpr int get_bias_scale() { + // The bias scale will be added to the first bit + // 15 = (1/2 + 1 + 2 + 4) / (1/2) + // 7 = (1/2 + 1 + 2) / (1/2) + // 3 = (1/2 + 1) / (1/2) + // 1 = (1/2) / (1/2) + // if constexpr (bits == 4) { + // return 15; + // } else if constexpr (bits == 3) { + // return 7; + // } else if constexpr (bits == 2) { + // return 3; + // } else if constexpr (bits == 1) { + // return 1; + // } else { + // return 0; + // } + return 3; +} + size_t Q2BitGemmPackQuantBDataSize( size_t N, @@ -91,8 +223,6 @@ void SQ2BitGemmPackQuantBData( uint8_t * buf = new uint8_t[N * bits * (K / g)]; memset(buf, 0, N * bits * (K / g)); - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(bits, BlkLen); // BlkLen/4 bytes const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed MlasTrySimpleParallel( @@ -218,138 +348,7 @@ Q2BitGemmPerGemmWorkspaceSize( } } -// pass in LUT for -size_t -SQ2BitGemmKernel_CompInt8_avx2( - size_t BlkLen, // group - const std::byte* QuantA, - const std::byte* QuantBData, // we pass in the LUT here - const float* QuantBScale, // LUT scales - const std::byte* QuantBZeroPoint, // LUT zero points - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t /*BlockCountK*/, // number of k blocks of length blklen?? - size_t /*ldc*/, // leading dimension for c (unused for CountN==1 path) - const float* /*Bias*/ // bias per output col for c -) -{ - // Implement qgemm_lut_int8_g4 (AVX2 path) for Bits=2, g=4, ActK=16, CountN == 1, K % 16 == 0. - // Notes: - // - This uses the same A/LUT/scales/biases layout assumptions as tmac's tbl.cpp AVX2 path. - // - C is updated in the same lane order as tmac (tile-local contiguous), which is fine for CountN==1. - - constexpr int Bits = 2; - constexpr int ActK = 16; - MLAS_UNREFERENCED_PARAMETER(BlkLen); - - // Preconditions we support in this initial implementation. - if (CountN != 1 || (CountK % ActK) != 0) { - return 0; // not handled - } - - const uint8_t* a = reinterpret_cast(QuantA); - const int8_t* lut = reinterpret_cast(QuantBData); - const float* lut_scales = QuantBScale; // one per kk-chunk (ActK) - const float* lut_biases = reinterpret_cast(QuantBZeroPoint); // one per kk-chunk (ActK) - float* c = C; - - // Process rows in groups of 32 as in tmac AVX2 path (i iterates 16 over m/2). - size_t rows_handled = (CountM / 32) * 32; - if (rows_handled == 0) { - return 0; - } - - const __m128i vec_mask = _mm_set1_epi8(0x0f); - - for (size_t i = 0; i < rows_handled / 2; i += 16) { - __m256 vec_c0{}, vec_c1{}, vec_c2{}, vec_c3{}; - bool c_initialized = false; - float partial_sum = -0.0f; - - for (size_t kk = 0; kk < CountK; kk += ActK) { - // Accumulators for this kk-chunk: sum 16 int8 lookups across ActK into 4x8 lanes - __m128i acc_lo_low = _mm_setzero_si128(); - __m128i acc_lo_high = _mm_setzero_si128(); - __m128i acc_hi_low = _mm_setzero_si128(); - __m128i acc_hi_high = _mm_setzero_si128(); - - for (int k = 0; k < ActK; ++k) { - // Load 16 LUT entries for this k (indices 0..15) - const __m128i vec_lut_k = _mm_loadu_si128(reinterpret_cast(lut + (kk + k) * 16)); - // Load 16 selector bytes for bottom/top nibbles from A for this (i-block, k) - const __m128i vec_as = _mm_loadu_si128(reinterpret_cast(a + i * CountK + (kk + k) * 16)); - const __m128i vec_a_bot = _mm_and_si128(vec_as, vec_mask); - const __m128i vec_a_top = _mm_and_si128(_mm_srli_epi16(vec_as, 4), vec_mask); - - // Shuffle-gather from LUT using bottom and top nibble indices - const __m256i vec_lut_dup = _mm256_set_m128i(vec_lut_k, vec_lut_k); - const __m256i vec_a_bt = _mm256_set_m128i(vec_a_top, vec_a_bot); - const __m256i vec_v = _mm256_shuffle_epi8(vec_lut_dup, vec_a_bt); // 32 int8 results - - // Split to 2x16 and sign-extend to int16 - const __m128i v_bot8 = _mm256_castsi256_si128(vec_v); - const __m128i v_top8 = _mm256_extracti128_si256(vec_v, 1); - - const __m256i vb16 = _mm256_cvtepi8_epi16(v_bot8); - const __m256i vt16 = _mm256_cvtepi8_epi16(v_top8); - - const __m128i vb16_low = _mm256_castsi256_si128(vb16); - const __m128i vb16_high = _mm256_extracti128_si256(vb16, 1); - const __m128i vt16_low = _mm256_castsi256_si128(vt16); - const __m128i vt16_high = _mm256_extracti128_si256(vt16, 1); - - acc_lo_low = _mm_add_epi16(acc_lo_low, vb16_low); - acc_lo_high = _mm_add_epi16(acc_lo_high, vb16_high); - acc_hi_low = _mm_add_epi16(acc_hi_low, vt16_low); - acc_hi_high = _mm_add_epi16(acc_hi_high, vt16_high); - } - - // Convert to float vectors (4 groups of 8) - const __m256 vec_v_low_low = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(acc_lo_low)); - const __m256 vec_v_low_high = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(acc_lo_high)); - const __m256 vec_v_high_low = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(acc_hi_low)); - const __m256 vec_v_high_high = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(acc_hi_high)); - - float lut_s = lut_scales[kk / ActK]; - float lut_b = lut_biases ? lut_biases[kk / ActK] : 0.0f; - partial_sum += lut_b; - - // Apply per-bit-group bias pattern: add bias only when (ib % Bits == 0) - auto fma_with_bias = [&](const __m256& vs, size_t ib) { - if ((ib % Bits) == 0) { - return _mm256_fmadd_ps(vs, _mm256_set1_ps(lut_s), _mm256_set1_ps(lut_b)); - } else { - return _mm256_mul_ps(vs, _mm256_set1_ps(lut_s)); - } - }; - - if (!c_initialized) { - vec_c0 = fma_with_bias(vec_v_low_low, (i / 4)); - vec_c1 = fma_with_bias(vec_v_low_high, (i / 4 + 1)); - vec_c2 = fma_with_bias(vec_v_high_low, (i / 4 + 2)); - vec_c3 = fma_with_bias(vec_v_high_high, (i / 4 + 3)); - c_initialized = true; - } else { - vec_c0 = _mm256_add_ps(vec_c0, fma_with_bias(vec_v_low_low, (i / 4))); - vec_c1 = _mm256_add_ps(vec_c1, fma_with_bias(vec_v_low_high, (i / 4 + 1))); - vec_c2 = _mm256_add_ps(vec_c2, fma_with_bias(vec_v_high_low, (i / 4 + 2))); - vec_c3 = _mm256_add_ps(vec_c3, fma_with_bias(vec_v_high_high, (i / 4 + 3))); - } - } // kk - - // Store back to C in tmac lane order: 8 floats x 4 groups - _mm256_storeu_ps(c + i * 2, vec_c0); - _mm256_storeu_ps(c + i * 2 + 8, vec_c1); - _mm256_storeu_ps(c + i * 2 + 16, vec_c2); - _mm256_storeu_ps(c + i * 2 + 24, vec_c3); - } - - return rows_handled; -} - -void partial_max_g4_int8_k8(float* lut_scales, float* b) { +void partial_max_g4_int8_k8(float* lut_scales, const float* b) { const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); __m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1); __m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1); @@ -371,7 +370,7 @@ void partial_max_g4_int8_k8(float* lut_scales, float* b) { void lut_ctor_g4_int8_impl( int32_t group_size, int8_t* qlut, - float* b, + const float* b, float* lut_scales, float* lut_biases ) { @@ -410,9 +409,7 @@ PRAGMA_UNROLL } PRAGMA_UNROLL for (int g = 0; g < 16; g += 2) { - //vec_lut[g] = -vec_lut[15 - g]; - const __m256 neg_mask = _mm256_set1_ps(-0.0f); // all lanes have sign bit set - vec_lut[g] = _mm256_xor_ps(vec_lut[15 - g], neg_mask); + vec_lut[g] = -vec_lut[15 - g]; } biases += _mm256_addv_ps(vec_lut[0]); @@ -485,7 +482,7 @@ void GenerateLUT_avx2( int32_t group_size, int8_t* lut, - float* b, + const float* b, float* scales, float* biases, int K @@ -506,7 +503,288 @@ GenerateLUT_avx2( } -// try adding this back in: +inline void tbl_g4_int8_float_gather_bit2_impl(int32_t m, float* C_global, float* CBits, float* C) { + constexpr int32_t bits = 2; + + int32_t m_c_outer_max = m / 32; + for (int32_t m_c_outer = 0; m_c_outer < m_c_outer_max; ++m_c_outer) { + int32_t cse_var_2 = (m_c_outer * 32 * bits); + int32_t cse_var_1 = (m_c_outer * 32); + PRAGMA_UNROLL + for (int32_t m_c_inner = 0; m_c_inner < 32; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + int32_t bit_offset_1 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 8; + C_global[cse_var_1 + m_c_inner] = (CBits[cse_var_2 + bit_offset_0] * (float)5.000000e-01f) + + (CBits[cse_var_2 + bit_offset_1]); + } + } + + for (int32_t m_inner_outer = 0; m_inner_outer < m_c_outer_max; ++m_inner_outer) { + PRAGMA_UNROLL + for (int32_t m_inner = 0; m_inner < 32; ++m_inner) { + int offset = m_inner_outer * 32 + m_inner; + C[offset] = C_global[offset]; + } + } +} + +// When FastAggregation is enabled, FastAggregationK = ActK +// zero_points is merged into scales to maintain API +template +inline int32_t tbl_g4_int8_float_update_impl(int32_t m, float* c, const int8_t* lut, const uint8_t* a, const float* scales, const float* lut_scales, const float* lut_biases) { + const __m128i vec_mask = _mm_set1_epi8(0x0f); + __m128i vec_lut[K]; + + PRAGMA_UNROLL + for (int k = 0; k < K; k++) { + vec_lut[k] = _mm_loadu_si128(reinterpret_cast(lut + k * 16)); + } + + SignedAdder adder; + for (int i = 0; i < m / 2; i += 16) { + __m256 vec_c0, vec_c1, vec_c2, vec_c3; + + float partial_sum = -0.0f; + PRAGMA_UNROLL + for (int kk = 0; kk < K; kk += ActK) { + PRAGMA_UNROLL + for (int k = 0; k < ActK; k++) { + // (M // bm, KK / K / 4, bm / 16 / 2, K * 16) + __m128i vec_as = _mm_loadu_si128(reinterpret_cast(a + i * K + (kk + k) * 16)); + __m128i vec_a_bot = _mm_and_si128(vec_as, vec_mask); + __m128i vec_a_top = _mm_and_si128(_mm_srli_epi16(vec_as, 4), vec_mask); + + __m256i vec_lut_ = _mm256_set_m128i(vec_lut[kk + k], vec_lut[kk + k]); + __m256i vec_a = _mm256_set_m128i(vec_a_top, vec_a_bot); + __m256i vec_v = _mm256_shuffle_epi8(vec_lut_, vec_a); + adder.push(vec_v, k); + } + + __m256 vec_v_low_low = _mm256_cvtepi32_ps(extract_low_epi16_epi32(adder.get_low())); + __m256 vec_v_low_high = _mm256_cvtepi32_ps(extract_high_epi16_epi32(adder.get_low())); + __m256 vec_v_high_low = _mm256_cvtepi32_ps(extract_low_epi16_epi32(adder.get_high())); + __m256 vec_v_high_high = _mm256_cvtepi32_ps(extract_high_epi16_epi32(adder.get_high())); + + float lut_s = lut_scales[kk / ActK]; + float lut_b = lut_biases[kk / ActK]; + + partial_sum += lut_b; + + if (FastAggregation) { + lut_s = lut_s * ActK; + lut_b -= lut_s * (mylog2::value / 4 * get_bias_scale()); + } + +#define lut_fma(vs, ib) \ + ((ib) % Bits) ? (_mm256_mul_ps((vs), _mm256_set1_ps(lut_s))) \ + : (_mm256_fmadd_ps((vs), _mm256_set1_ps(lut_s), _mm256_set1_ps(lut_b))) + if (kk == 0) { + vec_c0 = lut_fma(vec_v_low_low, (i / 4 )); + vec_c1 = lut_fma(vec_v_low_high, (i / 4 + 1)); + vec_c2 = lut_fma(vec_v_high_low, (i / 4 + 2)); + vec_c3 = lut_fma(vec_v_high_high, (i / 4 + 3)); + } else { + vec_c0 = _mm256_add_ps(vec_c0, lut_fma(vec_v_low_low, (i / 4 ))); + vec_c1 = _mm256_add_ps(vec_c1, lut_fma(vec_v_low_high, (i / 4 + 1))); + vec_c2 = _mm256_add_ps(vec_c2, lut_fma(vec_v_high_low, (i / 4 + 2))); + vec_c3 = _mm256_add_ps(vec_c3, lut_fma(vec_v_high_high, (i / 4 + 3))); + } +#undef lut_fma + } + + if (ZeroPoint) { + __m256 vec_s0 = _mm256_loadu_ps(scales + ((i / 4 ) / Bits) * 16); + __m256 vec_s1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 16); + __m256 vec_s2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 16); + __m256 vec_s3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 16); + vec_c0 = _mm256_fmadd_ps(vec_c0, vec_s0, _mm256_loadu_ps(c + i * 2)); + vec_c1 = _mm256_fmadd_ps(vec_c1, vec_s1, _mm256_loadu_ps(c + i * 2 + 8)); + vec_c2 = _mm256_fmadd_ps(vec_c2, vec_s2, _mm256_loadu_ps(c + i * 2 + 16)); + vec_c3 = _mm256_fmadd_ps(vec_c3, vec_s3, _mm256_loadu_ps(c + i * 2 + 24)); + __m256 vec_z0 = _mm256_loadu_ps(scales + ((i / 4 ) / Bits) * 16 + 8); + __m256 vec_z1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 16 + 8); + __m256 vec_z2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 16 + 8); + __m256 vec_z3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 16 + 8); + partial_sum *= 2; +#define add_zero(cs, zs, ib) \ + ((ib) % Bits) ? ((cs)) \ + : (_mm256_fmadd_ps((zs), _mm256_set1_ps(partial_sum), (cs))) + _mm256_storeu_ps(c + i * 2, add_zero(vec_c0, vec_z0, (i / 4 ))); + _mm256_storeu_ps(c + i * 2 + 8, add_zero(vec_c1, vec_z1, (i / 4 + 1))); + _mm256_storeu_ps(c + i * 2 + 16, add_zero(vec_c2, vec_z2, (i / 4 + 2))); + _mm256_storeu_ps(c + i * 2 + 24, add_zero(vec_c3, vec_z3, (i / 4 + 3))); +#undef add_zero + } else if (OneScale) { + float single_scale = scales[0]; + __m256 vec_s = _mm256_set1_ps(single_scale); + _mm256_storeu_ps(c + i * 2, _mm256_fmadd_ps(vec_c0, vec_s, _mm256_loadu_ps(c + i * 2))); + _mm256_storeu_ps(c + i * 2 + 8, _mm256_fmadd_ps(vec_c1, vec_s, _mm256_loadu_ps(c + i * 2 + 8))); + _mm256_storeu_ps(c + i * 2 + 16, _mm256_fmadd_ps(vec_c2, vec_s, _mm256_loadu_ps(c + i * 2 + 16))); + _mm256_storeu_ps(c + i * 2 + 24, _mm256_fmadd_ps(vec_c3, vec_s, _mm256_loadu_ps(c + i * 2 + 24))); + } else { + __m256 vec_s0 = _mm256_loadu_ps(scales + ((i / 4 ) / Bits) * 8); + __m256 vec_s1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 8); + __m256 vec_s2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 8); + __m256 vec_s3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 8); + _mm256_storeu_ps(c + i * 2, _mm256_fmadd_ps(vec_c0, vec_s0, _mm256_loadu_ps(c + i * 2))); + _mm256_storeu_ps(c + i * 2 + 8, _mm256_fmadd_ps(vec_c1, vec_s1, _mm256_loadu_ps(c + i * 2 + 8))); + _mm256_storeu_ps(c + i * 2 + 16, _mm256_fmadd_ps(vec_c2, vec_s2, _mm256_loadu_ps(c + i * 2 + 16))); + _mm256_storeu_ps(c + i * 2 + 24, _mm256_fmadd_ps(vec_c3, vec_s3, _mm256_loadu_ps(c + i * 2 + 24))); + } + } + + return 0; +} + +int32_t tbl_int32_reset(int32_t m, int32_t* c) { + memset(c, 0, m * sizeof(int32_t)); + return 0; +} + +// based on qgemm_lut_int8_g4 +// Simplified version with hardcoded configuration for 2-bit quantization +void TMACComputeGemm_avx2( + const void* A, // Quantized packed weights + const void* Scales, // Weight scales (and optionally zero-points) + const void* LUT, // Pre-computed quantized lookup table + const void* LUT_Scales, // LUT scales from activation quantization + const void* LUT_Biases, // LUT biases from activation quantization + void* C, // Output buffer + int bm, // Bit-rows tile size (typically 512 for 2-bit) + int K, + int m, + int N, + size_t BlkLen // Weight quantization group size (q_group_size) + ) { + // Validate batch size + if (N != 1) { + throw std::runtime_error("N > 1 is not supported yet"); + } + + // ==================== CONFIGURATION ==================== + // Fixed parameters for this kernel implementation + bool has_zero_point = true; // Whether weights have zero-points (interleaved with scales) + bool one_scale = false; // Whether using single global scale for all weights + constexpr int bits = 2; // 2-bit quantization + constexpr int g = 4; // Packing group size + constexpr int ngroups_per_elem = 2; // 8 / g = 2 + constexpr int kfactor = 16; // K-dimension blocking factor + constexpr bool has_scale = true; // Always use weight scales + + // Parameters derived from inputs + const int q_group_size = static_cast(BlkLen); // Weight quant group size + const int act_group_size = static_cast(BlkLen); // Activation group size (same as weight) + const int actk = act_group_size / g; // CRITICAL: = 16 for BlkLen=64, NOT BlkLen! + + // Validate configuration + assert(bm % bits == 0); + assert(K % (kfactor * g) == 0); + assert(BlkLen % g == 0); + + // Validate configuration + assert(bm % bits == 0); + assert(K % (kfactor * g) == 0); + assert(BlkLen % g == 0); + + // ==================== ALLOCATE BUFFERS ==================== + // Use float for now (can be changed to _Float16 if needed) + + float* CBits = new float[bm]; + float* C_global = new float[m]; + + // Reset accumulator buffer to zero + tbl_int32_reset(bm * sizeof(float) / sizeof(int32_t), + reinterpret_cast(CBits)); + + // ==================== CALCULATE LOOP PARAMETERS ==================== + const int32_t k_outer_max = K / (kfactor * g); + const int32_t scale_gs = q_group_size / (kfactor * g); + + // Calculate bit shift for scale indexing + int32_t scale_idx_shfr = 0; + if (scale_gs == 1) { + scale_idx_shfr = 0; + } else if (scale_gs == 2) { + scale_idx_shfr = 1; + } else if (scale_gs == 4) { + scale_idx_shfr = 2; + } else if (scale_gs == 8) { + scale_idx_shfr = 3; + } else { + fprintf(stderr, "q_group_size=%d, kfactor=%d, g=%d\n", q_group_size, kfactor, g); + fprintf(stderr, "Unsupported scale group size over kfactor. Expected {1,2,4,8}, got %d.\n", scale_gs); + throw std::runtime_error("Invalid scale group size configuration"); + } + + // ==================== MAIN COMPUTATION LOOP ==================== + for (int32_t k_outer = 0; k_outer < k_outer_max; k_outer++) { + // Calculate pointers for this K-outer iteration + const uint8_t* a = reinterpret_cast(A) + k_outer * bm * kfactor / ngroups_per_elem; + + // Calculate scales pointer based on configuration + const float* scales = one_scale ? + reinterpret_cast(Scales) : // Single global scale + (has_zero_point ? + reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m * 2 : // Scale + zero_point pairs + reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m); // Scales only + + // Calculate LUT pointers + const int8_t* lut = reinterpret_cast(LUT) + k_outer * kfactor * (1 << g); // 2^g = 16 for g=4 + const float* lut_scales = reinterpret_cast(LUT_Scales) + + (k_outer * kfactor * g / act_group_size); + const float* lut_biases = reinterpret_cast(LUT_Biases) + + (k_outer * kfactor * g / act_group_size); + + // Select appropriate kernel template based on configuration + // For standard 2-bit, kfactor=16, BlkLen=64: actk = 64/4 = 16 + if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases); + } + // actk == 8 variants (for BlkLen=32) + else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases); + } + // kfactor == 8 variants + else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases); + } else { + // No matching kernel template found + // ORT_THROW("No matching kernel: has_scale=%d, kfactor=%d, bits=%d, actk=%d, has_zero_point=%d, one_scale=%d\n", + // has_scale, kfactor, bits, actk, has_zero_point, one_scale); + ORT_THROW("Reached else case"); + } + } + + // ==================== GATHER RESULTS ==================== + // Gather bit-plane results into final output + // Only support 2-bit in this implementation + tbl_g4_int8_float_gather_bit2_impl(m, C_global, CBits, reinterpret_cast(C)); + + // ==================== CLEANUP ==================== + delete[] C_global; + delete[] CBits; +} void QuantizeARow_CompInt8( @@ -515,13 +793,16 @@ QuantizeARow_CompInt8( size_t /*CountK*/, std::byte* /*QuantA*/ ) { - // Not implemented yet. + // placeholder so that dispatch doesn't break + // TODO: figure out a way that we can omit this altogether } + // Kernel dispatch structure definition. const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLUTGenKernelAvx2 = []() { MLAS_QNBIT_LUT_GEMM_DISPATCH d; d.GenerateLUT = GenerateLUT_avx2; + d.ComputeGemm = TMACComputeGemm_avx2; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h index 5e8aefb792265..a12abc76acd3d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h @@ -28,25 +28,27 @@ Q2BitGemmPerGemmWorkspaceSize( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); -size_t -SQ2BitGemmKernel_CompInt8_avx2( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float* Bias +void +GenerateLUT_avx2( + int32_t group_size, + int8_t lut, + const float* b, + float* scales, + float* biases, + int K ); -void QuantizeARow_CompInt8( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA -); +void +TMACComputeGemm_avx2( + const void* A, + const void* a_scales, + const void* LUT, + const void* LUT_Scales, + const void* LUT_Biases, + void* C, + int bm, + int K, + int M, + int N, + size_t BlkLen +); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 6d5d39abd039c..ef2c52c0d219d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1463,9 +1463,6 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; - return d; }(); @@ -1491,8 +1488,5 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; - return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 51f1ef79c4898..02d4092b411f4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -495,8 +495,5 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; - return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 7ff1000d267f9..c24d7ffacaa0c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -480,6 +480,5 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; return d; }(); From f9a9b47d93c6639639e8dc675046e197a4aa7920 Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Thu, 16 Oct 2025 19:49:12 +0000 Subject: [PATCH 25/33] bug fixes --- .../cpu/quantization/matmul_nbits.cc | 38 ++++++++++++++++--- onnxruntime/core/mlas/lib/qlutgemm.cpp | 23 +++++------ .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 2 +- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 2eb322abab94f..af56e4cdfae36 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -16,6 +16,8 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" +#include "core/platform/threadpool.h" +#include "core/util/thread_utils.h" namespace onnxruntime { namespace contrib { @@ -182,15 +184,39 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); is_packed = false; - // if (has_g_idx_ || has_unquantized_zero_point_) { + // if (has_g_idx_ || has_unquantized_zero_point_) // TODO: this part modified so i can test ek atmulnbits if (has_g_idx_) { return Status::OK(); } - if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && compute_type_ != TMAC) { + return Status::OK(); + } + if (compute_type_ == TMAC && !MlasIsTMACAvailable(nbits_, block_size_)) { return Status::OK(); } + + // Create a temporary threadpool for parallel packing + // This is used during model load time to speed up weight prepacking + std::unique_ptr temp_threadpool; + concurrency::ThreadPool* threadpool_ptr = nullptr; + + // Only create threadpool for operations that can benefit from it + if (compute_type_ == TMAC || compute_type_ == SQNBIT_CompInt8) { + OrtThreadPoolParams tpo; + tpo.thread_pool_size = 4; // Use default (typically number of cores) + tpo.allow_spinning = false; // Don't spin during model load + tpo.auto_set_affinity = false; + + temp_threadpool = concurrency::CreateThreadPool( + &Env::Default(), + tpo, + concurrency::ThreadPoolType::INTRA_OP); + + threadpool_ptr = temp_threadpool.get(); + } + if (input_idx == InputIndex::B) { const Tensor* scales = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); @@ -203,7 +229,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All auto scale_ptr = scales ? scales->DataRaw() : nullptr; packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, - has_zp_input_, nullptr, nullptr); + has_zp_input_, nullptr, threadpool_ptr); is_packed = true; } else if (compute_type_ == SQNBIT_CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 @@ -239,10 +265,12 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } else { packed_scales_zp_size_ = N_ * K_ / block_size_; packed_scales_zp_ = IAllocator::MakeUniquePtr(alloc, packed_scales_zp_size_, true); - MlasTMACPackScalesAndZeroPoints(N_, K_, nbits_, block_size_,has_zp_input_, packed_scales_zp_.get(), scales_ptr, nullptr); + MlasTMACPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, nullptr); } } } + + // Threadpool will be automatically destroyed when temp_threadpool goes out of scope return Status::OK(); } @@ -800,7 +828,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() // with B directly too. - if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) || (compute_type_ == TMAC && MlasIsTMACAvailable(nbits_, block_size_))) { return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); } } diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index 068c9fa371327..1aed867c1841b 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -147,7 +147,7 @@ void MlasTMACPackScalesAndZeroPoints( bool MLASCALL MlasIsTMACAvailable( size_t /*BlkBitWidth*/, size_t /*BlkLen*/ -) +) // TODO: fix the below to use smthg besides the gen kernel { const auto* Dispatch = GetMlasPlatform().LUTGenKernel; return Dispatch != nullptr; @@ -176,16 +176,17 @@ void MLASCALL MlasTmac( } const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2, BlkLen); - size_t lut_size = CalculateLUTSize(K, M, tmac_params.bits); + size_t lut_size = CalculateLUTSize(K, M, tmac_params.g); auto lut_buffer = std::make_unique(lut_size); - const size_t lut_meta_size = (K / BlkLen) * M; + const size_t lut_scales_size_meta = 64; + const size_t lut_meta_size = 64 * M * tmac_params.g; // TODO: 64 should be stored as lut_scales_size auto biases_float = std::make_unique(lut_meta_size); auto scales_float = std::make_unique(lut_meta_size); const auto* a_float = reinterpret_cast(A); // Activation data - const int num_groups = static_cast(K / BlkLen); + // const int num_groups = static_cast(K / BlkLen); // Parallelize over M (batch dimension) // Each iteration processes one row of the activation matrix @@ -195,7 +196,7 @@ void MLASCALL MlasTmac( [&](ptrdiff_t ine11) { const size_t row_offset = static_cast(ine11) * K; const size_t lut_offset = static_cast(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT - const size_t scale_bias_offset = static_cast(ine11) * num_groups; + const size_t scale_bias_offset = static_cast(ine11) * lut_scales_size_meta; // Call the dispatch function for this row Dispatch->GenerateLUT( @@ -216,11 +217,11 @@ void MLASCALL MlasTmac( // TODO: fix the below 4 // Matrix multiplication: Output[N×M] = QuantBData[N×K] × Weights[K×M] - const size_t OutputRows = M; // Number of output features - const size_t OutputCols = N; // Batch size - const size_t NumTiles = M * bits / bm; + const size_t OutputRows = N; // Number of output features + const size_t OutputCols = M; // Batch size + const size_t NumTiles = 8; // hardcoding -- TODO: should be moved to tmac kernel config - const size_t ChunkSize0 = M / NumTiles; + const size_t ChunkSize0 = N / NumTiles; const size_t ChunkSize1 = tmac_params.chunk_n; // process one batch item at a time // In llama.cpp terminology (note the swap!): @@ -233,7 +234,7 @@ void MLASCALL MlasTmac( const size_t total_chunks = nchunk0 * nchunk1; // Pre-calculate sizes for offset calculations - const size_t w_size = OutputRows * K * bits / 8; + const size_t w_size = N * K * bits / 8; const size_t w_chunk_size = w_size / NumTiles; // Determine weight-scale layout. These should be provided by the caller or inferred from the packed weights. @@ -315,7 +316,7 @@ void MLASCALL MlasTmac( static_cast(bm), // bm static_cast(K), // K dimension static_cast(M), // K dimension - static_cast(N), // N dimension (batch size) + 1, BlkLen // Weight quantization group size ); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 8c716c39a2eb6..c85456551bb92 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -656,7 +656,7 @@ void TMACComputeGemm_avx2( int N, size_t BlkLen // Weight quantization group size (q_group_size) ) { - // Validate batch size + // // Validate batch size if (N != 1) { throw std::runtime_error("N > 1 is not supported yet"); } From 5687e5e66c2b753e2a9b18975aa89efd173eb138 Mon Sep 17 00:00:00 2001 From: vraspar Date: Tue, 21 Oct 2025 13:37:14 -0700 Subject: [PATCH 26/33] Fix bug in scale unpacking --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index 1aed867c1841b..cf5d52307ada4 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -111,7 +111,7 @@ void MlasTMACPackScalesAndZeroPoints( if (HasZeroPoint) { // zp are two bit packed size_t elem_idx = idx % num_elem_per_byte; - uint8_t v = (QuantBZeroPoint[idx / num_elem_per_byte] >> (elem_idx * bits) & (1 << bits)) - 1; + uint8_t v = QuantBZeroPoint[idx / num_elem_per_byte] >> (elem_idx * bits) & (1 << bits) - 1; zp = static_cast(v); // Note: TMAC does this during model conversion. Since, we follow ORT format, we need to do it here. From 6f08418639b0f4a97ad251eaf90036c8b2937d39 Mon Sep 17 00:00:00 2001 From: vraspar Date: Tue, 28 Oct 2025 16:53:45 -0700 Subject: [PATCH 27/33] Fix issues with TMAC GEMM kernels and remove hard coded variables --- .../cpu/quantization/matmul_nbits.cc | 26 ++- onnxruntime/core/mlas/inc/mlas_qnbit.h | 28 ++- onnxruntime/core/mlas/lib/qlutgemm.cpp | 207 +++++++++++++----- onnxruntime/core/mlas/lib/qlutgemm.h | 35 +-- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 6 +- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 98 +++++---- .../test/mlas/unittest/test_sqlutgemm.cpp | 45 ++++ .../test/mlas/unittest/test_sqnbitgemm.cpp | 60 +++-- 8 files changed, 358 insertions(+), 147 deletions(-) create mode 100644 onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index af56e4cdfae36..51ed04da617af 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -201,21 +201,24 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All // This is used during model load time to speed up weight prepacking std::unique_ptr temp_threadpool; concurrency::ThreadPool* threadpool_ptr = nullptr; - + // Only create threadpool for operations that can benefit from it if (compute_type_ == TMAC || compute_type_ == SQNBIT_CompInt8) { OrtThreadPoolParams tpo; tpo.thread_pool_size = 4; // Use default (typically number of cores) tpo.allow_spinning = false; // Don't spin during model load tpo.auto_set_affinity = false; - + temp_threadpool = concurrency::CreateThreadPool( &Env::Default(), tpo, concurrency::ThreadPoolType::INTRA_OP); - + threadpool_ptr = temp_threadpool.get(); } + if (compute_type_ == TMAC) { + InitTMACKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); + } if (input_idx == InputIndex::B) { const Tensor* scales = nullptr; @@ -254,22 +257,21 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } else if (compute_type_ == TMAC) { if (input_idx == InputIndex::scales && packed_b_ != nullptr) { auto scales_ptr = tensor.Data(); + packed_scales_zp_size_ = MlasTMACPackQuantScalesAndZeroPointsSize(N_, K_, block_size_, has_zp_input_); + packed_scales_zp_ = IAllocator::MakeUniquePtr(alloc, packed_scales_zp_size_, true); + + // TODO(vraspar): improve this logic block if (has_zp_input_) { const Tensor* zero_points = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); auto zero_points_ptr = zero_points->Data(); - - packed_scales_zp_size_ = N_ * K_ / block_size_ * 2; - packed_scales_zp_ = IAllocator::MakeUniquePtr(alloc, packed_scales_zp_size_, true); MlasTMACPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, zero_points_ptr); } else { - packed_scales_zp_size_ = N_ * K_ / block_size_; - packed_scales_zp_ = IAllocator::MakeUniquePtr(alloc, packed_scales_zp_size_, true); MlasTMACPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, nullptr); } - } + } } - + // Threadpool will be automatically destroyed when temp_threadpool goes out of scope return Status::OK(); @@ -384,8 +386,10 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, // TODO: add the logic for generating lookup table here -- for now we can assume that // 2 bits + acc level 4 = use look up table but in the future adapt so that we use a mamtulnbits attr to decide // if we want to do lut generation + + // TODO(vraspar): Should we batch it here? if (compute_type_ == TMAC) { - MlasTmac(a_data, block_size_, packed_b_.get(), scales_data, y_data, K, M, N, thread_pool); + MlasTmac(a_data, block_size_, packed_b_.get(), packed_scales_zp_.get(), y_data, K, M, N, thread_pool); return Status::OK(); } const size_t lda = helper.Lda(false); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 2dd73f49fc9c8..c83cb5b30bb3b 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -223,6 +223,18 @@ MlasQNBitGemmScalesPacked( bool HasZeroPoint ); +/** + * @brief Gets the size in float of the packed quantized B scales and zero points. + */ + +size_t MLASCALL +MlasTMACPackQuantScalesAndZeroPointsSize( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint +); + /** * @brief Packs the scales and zero points into a format that the TMAC kernel expects. */ @@ -252,6 +264,14 @@ MlasIsTMACAvailable( size_t BlkLen ); +void MLASCALL +InitTMACKernelConfig( + size_t M, + size_t N, + size_t nbits, + size_t block_size, + bool has_zp_point); + /** * @brief Executes TMAC compute * @@ -262,11 +282,11 @@ void MLASCALL MlasTmac( const void* A, size_t BlkLen, - const void* QuantBData, - const float* QuantBScale, + const void* QuantBData, + const float* QuantBScale, void* C, int K, - int M, + int M, int N, MLAS_THREADPOOL* threadpool -); \ No newline at end of file +); diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index cf5d52307ada4..a42e6a030eabd 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -10,16 +10,24 @@ module includes kernel functions for generating LUT for T-MAC GEMM optimization #include "qlutgemm.h" +#include + /** T-MAC GEMM kernel Config */ static std::unordered_map tmac_kernel_configs; - - - -const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits, size_t block_size) { +const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits) { std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits); if (tmac_kernel_configs.count(key)) { return tmac_kernel_configs[key]; + } else { + ORT_THROW("T-MAC kernel parameters not initialized for M=", M, ", N=", N, ", nbits=", nbits); + } +} + +void InitTMACKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zp_point) { + std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits); + if (tmac_kernel_configs.count(key)) { + return; } MlasTMACKernelParams params; @@ -53,10 +61,11 @@ const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits std::vector bns = {8, 16, 32, 64}; std::vector kfactors = {8, 16}; - // TODO: add profile based policy + // TODO(vraspar): add profile based policy int threads = std::thread::hardware_concurrency(); float smallest_penalty = 1e9; + params.bm = bms[0]; for (int bm: bms) { if (M % (bm/nbits) != 0 || bm % nbits != 0) { continue; @@ -71,6 +80,7 @@ const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits } size_t largest_kfactor = 0; + params.kfactor = kfactors[0]; for (size_t kfactor: kfactors) { if ((kfactor < params.actk) || (kfactor * params.g > params.q_group_size)) { continue; @@ -81,8 +91,13 @@ const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits } } + params.n_tiles_num = M * params.bits / params.bm; + params.has_scale = true; // TODO(vraspar): TMAC supports only scale for now + params.has_zero_point = has_zp_point; + params.one_scale = false; //TODO(vraspar): support one scale case for bitnet + tmac_kernel_configs[key] = params; - return tmac_kernel_configs[key]; + return; } void MlasTMACPackScalesAndZeroPoints( @@ -96,7 +111,7 @@ void MlasTMACPackScalesAndZeroPoints( const uint8_t* QuantBZeroPoint ) { - const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2, BlkLen); + const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, BitWidth); const size_t bits = tmac_params.bits; const size_t simd_n_out = tmac_params.simd_n_out; const size_t bm = tmac_params.bm; @@ -111,6 +126,7 @@ void MlasTMACPackScalesAndZeroPoints( if (HasZeroPoint) { // zp are two bit packed size_t elem_idx = idx % num_elem_per_byte; + // TODO(vraspar): logically correct but not readable uint8_t v = QuantBZeroPoint[idx / num_elem_per_byte] >> (elem_idx * bits) & (1 << bits) - 1; zp = static_cast(v); @@ -121,11 +137,21 @@ void MlasTMACPackScalesAndZeroPoints( zp = zp * scale; // store scale * zp } + // TODO(vraspar): fix when k < BlkLen and nb1 is 0 size_t nb1 = K / BlkLen; - size_t nb0 = bm / BitWidth * nb1; - size_t new_im = idx / nb0; - size_t new_ibm = (idx % nb0) / nb1; - size_t new_ik = (idx % nb1); + size_t nb0 = bm / bits * nb1; + + size_t new_im, new_ibm, new_ik; + if (nb1 == 0) { + new_im = 0; + new_ibm = 0; + new_ik = 0; + + } else { + new_im = idx / nb0; + new_ibm = (idx % nb0) / nb1; + new_ik = (idx % nb1); + } if (HasZeroPoint) { size_t new_isimd = new_ibm % simd_n_out; @@ -147,15 +173,41 @@ void MlasTMACPackScalesAndZeroPoints( bool MLASCALL MlasIsTMACAvailable( size_t /*BlkBitWidth*/, size_t /*BlkLen*/ -) // TODO: fix the below to use smthg besides the gen kernel +) // TODO(Vraspar): fix the below to use smthg besides the gen kernel, add ComputeGemm { const auto* Dispatch = GetMlasPlatform().LUTGenKernel; return Dispatch != nullptr; // return Dispatch != nullptr && BlkLen == 4; // only support group sizes of 4 for now } -size_t CalculateLUTSize(int k, int m, size_t group_size) { - return k * m * group_size; +size_t MLASCALL MlasTMACPackQuantScalesAndZeroPointsSize( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint +) +{ + // TODO(vraspar): support one scale case + if (HasZeroPoint) { + return N * K / BlkLen * 2; + } else { + return N * K / BlkLen; + } +} + +size_t +CalculateLUTBufferSize(size_t n, size_t k, size_t m, const MlasTMACKernelParams& tmac_params) { + constexpr size_t kAllockAligment = 64; + const size_t lut_scales_size = k / tmac_params.act_group_size; + + + size_t wsize = k * m * 4 * sizeof(int8_t); // 4 bytes per k element for 2-bit LUT + wsize += lut_scales_size * m * 2 * sizeof(float); // scales + biases + + wsize = ((wsize - 1) / kAllockAligment + 1) * kAllockAligment; + + // TODO(vrapar): add temp buffer for FP16 + return wsize; } void MLASCALL MlasTmac( @@ -175,14 +227,41 @@ void MLASCALL MlasTmac( ORT_THROW("TMAC not supported in this configuration."); } - const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2, BlkLen); - size_t lut_size = CalculateLUTSize(K, M, tmac_params.g); - auto lut_buffer = std::make_unique(lut_size); - const size_t lut_scales_size_meta = 64; - const size_t lut_meta_size = 64 * M * tmac_params.g; // TODO: 64 should be stored as lut_scales_size - auto biases_float = std::make_unique(lut_meta_size); - auto scales_float = std::make_unique(lut_meta_size); + + /** TODO(vraspar): The biases_float and scales float values don't make sense + * FP 16 + * QLUT K(ne10) x M(ne11) x 4 bytes + * Scales: lut_scales_size * M * 2 bytes + * Biases: lut_scales_size * M * 2 bytes + * Needs FP 16 conversion Buffer: max(K, N) * M * 2 bytes + * + * FP 32 + * QLUT K x M x 4 bytes + * Scales: lut_scales_size * M * 4 bytes + * Biases: lut_scales_size * M * 4 bytes + * + * Currently, we only support FP32, add FP16 support later which requires conversion buffer + * + * LUT Buffer for FP32 : K * M * 4 * sizeof(uint8_t) bytes + lut_scale_size * m * 2 * sizeof(float) bytes + allignment + * + */ + + // n_tiles_num = m * bits / bm; + + // TODO(vraspar): support other bitwidths + const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2); + const size_t lut_scales_size = K / tmac_params.act_group_size; + size_t lut_buffer_size = CalculateLUTBufferSize(N, K, M, tmac_params); + + // make buffer of lut_buffer_size bytes + // TODO(vraspar): other way to do it + auto lut_buffer = std::make_unique(lut_buffer_size); + + int8_t* qlut = reinterpret_cast(lut_buffer.get()); + float* lut_scales = reinterpret_cast(qlut + K * M * 4); // after lut + float* lut_biases = reinterpret_cast(lut_scales + lut_scales_size * M); // after scales + const auto* a_float = reinterpret_cast(A); // Activation data @@ -190,38 +269,50 @@ void MLASCALL MlasTmac( // Parallelize over M (batch dimension) // Each iteration processes one row of the activation matrix + // TODO(vraspar): Ideally we have to do block parallelism here + MlasTrySimpleParallel( - threadpool, + threadpool, static_cast(M), [&](ptrdiff_t ine11) { const size_t row_offset = static_cast(ine11) * K; const size_t lut_offset = static_cast(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT - const size_t scale_bias_offset = static_cast(ine11) * lut_scales_size_meta; + const size_t scale_bias_offset = static_cast(ine11) * lut_scales_size; // Call the dispatch function for this row + // ggml_tmac_mul_mat_task_init Dispatch->GenerateLUT( - static_cast(BlkLen), - reinterpret_cast(lut_buffer.get()) + lut_offset, // Output LUT for this row - const_cast(a_float + row_offset), // Input activation for this row - scales_float.get() + scale_bias_offset, // Scales for this row - biases_float.get() + scale_bias_offset, // Biases for this row - K + const_cast(a_float + row_offset), // Input activation for this row + qlut + lut_offset, // Output LUT for this row + lut_scales + scale_bias_offset, // Scales for this row + lut_biases + scale_bias_offset, // Biases for this row + M, + N, + K, + tmac_params.act_group_size ); } ); // all relevant LUT's have been generated // equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line + + const size_t n_tiles_num = tmac_params.n_tiles_num; + assert(N % n_tiles_num == 0); + const size_t bm = tmac_params.bm; // TODO: hardcoding for now - const size_t bits = tmac_params.bits; + const size_t bits = tmac_params.bits; + + // Pre-calculate sizes for offset calculations + const size_t w_size = N * K * bits / 8; + const size_t w_chunk_size = w_size / n_tiles_num; // TODO: fix the below 4 // Matrix multiplication: Output[N×M] = QuantBData[N×K] × Weights[K×M] const size_t OutputRows = N; // Number of output features const size_t OutputCols = M; // Batch size - const size_t NumTiles = 8; // hardcoding -- TODO: should be moved to tmac kernel config - const size_t ChunkSize0 = N / NumTiles; + const size_t ChunkSize0 = N / n_tiles_num; const size_t ChunkSize1 = tmac_params.chunk_n; // process one batch item at a time // In llama.cpp terminology (note the swap!): @@ -233,46 +324,41 @@ void MLASCALL MlasTmac( const size_t nchunk1 = (OutputCols + ChunkSize1 - 1) / ChunkSize1; const size_t total_chunks = nchunk0 * nchunk1; - // Pre-calculate sizes for offset calculations - const size_t w_size = N * K * bits / 8; - const size_t w_chunk_size = w_size / NumTiles; + // TODO(vraspar): support one_scale case // Determine weight-scale layout. These should be provided by the caller or inferred from the packed weights. // For now we default to per-group symmetric quantization (no zero-point, not one-scale). - bool one_scale = false; // TODO: expose this as a function parameter if needed - bool has_zero_point = false; // TODO: expose this as a function parameter if needed // Total number of scale (float) entries for the whole weight matrix: // - if one_scale: single global scale (1) // - otherwise: number of quantization groups = (M * K / BlkLen) // and if zero-points are present each group stores (scale, zero_point) -> *2 const size_t groups_total = static_cast(M) * static_cast(K) / BlkLen; - const size_t scales_size_total = one_scale ? 1 : (groups_total * (has_zero_point ? 2 : 1)); + const size_t scales_size_total = MlasTMACPackQuantScalesAndZeroPointsSize( + static_cast(N), + static_cast(K), + BlkLen, + tmac_params.has_zero_point + ); - // n_tile_num == NumTiles (number of M tiles) - const size_t n_tile_num = NumTiles; // Per-tile scales size = total scales size divided evenly across tiles. // If one_scale is true we do not advance the scales pointer per tile, so set per tile size to 0 size_t scales_size_per_tile = 0; - if (!one_scale) { - if (scales_size_total % n_tile_num != 0) { - // Sanity: scales should partition evenly across tiles. If they don't, choose floor division - // and document that callers must layout scales accordingly. - // Prefer to error loudly in debug builds. - fprintf(stderr, "Warning: scales_size_total=%zu is not divisible by n_tile_num=%zu; using floor division.\n", scales_size_total, n_tile_num); - } - scales_size_per_tile = scales_size_total / n_tile_num; + + if (scales_size_total % n_tiles_num != 0) { + // Sanity: scales should partition evenly across tiles. If they don't, choose floor division + // and document that callers must layout scales accordingly. + // Prefer to error loudly in debug builds. + fprintf(stderr, "Warning: scales_size_total=%zu is not divisible by n_tiles_num=%zu; using floor division.\n", scales_size_total, n_tiles_num); } + scales_size_per_tile = scales_size_total / n_tiles_num; + // Note: when one_scale == true, callers should pass a pointer to a single scale value (scales_offset=0 will be used) // Cast to appropriate types const auto* packed_weights = reinterpret_cast(QuantBData); - const int8_t* lut_i8 = reinterpret_cast(lut_buffer.get()); - - // lut_scales_size is the number of scale values per batch item (= K / BlkLen) - const size_t lut_scales_size = static_cast(K) / BlkLen; // Parallelize over the 2D chunk grid MlasTrySimpleParallel( @@ -306,21 +392,22 @@ void MLASCALL MlasTmac( const size_t dst_offset = OutputRows * ine11 + ichunk0 * ChunkSize0; // Call the dispatch function to compute this tile + // Note M and N are swapped in TMAC terminology + // TODO(vrapsar): fix this M and N swapp mess Dispatch->ComputeGemm( - const_cast(reinterpret_cast(packed_weights + w_offset)), // Weight tile + packed_weights + w_offset, // Weight tile QuantBScale + scales_offset, // Weight scales for this tile - const_cast(reinterpret_cast(lut_i8 + qlut_offset)), // LUT for this batch row - scales_float.get() + lut_scales_offset, // LUT scales - biases_float.get() + lut_scales_offset, // LUT biases + qlut + qlut_offset, // LUT for this batch row + lut_scales + lut_scales_offset, // LUT scales + lut_biases + lut_scales_offset, // LUT biases reinterpret_cast(C) + dst_offset, // Output location - static_cast(bm), // bm static_cast(K), // K dimension - static_cast(M), // K dimension - 1, + static_cast(N), // K dimension + static_cast(M), BlkLen // Weight quantization group size ); } } } ); -} \ No newline at end of file +} diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h index 277c56feac9b0..f2d024e117a1c 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.h +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -23,29 +23,38 @@ struct MlasTMACKernelParams { size_t simd_n_in; size_t simd_n_out; size_t chunk_n; + size_t n_tiles_num; + + + bool has_scale; + bool has_zero_point; + bool one_scale; + + }; -const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits, size_t block_size); +const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits); typedef void(MLAS_QNBIT_GEMM_LUT_GEN)( - int32_t group_size, - int8_t* lut, - const float* b, - float* scales, - float* biases, - int K + const float * b, + int8_t * qlut, + float* lut_scales, + float* lut_biases, + size_t M, + size_t K, + size_t N, + size_t act_group_size ); typedef void(MLAS_QNBIT_LUT_GEMM_COMPUTE)( - const void* A, - const void* a_scales, - const void* LUT, - const void* LUT_Scales, - const void* LUT_Biases, + const uint8_t* weights, + const float* scales, + const int8_t* LUT, + const float* LUT_Scales, + const float* LUT_Biases, void* C, - int bm, int K, int M, // batch size (number of rows in activation) int N, diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index c9a9e1b7b9ba8..4f26dbed5f49a 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -52,7 +52,7 @@ GetQNBitGemmVariant( if ((BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { if (BlkBitWidth == 2) { if (ComputeType == TMAC) { - return SQNBitGemmVariant_BitWidth2_CompInt8; // TODO: rename this kernel + return SQNBitGemmVariant_BitWidth2_CompInt8; // TODO(vraspar): rename this kernel } } else if (BlkBitWidth == 4) { if (ComputeType == SQNBIT_CompFp32) { @@ -88,6 +88,10 @@ MlasIsQNBitGemmAvailable( return false; } + if (ComputeType == TMAC) { + return MlasIsTMACAvailable(BlkBitWidth, BlkLen); + } + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index c85456551bb92..a3f7023e6e7fc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -203,13 +203,13 @@ void SQ2BitGemmPackQuantBData( // T-MAC like configuration (approved): // bits=2, g=4, ngroups_per_elem=8/g=2, simd_n_in=16, simd_n_out=8, bm=256, kfactor=16 - const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2, BlkLen); + const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2); const size_t bits = 2; const size_t g = tmac_params.g; const size_t ngroups_per_elem = tmac_params.ngroups_per_elem; const size_t simd_n_in = tmac_params.simd_n_in; const size_t simd_n_out = tmac_params.simd_n_out; - const size_t bm = tmac_params.bm; + const size_t bm = tmac_params.bm; const size_t kfactor = tmac_params.kfactor; // Basic checks @@ -349,6 +349,7 @@ Q2BitGemmPerGemmWorkspaceSize( } void partial_max_g4_int8_k8(float* lut_scales, const float* b) { + // TODO(vraspar): add support for arm neon const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); __m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1); __m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1); @@ -367,14 +368,13 @@ void partial_max_g4_int8_k8(float* lut_scales, const float* b) { *lut_scales = std::max(*lut_scales, scales); } -void lut_ctor_g4_int8_impl( - int32_t group_size, +inline void lut_ctor_g4_int8_impl( + int32_t act_k, int8_t* qlut, const float* b, float* lut_scales, float* lut_biases ) { - const int act_k = group_size; // we assume K == group_size for now __m256 vec_lut[16]; float biases = 0.0; @@ -409,7 +409,9 @@ PRAGMA_UNROLL } PRAGMA_UNROLL for (int g = 0; g < 16; g += 2) { - vec_lut[g] = -vec_lut[15 - g]; + //vec_lut[g] = -vec_lut[15 - g]; + const __m256 neg_mask = _mm256_set1_ps(-0.0f); // all lanes have sign bit set + vec_lut[g] = _mm256_xor_ps(vec_lut[15 - g], neg_mask); } biases += _mm256_addv_ps(vec_lut[0]); @@ -480,25 +482,27 @@ PRAGMA_UNROLL // based on lut_ctor_g4_int8_impl void GenerateLUT_avx2( - int32_t group_size, - int8_t* lut, - const float* b, - float* scales, - float* biases, - int K + const float* b, + int8_t* qlut, + float* lut_scales, + float* lut_biases, + size_t M, + size_t K, + size_t N, + size_t act_group_size ) { - const int kk_outer_max = K / group_size; + const size_t kk_outer_max = K / act_group_size; for (int32_t kk_outer = 0; kk_outer < kk_outer_max; ++kk_outer) { // compute partial max - directly reset scale to 0.0 - scales[kk_outer] = 0.0f; - for (int32_t k_outer = 0; k_outer < group_size / 32; ++k_outer) { - partial_max_g4_int8_k8(&scales[kk_outer], &b[(kk_outer * group_size) + (k_outer * 32)]); + lut_scales[kk_outer] = 0.0f; // partial max reset + for (int32_t k_outer = 0; k_outer 1 is not supported yet"); } + // get kernel config + const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(M, K, 2); + + + // ==================== CONFIGURATION ==================== // Fixed parameters for this kernel implementation - bool has_zero_point = true; // Whether weights have zero-points (interleaved with scales) - bool one_scale = false; // Whether using single global scale for all weights - constexpr int bits = 2; // 2-bit quantization - constexpr int g = 4; // Packing group size - constexpr int ngroups_per_elem = 2; // 8 / g = 2 - constexpr int kfactor = 16; // K-dimension blocking factor - constexpr bool has_scale = true; // Always use weight scales + bool has_zero_point = tmac_params.has_zero_point; // Whether weights have zero-points (interleaved with scales) + bool one_scale = tmac_params.one_scale; // Whether using single global scale for all weights + + const int bits = tmac_params.bits; // 2-bit quantization + const int g = tmac_params.g; // Packing group size + const int ngroups_per_elem = tmac_params.ngroups_per_elem; // 8 / g = 2 + const int kfactor = tmac_params.kfactor; // K-dimension blocking factor + + const bool has_scale = tmac_params.has_scale; // Always use weight scales // Parameters derived from inputs - const int q_group_size = static_cast(BlkLen); // Weight quant group size - const int act_group_size = static_cast(BlkLen); // Activation group size (same as weight) - const int actk = act_group_size / g; // CRITICAL: = 16 for BlkLen=64, NOT BlkLen! + const int q_group_size = tmac_params.q_group_size; // Weight quant group size + const int act_group_size = tmac_params.act_group_size; // Activation group size (same as weight) + const int actk = tmac_params.actk; // CRITICAL: = 16 for BlkLen=64, NOT BlkLen! + + const int bm = tmac_params.bm; + int m = bm / bits; // Validate configuration assert(bm % bits == 0); @@ -693,7 +706,7 @@ void TMACComputeGemm_avx2( float* C_global = new float[m]; // Reset accumulator buffer to zero - tbl_int32_reset(bm * sizeof(float) / sizeof(int32_t), + tbl_int32_reset(bm * sizeof(float) / sizeof(int32_t), reinterpret_cast(CBits)); // ==================== CALCULATE LOOP PARAMETERS ==================== @@ -719,20 +732,20 @@ void TMACComputeGemm_avx2( // ==================== MAIN COMPUTATION LOOP ==================== for (int32_t k_outer = 0; k_outer < k_outer_max; k_outer++) { // Calculate pointers for this K-outer iteration - const uint8_t* a = reinterpret_cast(A) + k_outer * bm * kfactor / ngroups_per_elem; + const uint8_t* a = A + k_outer * bm * kfactor / ngroups_per_elem; // Calculate scales pointer based on configuration - const float* scales = one_scale ? + const float* scales = one_scale ? reinterpret_cast(Scales) : // Single global scale - (has_zero_point ? + (has_zero_point ? reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m * 2 : // Scale + zero_point pairs reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m); // Scales only // Calculate LUT pointers const int8_t* lut = reinterpret_cast(LUT) + k_outer * kfactor * (1 << g); // 2^g = 16 for g=4 - const float* lut_scales = reinterpret_cast(LUT_Scales) + + const float* lut_scales = reinterpret_cast(LUT_Scales) + (k_outer * kfactor * g / act_group_size); - const float* lut_biases = reinterpret_cast(LUT_Biases) + + const float* lut_biases = reinterpret_cast(LUT_Biases) + (k_outer * kfactor * g / act_group_size); // Select appropriate kernel template based on configuration @@ -779,6 +792,7 @@ void TMACComputeGemm_avx2( // ==================== GATHER RESULTS ==================== // Gather bit-plane results into final output // Only support 2-bit in this implementation + // TODO(vraspar): extend to other bit-widths tbl_g4_int8_float_gather_bit2_impl(m, C_global, CBits, reinterpret_cast(C)); // ==================== CLEANUP ==================== diff --git a/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp new file mode 100644 index 0000000000000..505f634317489 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp @@ -0,0 +1,45 @@ +// /*++ + +// Copyright (c) Microsoft Corporation. All rights reserved. + +// Licensed under the MIT License. + +// Module Name: + +// test_sqlutgemm.h + +// Abstract: + +// Tests for MLAS T-MAC quantized GEMM. + +// --*/ + +// #include "test_util.h" +// #include "mlas_q4.h" +// #include "mlas_qnbit.h" + + +// static size_t MlasQLUTGemmTestAllShortExecuteTests() { +// size_t tests_registered = 0; + +// tests_registered += MlasQLUTGemmShortExecuteTest<2,16>::RegisterShortExecuteTests(); +// tests_registered += MlasQLUTGemmShortExecuteTest<2,32>::RegisterShortExecuteTests(); +// tests_registered += MlasQLUTGemmShortExecuteTest<2,64>::RegisterShortExecuteTests(); +// tests_registered += MlasQLUTGemmShortExecuteTest<2,128>::RegisterShortExecuteTests(); +// tests_registered += MlasQLUTGemmShortExecuteTest<2,256>::RegisterShortExecuteTests(); + +// tests_registered += MlasQLUTGemmShortExecuteTest<4,16>::RegisterShortExecuteTests(); +// tests_registered += MlasQLUTGemmShortExecuteTest<4,32>::RegisterShortExecuteTests(); +// tests_registered += MlasQLUTGemmShortExecuteTest<4,64>::RegisterShortExecuteTests(); +// tests_registered += MlasQLUTGemmShortExecuteTest<4,128>::RegisterShortExecuteTests(); +// tests_registered += MlasQLUTGemmShortExecuteTest<4,256>::RegisterShortExecuteTests(); + +// return tests_registered; +// } + +// static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { +// if (is_short_execute) { +// return MlasQLUTGemmTestAllShortExecuteTests(); +// } +// return 0; +// }); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 47002dd7eea72..426715b7138df 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -24,6 +24,8 @@ static constexpr const char* ComputeTypeName(MLAS_QNBIT_GEMM_COMPUTE_TYPE Comput return "Fp32"; case SQNBIT_CompInt8: return "Int8"; + case TMAC: + return "TMAC"; default: return "unknown"; } @@ -50,6 +52,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { MatrixGuardBuffer BufferC; MatrixGuardBuffer BufferCReference; + // TMAC LUT related buffers + MatrixGuardBuffer BufferPackedQuantScalesZP; + void CallGemm(size_t M, size_t N, size_t K, @@ -286,6 +291,10 @@ class MlasSQNBitGemmTest : public MlasTestBase { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } + if (ComputeType == TMAC) { + InitTMACKernelConfig(N, K, BlkBitWidth, BlkLen, QuantBZeroPoint != nullptr); + } + void* PackedQuantBDataWorkspace = nullptr; if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, !Symmetric, ComputeType); PackedQuantBDataSize > 0) { @@ -296,18 +305,34 @@ class MlasSQNBitGemmTest : public MlasTestBase { GetMlasThreadPool()); } - CallGemm(M, N, K, - A, /* lda */ K, - QuantBData, PackedQuantBDataWorkspace, QuantBScale, QuantBZeroPoint, - Bias, - C, /* ldc */ N, - Workspace, - ComputeType, - Threadpool); + float* PackedQuantScalesZPWorkspace = nullptr; + if (ComputeType == TMAC) { + bool has_zp_input = QuantBZeroPoint != nullptr; + const auto PackedQuantScalesZPSize = MlasTMACPackQuantScalesAndZeroPointsSize(N, K, BlkBitWidth, has_zp_input); + PackedQuantScalesZPWorkspace = BufferPackedQuantScalesZP.GetBuffer(PackedQuantScalesZPSize); + MlasTMACPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, has_zp_input, PackedQuantScalesZPWorkspace, + QuantBScale, QuantBZeroPoint); + } + + if (ComputeType == TMAC) { + MlasTmac(A, BlkLen, QuantBData, PackedQuantScalesZPWorkspace, C, K, M, N, Threadpool); + + } else { + CallGemm(M, N, K, + A, /* lda */ K, + QuantBData, PackedQuantBDataWorkspace, QuantBScale, QuantBZeroPoint, + Bias, + C, /* ldc */ N, + Workspace, + ComputeType, + Threadpool); + } + if (ComputeType == SQNBIT_CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); - } else if (ComputeType == SQNBIT_CompInt8) { + } else if (ComputeType == SQNBIT_CompInt8 || ComputeType == TMAC) { + // use same reference implementation for TMAC as CompInt8 CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else { FAIL() << "Test is not implemented for compute type " @@ -362,6 +387,9 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); - // count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); - // count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); - // count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); - // count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); + // TODO(vraspar): enable these test for 2bit development and also add 3 bit test for TMAC + count += SQNBitGemmShortExecuteTest<2, 16>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests(); From 6191aadb19f73324ee3d594edd1b2cab9a81bdfe Mon Sep 17 00:00:00 2001 From: vraspar Date: Fri, 31 Oct 2025 10:47:19 -0700 Subject: [PATCH 28/33] Fix bug in LUT table generation --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index a42e6a030eabd..aaec0392141af 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -287,8 +287,8 @@ void MLASCALL MlasTmac( lut_scales + scale_bias_offset, // Scales for this row lut_biases + scale_bias_offset, // Biases for this row M, - N, K, + N, tmac_params.act_group_size ); } From f2de7764b2b8c91e20670dafcbb4ccd379022bf6 Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 10 Nov 2025 14:15:14 -0800 Subject: [PATCH 29/33] Fix casting issue --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index aaec0392141af..5a174679a17c6 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -400,10 +400,10 @@ void MLASCALL MlasTmac( qlut + qlut_offset, // LUT for this batch row lut_scales + lut_scales_offset, // LUT scales lut_biases + lut_scales_offset, // LUT biases - reinterpret_cast(C) + dst_offset, // Output location + reinterpret_cast(C) + dst_offset, // Output location static_cast(K), // K dimension static_cast(N), // K dimension - static_cast(M), + static_cast(1), BlkLen // Weight quantization group size ); } From 9ef6d75fbb30b6b4f592c3495cdafdd8319b6f57 Mon Sep 17 00:00:00 2001 From: vraspar Date: Thu, 13 Nov 2025 14:53:11 -0800 Subject: [PATCH 30/33] add session option and clean up --- .../onnxruntime_session_options_config_keys.h | 6 + .../cpu/quantization/matmul_nbits.cc | 81 ++++--- onnxruntime/core/mlas/inc/mlas_qnbit.h | 35 +++- onnxruntime/core/mlas/lib/qlutgemm.cpp | 198 ++++++++++++++++-- onnxruntime/core/mlas/lib/qlutgemm.h | 2 +- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 4 +- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 5 +- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 8 +- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 4 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 4 +- 10 files changed, 279 insertions(+), 68 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 314cf76cc8044..9ca77f62c6cd8 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -374,6 +374,12 @@ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFil // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; +// Use LUT based GEMM for quantized models when available. +// Option values: +// - "0": Do not use LUT based GEMM. [DEFAULT] +// - "1": Use LUT based GEMM when available. +static const char* const kOrtSessionOptionsMlasLUTGemm = "mlas.use_lut_gemm"; + // When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. // Refer to MatMulNBits op schema for more details. // If not provided, default is 4. diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 51ed04da617af..ceb3250c0b440 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -15,6 +15,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" #include "core/platform/threadpool.h" #include "core/util/thread_utils.h" @@ -48,6 +49,12 @@ typedef enum { template MLAS_QNBIT_GEMM_COMPUTE_TYPE GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + + // TODO(vraspar): check against session option + if (MlasIsLUTGemmAvailable(nbits, block_size)) { + return TMAC; + } + // For Fp32, only accuracy level 1 or 4 makes sense. // non-ARM CPU converts Fp16 to Fp32. // By converting Fp32 to Fp16, precision becomes worse. And due to the casting, @@ -57,9 +64,6 @@ GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { return SQNBIT_CompInt8; } - if (accuracy_level_attr == static_cast(Level5) && MlasIsTMACAvailable(nbits, block_size)) { - return TMAC; - } return SQNBIT_CompFp32; } @@ -107,6 +111,7 @@ class MatMulNBits final : public OpKernel { nbits_{narrow(info.GetAttr("bits"))}, has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()}, has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()}, + prefer_lut_gemm_{info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasLUTGemm) == "1"}, compute_type_{GetComputeType(nbits_, block_size_, info.GetAttr("accuracy_level"))} { const auto& node = info.node(); auto input_defs = node.InputDefs(); @@ -123,6 +128,7 @@ class MatMulNBits final : public OpKernel { "Only 2b, 4b and 8b quantization is supported for MatMulNBits op, additional bits support is planned."); const Tensor* tensor_zero_point = nullptr; has_zp_input_ = info.TryGetConstantInput(InputIndex::zero_points, &tensor_zero_point); + prefer_lut_gemm_ = true; } Status Compute(OpKernelContext* context) const override; @@ -142,6 +148,7 @@ class MatMulNBits final : public OpKernel { const bool has_g_idx_; const bool has_bias_; bool scales_are_packed_{false}; + bool prefer_lut_gemm_{false}; const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_; bool has_unquantized_zero_point_{false}; const bool column_wise_quant_{true}; @@ -176,6 +183,15 @@ class MatMulNBits final : public OpKernel { AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, const MatMulComputeHelper& helper) const; + + Status ComputeBPackedLUT(const Tensor* a, + const Tensor* scales, + const Tensor* zero_points, + const Tensor* bias, + Tensor* y, + AllocatorPtr& allocator, + concurrency::ThreadPool* thread_pool, + const MatMulComputeHelper& helper) const; }; template @@ -190,10 +206,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All return Status::OK(); } - if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && compute_type_ != TMAC) { + if (prefer_lut_gemm_ && !MlasIsLUTGemmAvailable(nbits_, block_size_)) { return Status::OK(); } - if (compute_type_ == TMAC && !MlasIsTMACAvailable(nbits_, block_size_)) { + + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && compute_type_ != TMAC) { return Status::OK(); } @@ -203,7 +220,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All concurrency::ThreadPool* threadpool_ptr = nullptr; // Only create threadpool for operations that can benefit from it - if (compute_type_ == TMAC || compute_type_ == SQNBIT_CompInt8) { + if (prefer_lut_gemm_ || compute_type_ == SQNBIT_CompInt8) { OrtThreadPoolParams tpo; tpo.thread_pool_size = 4; // Use default (typically number of cores) tpo.allow_spinning = false; // Don't spin during model load @@ -216,23 +233,33 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All threadpool_ptr = temp_threadpool.get(); } - if (compute_type_ == TMAC) { - InitTMACKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); - } if (input_idx == InputIndex::B) { + const Tensor* scales = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); - packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); - if (packed_b_size_ == 0) { - return Status::OK(); + if (prefer_lut_gemm_) { + MlasInitLUTGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); + packed_b_size_ = MlasLUTGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); + if (packed_b_size_ == 0) { + return Status::OK(); + } + auto qptr = tensor.DataRaw(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasLUTGemmPackQuantBData(N_, K_, nbits_, block_size_, static_cast(qptr), static_cast(packed_b_.get()), threadpool_ptr); + } else { + packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); + if (packed_b_size_ == 0) { + return Status::OK(); + } + auto qptr = tensor.DataRaw(); + auto scale_ptr = scales ? scales->DataRaw() : nullptr; + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, + has_zp_input_, nullptr, threadpool_ptr); + } - auto qptr = tensor.DataRaw(); - auto scale_ptr = scales ? scales->DataRaw() : nullptr; - packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, - has_zp_input_, nullptr, threadpool_ptr); is_packed = true; } else if (compute_type_ == SQNBIT_CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 @@ -254,10 +281,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All is_packed = true; } #endif // MLAS_TARGET_ARM64 - } else if (compute_type_ == TMAC) { + } else if (prefer_lut_gemm_) { if (input_idx == InputIndex::scales && packed_b_ != nullptr) { auto scales_ptr = tensor.Data(); - packed_scales_zp_size_ = MlasTMACPackQuantScalesAndZeroPointsSize(N_, K_, block_size_, has_zp_input_); + packed_scales_zp_size_ = MlasLUTPackScalesAndZeroPointsSize(N_, K_, block_size_, has_zp_input_); packed_scales_zp_ = IAllocator::MakeUniquePtr(alloc, packed_scales_zp_size_, true); // TODO(vraspar): improve this logic block @@ -265,9 +292,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All const Tensor* zero_points = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); auto zero_points_ptr = zero_points->Data(); - MlasTMACPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, zero_points_ptr); + MlasLUTPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, zero_points_ptr); } else { - MlasTMACPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, nullptr); + MlasLUTPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, nullptr); } } } @@ -355,7 +382,7 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; - if (input_idx == 1) { + if (input_idx == 1) { //TODO(vraspar): DO we need shared Prepacked buffer for TMAC, combine packing of weights + scales/ZP into one buffer ??? used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } @@ -388,8 +415,8 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, // if we want to do lut generation // TODO(vraspar): Should we batch it here? - if (compute_type_ == TMAC) { - MlasTmac(a_data, block_size_, packed_b_.get(), packed_scales_zp_.get(), y_data, K, M, N, thread_pool); + if (prefer_lut_gemm_) { + MlasLUTGemm(a_data, block_size_, packed_b_.get(), packed_scales_zp_.get(), y_data, K, M, N, thread_pool); return Status::OK(); } const size_t lda = helper.Lda(false); @@ -832,7 +859,11 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() // with B directly too. - if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) || (compute_type_ == TMAC && MlasIsTMACAvailable(nbits_, block_size_))) { + /* if (MlasIsLUTGemmAvailable(nbits_, block_size_) && prefer_lut_gemm_) { + return ComputeBPackedLUT(a, scales, zero_points, bias, y, allocator, thread_pool, helper); + }*/ + + if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) || (prefer_lut_gemm_ && MlasIsLUTGemmAvailable(nbits_, block_size_))) { return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); } } diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index c83cb5b30bb3b..fa3b6d843f2fa 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -223,12 +223,35 @@ MlasQNBitGemmScalesPacked( bool HasZeroPoint ); +size_t MLASCALL +MlasLUTGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +); + + +void MLASCALL +MlasLUTGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +); + + /** * @brief Gets the size in float of the packed quantized B scales and zero points. */ size_t MLASCALL -MlasTMACPackQuantScalesAndZeroPointsSize( +MlasLUTPackScalesAndZeroPointsSize( size_t N, size_t K, size_t BlkLen, @@ -239,10 +262,10 @@ MlasTMACPackQuantScalesAndZeroPointsSize( * @brief Packs the scales and zero points into a format that the TMAC kernel expects. */ void MLASCALL -MlasTMACPackScalesAndZeroPoints( +MlasLUTPackScalesAndZeroPoints( size_t N, size_t K, - size_t BitWidth, + size_t BlkBitWidth, size_t BlkLen, bool HasZeroPoint, float* PackedQuantBZPBegin, @@ -259,13 +282,13 @@ MlasTMACPackScalesAndZeroPoints( * MlasIsQNBitGemmAvailable by querying availability of the LUT-based strategy. */ bool MLASCALL -MlasIsTMACAvailable( +MlasIsLUTGemmAvailable( size_t BlkBitWidth, size_t BlkLen ); void MLASCALL -InitTMACKernelConfig( +MlasInitLUTGemmKernelConfig( size_t M, size_t N, size_t nbits, @@ -279,7 +302,7 @@ InitTMACKernelConfig( * Results will be stored in C. */ void MLASCALL -MlasTmac( +MlasLUTGemm( const void* A, size_t BlkLen, const void* QuantBData, diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index 5a174679a17c6..836d1814c4651 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -15,7 +15,7 @@ module includes kernel functions for generating LUT for T-MAC GEMM optimization /** T-MAC GEMM kernel Config */ static std::unordered_map tmac_kernel_configs; -const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits) { +const MlasTMACKernelParams& MlasGetLUTGemmKernelParams(size_t M, size_t N, size_t nbits) { std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits); if (tmac_kernel_configs.count(key)) { return tmac_kernel_configs[key]; @@ -24,7 +24,7 @@ const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits } } -void InitTMACKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zp_point) { +void MlasInitLUTGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zp_point) { std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits); if (tmac_kernel_configs.count(key)) { return; @@ -100,10 +100,176 @@ void InitTMACKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, b return; } -void MlasTMACPackScalesAndZeroPoints( + +size_t MlasLUTGemmPackQuantBDataSize( size_t N, size_t K, - size_t BitWidth, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); + const MlasTMACKernelParams& tmac_params = MlasGetLUTGemmKernelParams(N, K, BlkBitWidth); + const size_t PackedQuantBDataSize = (N * BlkBitWidth) * (K / tmac_params.g / tmac_params.ngroups_per_elem); + return PackedQuantBDataSize; +} + + +void +MlasLUTGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + // decompose W into w1,... w_bits create temp buffer buf2 of size N * bits * (K/g) + const MlasTMACKernelParams& tmac_params = MlasGetLUTGemmKernelParams(N, K, BlkBitWidth); + const size_t bits = tmac_params.bits; + const size_t g = tmac_params.g; + const size_t ngroups_per_elem = tmac_params.ngroups_per_elem; + const size_t simd_n_in = tmac_params.simd_n_in; + const size_t simd_n_out = tmac_params.simd_n_out; + const size_t bm = tmac_params.bm; + const size_t kfactor = tmac_params.kfactor; + + assert(BlkLen % g == 0); + assert((BlkLen / g) % kfactor == 0); + + const int mgroup = ngroups_per_elem * simd_n_in; // 32 + assert(bm % mgroup == 0); + assert(bm % bits == 0); + + uint8_t* buf = new uint8_t[N * bits * (K / g)]; + memset(buf, 0, N * bits * (K / g)); + + const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + for (size_t ik = 0; ik < K; ++ik) { + size_t idx = (im * K + ik); + size_t num_elem_per_byte = 8 / bits; + size_t elem_idx = idx % num_elem_per_byte; + + uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits); + + for (size_t ib = 0; ib < bits; ++ib) { + size_t new_ik = ik / g; + size_t shft_left = ik % g; + buf[im * bits * K / g + ib * K / g + new_ik] += ((v >> ib) & 1) << shft_left; + } + } + } + ); + + // Now buf contains the bit planes grouped by g along K + // Next, we need to do a multi-reshape/transpose into the final layout + + const size_t c0_fac2 = K / g; + const size_t c0_fac1 = simd_n_out * c0_fac2; + const size_t c0_fac0 = bits * c0_fac1; + + const size_t c1_nb2 = K / g; + const size_t c1_nb1 = simd_n_in * c1_nb2; + const size_t c1_nb0 = ngroups_per_elem * c1_nb1; + const size_t c1_fac2 = K / g; + const size_t c1_fac1 = ngroups_per_elem * c1_fac2; + const size_t c1_fac0 = simd_n_in * c1_fac1; + + const size_t c2_nb4 = kfactor; + const size_t c2_nb3 = K / g / kfactor * c2_nb4; + const size_t c2_nb2 = ngroups_per_elem * c2_nb3; + const size_t c2_nb1 = simd_n_in * c2_nb2; + const size_t c2_nb0 = bm / mgroup * c2_nb1; + const size_t c2_fac3 = simd_n_in * ngroups_per_elem; + const size_t c2_fac2 = kfactor * c2_fac3; + const size_t c2_fac1 = bm / mgroup * c2_fac2; + const size_t c2_fac0 = K / g / kfactor * c2_fac1; + + const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); + memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed? + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + for (size_t ib = 0; ib < bits; ib++) { + for (size_t ik = 0; ik < K / g; ik++) { + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + size_t new_im = im / simd_n_out; + size_t new_isno = im % simd_n_out; + size_t new_ib = ib; + size_t new_ik = ik; + size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; + + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + new_im = new_idx / c1_nb0; + size_t new_ing = (new_idx % c1_nb0) / c1_nb1; + size_t new_isni = (new_idx % c1_nb1) / c1_nb2; + new_ik = (new_idx % c1_nb2); + new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; + + // # 0 1 2 3 4 5 + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + new_im = new_idx / c2_nb0; + size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; + new_isni = (new_idx % c2_nb1) / c2_nb2; + new_ing = (new_idx % c2_nb2) / c2_nb3; + new_ik = (new_idx % c2_nb3) / c2_nb4; + size_t new_ikf = (new_idx % c2_nb4); + new_idx = new_im * c2_fac0 + + new_ik * c2_fac1 + + new_ibm * c2_fac2 + + new_ikf * c2_fac3 + + new_isni * ngroups_per_elem + + new_ing; + new_idx = new_idx / ngroups_per_elem; + size_t buf_idx = im * bits * K / g + ib * K / g + ik; + uint8_t buf_val = buf[buf_idx]; + + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + PackedQuantBDataBegin[new_idx] = static_cast( + static_cast(PackedQuantBDataBegin[new_idx]) + + (buf_val << (new_ing * g)) + ); + } + } + } + ); + delete[] buf; + +} + + +size_t MLASCALL +MlasLUTPackScalesAndZeroPointsSize( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint +) +{ + // TODO(vraspar): support one scale case + if (HasZeroPoint) { + return N * K / BlkLen * 2; + } else { + return N * K / BlkLen; + } +} + + +void MlasLUTPackScalesAndZeroPoints( + size_t N, + size_t K, + size_t BlkBitWidth, size_t BlkLen, bool HasZeroPoint, float* PackedQuantBZPBegin, @@ -111,7 +277,7 @@ void MlasTMACPackScalesAndZeroPoints( const uint8_t* QuantBZeroPoint ) { - const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, BitWidth); + const MlasTMACKernelParams& tmac_params = MlasGetLUTGemmKernelParams(N, K, BlkBitWidth); const size_t bits = tmac_params.bits; const size_t simd_n_out = tmac_params.simd_n_out; const size_t bm = tmac_params.bm; @@ -170,7 +336,7 @@ void MlasTMACPackScalesAndZeroPoints( } -bool MLASCALL MlasIsTMACAvailable( +bool MLASCALL MlasIsLUTGemmAvailable( size_t /*BlkBitWidth*/, size_t /*BlkLen*/ ) // TODO(Vraspar): fix the below to use smthg besides the gen kernel, add ComputeGemm @@ -180,20 +346,6 @@ bool MLASCALL MlasIsTMACAvailable( // return Dispatch != nullptr && BlkLen == 4; // only support group sizes of 4 for now } -size_t MLASCALL MlasTMACPackQuantScalesAndZeroPointsSize( - size_t N, - size_t K, - size_t BlkLen, - bool HasZeroPoint -) -{ - // TODO(vraspar): support one scale case - if (HasZeroPoint) { - return N * K / BlkLen * 2; - } else { - return N * K / BlkLen; - } -} size_t CalculateLUTBufferSize(size_t n, size_t k, size_t m, const MlasTMACKernelParams& tmac_params) { @@ -210,7 +362,7 @@ CalculateLUTBufferSize(size_t n, size_t k, size_t m, const MlasTMACKernelParams& return wsize; } -void MLASCALL MlasTmac( +void MLASCALL MlasLUTGemm( const void* A, size_t BlkLen, const void* QuantBData, // Quantized weights (B matrix) @@ -250,7 +402,7 @@ void MLASCALL MlasTmac( // n_tiles_num = m * bits / bm; // TODO(vraspar): support other bitwidths - const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2); + const MlasTMACKernelParams& tmac_params = MlasGetLUTGemmKernelParams(N, K, 2); const size_t lut_scales_size = K / tmac_params.act_group_size; size_t lut_buffer_size = CalculateLUTBufferSize(N, K, M, tmac_params); @@ -334,7 +486,7 @@ void MLASCALL MlasTmac( // - otherwise: number of quantization groups = (M * K / BlkLen) // and if zero-points are present each group stores (scale, zero_point) -> *2 const size_t groups_total = static_cast(M) * static_cast(K) / BlkLen; - const size_t scales_size_total = MlasTMACPackQuantScalesAndZeroPointsSize( + const size_t scales_size_total = MlasLUTPackScalesAndZeroPointsSize( static_cast(N), static_cast(K), BlkLen, diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h index f2d024e117a1c..4a7d31e17894b 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.h +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -33,7 +33,7 @@ struct MlasTMACKernelParams { }; -const MlasTMACKernelParams& GetTMACKernelParams(size_t M, size_t N, size_t nbits); +const MlasTMACKernelParams& MlasGetLUTGemmKernelParams(size_t M, size_t N, size_t nbits); typedef void(MLAS_QNBIT_GEMM_LUT_GEN)( diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 4f26dbed5f49a..983e843f5412c 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -88,9 +88,6 @@ MlasIsQNBitGemmAvailable( return false; } - if (ComputeType == TMAC) { - return MlasIsTMACAvailable(BlkBitWidth, BlkLen); - } const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); @@ -238,6 +235,7 @@ MlasQNBitGemmPackQuantBDataSize( ); } + // This would be for non LUT based 2-bit gemm kernel if (BlkBitWidth == 2 && Dispatch->Q2BitGemmPackQuantBDataSize != nullptr) { return Dispatch->Q2BitGemmPackQuantBDataSize( N, K, BlkLen, ComputeType diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index a3f7023e6e7fc..9b1288e7ae5f3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -203,7 +203,7 @@ void SQ2BitGemmPackQuantBData( // T-MAC like configuration (approved): // bits=2, g=4, ngroups_per_elem=8/g=2, simd_n_in=16, simd_n_out=8, bm=256, kfactor=16 - const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(N, K, 2); + const MlasTMACKernelParams& tmac_params = MlasGetLUTGemmKernelParams(N, K, 2); const size_t bits = 2; const size_t g = tmac_params.g; const size_t ngroups_per_elem = tmac_params.ngroups_per_elem; @@ -332,6 +332,7 @@ Q2BitGemmPerGemmWorkspaceSize( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { + // TODO(vraspar): Wht was this function needed? MLAS_UNREFERENCED_PARAMETER(N); switch (ComputeType) { @@ -665,7 +666,7 @@ void TMACComputeGemm_avx2( } // get kernel config - const MlasTMACKernelParams& tmac_params = GetTMACKernelParams(M, K, 2); + const MlasTMACKernelParams& tmac_params = MlasGetLUTGemmKernelParams(M, K, 2); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index ef2c52c0d219d..1c3e6b284e7d3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1446,8 +1446,8 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + //d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + //d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; @@ -1475,8 +1475,8 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + //d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + //d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 02d4092b411f4..89664eb2ebb01 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -478,9 +478,9 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + //d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + //d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index c24d7ffacaa0c..5f708f811c6e1 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -463,9 +463,9 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + //d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + //d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512vnni; From 59c0055051e6d517f5e7ee677044c42a02c43fa3 Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 1 Dec 2025 13:45:18 -0800 Subject: [PATCH 31/33] Refactor QNBit GEMM Implementation for AVX2 --- cmake/onnxruntime_mlas.cmake | 4 +- .../cpu/quantization/matmul_nbits.cc | 39 ++-- onnxruntime/core/mlas/lib/platform.cpp | 2 +- onnxruntime/core/mlas/lib/qlutgemm.cpp | 1 + onnxruntime/core/mlas/lib/qnbitgemm.cpp | 21 -- onnxruntime/core/mlas/lib/qnbitgemm.h | 16 +- .../core/mlas/lib/qnbitgemm_kernel_neon.cpp | 62 +----- .../mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h | 54 ------ .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 4 +- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 1 - .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 1 - ...vx2.cpp => sqnbitgemm_lut_kernel_avx2.cpp} | 180 ------------------ .../mlas/lib/sqnbitgemm_lut_kernel_avx2.h | 27 +++ 13 files changed, 70 insertions(+), 342 deletions(-) delete mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h rename onnxruntime/core/mlas/lib/{sqnbitgemm_bitnet_kernel_avx2.cpp => sqnbitgemm_lut_kernel_avx2.cpp} (79%) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 34ed6901f8e4e..892543c105f5f 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -202,7 +202,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_bitnet_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp @@ -649,7 +649,7 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_bitnet_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index ceb3250c0b440..ea4d6dd7e7240 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -390,6 +390,30 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& return Status::OK(); } +template +Status MatMulNBits::ComputeBPackedLUT(const Tensor* a, + const Tensor* scales, + const Tensor* zero_points, + const Tensor* bias, + Tensor* y, + AllocatorPtr& allocator, + concurrency::ThreadPool* thread_pool, + const MatMulComputeHelper& helper) const { + const auto* a_data = a->Data(); + const auto* scales_data = scales == nullptr ? nullptr : scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); + auto* y_data = y->MutableData(); + const size_t batch_count = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + // TODO(vraspar): Should we batch it here? + //MlasInitLUTGemmKernelConfig(N, K, nbits_, block_size_, has_zp_input_); + MlasLUTGemm(a_data, block_size_, packed_b_.get(), packed_scales_zp_.get(), y_data, K, M, N, thread_pool); + return Status::OK(); +} + template Status MatMulNBits::ComputeBPacked(const Tensor* a, const Tensor* scales, @@ -410,15 +434,6 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t N = static_cast(helper.N()); const size_t K = static_cast(helper.K()); - // TODO: add the logic for generating lookup table here -- for now we can assume that - // 2 bits + acc level 4 = use look up table but in the future adapt so that we use a mamtulnbits attr to decide - // if we want to do lut generation - - // TODO(vraspar): Should we batch it here? - if (prefer_lut_gemm_) { - MlasLUTGemm(a_data, block_size_, packed_b_.get(), packed_scales_zp_.get(), y_data, K, M, N, thread_pool); - return Status::OK(); - } const size_t lda = helper.Lda(false); IAllocatorUniquePtr workspace{}; @@ -859,11 +874,11 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() // with B directly too. - /* if (MlasIsLUTGemmAvailable(nbits_, block_size_) && prefer_lut_gemm_) { + if (prefer_lut_gemm_ && MlasIsLUTGemmAvailable(nbits_, block_size_)) { return ComputeBPackedLUT(a, scales, zero_points, bias, y, allocator, thread_pool, helper); - }*/ + } - if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) || (prefer_lut_gemm_ && MlasIsLUTGemmAvailable(nbits_, block_size_))) { + if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); } } diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 2413144919cdb..db6882798dcc8 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -411,7 +411,7 @@ Return Value: this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; this->RopeDispatch = &MlasRopeDispatchAvx2; - // TODO: check if this really goes here or if there are other platform reqs that we need to fulfill + // TODO(vraspar): check if this really goes here or if there are other platform reqs that we need to fulfill this->LUTGenKernel = &MlasLUTGenKernelAvx2; // diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index 836d1814c4651..d9ef450bf5d54 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -344,6 +344,7 @@ bool MLASCALL MlasIsLUTGemmAvailable( const auto* Dispatch = GetMlasPlatform().LUTGenKernel; return Dispatch != nullptr; // return Dispatch != nullptr && BlkLen == 4; // only support group sizes of 4 for now + // add check for M, N sizes } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 983e843f5412c..151b7878caeb3 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -234,14 +234,6 @@ MlasQNBitGemmPackQuantBDataSize( N, K, BlkLen, HasZeroPoint, ComputeType ); } - - // This would be for non LUT based 2-bit gemm kernel - if (BlkBitWidth == 2 && Dispatch->Q2BitGemmPackQuantBDataSize != nullptr) { - return Dispatch->Q2BitGemmPackQuantBDataSize( - N, K, BlkLen, ComputeType - ); - } - return 0; } @@ -321,19 +313,6 @@ MlasQNBitGemmPackQuantBData( ); return; } - } else if (BlkBitWidth == 2) { // TODO:: might switch to for TMAC type if other 2-bit kernels like i2s are added - if (Dispatch->SQ2BitGemmPackQuantBData != nullptr) { - Dispatch->SQ2BitGemmPackQuantBData( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(PackedQuantBDataAndOrBlkSumWorkspace), - ThreadPool - ); - return; - } } else if (BlkBitWidth == 8) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index a231255c9fd16..543d903dd3ebf 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -101,16 +101,16 @@ struct MLAS_QNBIT_GEMM_DISPATCH { ); // TODO:: just use Q4BitGemmPackQuantBDataSize if extra params are not needed in future - typedef size_t(Q2BitGemmPackQuantBDataSize_Fn)( - size_t N, - size_t K, - size_t BlkLen, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType - ); + // typedef size_t(Q2BitGemmPackQuantBDataSize_Fn)( + // size_t N, + // size_t K, + // size_t BlkLen, + // MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + // ); Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; - Q2BitGemmPackQuantBDataSize_Fn* Q2BitGemmPackQuantBDataSize = nullptr; + // Q2BitGemmPackQuantBDataSize_Fn* Q2BitGemmPackQuantBDataSize = nullptr; /** Gets size of packed quantized B data containing 8-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ typedef size_t(Q8BitGemmPackQuantBDataSize_Fn)( @@ -136,7 +136,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; - Q4BitGemmPackQuantBData_Fn* SQ2BitGemmPackQuantBData = nullptr; + // Q4BitGemmPackQuantBData_Fn* SQ2BitGemmPackQuantBData = nullptr; typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( size_t N, diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index ee8e4bb0216d2..bad607cd586fd 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -261,62 +261,6 @@ QNBitGemmPerGemmWorkspaceAlignment( } } } - -size_t -Q2BitGemmPackQuantBDataSize( - size_t /*N*/, - size_t /*K*/, - size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ -) -{ - return 0; -} - -void -SQ2BitGemmPackQuantBData( - size_t /*N*/, - size_t /*K*/, - size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, - const std::byte* /*QuantBDataBegin*/, - std::byte* /*PackedQuantBDataBegin*/, - MLAS_THREADPOOL* /*ThreadPool*/ -) -{ -} - -size_t -Q2BitGemmPerGemmWorkspaceSize( - size_t /*M*/, - size_t /*N*/, - size_t /*K*/, - size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ -) -{ - return 0; -} - -size_t -SQ2BitGemmKernel_CompInt8_avx2( - size_t /*BlkLen*/, - const std::byte* /*QuantA*/, - const std::byte* /*QuantBData*/, - const float* /*QuantBScale*/, - const std::byte* /*QuantBZeroPoint*/, - float* /*C*/, - size_t /*CountM*/, - size_t /*CountN*/, - size_t /*CountK*/, - size_t /*BlockCountK*/, - size_t /*ldc*/, - const float* /*Bias*/ -) -{ - return 0; -} - } // namespace bool @@ -377,10 +321,10 @@ GetMlasQNBitGemmDispatchNeon( d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 - d.Q2BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = sqnbitgemm_neon::SQ2BitGemmPackQuantBData; + // d.Q2BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q2BitGemmPackQuantBDataSize; + // d.SQ2BitGemmPackQuantBData = sqnbitgemm_neon::SQ2BitGemmPackQuantBData; - d.Q2BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q2BitGemmPerGemmWorkspaceSize; + // d.Q2BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q2BitGemmPerGemmWorkspaceSize; d.SQ2BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ2BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h deleted file mode 100644 index a12abc76acd3d..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h +++ /dev/null @@ -1,54 +0,0 @@ -#pragma once -#include "qnbitgemm.h" - -size_t Q2BitGemmPackQuantBDataSize( - size_t N, - size_t K, - size_t BlkLen, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType -); - -void -SQ2BitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkLen, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -); - -size_t -Q2BitGemmPerGemmWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BlkLen, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType -); - -void -GenerateLUT_avx2( - int32_t group_size, - int8_t lut, - const float* b, - float* scales, - float* biases, - int K -); - -void -TMACComputeGemm_avx2( - const void* A, - const void* a_scales, - const void* LUT, - const void* LUT_Scales, - const void* LUT_Biases, - void* C, - int bm, - int K, - int M, - int N, - size_t BlkLen -); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 1c3e6b284e7d3..470376321e5cc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -29,8 +29,6 @@ Module Name: #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" -#include "sqnbitgemm_bitnet_kernel_avx2.h" - void MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) { @@ -1476,7 +1474,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; //d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - //d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + //d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 89664eb2ebb01..b37efd6434730 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -32,7 +32,6 @@ Module Name: // #include "sqnbitgemm_kernel_avx_common_fp32.h" -#include "sqnbitgemm_bitnet_kernel_avx2.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32_avx512( diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 5f708f811c6e1..ed4882d62d0a5 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -27,7 +27,6 @@ Module Name: #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" #include "sqnbitgemm_kernel_avx512_int8_blklen128.h" -#include "sqnbitgemm_bitnet_kernel_avx2.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32( diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp similarity index 79% rename from onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp rename to onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index 9b1288e7ae5f3..0576a0077d6f9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -169,186 +169,6 @@ constexpr int get_bias_scale() { return 3; } -size_t -Q2BitGemmPackQuantBDataSize( - size_t N, - size_t K, - size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - // TODO: This code shall change according to T-Mac. - // Modify based on tmac compute type if needed. - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - - // const size_t PackedQuantBDataSize = N * K / 8; - constexpr size_t BlkBitWidth = 2; - constexpr size_t g = 4; // group size - const size_t ngroups_per_elem = 8 / g; - const size_t PackedQuantBDataSize = (N * BlkBitWidth) * (K / g / ngroups_per_elem); - return PackedQuantBDataSize; // 1048576 -} - -void SQ2BitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkLen, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -) -{ - //decompose W into w1,... w_bits create temp buffer buf2 of size N * bits * (K/g) - - // T-MAC like configuration (approved): - // bits=2, g=4, ngroups_per_elem=8/g=2, simd_n_in=16, simd_n_out=8, bm=256, kfactor=16 - const MlasTMACKernelParams& tmac_params = MlasGetLUTGemmKernelParams(N, K, 2); - const size_t bits = 2; - const size_t g = tmac_params.g; - const size_t ngroups_per_elem = tmac_params.ngroups_per_elem; - const size_t simd_n_in = tmac_params.simd_n_in; - const size_t simd_n_out = tmac_params.simd_n_out; - const size_t bm = tmac_params.bm; - const size_t kfactor = tmac_params.kfactor; - - // Basic checks - MLAS_UNREFERENCED_PARAMETER(K); - assert(BlkLen % g == 0); - assert((BlkLen / g) % kfactor == 0); - const int mgroup = ngroups_per_elem * simd_n_in; // 32 - assert(bm % mgroup == 0); - assert(bm % bits == 0); - - uint8_t * buf = new uint8_t[N * bits * (K / g)]; - memset(buf, 0, N * bits * (K / g)); - - const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - size_t im = static_cast(tid); - for (size_t ik = 0; ik < K; ++ik) { - size_t idx = (im * K + ik); - size_t num_elem_per_byte = 8 / bits; - size_t elem_idx = idx % num_elem_per_byte; - - uint8_t v = ((const uint8_t *)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits); - - for (size_t ib =0; ib < bits; ++ib) { - size_t new_ik = ik / g; - size_t shft_left = ik % g; - buf[im * bits * K / g + ib * K /g + new_ik] += ((v >> ib) & 1) << shft_left; - } - } - } - ); - - // Now buf contains the bit planes grouped by g along K - // Next, we need to do a multi-reshape/transpose into the final layout - - - const size_t c0_fac2 = K / g; - const size_t c0_fac1 = simd_n_out * c0_fac2; - const size_t c0_fac0 = bits * c0_fac1; - - const size_t c1_nb2 = K / g; - const size_t c1_nb1 = simd_n_in * c1_nb2; - const size_t c1_nb0 = ngroups_per_elem * c1_nb1; - const size_t c1_fac2 = K / g; - const size_t c1_fac1 = ngroups_per_elem * c1_fac2; - const size_t c1_fac0 = simd_n_in * c1_fac1; - - - const size_t c2_nb4 = kfactor; - const size_t c2_nb3 = K / g / kfactor * c2_nb4; - const size_t c2_nb2 = ngroups_per_elem * c2_nb3; - const size_t c2_nb1 = simd_n_in * c2_nb2; - const size_t c2_nb0 = bm / mgroup * c2_nb1; - const size_t c2_fac3 = simd_n_in * ngroups_per_elem; - const size_t c2_fac2 = kfactor * c2_fac3; - const size_t c2_fac1 = bm / mgroup * c2_fac2; - const size_t c2_fac0 = K / g / kfactor * c2_fac1; - - const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); - memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed? - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - size_t im = static_cast(tid); - for (size_t ib = 0; ib < bits; ib++) { - for (size_t ik = 0; ik < K / g; ik++) { - // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) - size_t new_im = im / simd_n_out; - size_t new_isno = im % simd_n_out; - size_t new_ib = ib; - size_t new_ik = ik; - size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; - - // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) - new_im = new_idx / c1_nb0; - size_t new_ing = (new_idx % c1_nb0) / c1_nb1; - size_t new_isni = (new_idx % c1_nb1) / c1_nb2; - new_ik = (new_idx % c1_nb2); - new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; - - // # 0 1 2 3 4 5 - // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) - new_im = new_idx / c2_nb0; - size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; - new_isni = (new_idx % c2_nb1) / c2_nb2; - new_ing = (new_idx % c2_nb2) / c2_nb3; - new_ik = (new_idx % c2_nb3) / c2_nb4; - size_t new_ikf = (new_idx % c2_nb4); - new_idx = new_im * c2_fac0 + - new_ik * c2_fac1 + - new_ibm * c2_fac2 + - new_ikf * c2_fac3 + - new_isni * ngroups_per_elem + - new_ing; - new_idx = new_idx / ngroups_per_elem; - size_t buf_idx = im * bits * K / g + ib * K / g + ik; - uint8_t buf_val = buf[buf_idx]; - - // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) - PackedQuantBDataBegin[new_idx] = static_cast( - static_cast(PackedQuantBDataBegin[new_idx]) + - (buf_val << (new_ing * g))); - } - } - } - ); - delete[] buf; -} - -size_t -Q2BitGemmPerGemmWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BlkLen, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - // TODO(vraspar): Wht was this function needed? - MLAS_UNREFERENCED_PARAMETER(N); - - switch (ComputeType) { - case SQNBIT_CompInt8: { - // workspace buffer is used for block quantization of A to int8 - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - // QuantData + Scale - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); - return PerGemmWorkspaceSize; - } - default: { - return 0; - } - } -} - void partial_max_g4_int8_k8(float* lut_scales, const float* b) { // TODO(vraspar): add support for arm neon const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h new file mode 100644 index 0000000000000..5b206e296626c --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h @@ -0,0 +1,27 @@ +#pragma once +#include "qnbitgemm.h" + +void +GenerateLUT_avx2( + int32_t group_size, + int8_t lut, + const float* b, + float* scales, + float* biases, + int K +); + +void +TMACComputeGemm_avx2( + const void* A, + const void* a_scales, + const void* LUT, + const void* LUT_Scales, + const void* LUT_Biases, + void* C, + int bm, + int K, + int M, + int N, + size_t BlkLen +); From 457cfa37e1f1b27546cde2e7efa87a33f2ddac5e Mon Sep 17 00:00:00 2001 From: vraspar Date: Tue, 2 Dec 2025 13:32:54 -0800 Subject: [PATCH 32/33] Refactor dispatch --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 3 ++- onnxruntime/core/mlas/lib/qlutgemm.h | 2 +- onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index d9ef450bf5d54..b4cc6eb36199e 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -512,6 +512,7 @@ void MLASCALL MlasLUTGemm( // Cast to appropriate types const auto* packed_weights = reinterpret_cast(QuantBData); + float* act_output = reinterpret_cast(C); // Parallelize over the 2D chunk grid MlasTrySimpleParallel( @@ -553,7 +554,7 @@ void MLASCALL MlasLUTGemm( qlut + qlut_offset, // LUT for this batch row lut_scales + lut_scales_offset, // LUT scales lut_biases + lut_scales_offset, // LUT biases - reinterpret_cast(C) + dst_offset, // Output location + act_output + dst_offset, // Output location static_cast(K), // K dimension static_cast(N), // K dimension static_cast(1), diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h index 4a7d31e17894b..c2de2e7b9f021 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.h +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -54,7 +54,7 @@ void(MLAS_QNBIT_LUT_GEMM_COMPUTE)( const int8_t* LUT, const float* LUT_Scales, const float* LUT_Biases, - void* C, + float* C, int K, int M, // batch size (number of rows in activation) int N, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index 0576a0077d6f9..d31f38a7745d6 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -474,7 +474,7 @@ void TMACComputeGemm_avx2( const int8_t* LUT, // Pre-computed quantized lookup table const float* LUT_Scales, // LUT scales from activation quantization const float* LUT_Biases, // LUT biases from activation quantization - void* C, // Output buffer + float* C, // Output buffer int K, int M, int N, @@ -614,7 +614,7 @@ void TMACComputeGemm_avx2( // Gather bit-plane results into final output // Only support 2-bit in this implementation // TODO(vraspar): extend to other bit-widths - tbl_g4_int8_float_gather_bit2_impl(m, C_global, CBits, reinterpret_cast(C)); + tbl_g4_int8_float_gather_bit2_impl(m, C_global, CBits, C); // ==================== CLEANUP ==================== delete[] C_global; From bdb298235f221ff756dc65655ce40c6cdd76a4f2 Mon Sep 17 00:00:00 2001 From: vraspar Date: Tue, 2 Dec 2025 15:44:40 -0800 Subject: [PATCH 33/33] Add test cases --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 21 +- .../test/mlas/unittest/test_sqlutgemm.cpp | 326 ++++++++++++++++-- .../test/mlas/unittest/test_sqnbitgemm.cpp | 60 +--- 3 files changed, 326 insertions(+), 81 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index b4cc6eb36199e..49256c197c197 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -337,14 +337,23 @@ void MlasLUTPackScalesAndZeroPoints( bool MLASCALL MlasIsLUTGemmAvailable( - size_t /*BlkBitWidth*/, - size_t /*BlkLen*/ + size_t BlkBitWidth, + size_t BlkLen ) // TODO(Vraspar): fix the below to use smthg besides the gen kernel, add ComputeGemm { - const auto* Dispatch = GetMlasPlatform().LUTGenKernel; - return Dispatch != nullptr; - // return Dispatch != nullptr && BlkLen == 4; // only support group sizes of 4 for now - // add check for M, N sizes + if (GetMlasPlatform().LUTGenKernel == nullptr) { + return false; + } + + if (BlkBitWidth != 2) { + return false; + } + + if (BlkLen % 32 != 0) { + return false; + } + + return true; } diff --git a/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp index 505f634317489..bf58a71aef1e4 100644 --- a/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp @@ -1,45 +1,309 @@ -// /*++ +/*++ -// Copyright (c) Microsoft Corporation. All rights reserved. +Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +Licensed under the MIT License. -// Module Name: +Module Name: -// test_sqlutgemm.h + test_sqlutgemm.cpp -// Abstract: +Abstract: -// Tests for MLAS T-MAC quantized GEMM. + Tests for MLAS LUT-based n-bit GEMM (TMAC/LUT path) for 2-bit. -// --*/ +--*/ -// #include "test_util.h" -// #include "mlas_q4.h" -// #include "mlas_qnbit.h" +#include "test_util.h" +#include "mlas_qnbit.h" +#include "mlas_q4.h" +// Generic template to future-proof for different bit widths; instantiate with 2 for now. +template +class MlasSQLutGemm2BitTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferQuantBData; + MatrixGuardBuffer BufferQuantBZeroPoint; + MatrixGuardBuffer BufferQuantBScale; + MatrixGuardBuffer BufferPackedQuantB; + MatrixGuardBuffer BufferPackedScalesZP; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; -// static size_t MlasQLUTGemmTestAllShortExecuteTests() { -// size_t tests_registered = 0; + void QuantizeB(size_t K, size_t N, + const float* B, + uint8_t*& qdata, + float*& qscale, + uint8_t*& qzp, + bool symmetric) { + size_t q_data_bytes = 0, q_scale_size = 0, q_zp_bytes = 0; + MlasBlockwiseQuantizedBufferSizes(BlkLen, /*columnwise*/ true, + static_cast(K), static_cast(N), + q_data_bytes, q_scale_size, &q_zp_bytes); + qdata = BufferQuantBData.GetBuffer(q_data_bytes); + qscale = BufferQuantBScale.GetBuffer(q_scale_size); + qzp = symmetric ? nullptr : BufferQuantBZeroPoint.GetBuffer(q_zp_bytes); -// tests_registered += MlasQLUTGemmShortExecuteTest<2,16>::RegisterShortExecuteTests(); -// tests_registered += MlasQLUTGemmShortExecuteTest<2,32>::RegisterShortExecuteTests(); -// tests_registered += MlasQLUTGemmShortExecuteTest<2,64>::RegisterShortExecuteTests(); -// tests_registered += MlasQLUTGemmShortExecuteTest<2,128>::RegisterShortExecuteTests(); -// tests_registered += MlasQLUTGemmShortExecuteTest<2,256>::RegisterShortExecuteTests(); + MlasQuantizeBlockwise(qdata, qscale, qzp, + B, BlkLen, + /*columnwise*/ true, + static_cast(K), static_cast(N), + static_cast(N), + GetMlasThreadPool()); + } -// tests_registered += MlasQLUTGemmShortExecuteTest<4,16>::RegisterShortExecuteTests(); -// tests_registered += MlasQLUTGemmShortExecuteTest<4,32>::RegisterShortExecuteTests(); -// tests_registered += MlasQLUTGemmShortExecuteTest<4,64>::RegisterShortExecuteTests(); -// tests_registered += MlasQLUTGemmShortExecuteTest<4,128>::RegisterShortExecuteTests(); -// tests_registered += MlasQLUTGemmShortExecuteTest<4,256>::RegisterShortExecuteTests(); + void ReferenceDequantFp32(size_t M, size_t N, size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C) { + MatrixGuardBuffer deqBbuf; + float* DeqB = deqBbuf.GetBuffer(K * N); + MlasDequantizeBlockwise( + DeqB, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /*columnwise*/ true, + static_cast(K), static_cast(N), GetMlasThreadPool()); -// return tests_registered; -// } + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + const float* a = A + m * K; + const float* b = DeqB + n * K; + float sum = Bias ? Bias[n] : 0.0f; + for (size_t k = 0; k < K; ++k) { + sum += a[k] * b[k]; + } + C[m * N + n] = sum; + } + } + } -// static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { -// if (is_short_execute) { -// return MlasQLUTGemmTestAllShortExecuteTests(); -// } -// return 0; -// }); + void ReferenceInt8(size_t M, size_t N, size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C) { + // Reference path equivalent to CompInt8 for SQ: quantize A to int8 per block and accumulate with unpacked 2-bit B + const size_t BlockCountK = (K + BlkLen - 1) / BlkLen; + + MatrixGuardBuffer qa_buf; + MatrixGuardBuffer a_scales_buf; + int8_t* QA = qa_buf.GetBuffer(M * BlockCountK * BlkLen); + float* AScales = a_scales_buf.GetBuffer(M * BlockCountK); + + for (size_t m = 0; m < M; ++m) { + for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) { + const size_t local_blk_len = std::min(K - k, BlkLen); + float amax = 0.0f; + for (size_t kk = 0; kk < local_blk_len; ++kk) { + amax = std::max(amax, fabsf(A[m * K + k + kk])); + } + constexpr float rmax = (1 << 7) - 1; + float scale = amax / rmax; + float inv = scale != 0.0f ? 1.0f / scale : 0.0f; + AScales[m * BlockCountK + k_blk] = scale; + for (size_t kk = 0; kk < BlkLen; ++kk) { + float q = roundf((k + kk < K ? A[m * K + k + kk] : 0.0f) * inv); + QA[m * BlockCountK * BlkLen + k + kk] = static_cast(std::clamp(q, -128.0f, 127.0f)); + } + } + } + + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = Bias ? Bias[n] : 0.0f; + for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) { + const size_t k_blk_len = std::min(K - k, BlkLen); + const float a_scale = AScales[m * BlockCountK + k_blk]; + const float b_scale = QuantBScale[n * BlockCountK + k_blk]; + uint8_t b_zp = (BlkBitWidth == 4 ? 8 : (BlkBitWidth == 2 ? 2 : 0)); // symmetric default + if (QuantBZeroPoint) { + const int pack = 8 / BlkBitWidth; + uint8_t zp_byte = QuantBZeroPoint[n * ((BlockCountK + 3) / pack) + k_blk / pack]; + if constexpr (BlkBitWidth == 2) { + int shift = (k_blk & 3) * 2; + b_zp = (zp_byte >> shift) & 0x03; + } else if constexpr (BlkBitWidth == 4) { + b_zp = (k_blk & 1) ? (zp_byte >> 4) : (zp_byte & 0x0F); + } + } + int32_t qsum = 0; + for (size_t kk = 0; kk < k_blk_len; ++kk) { + const int8_t qa = QA[m * BlockCountK * BlkLen + k + kk]; + const int pack = 8 / BlkBitWidth; // entries per byte + const size_t idx = (n * BlockCountK * BlkLen + k + kk) / pack; + const uint8_t qb_byte = QuantBData[idx]; + int8_t qb = 0; + if constexpr (BlkBitWidth == 2) { + qb = static_cast((qb_byte >> ((kk & 3) * 2)) & 0x03); + } else if constexpr (BlkBitWidth == 4) { + qb = static_cast((kk & 1) ? (qb_byte >> 4) : (qb_byte & 0x0F)); + } + qb -= static_cast(b_zp); + qsum += static_cast(qa) * static_cast(qb); + } + sum += static_cast(qsum) * a_scale * b_scale; + } + C[m * N + n] = sum; + } + } + } + + public: + void Test(size_t M, size_t N, size_t K, bool with_threadpool, bool symmetric, bool with_bias) { + MLAS_THREADPOOL* tp = with_threadpool ? GetMlasThreadPool() : nullptr; + + const float* A = BufferA.GetBuffer(K * M); + const float* B = BufferB.GetBuffer(N * K); + + const float* Bias = nullptr; + if (with_bias) { + Bias = BufferBias.GetBuffer(N); + } + + // Quantize B to BlkBitWidth-bit blockwise + uint8_t* qB = nullptr; + float* sB = nullptr; + uint8_t* zpB = nullptr; + QuantizeB(K, N, B, qB, sB, zpB, symmetric); + + // Initialize LUT config and pack B/scales/zp + MlasInitLUTGemmKernelConfig(N, K, /*nbits*/ BlkBitWidth, BlkLen, /*has_zp*/ zpB != nullptr); + + void* packedB = nullptr; + size_t packedBSize = MlasLUTGemmPackQuantBDataSize(N, K, /*nbits*/ BlkBitWidth, BlkLen, /*has_zp*/ zpB != nullptr, TMAC); + if (packedBSize > 0) { + packedB = BufferPackedQuantB.GetBuffer(packedBSize); + MlasLUTGemmPackQuantBData(N, K, /*nbits*/ BlkBitWidth, BlkLen, + static_cast(reinterpret_cast(qB)), + static_cast(packedB), tp); + } + + size_t packedSZSize = MlasLUTPackScalesAndZeroPointsSize(N, K, BlkLen, /*has_zp*/ zpB != nullptr); + if (packedSZSize > 0) { + float* packedSZ = BufferPackedScalesZP.GetBuffer(packedSZSize); + MlasLUTPackScalesAndZeroPoints(N, K, /*nbits*/ BlkBitWidth, BlkLen, /*has_zp*/ zpB != nullptr, + packedSZ, sB, zpB); + } + + float* C = BufferC.GetBuffer(N * M, true); + float* CRef = BufferCReference.GetBuffer(N * M, true); + + // Execute LUT GEMM + MlasLUTGemm(A, BlkLen, + static_cast(packedB), + BufferPackedScalesZP.GetBuffer(packedSZSize), + C, + K, M, N, + tp); + + // Reference implementation (int8-style accumulation) + ReferenceInt8(M, N, K, A, qB, sB, zpB, Bias, CRef); + + // Cross-check via explicit dequantization + FP32 GEMM + MatrixGuardBuffer CRefDeqBuf; + float* CRefDeq = CRefDeqBuf.GetBuffer(N * M, true); + ReferenceDequantFp32(M, N, K, A, qB, sB, zpB, Bias, CRefDeq); + + // Compare results + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t idx = m * N + n; + ASSERT_TRUE(CloseEnough(C[idx], CRef[idx])) + << "Expected: " << CRef[idx] << " Actual: " << C[idx] + << "@[" << m << "x" << n << "], M=" << M << ", N=" << N << ", K=" << K; + ASSERT_TRUE(CloseEnough(C[idx], CRefDeq[idx])) + << "DequantRef mismatch. Expected: " << CRefDeq[idx] << " Actual: " << C[idx] + << "@[" << m << "x" << n << "], M=" << M << ", N=" << N << ", K=" << K; + } + } + } + + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SQLutGemm2Bit") + "BlkLen" + std::to_string(BlkLen); + return suite_name.c_str(); + } +}; + +// Fixture to register parameterized tests quickly +template +class SQLutGemm2BitShortExecuteTest : public MlasTestFixture> { + public: + explicit SQLutGemm2BitShortExecuteTest(size_t M, size_t N, size_t K, + bool with_threadpool, bool symmetric, bool with_bias) + : M_(M), N_(N), K_(K), with_threadpool_(with_threadpool), symmetric_(symmetric), with_bias_(with_bias) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(M_, N_, K_, with_threadpool_, symmetric_, with_bias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, bool with_threadpool, bool symmetric, bool with_bias) { + if (!MlasIsLUTGemmAvailable(BlkBitWidth, BlkLen)) { + return 0; + } + if (M < BlkLen || K < BlkLen || N < BlkLen) { + return 0; + } + + std::stringstream ss; + ss << (with_threadpool ? "Threaded" : "SingleThread") + << "/isSymmetric" << symmetric + << "/M" << M << "xN" << N << "xK" << K + << "/hasBias" << with_bias; + auto name = ss.str(); + + testing::RegisterTest( + MlasSQLutGemm2BitTest::GetTestSuiteName(), + name.c_str(), + nullptr, + name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture>* { + return new SQLutGemm2BitShortExecuteTest(M, N, K, with_threadpool, symmetric, with_bias); + }); + return 1; + } + + static size_t RegisterAll() { + size_t count = 0; + for (bool with_threadpool : {false, true}) { + for (bool symmetric : {false, true}) { + for (size_t b = 256; b <= 512; b <<= 1) { + count += RegisterSingleTest(b, b, b, with_threadpool, symmetric, false); + count += RegisterSingleTest(b, b, b, with_threadpool, symmetric, true); + } + count += RegisterSingleTest(64, 128, 128, with_threadpool, symmetric, false); + count += RegisterSingleTest(128, 256, 256, with_threadpool, symmetric, true); + } + } + return count; + } + + private: + size_t M_, N_, K_; + bool with_threadpool_, symmetric_, with_bias_; +}; + +static size_t SQLutGemmRegisterAll() { + size_t count = 0; + // Instantiate only 2-bit for now + count += SQLutGemm2BitShortExecuteTest<2, 16>::RegisterAll(); + count += SQLutGemm2BitShortExecuteTest<2, 32>::RegisterAll(); + count += SQLutGemm2BitShortExecuteTest<2, 64>::RegisterAll(); + count += SQLutGemm2BitShortExecuteTest<2, 128>::RegisterAll(); + return count; +} + +static UNUSED_VARIABLE bool lut_added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return SQLutGemmRegisterAll(); + } + return 0; + }); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 426715b7138df..47002dd7eea72 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -24,8 +24,6 @@ static constexpr const char* ComputeTypeName(MLAS_QNBIT_GEMM_COMPUTE_TYPE Comput return "Fp32"; case SQNBIT_CompInt8: return "Int8"; - case TMAC: - return "TMAC"; default: return "unknown"; } @@ -52,9 +50,6 @@ class MlasSQNBitGemmTest : public MlasTestBase { MatrixGuardBuffer BufferC; MatrixGuardBuffer BufferCReference; - // TMAC LUT related buffers - MatrixGuardBuffer BufferPackedQuantScalesZP; - void CallGemm(size_t M, size_t N, size_t K, @@ -291,10 +286,6 @@ class MlasSQNBitGemmTest : public MlasTestBase { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } - if (ComputeType == TMAC) { - InitTMACKernelConfig(N, K, BlkBitWidth, BlkLen, QuantBZeroPoint != nullptr); - } - void* PackedQuantBDataWorkspace = nullptr; if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, !Symmetric, ComputeType); PackedQuantBDataSize > 0) { @@ -305,34 +296,18 @@ class MlasSQNBitGemmTest : public MlasTestBase { GetMlasThreadPool()); } - float* PackedQuantScalesZPWorkspace = nullptr; - if (ComputeType == TMAC) { - bool has_zp_input = QuantBZeroPoint != nullptr; - const auto PackedQuantScalesZPSize = MlasTMACPackQuantScalesAndZeroPointsSize(N, K, BlkBitWidth, has_zp_input); - PackedQuantScalesZPWorkspace = BufferPackedQuantScalesZP.GetBuffer(PackedQuantScalesZPSize); - MlasTMACPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, has_zp_input, PackedQuantScalesZPWorkspace, - QuantBScale, QuantBZeroPoint); - } - - if (ComputeType == TMAC) { - MlasTmac(A, BlkLen, QuantBData, PackedQuantScalesZPWorkspace, C, K, M, N, Threadpool); - - } else { - CallGemm(M, N, K, - A, /* lda */ K, - QuantBData, PackedQuantBDataWorkspace, QuantBScale, QuantBZeroPoint, - Bias, - C, /* ldc */ N, - Workspace, - ComputeType, - Threadpool); - } - + CallGemm(M, N, K, + A, /* lda */ K, + QuantBData, PackedQuantBDataWorkspace, QuantBScale, QuantBZeroPoint, + Bias, + C, /* ldc */ N, + Workspace, + ComputeType, + Threadpool); if (ComputeType == SQNBIT_CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); - } else if (ComputeType == SQNBIT_CompInt8 || ComputeType == TMAC) { - // use same reference implementation for TMAC as CompInt8 + } else if (ComputeType == SQNBIT_CompInt8) { CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else { FAIL() << "Test is not implemented for compute type " @@ -387,9 +362,6 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); + // TODO: enable these test for 2bit development. + // count += SQNBitGemmShortExecuteTest<2, 16>::RegisterShortExecuteTests(); + // count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + // count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + // count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + // count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests();