-
Couldn't load subscription status.
- Fork 3.5k
webgpu / nbitmm support for bias and weight_index #26392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…its to template and make program callable from other ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds WebGPU support for bias and weight_index parameters to N-bit matrix multiplication operations, enabling features like stacked weights and bias addition in quantized operations.
Key changes:
- Extended matmul_nbits operations to support optional bias parameter across multiple implementations (DP4A, wide tile, subgroup matrix)
- Added weight_index uniform variable to enable weight stacking and offset computation in quantized matmul
- Refactored Apple-specific shader generation to use WGSL templates instead of inline string concatenation
Reviewed Changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| matmul_4bits_test.cc | Added test case for WebGPU with bias support |
| subgroup_matrix_matmul_nbits_apple.wgsl.template | New template file consolidating Apple shader generation with bias support |
| subgroup_matrix_matmul_nbits.h | Added has_bias and weight_idx parameters to program interface |
| subgroup_matrix_matmul_nbits.cc | Refactored to use template system and added bias/weight_index support |
| matmul_nbits_zero_pt.wgsl.template | Added has_bias parameter declaration |
| matmul_nbits_wide_tile.wgsl.template | Implemented bias addition and weight offset calculations |
| matmul_nbits.h | Added has_bias parameter and exposed ApplyMatMulNBits function |
| matmul_nbits.cc | Removed bias constraint, added ApplyMatMulNBits function with extensive documentation |
| dp4a_matmul_small_m.wgsl.template | Added bias support with offset calculations |
| dp4a_matmul_nbits.h | Added has_bias and weight_index parameters to program interfaces |
| dp4a_matmul_nbits.cc | Integrated bias support across DP4A implementations |
| dp4a_matmul_common.wgsl.template | Added has_bias parameter declaration |
| dp4a_matmul.wgsl.template | Implemented bias addition with vectorized operations |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template
Show resolved
Hide resolved
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In current change, it seems that ApplyMatMulNBits calculate a with one weight index of b. For QMoE case, it only compute one expert. I remember you said you need 4 selected experts. So will ApplyMatMulNBits be called 4 times to get the up projection results? Why not directly generate the result by calling ApplyMatMulNBits once?
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template
Outdated
Show resolved
Hide resolved
| let zero = mm_read_zero(0, 0, uniforms.N, uniforms.zero_blocks_per_col); | ||
| let own_scale_b = scales_b.getByOffset(0); | ||
| let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K / uniforms.block_size); | ||
| let own_scale_b = scales_b.getByOffset(b_scale_offset); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for single_scale_weights, you can directly use let own_scale_b = scales_b.getByOffset(uniforms.weight_idx);.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| * @param accuracy_level Accuracy level influencing the choice of optimized kernel. | ||
| * @param nbits Number of bits used for quantization. | ||
| * @param context Compute context for WebGPU, providing device-specific information and execution facilities. | ||
| * @param y Pointer to the output tensor that will hold the result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add description for param weight_index which is used to specify which index of batch in b to participant into the calculation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| * | ||
| * @param a Pointer to the left-hand side (activation) tensor. | ||
| * @param b Pointer to the quantized weight tensor. | ||
| * @param scales Pointer to the tensor containing scaling factors for quantization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume the b's shape will be (weight_batch, N, k_blocks, blob_size) instead of (N, k_blocks, blob_size). For MatMulNBits operator, weight_batch is 1. And for custom scenarios like QMoe, weight_batch is num_experts. So it will be good to have some description about this parameter in case others modify this file in future.
And for scales, similarly, the shape of it is (weight_batch, N) instead of N.
| * @param a Pointer to the left-hand side (activation) tensor. | ||
| * @param b Pointer to the quantized weight tensor. | ||
| * @param scales Pointer to the tensor containing scaling factors for quantization. | ||
| * @param zero_points Pointer to the zero-point tensor for quantization; must be of type uint8 if provided. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For zero point, currently, weight_index is not used since QMoe is using symmetric quantization. Please add comment here that weight_batch is not supported in zero_points.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
| 4}); | ||
| } | ||
| if (has_bias) { | ||
| program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); | |
| program.AddInput({bias, ProgramTensorMetadataDependency::None}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); | ||
| } | ||
| if (has_bias) { | ||
| program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); | |
| program.AddInput({bias, ProgramTensorMetadataDependency::None}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| #if !single_scale_weights | ||
| let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; | ||
| let scale_b = scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); | ||
| let scale_b = scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx + b_scale_offset); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also update the scale_b in single_scale_weights path in line 48 (let scale_b = scales_b.getByOffset(uniforms.weight_idx);).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| } | ||
|
|
||
| let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col); | ||
| let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Since b_weight_offset and b_scale_offset are constant value, will it be better to calculate them in cpu and write into uniform? In shader, we can always read them from uniform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could do that but I intentionally did not because there are 32 experts so we'd need to compile 32 shaders.
But I was thinking for weight_idx == 0 I could do a #if in the template and 'const b_weight_offset = 0' so everything not QMoE would benefit of const and for QMoE we'd need to compile 2 shaders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a change to use const for weight_idx related offset if weight_idx == 0. So only QMoE takes a tiny hit for the weight_idx.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current code changes are clean. No more questions for the nbitmm.
If you have say 1000 tokens it is most likely we need to run all experts, for generation we'd run 4.
Does that mean that for generation, we can do specific optimization to calculate the four experts in one ApplyMatMulNBits?
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template
Outdated
Show resolved
Hide resolved
| #endif | ||
| #else | ||
| #if has_bias | ||
| // TODO: wanted to use vec4 for bias but for some reason that fails ut. Later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using vec4 for bias, you need to make sure N % 4 == 0 or it will be very complicated to re-arrange to get the correct vec4 data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can sit it out
add support for bias and weight_index, move subgroup_matrix_matmul_nbits to template and make program callable from other ops.