Enable U8 KV caching in SDPA operator for ARM#33567
Enable U8 KV caching in SDPA operator for ARM#33567ashwins990 wants to merge 4 commits intoopenvinotoolkit:masterfrom
Conversation
|
@alvoron, could you please review? |
There was a problem hiding this comment.
Pull request overview
This PR enables U8 (uint8) key-value cache precision for the SDPA (Scaled Dot-Product Attention) operator on ARM architectures and provides optimized implementations using NEON and SVE instructions. The change improves performance over the reference implementation by 27% while maintaining memory efficiency through quantization, though it incurs a 2.7% overhead compared to F16 for smaller, compute-bound models.
Changes:
- Added U8 KV cache quantization/dequantization support with ARM NEON and SVE optimizations
- Implemented specialized dot product and accumulation functions for U8 precision with grouped quantization
- Extended CMake build configuration to include NEON_FP16 compilation target
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp | Adds U8 KV cache support with optimized SIMD implementations for dot products and value accumulation |
| src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp | Implements ARM NEON/SVE optimized min/max finding for quantization operations |
| src/plugins/intel_cpu/CMakeLists.txt | Adds NEON_FP16 architecture target for cross-compilation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| svfloat16_t a0 = svld1_f16(pg_b16, _a + i); | ||
| svfloat16_t a1 = svld1_f16(pg_b16, _a + i + offset + svcnth()); |
There was a problem hiding this comment.
The variable a1 loads from _a + i + offset + svcnth() but should load from _a + offset + i + svcnth() to maintain consistent indexing with the corresponding b1 load on line 905 and the usage pattern throughout this loop.
| svfloat16_t a0 = svld1_f16(pg_b16, _a + i); | |
| svfloat16_t a1 = svld1_f16(pg_b16, _a + i + offset + svcnth()); | |
| svfloat16_t a0 = svld1_f16(pg_b16, _a + offset + i); | |
| svfloat16_t a1 = svld1_f16(pg_b16, _a + offset + i + svcnth()); |
| size_t offset = group_id * group_size; | ||
| float16_t group_scale = *(scale + group_id * 2); | ||
| float16_t group_zp = *(zp + group_id * 2); | ||
| while (group_id < n / group_size) { | ||
| float16_t group_sum = 0.0f; | ||
| i = 0; |
There was a problem hiding this comment.
Variables offset, group_scale, and group_zp are initialized before the while loop but never updated inside it. These should be moved inside the loop body after line 1090 to ensure they are recalculated for each group iteration.
| size_t offset = group_id * group_size; | |
| float16_t group_scale = *(scale + group_id * 2); | |
| float16_t group_zp = *(zp + group_id * 2); | |
| while (group_id < n / group_size) { | |
| float16_t group_sum = 0.0f; | |
| i = 0; | |
| while (group_id < n / group_size) { | |
| float16_t group_sum = 0.0f; | |
| i = 0; | |
| const size_t offset = group_id * group_size; | |
| const float16_t group_scale = *(scale + group_id * 2); | |
| const float16_t group_zp = *(zp + group_id * 2); |
| if constexpr (std::is_same_v<T, ov::float16>) { | ||
| auto v_max = vdupq_n_f16(max); | ||
| auto v_min = vdupq_n_f16(min); | ||
| for (; i + 8 < n; i += 8) { |
There was a problem hiding this comment.
The loop condition should be i + 8 <= n instead of i + 8 < n to process all complete 8-element vectors and be consistent with the float32 version on line 150 which uses i + 4 <= n.
| for (; i + 8 < n; i += 8) { | |
| for (; i + 8 <= n; i += 8) { |
| @@ -118,6 +118,57 @@ void find_minmax(const T* src, size_t n, float& min, float& max) { | |||
| hmin(v0_min); | |||
| max = _mm256_cvtss_f32(v0_max); | |||
| min = _mm256_cvtss_f32(v0_min); | |||
| #elif defined(OPENVINO_ARCH_ARM64) | |||
There was a problem hiding this comment.
question to @maxnick: do we need to add a comment that ARM behavior differs from x86? ARM path uses fp16 accumulator while x86 - fp32
There was a problem hiding this comment.
A comment would definitely be helpful.
|
build_jenkins |
|
@ashwins990 could you please rebase the branch to pick up some fixes required to pass CI? |
|
@abhijain1204fujitsu could you please cover these changes by functional tests? |
32fa83a to
930d7b0
Compare
|
Hi @alvoron, |
|
build_jenkins |
Hi @alvoron, I believe the Reason of Failure is: Is there any way to handle such scenarios, when the reference have different behaviour ? |
We can try to tune threshold in |
930d7b0 to
b10d1c6
Compare
[About]
This PR enables u8 kv cache precsion for SDPA operator and optimizes the same with NEON and SVE.
Improves the performance of OSS master [ where reference implementation is available ] version by 27%.
But we are slower by 2.7% when compared with non-quantized f16 cache precision due to additional overhead of quantization and dequantization for smaller models like TinyLlama-1.1B for single inference.
Such performance benefit [from u8 quantization] can be seen only when the inference is more memory bound. We see speedups around 3-5% when inferencing LLama-70B int8 quantized model for single Inference case.
Therefore, even though we achieve a speedup of 27% compared to reference implementation, we assume the general case to be compute bound and currently keeping the default as F16 only.
As models get larger and in multiple batch scenarios, by setting kv_cache as "u8" we see significant boost at inference level.
Single inference performance on LLAMA2-7B model on 32c graviton machine.
The values are in TPS [ Tokens per second ].
This work is contributed by @ashwins990 & @abhijain1204fujitsu