Skip to content

Enable U8 KV caching in SDPA operator for ARM#33567

Open
ashwins990 wants to merge 4 commits intoopenvinotoolkit:masterfrom
MonakaResearch:ARM-SVE-Quant_u8-support-for-SDPA
Open

Enable U8 KV caching in SDPA operator for ARM#33567
ashwins990 wants to merge 4 commits intoopenvinotoolkit:masterfrom
MonakaResearch:ARM-SVE-Quant_u8-support-for-SDPA

Conversation

@ashwins990
Copy link
Contributor

@ashwins990 ashwins990 commented Jan 13, 2026

[About]

This PR enables u8 kv cache precsion for SDPA operator and optimizes the same with NEON and SVE.

  • Improves the performance of OSS master [ where reference implementation is available ] version by 27%.

  • But we are slower by 2.7% when compared with non-quantized f16 cache precision due to additional overhead of quantization and dequantization for smaller models like TinyLlama-1.1B for single inference.

  • Such performance benefit [from u8 quantization] can be seen only when the inference is more memory bound. We see speedups around 3-5% when inferencing LLama-70B int8 quantized model for single Inference case.

  • Therefore, even though we achieve a speedup of 27% compared to reference implementation, we assume the general case to be compute bound and currently keeping the default as F16 only.

  • As models get larger and in multiple batch scenarios, by setting kv_cache as "u8" we see significant boost at inference level.

OSS ref impl - u8 This PR
10.8 tokens/sec 13.7 tokens/sec

Single inference performance on LLAMA2-7B model on 32c graviton machine.
The values are in TPS [ Tokens per second ].

This work is contributed by @ashwins990 & @abhijain1204fujitsu

@ashwins990 ashwins990 requested review from a team as code owners January 13, 2026 05:06
@github-actions github-actions bot added category: CPU OpenVINO CPU plugin category: build OpenVINO cmake script / infra labels Jan 13, 2026
@sys-openvino-ci sys-openvino-ci added the ExternalPR External contributor label Jan 13, 2026
@maxnick maxnick added the platform: arm OpenVINO on ARM / ARM64 label Jan 13, 2026
@maxnick maxnick requested a review from alvoron January 13, 2026 09:01
@maxnick maxnick requested a review from Copilot January 13, 2026 09:01
@maxnick
Copy link
Contributor

maxnick commented Jan 13, 2026

@alvoron, could you please review?

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables U8 (uint8) key-value cache precision for the SDPA (Scaled Dot-Product Attention) operator on ARM architectures and provides optimized implementations using NEON and SVE instructions. The change improves performance over the reference implementation by 27% while maintaining memory efficiency through quantization, though it incurs a 2.7% overhead compared to F16 for smaller, compute-bound models.

Changes:

  • Added U8 KV cache quantization/dequantization support with ARM NEON and SVE optimizations
  • Implemented specialized dot product and accumulation functions for U8 precision with grouped quantization
  • Extended CMake build configuration to include NEON_FP16 compilation target

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp Adds U8 KV cache support with optimized SIMD implementations for dot products and value accumulation
src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp Implements ARM NEON/SVE optimized min/max finding for quantization operations
src/plugins/intel_cpu/CMakeLists.txt Adds NEON_FP16 architecture target for cross-compilation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 902 to 903
svfloat16_t a0 = svld1_f16(pg_b16, _a + i);
svfloat16_t a1 = svld1_f16(pg_b16, _a + i + offset + svcnth());
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable a1 loads from _a + i + offset + svcnth() but should load from _a + offset + i + svcnth() to maintain consistent indexing with the corresponding b1 load on line 905 and the usage pattern throughout this loop.

Suggested change
svfloat16_t a0 = svld1_f16(pg_b16, _a + i);
svfloat16_t a1 = svld1_f16(pg_b16, _a + i + offset + svcnth());
svfloat16_t a0 = svld1_f16(pg_b16, _a + offset + i);
svfloat16_t a1 = svld1_f16(pg_b16, _a + offset + i + svcnth());

Copilot uses AI. Check for mistakes.
Comment on lines 1085 to 1090
size_t offset = group_id * group_size;
float16_t group_scale = *(scale + group_id * 2);
float16_t group_zp = *(zp + group_id * 2);
while (group_id < n / group_size) {
float16_t group_sum = 0.0f;
i = 0;
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variables offset, group_scale, and group_zp are initialized before the while loop but never updated inside it. These should be moved inside the loop body after line 1090 to ensure they are recalculated for each group iteration.

Suggested change
size_t offset = group_id * group_size;
float16_t group_scale = *(scale + group_id * 2);
float16_t group_zp = *(zp + group_id * 2);
while (group_id < n / group_size) {
float16_t group_sum = 0.0f;
i = 0;
while (group_id < n / group_size) {
float16_t group_sum = 0.0f;
i = 0;
const size_t offset = group_id * group_size;
const float16_t group_scale = *(scale + group_id * 2);
const float16_t group_zp = *(zp + group_id * 2);

Copilot uses AI. Check for mistakes.
if constexpr (std::is_same_v<T, ov::float16>) {
auto v_max = vdupq_n_f16(max);
auto v_min = vdupq_n_f16(min);
for (; i + 8 < n; i += 8) {
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop condition should be i + 8 <= n instead of i + 8 < n to process all complete 8-element vectors and be consistent with the float32 version on line 150 which uses i + 4 <= n.

Suggested change
for (; i + 8 < n; i += 8) {
for (; i + 8 <= n; i += 8) {

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxnick , @alvoron
Code to resolve the issues as suggested by copilot has been pushed
Kindly support to review the PR.

@@ -118,6 +118,57 @@ void find_minmax(const T* src, size_t n, float& min, float& max) {
hmin(v0_min);
max = _mm256_cvtss_f32(v0_max);
min = _mm256_cvtss_f32(v0_min);
#elif defined(OPENVINO_ARCH_ARM64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question to @maxnick: do we need to add a comment that ARM behavior differs from x86? ARM path uses fp16 accumulator while x86 - fp32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment would definitely be helpful.

@maxnick
Copy link
Contributor

maxnick commented Jan 23, 2026

build_jenkins

@maxnick maxnick added this to the 2026.1 milestone Jan 23, 2026
@alvoron
Copy link
Contributor

alvoron commented Jan 23, 2026

@ashwins990 could you please rebase the branch to pick up some fixes required to pass CI?

@alvoron
Copy link
Contributor

alvoron commented Jan 26, 2026

@abhijain1204fujitsu could you please cover these changes by functional tests?

@abhijain1204fujitsu abhijain1204fujitsu force-pushed the ARM-SVE-Quant_u8-support-for-SDPA branch from 32fa83a to 930d7b0 Compare January 27, 2026 03:51
@abhijain1204fujitsu
Copy link
Contributor

Hi @alvoron,
pushed changes to rebase the PR and resolve CI issues,
Kindly support to complete the review and merge the PR

@alvoron
Copy link
Contributor

alvoron commented Jan 27, 2026

build_jenkins

@ashwins990
Copy link
Contributor Author

@abhijain1204fujitsu could you please cover these changes by functional tests?

Hi @alvoron,
By enabling test case for SDPA-u8 kv cache results in failing tests. I have enabled it here.

I believe the Reason of Failure is:
For inference precision in Fp16 with u8 Kv cache we expect some values to overflow. We allow this initially and account for the same with detail::handle_inf_value function. While the full inference output gives expected output, the numerical output of this operator varies between F32 and F16 inference precision. This is leading to the failure of test case, when enabled.

Is there any way to handle such scenarios, when the reference have different behaviour ?

@alvoron
Copy link
Contributor

alvoron commented Jan 28, 2026

Is there any way to handle such scenarios, when the reference have different behaviour ?

We can try to tune threshold in ConcatSDPTest::SetUp() in m_forceKVU8 branch:

rel_threshold = 0.05f;

@abhijain1204fujitsu abhijain1204fujitsu force-pushed the ARM-SVE-Quant_u8-support-for-SDPA branch from 930d7b0 to b10d1c6 Compare February 5, 2026 07:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: build OpenVINO cmake script / infra category: CPU OpenVINO CPU plugin ExternalPR External contributor platform: arm OpenVINO on ARM / ARM64

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants