24
24
import triton
25
25
import triton .language as tl
26
26
27
+ from vllm .utils import is_navi
28
+
27
29
torch_dtype : tl .constexpr = torch .float16
28
30
29
31
@@ -217,88 +219,80 @@ def _attn_fwd_inner(
217
219
return acc , l_i , m_i
218
220
219
221
220
- @ triton . autotune (
221
- configs = [
222
+ def get_cdna_autotune_configs ():
223
+ return [
222
224
triton .Config (
223
225
{
224
- " BLOCK_M" : 256 ,
225
- " BLOCK_N" : 64 ,
226
- " waves_per_eu" : 2 ,
227
- " PRE_LOAD_V" : False ,
226
+ ' BLOCK_M' : 256 ,
227
+ ' BLOCK_N' : 64 ,
228
+ ' waves_per_eu' : 2 ,
229
+ ' PRE_LOAD_V' : False
228
230
},
229
231
num_stages = 1 ,
230
- num_warps = 8 ,
231
- ),
232
+ num_warps = 8 ),
232
233
triton .Config (
233
234
{
234
- " BLOCK_M" : 128 ,
235
- " BLOCK_N" : 128 ,
236
- " waves_per_eu" : 2 ,
237
- " PRE_LOAD_V" : False ,
235
+ ' BLOCK_M' : 128 ,
236
+ ' BLOCK_N' : 128 ,
237
+ ' waves_per_eu' : 2 ,
238
+ ' PRE_LOAD_V' : False
238
239
},
239
240
num_stages = 1 ,
240
- num_warps = 4 ,
241
- ),
241
+ num_warps = 4 ),
242
242
triton .Config (
243
243
{
244
- " BLOCK_M" : 256 ,
245
- " BLOCK_N" : 128 ,
246
- " waves_per_eu" : 2 ,
247
- " PRE_LOAD_V" : False ,
244
+ ' BLOCK_M' : 256 ,
245
+ ' BLOCK_N' : 128 ,
246
+ ' waves_per_eu' : 2 ,
247
+ ' PRE_LOAD_V' : False
248
248
},
249
249
num_stages = 1 ,
250
- num_warps = 8 ,
251
- ),
250
+ num_warps = 8 ),
252
251
triton .Config (
253
252
{
254
- " BLOCK_M" : 128 ,
255
- " BLOCK_N" : 64 ,
256
- " waves_per_eu" : 1 ,
257
- " PRE_LOAD_V" : False ,
253
+ ' BLOCK_M' : 128 ,
254
+ ' BLOCK_N' : 64 ,
255
+ ' waves_per_eu' : 1 ,
256
+ ' PRE_LOAD_V' : False
258
257
},
259
258
num_stages = 1 ,
260
- num_warps = 4 ,
261
- ),
259
+ num_warps = 4 ),
262
260
triton .Config (
263
261
{
264
- " BLOCK_M" : 128 ,
265
- " BLOCK_N" : 64 ,
266
- " waves_per_eu" : 3 ,
267
- " PRE_LOAD_V" : True ,
262
+ ' BLOCK_M' : 128 ,
263
+ ' BLOCK_N' : 64 ,
264
+ ' waves_per_eu' : 3 ,
265
+ ' PRE_LOAD_V' : True
268
266
},
269
267
num_stages = 1 ,
270
- num_warps = 4 ,
271
- ),
268
+ num_warps = 4 ),
272
269
triton .Config (
273
270
{
274
- " BLOCK_M" : 128 ,
275
- " BLOCK_N" : 64 ,
276
- " waves_per_eu" : 3 ,
277
- " PRE_LOAD_V" : False ,
271
+ ' BLOCK_M' : 128 ,
272
+ ' BLOCK_N' : 64 ,
273
+ ' waves_per_eu' : 3 ,
274
+ ' PRE_LOAD_V' : False
278
275
},
279
276
num_stages = 1 ,
280
- num_warps = 4 ,
281
- ),
277
+ num_warps = 4 ),
282
278
triton .Config (
283
279
{
284
- " BLOCK_M" : 64 ,
285
- " BLOCK_N" : 64 ,
286
- " waves_per_eu" : 4 ,
287
- " PRE_LOAD_V" : False ,
280
+ ' BLOCK_M' : 64 ,
281
+ ' BLOCK_N' : 64 ,
282
+ ' waves_per_eu' : 4 ,
283
+ ' PRE_LOAD_V' : False
288
284
},
289
285
num_stages = 1 ,
290
- num_warps = 8 ,
291
- ),
286
+ num_warps = 8 ),
292
287
triton .Config (
293
288
{
294
- " BLOCK_M" : 32 ,
295
- " BLOCK_N" : 32 ,
296
- " waves_per_eu" : 4 ,
297
- " PRE_LOAD_V" : False ,
289
+ ' BLOCK_M' : 32 ,
290
+ ' BLOCK_N' : 32 ,
291
+ ' waves_per_eu' : 4 ,
292
+ ' PRE_LOAD_V' : False
298
293
},
299
294
num_stages = 1 ,
300
- num_warps = 8 ,
301
- ),
295
+ num_warps = 8 ),
302
296
# TODO: This config fails with head_size not pow2 with data mismatches.
303
297
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
304
298
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
@@ -314,8 +308,93 @@ def _attn_fwd_inner(
314
308
# num_stages=1,
315
309
# num_warps=4,
316
310
# ),
317
- ],
318
- key = ['IS_CAUSAL' , 'dropout_p' , 'BLOCK_DMODEL' , 'USE_FP8' ],
311
+ ], ['IS_CAUSAL' , 'dropout_p' , 'BLOCK_DMODEL' , 'USE_FP8' ]
312
+
313
+
314
+ def get_rdna_autotune_configs ():
315
+ return [
316
+ triton .Config (
317
+ {
318
+ 'BLOCK_M' : 32 ,
319
+ 'BLOCK_N' : 32 ,
320
+ 'waves_per_eu' : 4 ,
321
+ 'PRE_LOAD_V' : False
322
+ },
323
+ num_stages = 1 ,
324
+ num_warps = 2 ),
325
+ triton .Config (
326
+ {
327
+ 'BLOCK_M' : 32 ,
328
+ 'BLOCK_N' : 32 ,
329
+ 'waves_per_eu' : 2 ,
330
+ 'PRE_LOAD_V' : False
331
+ },
332
+ num_stages = 1 ,
333
+ num_warps = 2 ),
334
+ triton .Config (
335
+ {
336
+ 'BLOCK_M' : 32 ,
337
+ 'BLOCK_N' : 16 ,
338
+ 'waves_per_eu' : 4 ,
339
+ 'PRE_LOAD_V' : False
340
+ },
341
+ num_stages = 1 ,
342
+ num_warps = 2 ),
343
+ triton .Config (
344
+ {
345
+ 'BLOCK_M' : 32 ,
346
+ 'BLOCK_N' : 16 ,
347
+ 'waves_per_eu' : 2 ,
348
+ 'PRE_LOAD_V' : False
349
+ },
350
+ num_stages = 1 ,
351
+ num_warps = 2 ),
352
+ # Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
353
+ # triton.Config(
354
+ # {
355
+ # 'BLOCK_M': 16,
356
+ # 'BLOCK_N': 16,
357
+ # 'waves_per_eu': 4,
358
+ # 'PRE_LOAD_V': False
359
+ # },
360
+ # num_stages=1,
361
+ # num_warps=2),
362
+ # triton.Config(
363
+ # {
364
+ # 'BLOCK_M': 16,
365
+ # 'BLOCK_N': 16,
366
+ # 'waves_per_eu': 2,
367
+ # 'PRE_LOAD_V': False
368
+ # },
369
+ # num_stages=1,
370
+ # num_warps=2),
371
+ # # Fall-back config.
372
+ # triton.Config(
373
+ # {
374
+ # 'BLOCK_M': 16,
375
+ # 'BLOCK_N': 16,
376
+ # 'waves_per_eu': 1,
377
+ # 'PRE_LOAD_V': False
378
+ # },
379
+ # num_stages=1,
380
+ # num_warps=2),
381
+ ], ['IS_CAUSAL' , 'dropout_p' , 'BLOCK_DMODEL' , 'USE_FP8' ]
382
+
383
+
384
+ def get_autotune_configs ():
385
+ if is_navi ():
386
+ return get_rdna_autotune_configs ()
387
+ else :
388
+ return get_cdna_autotune_configs ()
389
+
390
+
391
+ autotune_configs , autotune_keys = get_autotune_configs ()
392
+
393
+
394
+ @triton .autotune (
395
+ configs = autotune_configs ,
396
+ key = autotune_keys ,
397
+ use_cuda_graph = True ,
319
398
)
320
399
@triton .jit
321
400
def attn_fwd (
@@ -833,6 +912,10 @@ def check_and_convert(t, scale):
833
912
p_descale = 1.0 / p_scale
834
913
o_descale = 1.0 / o_scale
835
914
915
+ if is_navi ():
916
+ max_seqlens_q = 0
917
+ max_seqlens_k = 0
918
+
836
919
attn_fwd [grid ](
837
920
q ,
838
921
k ,
0 commit comments