Commit 549d741
authored
[webgpu] Optimize Attention by enhancing flash attention support (#26715)
This pull request improves the WebGPU BERT attention implementation by
enhancing FlashAttention support, generalizing tensor layout handling,
and increasing batch size flexibility. The changes focus on supporting
both BSNH and BNSH tensor layouts, enabling FlashAttention for
multi-batch scenarios, and ensuring correct broadcasting and dispatch
sizing for attention bias and batch dimensions.
Key improvements include:
**FlashAttention Support & Generalization:**
* Added support for both BSNH and BNSH tensor layouts by introducing the
`q_BNSH` parameter and updating shader code, program classes, and kernel
logic to handle either layout correctly. This includes changes in the
WGSL template and C++ logic for offset calculations and program
instantiation.
[[1]](diffhunk://#diff-de9fb56a92586a62185eae0a2e0153f12960bc73dab990e616185236e115885fR7)
[[2]](diffhunk://#diff-de9fb56a92586a62185eae0a2e0153f12960bc73dab990e616185236e115885fL45-R97)
[[3]](diffhunk://#diff-de9fb56a92586a62185eae0a2e0153f12960bc73dab990e616185236e115885fL86-R122)
[[4]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R445)
[[5]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R454)
[[6]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R76)
[[7]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R86)
[[8]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R110)
* Updated the `CanApplyFlashAttention` and `ApplyFlashAttention` logic
to allow multi-batch operation by removing the restriction to batch size
1 and ensuring present key/value tensors are always created for
FlashAttention.
[[1]](diffhunk://#diff-1ed746fa440247995dabd97ad1f318a548fc385cde70b9ea2d4a410219f91629R740-R752)
[[2]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L501-L506)
[[3]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9L177-R185)
**Batch & Bias Handling:**
* Modified dispatch group size calculations and uniform variables
throughout the FlashAttention pipeline to properly account for batch
size, ensuring correct parallelization for multi-batch scenarios.
[[1]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R260-R273)
[[2]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L272-R285)
[[3]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L320-R333)
[[4]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L366-R379)
[[5]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L454-R490)
[[6]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R95-R100)
[[7]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9L123-R131)
* Added logic to extract and pass attention bias dimensions as uniforms
for correct broadcasting in both the compute and shader code.
[[1]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R260-R273)
[[2]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L272-R285)
[[3]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L454-R490)
[[4]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R95-R100)
[[5]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9L123-R131)
**Other Enhancements:**
* Improved handling of QKV format detection and generalized code to
support more format variants in `CopyKVCache`.
* Updated includes and dependencies to ensure all necessary headers for
FlashAttention are present.
These changes collectively make the WebGPU BERT attention implementation
more robust, flexible, and performant across different tensor layouts
and batch sizes.
phi-4-mm-vision.onnx
Before
Kernel | Time (ms) | Percentage (%)
-- | -- | --
Attention\|AttentionProbs | 159.66 | 11.14
Attention\|VxAttentionScore | 122.56 | 8.55
Attention\|InPlaceSoftmax | 51.83 | 3.62
After
Kernel | Time (ms) | Percentage (%)
-- | -- | --
Attention\|FlashAttention | 60.23 | 5.381 parent 07bf9a0 commit 549d741
File tree
9 files changed
+185
-98
lines changed- onnxruntime
- contrib_ops/webgpu/bert
- core/providers/webgpu
9 files changed
+185
-98
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
| |||
736 | 737 | | |
737 | 738 | | |
738 | 739 | | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
| 744 | + | |
| 745 | + | |
| 746 | + | |
| 747 | + | |
| 748 | + | |
| 749 | + | |
| 750 | + | |
| 751 | + | |
| 752 | + | |
739 | 753 | | |
740 | 754 | | |
741 | 755 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
79 | | - | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
80 | 85 | | |
81 | 86 | | |
82 | 87 | | |
| |||
93 | 98 | | |
94 | 99 | | |
95 | 100 | | |
96 | | - | |
97 | | - | |
| 101 | + | |
98 | 102 | | |
99 | 103 | | |
100 | 104 | | |
| |||
104 | 108 | | |
105 | 109 | | |
106 | 110 | | |
107 | | - | |
108 | | - | |
| 111 | + | |
109 | 112 | | |
110 | 113 | | |
111 | 114 | | |
| |||
134 | 137 | | |
135 | 138 | | |
136 | 139 | | |
137 | | - | |
138 | | - | |
| 140 | + | |
| 141 | + | |
139 | 142 | | |
140 | | - | |
| 143 | + | |
141 | 144 | | |
142 | 145 | | |
143 | 146 | | |
| |||
207 | 210 | | |
208 | 211 | | |
209 | 212 | | |
| 213 | + | |
210 | 214 | | |
211 | 215 | | |
212 | 216 | | |
| |||
256 | 260 | | |
257 | 261 | | |
258 | 262 | | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
259 | 273 | | |
260 | 274 | | |
261 | 275 | | |
262 | | - | |
| 276 | + | |
263 | 277 | | |
264 | 278 | | |
265 | 279 | | |
| |||
269 | 283 | | |
270 | 284 | | |
271 | 285 | | |
272 | | - | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
273 | 290 | | |
274 | 291 | | |
275 | 292 | | |
| |||
313 | 330 | | |
314 | 331 | | |
315 | 332 | | |
| 333 | + | |
316 | 334 | | |
317 | 335 | | |
318 | 336 | | |
319 | 337 | | |
320 | | - | |
| 338 | + | |
321 | 339 | | |
322 | 340 | | |
323 | 341 | | |
| |||
326 | 344 | | |
327 | 345 | | |
328 | 346 | | |
329 | | - | |
| 347 | + | |
330 | 348 | | |
331 | 349 | | |
332 | 350 | | |
| |||
363 | 381 | | |
364 | 382 | | |
365 | 383 | | |
366 | | - | |
| 384 | + | |
| 385 | + | |
367 | 386 | | |
368 | 387 | | |
369 | 388 | | |
370 | 389 | | |
371 | 390 | | |
372 | 391 | | |
373 | | - | |
| 392 | + | |
374 | 393 | | |
375 | 394 | | |
376 | 395 | | |
| |||
429 | 448 | | |
430 | 449 | | |
431 | 450 | | |
| 451 | + | |
432 | 452 | | |
433 | 453 | | |
434 | 454 | | |
| |||
437 | 457 | | |
438 | 458 | | |
439 | 459 | | |
| 460 | + | |
440 | 461 | | |
441 | 462 | | |
442 | 463 | | |
| |||
451 | 472 | | |
452 | 473 | | |
453 | 474 | | |
454 | | - | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
455 | 486 | | |
456 | | - | |
| 487 | + | |
457 | 488 | | |
458 | 489 | | |
459 | 490 | | |
| 491 | + | |
460 | 492 | | |
461 | 493 | | |
462 | | - | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
463 | 497 | | |
464 | 498 | | |
465 | 499 | | |
| |||
500 | 534 | | |
501 | 535 | | |
502 | 536 | | |
503 | | - | |
504 | | - | |
| 537 | + | |
505 | 538 | | |
506 | 539 | | |
507 | 540 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
46 | | - | |
| 46 | + | |
47 | 47 | | |
48 | | - | |
| 48 | + | |
49 | 49 | | |
50 | 50 | | |
51 | 51 | | |
| |||
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
| 62 | + | |
62 | 63 | | |
63 | 64 | | |
64 | 65 | | |
| |||
73 | 74 | | |
74 | 75 | | |
75 | 76 | | |
| 77 | + | |
76 | 78 | | |
77 | 79 | | |
78 | 80 | | |
| |||
82 | 84 | | |
83 | 85 | | |
84 | 86 | | |
| 87 | + | |
85 | 88 | | |
86 | 89 | | |
87 | 90 | | |
| |||
90 | 93 | | |
91 | 94 | | |
92 | 95 | | |
| 96 | + | |
93 | 97 | | |
94 | 98 | | |
95 | | - | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
96 | 102 | | |
97 | 103 | | |
98 | 104 | | |
| |||
102 | 108 | | |
103 | 109 | | |
104 | 110 | | |
| 111 | + | |
105 | 112 | | |
106 | 113 | | |
107 | 114 | | |
| |||
120 | 127 | | |
121 | 128 | | |
122 | 129 | | |
123 | | - | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
124 | 134 | | |
125 | 135 | | |
126 | 136 | | |
| |||
141 | 151 | | |
142 | 152 | | |
143 | 153 | | |
144 | | - | |
| 154 | + | |
145 | 155 | | |
146 | 156 | | |
147 | 157 | | |
| |||
161 | 171 | | |
162 | 172 | | |
163 | 173 | | |
164 | | - | |
| 174 | + | |
165 | 175 | | |
166 | 176 | | |
167 | 177 | | |
| |||
0 commit comments