Skip to content

Conversation

@qjia7
Copy link
Contributor

@qjia7 qjia7 commented Dec 3, 2025

This pull request improves the WebGPU BERT attention implementation by enhancing FlashAttention support, generalizing tensor layout handling, and increasing batch size flexibility. The changes focus on supporting both BSNH and BNSH tensor layouts, enabling FlashAttention for multi-batch scenarios, and ensuring correct broadcasting and dispatch sizing for attention bias and batch dimensions.

Key improvements include:

FlashAttention Support & Generalization:

  • Added support for both BSNH and BNSH tensor layouts by introducing the q_BNSH parameter and updating shader code, program classes, and kernel logic to handle either layout correctly. This includes changes in the WGSL template and C++ logic for offset calculations and program instantiation. [1] [2] [3] [4] [5] [6] [7] [8]

  • Updated the CanApplyFlashAttention and ApplyFlashAttention logic to allow multi-batch operation by removing the restriction to batch size 1 and ensuring present key/value tensors are always created for FlashAttention. [1] [2] [3]

Batch & Bias Handling:

  • Modified dispatch group size calculations and uniform variables throughout the FlashAttention pipeline to properly account for batch size, ensuring correct parallelization for multi-batch scenarios. [1] [2] [3] [4] [5] [6] [7]

  • Added logic to extract and pass attention bias dimensions as uniforms for correct broadcasting in both the compute and shader code. [1] [2] [3] [4] [5]

Other Enhancements:

  • Improved handling of QKV format detection and generalized code to support more format variants in CopyKVCache.

  • Updated includes and dependencies to ensure all necessary headers for FlashAttention are present.

These changes collectively make the WebGPU BERT attention implementation more robust, flexible, and performant across different tensor layouts and batch sizes.

phi-4-mm-vision.onnx
Before

Kernel Time (ms) Percentage (%)
Attention|AttentionProbs 159.66 11.14
Attention|VxAttentionScore 122.56 8.55
Attention|InPlaceSoftmax 51.83 3.62

After

Kernel Time (ms) Percentage (%)
Attention|FlashAttention 60.23 5.38

@xenova
Copy link
Contributor

xenova commented Dec 3, 2025

Cool! I'm busy trying to fix a bug with GQA #25966 and this will help with correctness checks!

(FYI, you missed a spot here)
image

@xenova
Copy link
Contributor

xenova commented Dec 3, 2025

Flash Attention seems to introduce significant errors in the results of my test suite. Here are the results and reproduction:

  1. Before this PR (FA still failing, non-FA succeeding)
Testing on Providers: ['CPUExecutionProvider', 'WebGpuExecutionProvider']
=================================================================================

--- Testing Provider: CPUExecutionProvider ---
[CPU] Prefill_ColdStart    | In:16 Past:0 Total:16 -> ✅ PASS (Diff: 2.38e-07)
[CPU] Decode_Early         | In:1 Past:16 Total:17 -> ✅ PASS (Diff: 0.00e+00)
[CPU] Decode_Deep          | In:1 Past:64 Total:65 -> ✅ PASS (Diff: 0.00e+00)
[CPU] Speculative_Dec      | In:4 Past:20 Total:24 -> ✅ PASS (Diff: 1.79e-07)
[CPU] Batch_Prefill        | In:16 Past:0 Total:16 -> ✅ PASS (Diff: 1.79e-07)
[CPU] Batch_Decode         | In:1 Past:32 Total:33 -> ✅ PASS (Diff: 0.00e+00)
🎉 CPUExecutionProvider: ALL SCENARIOS PASSED.

--- Testing Provider: WebGpuExecutionProvider ---
[WebGpu] Prefill_ColdStart    | In:16 Past:0 Total:16 -> ❌ FAIL (Diff: 0.99350)
[WebGpu] Decode_Early         | In:1 Past:16 Total:17 -> ❌ FAIL (Diff: 0.65833)
[WebGpu] Decode_Deep          | In:1 Past:64 Total:65 -> ❌ FAIL (Diff: 0.54569)
[WebGpu] Speculative_Dec      | In:4 Past:20 Total:24 -> ❌ FAIL (Diff: 0.67574)
[WebGpu] Batch_Prefill        | In:16 Past:0 Total:16 -> ✅ PASS (Diff: 2.38e-07)
[WebGpu] Batch_Decode         | In:1 Past:32 Total:33 -> ✅ PASS (Diff: 0.00e+00)
⚠️ WebGpuExecutionProvider: FAILURES DETECTED.
  1. After this PR (FA still failing)
Testing on Providers: ['CPUExecutionProvider', 'WebGpuExecutionProvider']
=================================================================================

--- Testing Provider: CPUExecutionProvider ---
[CPU] Prefill_ColdStart    | In:16 Past:0 Total:16 -> ✅ PASS (Diff: 2.38e-07)
[CPU] Decode_Early         | In:1 Past:16 Total:17 -> ✅ PASS (Diff: 0.00e+00)
[CPU] Decode_Deep          | In:1 Past:64 Total:65 -> ✅ PASS (Diff: 0.00e+00)
[CPU] Speculative_Dec      | In:4 Past:20 Total:24 -> ✅ PASS (Diff: 1.79e-07)
[CPU] Batch_Prefill        | In:16 Past:0 Total:16 -> ✅ PASS (Diff: 1.79e-07)
[CPU] Batch_Decode         | In:1 Past:32 Total:33 -> ✅ PASS (Diff: 0.00e+00)
🎉 CPUExecutionProvider: ALL SCENARIOS PASSED.

--- Testing Provider: WebGpuExecutionProvider ---
[WebGpu] Prefill_ColdStart    | In:16 Past:0 Total:16 -> ❌ FAIL (Diff: 0.99350)
[WebGpu] Decode_Early         | In:1 Past:16 Total:17 -> ❌ FAIL (Diff: 0.65833)
[WebGpu] Decode_Deep          | In:1 Past:64 Total:65 -> ❌ FAIL (Diff: 0.54569)
[WebGpu] Speculative_Dec      | In:4 Past:20 Total:24 -> ❌ FAIL (Diff: 0.67574)
[WebGpu] Batch_Prefill        | In:16 Past:0 Total:16 -> ❌ FAIL (Diff: 0.99769)
[WebGpu] Batch_Decode         | In:1 Past:32 Total:33 -> ❌ FAIL (Diff: 0.65021)
⚠️ WebGpuExecutionProvider: FAILURES DETECTED.

Expand to see reproduction script
import numpy as np
from onnx import helper, TensorProto
import onnxruntime as ort
import dataclasses

# ==========================================
# 0. Test Harness Configuration
# ==========================================
@dataclasses.dataclass
class TestConfig:
    name: str
    batch_size: int
    seq_len: int        # Number of tokens to process NOW
    past_seq_len: int   # Number of tokens already in cache
    max_seq_len: int = 128

    @property
    def total_seq_len(self):
        return self.past_seq_len + self.seq_len

def create_session(model_def, provider):
    sess_options = ort.SessionOptions()
    sess_options.log_severity_level = 3

    try:
        return ort.InferenceSession(
            model_def.SerializeToString(),
            sess_options=sess_options,
            providers=[provider]
        )
    except Exception as e:
        print(f"⚠️ Failed to create session for {provider}: {e}")
        return None

# ==========================================
# 1. The Core Comparison Function
# ==========================================
def run_test_case(cfg: TestConfig, provider: str):
    print(f"[{provider.replace('ExecutionProvider', '')}] {cfg.name: <20} | In:{cfg.seq_len} Past:{cfg.past_seq_len} Total:{cfg.total_seq_len}", end="")

    # Constants
    NUM_HEADS = 4
    HEAD_SIZE = 32
    HIDDEN_SIZE = NUM_HEADS * HEAD_SIZE

    # Shapes
    query_shape = [cfg.batch_size, cfg.seq_len, HIDDEN_SIZE]
    kv_input_shape = [cfg.batch_size, cfg.seq_len, HIDDEN_SIZE]
    past_shape_gqa = [cfg.batch_size, NUM_HEADS, cfg.max_seq_len, HEAD_SIZE]
    past_shape_mha = [cfg.batch_size, NUM_HEADS, cfg.past_seq_len, HEAD_SIZE]

    # ----------------------------------------
    # A. Build GQA Model (Full Buffer Mode)
    # ----------------------------------------
    gqa_node = helper.make_node(
        'GroupQueryAttention',
        inputs=['query', 'key', 'value', 'past_key', 'past_value', 'seqlens_k', 'total_seq_len'],
        outputs=['output', 'present_key', 'present_value'],
        domain='com.microsoft',
        name='GQA_Node',
        do_rotary=0,
        kv_num_heads=NUM_HEADS,
        num_heads=NUM_HEADS,
        scale=1.0/np.sqrt(HEAD_SIZE),
    )

    gqa_inputs_info = [
        helper.make_tensor_value_info('query', TensorProto.FLOAT, query_shape),
        helper.make_tensor_value_info('key', TensorProto.FLOAT, kv_input_shape),
        helper.make_tensor_value_info('value', TensorProto.FLOAT, kv_input_shape),
        helper.make_tensor_value_info('past_key', TensorProto.FLOAT, past_shape_gqa),
        helper.make_tensor_value_info('past_value', TensorProto.FLOAT, past_shape_gqa),
        helper.make_tensor_value_info('seqlens_k', TensorProto.INT32, [cfg.batch_size]),
        helper.make_tensor_value_info('total_seq_len', TensorProto.INT32, []),
    ]

    gqa_graph = helper.make_graph([gqa_node], 'gqa-test', gqa_inputs_info,
                                  [helper.make_tensor_value_info('output', TensorProto.FLOAT, query_shape)])
    gqa_model = helper.make_model(gqa_graph, opset_imports=[helper.make_opsetid("", 14), helper.make_opsetid("com.microsoft", 1)])

    # ----------------------------------------
    # B. Build MHA Model (Sliced Mode)
    # ----------------------------------------
    mha_node = helper.make_node(
        'MultiHeadAttention',
        inputs=['query', 'key', 'value', '', '', '', 'past_key', 'past_value', 'past_seq_len'],
        outputs=['output', 'present_key', 'present_value'],
        domain='com.microsoft',
        name='MHA_Node',
        num_heads=NUM_HEADS,
        unidirectional=1,
        scale=1.0/np.sqrt(HEAD_SIZE),
    )

    mha_inputs_info = [
        helper.make_tensor_value_info('query', TensorProto.FLOAT, query_shape),
        helper.make_tensor_value_info('key', TensorProto.FLOAT, kv_input_shape),
        helper.make_tensor_value_info('value', TensorProto.FLOAT, kv_input_shape),
        helper.make_tensor_value_info('past_key', TensorProto.FLOAT, past_shape_mha),
        helper.make_tensor_value_info('past_value', TensorProto.FLOAT, past_shape_mha),
        helper.make_tensor_value_info('past_seq_len', TensorProto.INT32, [1]),
    ]

    mha_graph = helper.make_graph([mha_node], 'mha-test', mha_inputs_info,
                                  [helper.make_tensor_value_info('output', TensorProto.FLOAT, query_shape)])
    mha_model = helper.make_model(mha_graph, opset_imports=[helper.make_opsetid("", 14), helper.make_opsetid("com.microsoft", 1)])

    # ----------------------------------------
    # C. Data Generation
    # ----------------------------------------
    np.random.seed(42 + cfg.seq_len + cfg.past_seq_len)

    query = np.random.rand(*query_shape).astype(np.float32)
    key = np.random.rand(*kv_input_shape).astype(np.float32)
    value = np.random.rand(*kv_input_shape).astype(np.float32)

    past_key_full = np.random.rand(*past_shape_gqa).astype(np.float32)
    past_value_full = np.random.rand(*past_shape_gqa).astype(np.float32)

    past_key_sliced = past_key_full[:, :, :cfg.past_seq_len, :]
    past_value_sliced = past_value_full[:, :, :cfg.past_seq_len, :]

    total_seq_len_scalar = np.array(cfg.total_seq_len, dtype=np.int32)
    past_seq_len_scalar = np.array([cfg.past_seq_len], dtype=np.int32)

    seqlens_k_arr = np.array([cfg.total_seq_len - 1] * cfg.batch_size, dtype=np.int32)

    # ----------------------------------------
    # D. Execution
    # ----------------------------------------
    # Create Sessions
    sess_gqa = create_session(gqa_model, provider)
    sess_mha = create_session(mha_model, provider)

    if sess_gqa is None or sess_mha is None:
        print(" ... SKIP (Session Init Failed)")
        return False

    # Run GQA
    res_gqa = sess_gqa.run(['output'], {
        'query': query, 'key': key, 'value': value,
        'past_key': past_key_full, 'past_value': past_value_full,
        'seqlens_k': seqlens_k_arr,
        'total_seq_len': total_seq_len_scalar
    })

    # Run MHA
    res_mha = sess_mha.run(['output'], {
        'query': query, 'key': key, 'value': value,
        'past_key': past_key_sliced, 'past_value': past_value_sliced,
        'past_seq_len': past_seq_len_scalar
    })

    # ----------------------------------------
    # E. Validation
    # ----------------------------------------
    out_gqa = res_gqa[0]
    out_mha = res_mha[0]

    diff = np.abs(out_gqa - out_mha)
    max_diff = diff.max()

    base_tol = 1e-5
    if max_diff < base_tol:
        print(f" -> ✅ PASS (Diff: {max_diff:.2e})")
        return True
    else:
        print(f" -> ❌ FAIL (Diff: {max_diff:.5f})")
        return False

# ==========================================
# 2. Main Execution Loop
# ==========================================
test_scenarios = [
    TestConfig(name="Prefill_ColdStart", batch_size=1, seq_len=16, past_seq_len=0),
    TestConfig(name="Decode_Early", batch_size=1, seq_len=1, past_seq_len=16),
    TestConfig(name="Decode_Deep", batch_size=1, seq_len=1, past_seq_len=64),
    TestConfig(name="Speculative_Dec", batch_size=1, seq_len=4, past_seq_len=20),
    TestConfig(name="Batch_Prefill", batch_size=4, seq_len=16, past_seq_len=0),
    TestConfig(name="Batch_Decode", batch_size=4, seq_len=1, past_seq_len=32),
]

# Detect Providers
available = ort.get_available_providers()
target_providers = ['CPUExecutionProvider']

# Check for WebGPU
if 'WebGpuExecutionProvider' in available:
    target_providers.append('WebGpuExecutionProvider')
else:
    print("⚠️ WebGpuExecutionProvider not found in this environment. Skipping GPU tests.")

print(f"Testing on Providers: {target_providers}")
print("=================================================================================")

for provider in target_providers:
    print(f"\n--- Testing Provider: {provider} ---")
    all_passed = True
    for config in test_scenarios:
        if not run_test_case(config, provider):
            all_passed = False

    if all_passed:
        print(f"🎉 {provider}: ALL SCENARIOS PASSED.")
    else:
        print(f"⚠️ {provider}: FAILURES DETECTED.")

@qjia7
Copy link
Contributor Author

qjia7 commented Dec 4, 2025

--- Testing Provider: WebGpuExecutionProvider ---
[WebGpu] Prefill_ColdStart | In:16 Past:0 Total:16 -> ❌ FAIL (Diff: 0.99350)
[WebGpu] Decode_Early | In:1 Past:16 Total:17 -> ❌ FAIL (Diff: 0.65833)
[WebGpu] Decode_Deep | In:1 Past:64 Total:65 -> ❌ FAIL (Diff: 0.54569)
[WebGpu] Speculative_Dec | In:4 Past:20 Total:24 -> ❌ FAIL (Diff: 0.67574)
[WebGpu] Batch_Prefill | In:16 Past:0 Total:16 -> ❌ FAIL (Diff: 0.99769)
[WebGpu] Batch_Decode | In:1 Past:32 Total:33 -> ❌ FAIL (Diff: 0.65021)
⚠️ WebGpuExecutionProvider: FAILURES DETECTED.

I can reproduce them. Will investigate the reason and fix them in this PR. Thanks for reporting.

@xenova
Copy link
Contributor

xenova commented Dec 4, 2025

I can reproduce them. Will investigate the reason and fix them in this PR. Thanks for reporting.

Thanks so much! :) I kept doing a deep-dive, and if you set

    NUM_HEADS = 2
    HEAD_SIZE = 8
    HIDDEN_SIZE = NUM_HEADS * HEAD_SIZE

you can get a very small minimal reproduction. What happens, is that only half of the values in the last dimension are correctly calculated.


--- Testing Provider: WebGpuExecutionProvider ---
[WebGpu] Prefill_ColdStart    | In:3 Past:0 Total:3--------------------------------------------------
out_gqa=array([[[0.8407924 , 0.43310025, 0.7227552 , 0.66863596, 0.5237495 ,
         0.29788446, 0.57098573, 0.5739879 , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.4267663 , 0.25000745, 0.39151034, 0.51773685, 0.3504236 ,
         0.21153568, 0.39998588, 0.4887765 , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.49323374, 0.32594058, 0.640296  , 0.56106925, 0.54462355,
         0.43403137, 0.51209587, 0.44897917, 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ]]], dtype=float32)
--------------------------------------------------
out_mha=array([[[0.8407924 , 0.43310025, 0.7227552 , 0.66863596, 0.5237495 ,
         0.29788446, 0.57098573, 0.5739879 , 0.04442158, 0.45326525,
         0.9900071 , 0.22671595, 0.47324735, 0.07885394, 0.96738887,
         0.40757734],
        [0.4267663 , 0.25000745, 0.39151034, 0.51773685, 0.3504236 ,
         0.21153568, 0.39998588, 0.4887765 , 0.05752607, 0.39885843,
         0.5645982 , 0.3044865 , 0.60672575, 0.480561  , 0.94821554,
         0.25906754],
        [0.49323374, 0.32594058, 0.640296  , 0.56106925, 0.54462355,
         0.43403137, 0.51209587, 0.44897917, 0.14777677, 0.29784158,
         0.47044748, 0.4040054 , 0.5328809 , 0.57943636, 0.92094123,
         0.4904349 ]]], dtype=float32)
--------------------------------------------------
 -> ❌ FAIL (Diff: 0.99001)
⚠️ WebGpuExecutionProvider: FAILURES DETECTED.
==============================

As you can see, the first half is correct, but the second half are all zeroes.

@qjia7
Copy link
Contributor Author

qjia7 commented Dec 4, 2025

@xenova Your failed cases should be fixed by 53a944a. Please give a try. Thanks.

@xenova
Copy link
Contributor

xenova commented Dec 4, 2025

--- Testing Provider: CPUExecutionProvider ---
[CPU] Prefill_ColdStart    | In:3 Past:0 Total:3 -> ✅ PASS (Diff: 0.00e+00)
[CPU] Prefill_ColdStart    | In:16 Past:0 Total:16 -> ✅ PASS (Diff: 1.19e-07)
[CPU] Decode_Early         | In:1 Past:16 Total:17 -> ✅ PASS (Diff: 0.00e+00)
[CPU] Decode_Deep          | In:1 Past:64 Total:65 -> ✅ PASS (Diff: 0.00e+00)
[CPU] Speculative_Dec      | In:4 Past:20 Total:24 -> ✅ PASS (Diff: 1.19e-07)
[CPU] Batch_Prefill        | In:16 Past:0 Total:16 -> ✅ PASS (Diff: 1.79e-07)
[CPU] Batch_Decode         | In:1 Past:32 Total:33 -> ✅ PASS (Diff: 0.00e+00)
🎉 CPUExecutionProvider: ALL SCENARIOS PASSED.

--- Testing Provider: WebGpuExecutionProvider ---
[WebGpu] Prefill_ColdStart    | In:3 Past:0 Total:3 -> ✅ PASS (Diff: 0.00e+00)
[WebGpu] Prefill_ColdStart    | In:16 Past:0 Total:16 -> ✅ PASS (Diff: 0.00e+00)
[WebGpu] Decode_Early         | In:1 Past:16 Total:17 -> ✅ PASS (Diff: 0.00e+00)
[WebGpu] Decode_Deep          | In:1 Past:64 Total:65 -> ✅ PASS (Diff: 0.00e+00)
[WebGpu] Speculative_Dec      | In:4 Past:20 Total:24 -> ✅ PASS (Diff: 0.00e+00)
[WebGpu] Batch_Prefill        | In:16 Past:0 Total:16 -> ✅ PASS (Diff: 0.00e+00)
[WebGpu] Batch_Decode         | In:1 Past:32 Total:33 -> ✅ PASS (Diff: 0.00e+00)
🎉 WebGpuExecutionProvider: ALL SCENARIOS PASSED.

Amazing! Thanks so much @qjia7

Copy link
Contributor

@xenova xenova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on a bunch of other cases too; all tests pass!

@xenova
Copy link
Contributor

xenova commented Dec 4, 2025

I kept digging and designing test cases which pass on CPU, but fail on WebGPU.

Full reproduction
from typing import Optional
import numpy as np
from onnx import helper, TensorProto
import onnxruntime as ort
import dataclasses

# ==========================================
# 0. Test Harness Configuration
# ==========================================
@dataclasses.dataclass
class TestConfig:
    name: str
    batch_size: int
    seq_len: int        # Number of tokens to process NOW
    past_seq_len: int   # Number of tokens already in cache
    max_seq_len: int = 128
    num_heads: int = 2
    kv_num_heads: int = 2
    head_size: int = 8

    # New parameters for extended testing
    do_rotary: int = 0
    rotary_interleaved: int = 0
    local_window_size: int = -1
    softcap: float = 0.0
    use_rotary_cache: bool = False
    custom_scale: Optional[float] = None

    @property
    def total_seq_len(self):
        return self.past_seq_len + self.seq_len

    @property
    def is_gqa_specific(self):
        return (self.do_rotary > 0 or
                self.local_window_size != -1 or
                self.softcap > 0.0 or
                self.use_rotary_cache)

def create_session(model_def, provider):
    sess_options = ort.SessionOptions()
    # sess_options.log_severity_level = 0

    try:
        return ort.InferenceSession(
            model_def.SerializeToString(),
            sess_options=sess_options,
            providers=[provider]
        )
    except Exception as e:
        print(f"⚠️ Failed to create session for {provider}: {e}")
        return None

# ==========================================
# 1. The Core Comparison Function
# ==========================================
def run_test_case(cfg: TestConfig, providers: list[str]):
    print(f"{cfg.name: <20} | In:{cfg.seq_len} Past:{cfg.past_seq_len} Total:{cfg.total_seq_len} H:{cfg.num_heads} KV:{cfg.kv_num_heads}", end="")

    # Constants
    NUM_HEADS = cfg.num_heads
    KV_NUM_HEADS = cfg.kv_num_heads
    HEAD_SIZE = cfg.head_size
    HIDDEN_SIZE = NUM_HEADS * HEAD_SIZE
    KV_HIDDEN_SIZE = KV_NUM_HEADS * HEAD_SIZE

    SCALE = cfg.custom_scale if cfg.custom_scale is not None else 1.0/np.sqrt(HEAD_SIZE)

    # Shapes
    query_shape = [cfg.batch_size, cfg.seq_len, HIDDEN_SIZE]
    kv_input_shape = [cfg.batch_size, cfg.seq_len, KV_HIDDEN_SIZE]
    kv_input_shape_mha = [cfg.batch_size, cfg.seq_len, HIDDEN_SIZE]
    past_shape_gqa = [cfg.batch_size, KV_NUM_HEADS, cfg.past_seq_len, HEAD_SIZE]
    past_shape_mha = [cfg.batch_size, NUM_HEADS, cfg.past_seq_len, HEAD_SIZE]

    # Rotary Cache Shape
    cache_shape = [cfg.max_seq_len, HEAD_SIZE // 2] if cfg.use_rotary_cache else []

    # ----------------------------------------
    # A. Build GQA Model (Full Buffer Mode)
    # ----------------------------------------
    gqa_inputs = ['query', 'key', 'value', 'past_key', 'past_value', 'seqlens_k', 'total_seq_len']
    if cfg.use_rotary_cache:
        gqa_inputs.extend(['cos_cache', 'sin_cache'])

    gqa_node = helper.make_node(
        'GroupQueryAttention',
        inputs=gqa_inputs,
        outputs=['output', 'present_key', 'present_value'],
        domain='com.microsoft',
        name='GQA_Node',
        do_rotary=cfg.do_rotary,
        kv_num_heads=KV_NUM_HEADS,
        num_heads=NUM_HEADS,
        scale=SCALE,
        rotary_interleaved=cfg.rotary_interleaved,
        local_window_size=cfg.local_window_size,
        softcap=cfg.softcap,
    )

    gqa_inputs_info = [
        helper.make_tensor_value_info('query', TensorProto.FLOAT, query_shape),
        helper.make_tensor_value_info('key', TensorProto.FLOAT, kv_input_shape),
        helper.make_tensor_value_info('value', TensorProto.FLOAT, kv_input_shape),
        helper.make_tensor_value_info('past_key', TensorProto.FLOAT, past_shape_gqa),
        helper.make_tensor_value_info('past_value', TensorProto.FLOAT, past_shape_gqa),
        helper.make_tensor_value_info('seqlens_k', TensorProto.INT32, [cfg.batch_size]),
        helper.make_tensor_value_info('total_seq_len', TensorProto.INT32, []),
    ]

    if cfg.use_rotary_cache:
        gqa_inputs_info.extend([
            helper.make_tensor_value_info('cos_cache', TensorProto.FLOAT, cache_shape),
            helper.make_tensor_value_info('sin_cache', TensorProto.FLOAT, cache_shape),
        ])

    gqa_graph = helper.make_graph([gqa_node], 'gqa-test', gqa_inputs_info,
                                  [helper.make_tensor_value_info('output', TensorProto.FLOAT, query_shape)])
    gqa_model = helper.make_model(gqa_graph, opset_imports=[helper.make_opsetid("", 14), helper.make_opsetid("com.microsoft", 1)])

    # ----------------------------------------
    # B. Build MHA Model (Sliced Mode)
    # ----------------------------------------
    mha_node = helper.make_node(
        'MultiHeadAttention',
        inputs=['query', 'key', 'value', '', '', '', 'past_key', 'past_value', 'past_seq_len'],
        outputs=['output', 'present_key', 'present_value'],
        domain='com.microsoft',
        name='MHA_Node',
        num_heads=NUM_HEADS,
        unidirectional=1,
        scale=SCALE,
    )

    mha_inputs_info = [
        helper.make_tensor_value_info('query', TensorProto.FLOAT, query_shape),
        helper.make_tensor_value_info('key', TensorProto.FLOAT, kv_input_shape_mha),
        helper.make_tensor_value_info('value', TensorProto.FLOAT, kv_input_shape_mha),
        helper.make_tensor_value_info('past_key', TensorProto.FLOAT, past_shape_mha),
        helper.make_tensor_value_info('past_value', TensorProto.FLOAT, past_shape_mha),
        helper.make_tensor_value_info('past_seq_len', TensorProto.INT32, [1]),
    ]

    mha_graph = helper.make_graph([mha_node], 'mha-test', mha_inputs_info,
                                  [helper.make_tensor_value_info('output', TensorProto.FLOAT, query_shape)])
    mha_model = helper.make_model(mha_graph, opset_imports=[helper.make_opsetid("", 14), helper.make_opsetid("com.microsoft", 1)])

    # ----------------------------------------
    # C. Data Generation
    # ----------------------------------------
    np.random.seed(42 + cfg.seq_len + cfg.past_seq_len)

    query = np.random.rand(*query_shape).astype(np.float32)
    key = np.random.rand(*kv_input_shape).astype(np.float32)
    value = np.random.rand(*kv_input_shape).astype(np.float32)

    past_key_full = np.random.rand(*past_shape_gqa).astype(np.float32)
    past_value_full = np.random.rand(*past_shape_gqa).astype(np.float32)

    cos_cache = np.random.rand(*cache_shape).astype(np.float32) if cfg.use_rotary_cache else None
    sin_cache = np.random.rand(*cache_shape).astype(np.float32) if cfg.use_rotary_cache else None

    # Prepare MHA inputs (Repeat KV if necessary)
    n_rep = NUM_HEADS // KV_NUM_HEADS

    def repeat_kv(x, n_rep):
        if n_rep == 1: return x
        if x.ndim == 3: # [B, S, H_kv] -> [B, S, H_q]
            b, s, h_kv = x.shape
            head_size = h_kv // KV_NUM_HEADS
            x = x.reshape(b, s, KV_NUM_HEADS, head_size)
            x = np.repeat(x, n_rep, axis=2)
            return x.reshape(b, s, NUM_HEADS * head_size)
        elif x.ndim == 4: # [B, H_kv, S, D] -> [B, H_q, S, D]
            return np.repeat(x, n_rep, axis=1)
        return x

    key_mha = repeat_kv(key, n_rep)
    value_mha = repeat_kv(value, n_rep)

    past_key_sliced = past_key_full[:, :, :cfg.past_seq_len, :]
    past_value_sliced = past_value_full[:, :, :cfg.past_seq_len, :]

    past_key_mha = repeat_kv(past_key_sliced, n_rep)
    past_value_mha = repeat_kv(past_value_sliced, n_rep)

    total_seq_len_scalar = np.array(cfg.total_seq_len, dtype=np.int32)
    past_seq_len_scalar = np.array([cfg.past_seq_len], dtype=np.int32)

    seqlens_k_arr = np.array([cfg.total_seq_len - 1] * cfg.batch_size, dtype=np.int32)

    # ----------------------------------------
    # D. Execution
    # ----------------------------------------
    results = {}

    for provider in providers:
        # Create Sessions
        sess_gqa = create_session(gqa_model, provider)
        sess_mha = create_session(mha_model, provider)

        if sess_gqa is None or sess_mha is None:
            print(f" ... SKIP ({provider} Init Failed)")
            continue

        # Run GQA
        try:
            feed_gqa = {
                'query': query, 'key': key, 'value': value,
                'past_key': past_key_full, 'past_value': past_value_full,
                'seqlens_k': seqlens_k_arr,
                'total_seq_len': total_seq_len_scalar
            }
            if cfg.use_rotary_cache:
                feed_gqa['cos_cache'] = cos_cache
                feed_gqa['sin_cache'] = sin_cache

            res_gqa = sess_gqa.run(['output'], feed_gqa)
            results[f"{provider}_GQA"] = res_gqa[0]
        except Exception as e:
            print(f" ... ERR ({provider} GQA: {e})")

        # Run MHA
        if not cfg.is_gqa_specific:
            try:
                res_mha = sess_mha.run(['output'], {
                    'query': query, 'key': key_mha, 'value': value_mha,
                    'past_key': past_key_mha, 'past_value': past_value_mha,
                    'past_seq_len': past_seq_len_scalar
                })
                results[f"{provider}_MHA"] = res_mha[0]
            except Exception as e:
                print(f" ... ERR ({provider} MHA: {e})")

    # ----------------------------------------
    # E. Validation
    # ----------------------------------------
    if not results:
        print(" -> ⚠️ NO RESULTS")
        return False

    # Determine baseline
    if cfg.is_gqa_specific:
        baseline_key = "CPUExecutionProvider_GQA"
    else:
        baseline_key = "CPUExecutionProvider_MHA"

    if baseline_key not in results:
        # Fallback to the first available key if preferred baseline is missing
        if results:
            baseline_key = list(results.keys())[0]

    baseline = results[baseline_key]
    passed = True
    max_diff_global = 0.0

    failures = []

    for key, val in results.items():
        if key == baseline_key: continue

        diff = np.abs(baseline - val)
        max_diff = diff.max()
        max_diff_global = max(max_diff_global, max_diff)

        base_tol = 1e-4 # Slightly relaxed for cross-device
        if max_diff > base_tol:
            passed = False
            failures.append(f"{key} (Diff: {max_diff:.2e})")

    if passed:
        print(f" -> ✅ PASS (Max Diff vs {baseline_key}: {max_diff_global:.2e})")
        return True
    else:
        print(f" -> ❌ FAIL vs {baseline_key}: {', '.join(failures)}")
        return False

# ==========================================
# 2. Main Execution Loop
# ==========================================
test_scenarios = [
    TestConfig(name="Rotary Interleaved", batch_size=1, seq_len=4, past_seq_len=0, max_seq_len=128, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=1, local_window_size=-1, softcap=0.0, use_rotary_cache=True, custom_scale=0.25),
    TestConfig(name="Rotary_Window", batch_size=1, seq_len=16, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, local_window_size=4),
    TestConfig(name="All_Features", batch_size=1, seq_len=8, past_seq_len=4, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, local_window_size=4, softcap=50.0, custom_scale=1.0),
    TestConfig(name="Rotary_Interleaved", batch_size=1, seq_len=4, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=1),
    TestConfig(name="Rotary_Half", batch_size=1, seq_len=4, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=0),
]

# Detect Providers
available = ort.get_available_providers()
target_providers = ['CPUExecutionProvider']

# Check for WebGPU
if 'WebGpuExecutionProvider' in available:
    target_providers.append('WebGpuExecutionProvider')
else:
    print("⚠️ WebGpuExecutionProvider not found in this environment. Skipping GPU tests.")

print(f"Testing on Providers: {target_providers}")
print("=================================================================================")

all_passed = True
for config in test_scenarios:
    if not run_test_case(config, target_providers):
        all_passed = False

if all_passed:
    print("\n🎉 ALL SCENARIOS PASSED ACROSS ALL PROVIDERS.")
else:
    print("\n⚠️ FAILURES DETECTED.")
    TestConfig(name="Rotary Interleaved", batch_size=1, seq_len=4, past_seq_len=0, max_seq_len=128, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=1, local_window_size=-1, softcap=0.0, use_rotary_cache=True, custom_scale=0.25),
    TestConfig(name="Rotary_Window", batch_size=1, seq_len=16, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, local_window_size=4),
    TestConfig(name="All_Features", batch_size=1, seq_len=8, past_seq_len=4, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, local_window_size=4, softcap=50.0, custom_scale=1.0),
    TestConfig(name="Rotary_Interleaved", batch_size=1, seq_len=4, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=1),
    TestConfig(name="Rotary_Half", batch_size=1, seq_len=4, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=0),

The first produces incorrect results (always fails when do_rotary=1, rotary_interleaved=1), and the last 4 cause segmentation faults. FWIW, these also failed before this PR, so it may not need to block this PR from being merged.

@qjia7
Copy link
Contributor Author

qjia7 commented Dec 8, 2025

    TestConfig(name="Rotary Interleaved", batch_size=1, seq_len=4, past_seq_len=0, max_seq_len=128, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=1, local_window_size=-1, softcap=0.0, use_rotary_cache=True, custom_scale=0.25),
    TestConfig(name="Rotary_Window", batch_size=1, seq_len=16, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, local_window_size=4),
    TestConfig(name="All_Features", batch_size=1, seq_len=8, past_seq_len=4, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, local_window_size=4, softcap=50.0, custom_scale=1.0),
    TestConfig(name="Rotary_Interleaved", batch_size=1, seq_len=4, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=1),
    TestConfig(name="Rotary_Half", batch_size=1, seq_len=4, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=0),

The first produces incorrect results (always fails when do_rotary=1, rotary_interleaved=1), and the last 4 cause segmentation faults. FWIW, these also failed before this PR, so it may not need to block this PR from being merged.

Reproduced locally. It seems that the q_rotary and k_rotary are not calculated correctly. Will fix them in separate PR. The last 4 tests are not using rotary cache? At least for webgpu, when do_rotary is true, cos_cache and sin_cache are required.

@qjia7 qjia7 marked this pull request as ready for review December 8, 2025 12:48
@qjia7 qjia7 requested review from fs-eire and guschmue December 8, 2025 12:51
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Dec 8, 2025
@xenova
Copy link
Contributor

xenova commented Dec 8, 2025

Reproduced locally. It seems that the q_rotary and k_rotary are not calculated correctly. Will fix them in separate PR.

Thanks!

The last 4 tests are not using rotary cache? At least for webgpu, when do_rotary is true, cos_cache and sin_cache are required.

I believe so, yes. However, for CPU it produces an output, while on WebGPU, it runs into a segmentation fault (we should either throw an error, or produce same results as CPU, imo).

@qjia7 qjia7 merged commit 549d741 into main Dec 9, 2025
91 checks passed
@qjia7 qjia7 deleted the attention_opt branch December 9, 2025 02:10
guschmue pushed a commit that referenced this pull request Dec 9, 2025
### Description

This PR fixes the last tests that were failing in
#26715 (comment),
where rotary_interleaved=1 in GQA kernel. The root cause was that the
`rotary_interleaved` parameter was not being propagated correctly,
meaning it always defaulted to 0 in `FusedQKRotaryEmbeddingProgram`.

```
Testing on Providers: ['CPUExecutionProvider', 'WebGpuExecutionProvider']
=================================================================================
Prefill_ColdStart    | In:3 Past:0 Total:3 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08)
Prefill_ColdStart    | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Decode_Early         | In:1 Past:16 Total:17 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08)
Decode_Deep          | In:1 Past:64 Total:65 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Speculative_Dec      | In:4 Past:20 Total:24 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
Batch_Prefill        | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Batch_Decode         | In:1 Past:32 Total:33 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
GQA_Prefill          | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
GQA_Decode           | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
GQA_Batch_Dec        | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
MQA_Prefill          | In:32 Past:0 Total:32 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
MQA_Decode           | In:1 Past:32 Total:33 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
LgBatch_MHA          | In:1 Past:16 Total:17 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
LgBatch_GQA          | In:1 Past:16 Total:17 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Odd_SeqLen           | In:7 Past:13 Total:20 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
Odd_Heads            | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
HighHeads_MHA        | In:1 Past:32 Total:33 H:32 KV:32 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
HighHeads_GQA        | In:1 Past:32 Total:33 H:32 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
HighHeads_MQA        | In:1 Past:32 Total:33 H:32 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
LgCtx_Prefill        | In:128 Past:0 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 4.17e-07)
LgCtx_Decode         | In:1 Past:127 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.98e-07)
TinyHead_MHA         | In:4 Past:4 Total:8 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
TinyHead_GQA         | In:4 Past:4 Total:8 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.49e-07)
LgHead_MHA           | In:2 Past:2 Total:4 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
LgHead_GQA           | In:2 Past:2 Total:4 H:4 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Ratio_5_1            | In:1 Past:10 Total:11 H:5 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Ratio_6_2            | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Ratio_6_3            | In:1 Past:10 Total:11 H:6 KV:3 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08)
Ratio_12_4           | In:1 Past:10 Total:11 H:12 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Zero_Past            | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00)
Single_Token_Prefill | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00)
Rotary_Cache_Test    | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary               | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Window_Small         | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Window_Large         | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07)
Window_Decode        | In:1 Past:20 Total:21 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Softcap_Enabled      | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 5.98e-05)
Scale_0.5            | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Rotary_Interleaved   | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary               | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary_Half          | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary Interleaved 2 | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary_Window        | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07)

🎉 ALL SCENARIOS PASSED ACROSS ALL PROVIDERS.
```

### Motivation and Context

cc @qjia7 @guschmue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants