Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5484560
init code structure for matmul 2 bits
liqunfu Jan 30, 2025
8c1cfe1
add and pass q4dq tests for q2bit - rename file and test name later
liqunfu Jan 31, 2025
f6f22e3
some fixes
liqunfu Jan 31, 2025
3e1a951
add apis to neon and other avxs
liqunfu Feb 3, 2025
0130061
fix neon build
liqunfu Feb 3, 2025
b4aad01
disable 2bit test
liqunfu Feb 3, 2025
ff531cb
2 bit quantize to support model builder
liqunfu Mar 7, 2025
6849ea2
Merge remote-tracking branch 'msft/main' into carzh/bitnet-reverse-la…
carzh Jul 16, 2025
e85431e
fix compile errors
carzh Jul 17, 2025
9642740
resolve build failure update
carzh Jul 18, 2025
892222a
2 bits check
HectorSVC Jul 23, 2025
07b7f3f
fixed bug causing int8 tests to fail
Jul 25, 2025
5fb2edd
Merge remote-tracking branch 'origin/main' into carzh/bitnet-reverse-…
carzh Aug 7, 2025
493ebd1
lintrunner
carzh Aug 7, 2025
b4b143f
prepack wip -- not prepacking b data because dispatch to check for ml…
carzh Aug 13, 2025
534b8e6
fixed dispatch issue, added acc level 4 tests, and now running into a…
carzh Aug 15, 2025
70d6588
deep sigh
Sep 2, 2025
ad2572b
builds somehow
Sep 4, 2025
b312815
update
Sep 10, 2025
bfeac34
udpate
Sep 16, 2025
a5de108
Implement Pre Packing of qweight for tmac
vraspar Oct 1, 2025
7ff8218
Implement Pre packing for Scales and zero points
vraspar Oct 6, 2025
6d8e8ec
Transform zero points before interleaving
vraspar Oct 6, 2025
5d19daf
Initial implementation of tmac kernel config
vraspar Oct 7, 2025
c600056
Move pre packing scales and zp code to qlutgemm and use tmac_params
vraspar Oct 8, 2025
5cf99e6
update
Oct 13, 2025
f9a9b47
bug fixes
Oct 16, 2025
5687e5e
Fix bug in scale unpacking
vraspar Oct 21, 2025
6f08418
Fix issues with TMAC GEMM kernels and remove hard coded variables
vraspar Oct 28, 2025
6191aad
Fix bug in LUT table generation
vraspar Oct 31, 2025
f2de776
Fix casting issue
vraspar Nov 10, 2025
9ef6d75
add session option and clean up
vraspar Nov 13, 2025
59c0055
Refactor QNBit GEMM Implementation for AVX2
vraspar Dec 1, 2025
457cfa3
Refactor dispatch
vraspar Dec 2, 2025
bdb2982
Add test cases
vraspar Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/qnbitgemm.h
${MLAS_SRC_DIR}/qnbitgemm.cpp
${MLAS_SRC_DIR}/qlutgemm.h
${MLAS_SRC_DIR}/qlutgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
Expand Down Expand Up @@ -200,6 +202,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
Expand Down Expand Up @@ -646,6 +649,7 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,12 @@ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFil
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

// Use LUT based GEMM for quantized models when available.
// Option values:
// - "0": Do not use LUT based GEMM. [DEFAULT]
// - "1": Use LUT based GEMM when available.
static const char* const kOrtSessionOptionsMlasLUTGemm = "mlas.use_lut_gemm";

// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
// Refer to MatMulNBits op schema for more details.
// If not provided, default is 4.
Expand Down
140 changes: 129 additions & 11 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
#include "core/mlas/inc/mlas_q4.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "contrib_ops/cpu/quantization/matmul_nbits_helper.h"
#include "core/platform/threadpool.h"
#include "core/util/thread_utils.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -39,12 +42,19 @@ typedef enum {
Level2, /*!< input fp16, accumulator fp16 */
Level3, /*!< input bf16, accumulator fp32 */
Level4, /*!< input int8, accumulator int32 */
Level5, /*!< input uint8, use TMAC LUT approach TODO: fix this comment*/
} ACCURACY_LEVEL;

// T: A data type.
template <typename T>
MLAS_QNBIT_GEMM_COMPUTE_TYPE
GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) {

// TODO(vraspar): check against session option
if (MlasIsLUTGemmAvailable(nbits, block_size)) {
return TMAC;
}

// For Fp32, only accuracy level 1 or 4 makes sense.
// non-ARM CPU converts Fp16 to Fp32.
// By converting Fp32 to Fp16, precision becomes worse. And due to the casting,
Expand All @@ -54,6 +64,7 @@ GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) {
return SQNBIT_CompInt8;
}


return SQNBIT_CompFp32;
}

Expand Down Expand Up @@ -100,6 +111,7 @@ class MatMulNBits final : public OpKernel {
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()},
has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()},
prefer_lut_gemm_{info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasLUTGemm) == "1"},
compute_type_{GetComputeType<T1>(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
const auto& node = info.node();
auto input_defs = node.InputDefs();
Expand All @@ -116,6 +128,7 @@ class MatMulNBits final : public OpKernel {
"Only 2b, 4b and 8b quantization is supported for MatMulNBits op, additional bits support is planned.");
const Tensor* tensor_zero_point = nullptr;
has_zp_input_ = info.TryGetConstantInput(InputIndex::zero_points, &tensor_zero_point);
prefer_lut_gemm_ = true;
}

Status Compute(OpKernelContext* context) const override;
Expand All @@ -135,11 +148,14 @@ class MatMulNBits final : public OpKernel {
const bool has_g_idx_;
const bool has_bias_;
bool scales_are_packed_{false};
bool prefer_lut_gemm_{false};
const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_;
bool has_unquantized_zero_point_{false};
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_{};
size_t packed_b_size_{0};
IAllocatorUniquePtr<float> packed_scales_zp_{};
size_t packed_scales_zp_size_{0};
IAllocatorUniquePtr<float> scales_fp32_{};
IAllocatorUniquePtr<float> bias_fp32_{};

Expand Down Expand Up @@ -167,6 +183,15 @@ class MatMulNBits final : public OpKernel {
AllocatorPtr& allocator,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const;

Status ComputeBPackedLUT(const Tensor* a,
const Tensor* scales,
const Tensor* zero_points,
const Tensor* bias,
Tensor* y,
AllocatorPtr& allocator,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const;
};

template <typename T1>
Expand All @@ -175,26 +200,66 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
/*out*/ PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
is_packed = false;
if (has_g_idx_ || has_unquantized_zero_point_) {
// if (has_g_idx_ || has_unquantized_zero_point_)
// TODO: this part modified so i can test ek atmulnbits
if (has_g_idx_) {
return Status::OK();
}

if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) {
if (prefer_lut_gemm_ && !MlasIsLUTGemmAvailable(nbits_, block_size_)) {
return Status::OK();
}

if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && compute_type_ != TMAC) {
return Status::OK();
}

// Create a temporary threadpool for parallel packing
// This is used during model load time to speed up weight prepacking
std::unique_ptr<concurrency::ThreadPool> temp_threadpool;
concurrency::ThreadPool* threadpool_ptr = nullptr;

// Only create threadpool for operations that can benefit from it
if (prefer_lut_gemm_ || compute_type_ == SQNBIT_CompInt8) {
OrtThreadPoolParams tpo;
tpo.thread_pool_size = 4; // Use default (typically number of cores)
tpo.allow_spinning = false; // Don't spin during model load
tpo.auto_set_affinity = false;

temp_threadpool = concurrency::CreateThreadPool(
&Env::Default(),
tpo,
concurrency::ThreadPoolType::INTRA_OP);

threadpool_ptr = temp_threadpool.get();
}

if (input_idx == InputIndex::B) {

const Tensor* scales = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales);

packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_);
if (packed_b_size_ == 0) {
return Status::OK();
if (prefer_lut_gemm_) {
MlasInitLUTGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_);
packed_b_size_ = MlasLUTGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_);
if (packed_b_size_ == 0) {
return Status::OK();
}
auto qptr = tensor.DataRaw();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasLUTGemmPackQuantBData(N_, K_, nbits_, block_size_, static_cast<const std::byte*>(qptr), static_cast<std::byte*>(packed_b_.get()), threadpool_ptr);
} else {
packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_);
if (packed_b_size_ == 0) {
return Status::OK();
}
auto qptr = tensor.DataRaw();
auto scale_ptr = scales ? scales->DataRaw() : nullptr;
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr,
has_zp_input_, nullptr, threadpool_ptr);

}
auto qptr = tensor.DataRaw();
auto scale_ptr = scales ? scales->DataRaw() : nullptr;
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr,
has_zp_input_, nullptr, nullptr);
is_packed = true;
} else if (compute_type_ == SQNBIT_CompInt8) {
#ifdef MLAS_TARGET_AMD64_IX86
Expand All @@ -216,8 +281,26 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
is_packed = true;
}
#endif // MLAS_TARGET_ARM64
} else if (prefer_lut_gemm_) {
if (input_idx == InputIndex::scales && packed_b_ != nullptr) {
auto scales_ptr = tensor.Data<float>();
packed_scales_zp_size_ = MlasLUTPackScalesAndZeroPointsSize(N_, K_, block_size_, has_zp_input_);
packed_scales_zp_ = IAllocator::MakeUniquePtr<float>(alloc, packed_scales_zp_size_, true);

// TODO(vraspar): improve this logic block
if (has_zp_input_) {
const Tensor* zero_points = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points);
auto zero_points_ptr = zero_points->Data<uint8_t>();
MlasLUTPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, zero_points_ptr);
} else {
MlasLUTPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, nullptr);
}
}
}

// Threadpool will be automatically destroyed when temp_threadpool goes out of scope

return Status::OK();
}

Expand Down Expand Up @@ -282,6 +365,12 @@ Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /*ou
is_packed = false;
}
#endif // MLAS_TARGET_AMD64_IX86
} else if (compute_type_ == TMAC) {
//TODO:: handle fp16 scales
// TMAC packs scales and zero points together by interleaving them



}

return Status::OK();
Expand All @@ -293,14 +382,38 @@ Status MatMulNBits<T1>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>&
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;

if (input_idx == 1) {
if (input_idx == 1) { //TODO(vraspar): DO we need shared Prepacked buffer for TMAC, combine packing of weights + scales/ZP into one buffer ???
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}

return Status::OK();
}

template<typename T1>
Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a,
const Tensor* scales,
const Tensor* zero_points,
const Tensor* bias,
Tensor* y,
AllocatorPtr& allocator,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const {
const auto* a_data = a->Data<T1>();
const auto* scales_data = scales == nullptr ? nullptr : scales->Data<T1>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* bias_data = bias == nullptr ? nullptr : bias->Data<T1>();
auto* y_data = y->MutableData<T1>();
const size_t batch_count = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
// TODO(vraspar): Should we batch it here?
//MlasInitLUTGemmKernelConfig(N, K, nbits_, block_size_, has_zp_input_);
MlasLUTGemm(a_data, block_size_, packed_b_.get(), packed_scales_zp_.get(), y_data, K, M, N, thread_pool);
return Status::OK();
}

template <typename T1>
Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a,
const Tensor* scales,
Expand All @@ -320,6 +433,7 @@ Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a,
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());

const size_t lda = helper.Lda(false);

IAllocatorUniquePtr<std::byte> workspace{};
Expand Down Expand Up @@ -760,6 +874,10 @@ Status MatMulNBits<T1>::Compute(OpKernelContext* ctx) const {
// If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while
// MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch()
// with B directly too.
if (prefer_lut_gemm_ && MlasIsLUTGemmAvailable(nbits_, block_size_)) {
return ComputeBPackedLUT(a, scales, zero_points, bias, y, allocator, thread_pool, helper);
}

if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) {
return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper);
}
Expand Down
34 changes: 27 additions & 7 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
namespace onnxruntime {
namespace contrib {

template <class T, class zeroT>
void Dequantize2BitsKernelReOrder(
T* /*output*/, const uint8_t* /*quant_data*/, const T* /*scale_data*/,
const zeroT* /*zero_points*/, const int32_t* /*reorder_idx*/, int /*block_size*/,
int /*groups_per_threadblock*/, int /*total_groups*/, int /*out_rows*/, int /*out_cols*/,
int /*blockIdx_x*/, int /*threadIdx_x*/) {
assert(false);
}

template <class T, class zeroT>
void Dequantize4BitsKernelReOrder(
T* output, const uint8_t* quant_data, const T* scale_data,
Expand Down Expand Up @@ -73,7 +82,7 @@ void Dequantize4BitsKernelReOrder(
}
}

template <typename inputT, typename zeroT>
template <typename inputT, typename zeroT, int qbits>
void DequantizeBlockwise(
inputT* output, // dequantized output
const uint8_t* quant_data, // quantized input
Expand All @@ -95,24 +104,35 @@ void DequantizeBlockwise(
pool, static_cast<std::ptrdiff_t>(blocks_per_grid),
[&](std::ptrdiff_t block_id) {
for (int j = 0; j < 256; j++) {
Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points,
reorder_idx, block_size, groups_per_threadblock,
total_groups, N, K, static_cast<int>(block_id), j);
if constexpr (qbits == 2) {
Dequantize2BitsKernelReOrder(output, quant_data, scales_data, zero_points,
reorder_idx, block_size, groups_per_threadblock,
total_groups, N, K, static_cast<int>(block_id), j);
} else {
Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points,
reorder_idx, block_size, groups_per_threadblock,
total_groups, N, K, static_cast<int>(block_id), j);
}
}
});
}

template void DequantizeBlockwise<float, uint8_t>(
template void DequantizeBlockwise<float, uint8_t, 2>(
float* output, const uint8_t* quant_data, const float* scales_data,
const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise<float, uint8_t, 4>(
float* output, const uint8_t* quant_data, const float* scales_data,
const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise<float, float>(
template void DequantizeBlockwise<float, float, 4>(
float* output, const uint8_t* quant_data, const float* scales_data,
const float* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise<float, MLFloat16>(
template void DequantizeBlockwise<float, MLFloat16, 4>(
float* output, const uint8_t* quant_data, const float* scales_data,
const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace onnxruntime {
namespace contrib {

template <typename inputT, typename zeroT>
template <typename inputT, typename zeroT, int qbits = 4>
void DequantizeBlockwise(
inputT* output, // dequantized output
const uint8_t* quant_data, // quantized input
Expand Down
Loading